mirror of
https://git.isriupjv.fr/ISRI/ai-server
synced 2025-04-24 18:18:11 +02:00
Models are now loaded in separate endpoints for the inputs to be easier to recognise
29 lines
680 B
Python
29 lines
680 B
Python
import json
|
|
import typing
|
|
|
|
import torch
|
|
import transformers
|
|
|
|
|
|
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
|
|
|
|
|
def load(model) -> None:
|
|
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
|
|
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
def unload(model) -> None:
|
|
model.model = None
|
|
model.tokenizer = None
|
|
|
|
def infer(model, prompt: str) -> typing.Iterator[bytes]:
|
|
inputs = model.tokenizer(prompt, return_tensors="pt")
|
|
|
|
with torch.no_grad():
|
|
outputs = model.model(**inputs)
|
|
|
|
embeddings = outputs.last_hidden_state
|
|
|
|
yield json.dumps({
|
|
"data": embeddings.tolist()
|
|
}).encode("utf-8")
|