ai-server/samples/models/python-bert-2/model.py

25 lines
663 B
Python

import json
import typing
import torch
import transformers
class Model:
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
def __init__(self) -> None:
self.model = transformers.AutoModel.from_pretrained(self.NAME)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
inputs = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state
yield json.dumps({
"data": embeddings.tolist()
}).encode("utf-8")