import json import typing import torch import transformers MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D" def load(model) -> None: model.model = transformers.AutoModel.from_pretrained(MODEL_NAME) model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME) def unload(model) -> None: model.model = None model.tokenizer = None async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]: inputs = model.tokenizer(prompt, return_tensors="pt") with torch.no_grad(): outputs = model.model(**inputs) embeddings = outputs.last_hidden_state yield json.dumps({ "data": embeddings.tolist() }).encode("utf-8")