mirror of
https://git.isriupjv.fr/ISRI/ai-server
synced 2025-04-24 10:08:11 +02:00
added support of inputs parameters that are recognised by the API.
Models are now loaded in separate endpoints for the inputs to be easier to recognise
This commit is contained in:
parent
900c58ffcb
commit
7bd84c8570
17 changed files with 163 additions and 128 deletions
|
@ -2,6 +2,7 @@
|
|||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
python-multipart
|
||||
|
||||
# AI
|
||||
accelerate
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
{
|
||||
"type": "dummy"
|
||||
"type": "python",
|
||||
"file": "model.py",
|
||||
|
||||
"inputs": {}
|
||||
}
|
||||
|
|
12
samples/models/dummy/model.py
Normal file
12
samples/models/dummy/model.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import json
|
||||
import typing
|
||||
|
||||
|
||||
def load(model) -> None:
|
||||
pass
|
||||
|
||||
def unload(model) -> None:
|
||||
pass
|
||||
|
||||
def infer(model) -> typing.Iterator[bytes]:
|
||||
yield json.dumps({"hello": "world!"}).encode("utf-8")
|
|
@ -2,6 +2,10 @@
|
|||
"type": "python",
|
||||
"file": "model.py",
|
||||
|
||||
"inputs": {
|
||||
"prompt": {"type": "str"}
|
||||
},
|
||||
|
||||
"requirements": [
|
||||
"transformers",
|
||||
"torch",
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
@ -7,22 +8,22 @@ import transformers
|
|||
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||
|
||||
|
||||
def load(model):
|
||||
def load(model) -> None:
|
||||
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
|
||||
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
def unload(model):
|
||||
def unload(model) -> None:
|
||||
model.model = None
|
||||
model.tokenizer = None
|
||||
|
||||
def infer(model, payload: dict) -> str:
|
||||
inputs = model.tokenizer(payload["prompt"], return_tensors="pt")
|
||||
def infer(model, prompt: str) -> typing.Iterator[bytes]:
|
||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.model(**inputs)
|
||||
|
||||
embeddings = outputs.last_hidden_state
|
||||
|
||||
return json.dumps({
|
||||
yield json.dumps({
|
||||
"data": embeddings.tolist()
|
||||
})
|
||||
}).encode("utf-8")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
@ -7,22 +8,22 @@ import transformers
|
|||
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||
|
||||
|
||||
def load(model):
|
||||
def load(model) -> None:
|
||||
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
|
||||
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
|
||||
def unload(model):
|
||||
def unload(model) -> None:
|
||||
model.model = None
|
||||
model.tokenizer = None
|
||||
|
||||
def infer(model, payload: dict) -> str:
|
||||
inputs = model.tokenizer(payload["prompt"], return_tensors="pt")
|
||||
def infer(model, prompt: str) -> typing.Iterator[bytes]:
|
||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.model(**inputs)
|
||||
|
||||
embeddings = outputs.last_hidden_state
|
||||
|
||||
return json.dumps({
|
||||
yield json.dumps({
|
||||
"data": embeddings.tolist()
|
||||
})
|
||||
}).encode("utf-8")
|
||||
|
|
|
@ -7,13 +7,9 @@ application = api.Application()
|
|||
|
||||
|
||||
# create the model controller
|
||||
model_controller = manager.ModelManager(os.environ["MODEL_LIBRARY"])
|
||||
model_controller.register_model_type("dummy", model.DummyModel)
|
||||
model_controller = manager.ModelManager(application, os.environ["MODEL_LIBRARY"])
|
||||
model_controller.register_model_type("python", model.PythonModel)
|
||||
model_controller.reload()
|
||||
|
||||
api.route.models.load(application, model_controller)
|
||||
|
||||
|
||||
# serve the application
|
||||
application.serve("0.0.0.0", 8000)
|
||||
|
|
|
@ -1,3 +1 @@
|
|||
from . import route
|
||||
|
||||
from .Application import Application
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
from . import models
|
|
@ -1,74 +0,0 @@
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
import fastapi
|
||||
import pydantic
|
||||
|
||||
from source.api import Application
|
||||
from source import manager
|
||||
|
||||
|
||||
class InferenceRequest(pydantic.BaseModel):
|
||||
"""
|
||||
Represent a request made when inferring a model
|
||||
"""
|
||||
|
||||
request: dict
|
||||
|
||||
|
||||
def load(application: Application, model_manager: manager.ModelManager):
|
||||
@application.get("/models")
|
||||
async def get_models() -> list[str]:
|
||||
"""
|
||||
Get the list of models available
|
||||
:return: the list of models available
|
||||
"""
|
||||
|
||||
# reload the model list
|
||||
model_manager.reload()
|
||||
# list the models found
|
||||
return list(model_manager.models.keys())
|
||||
|
||||
@application.get("/models/{model_name}")
|
||||
async def get_model(model_name: str) -> dict:
|
||||
"""
|
||||
Get information about a specific model
|
||||
:param model_name: the name of the model
|
||||
:return: the information about the corresponding model
|
||||
"""
|
||||
|
||||
# get the corresponding model
|
||||
model = model_manager.models.get(model_name)
|
||||
if model is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# return the model information
|
||||
return model.get_information()
|
||||
|
||||
|
||||
@application.post("/models/{model_name}/infer")
|
||||
async def infer_model(model_name: str, request: InferenceRequest) -> fastapi.Response:
|
||||
"""
|
||||
Run an inference through the selected model
|
||||
:param model_name: the name of the model
|
||||
:param request: the data to infer to the model
|
||||
:return: the model response
|
||||
"""
|
||||
|
||||
# get the corresponding model
|
||||
model = model_manager.models.get(model_name)
|
||||
if model is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# infer the data through the model
|
||||
try:
|
||||
response = model.infer(request.request)
|
||||
except Exception:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
raise fastapi.HTTPException(status_code=500, detail="An error occurred while inferring the model.")
|
||||
|
||||
# pack the model response into a fastapi response
|
||||
return fastapi.Response(
|
||||
content=response,
|
||||
media_type=model.response_mimetype,
|
||||
)
|
|
@ -4,11 +4,14 @@ import typing
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from source import model
|
||||
import fastapi
|
||||
|
||||
from source import model, api
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, model_library: os.PathLike | str):
|
||||
def __init__(self, application: api.Application, model_library: os.PathLike | str):
|
||||
self.application: api.Application = application
|
||||
self.model_library: Path = Path(model_library)
|
||||
|
||||
# the model types
|
||||
|
@ -20,10 +23,43 @@ class ModelManager:
|
|||
# TODO(Faraphel): load more than one model at a time ? require a way more complex manager to handle memory issue
|
||||
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
|
||||
|
||||
def register_model_type(self, name: str, model_type: typing.Type[model.base.BaseModel]):
|
||||
@self.application.get("/models")
|
||||
async def get_models() -> list[str]:
|
||||
"""
|
||||
Get the list of models available
|
||||
:return: the list of models available
|
||||
"""
|
||||
|
||||
# list the models found
|
||||
return list(self.models.keys())
|
||||
|
||||
@self.application.get("/models/{model_name}")
|
||||
async def get_model(model_name: str) -> dict:
|
||||
"""
|
||||
Get information about a specific model
|
||||
:param model_name: the name of the model
|
||||
:return: the information about the corresponding model
|
||||
"""
|
||||
|
||||
# get the corresponding model
|
||||
model = self.models.get(model_name)
|
||||
if model is None:
|
||||
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
# return the model information
|
||||
return model.get_information()
|
||||
|
||||
|
||||
def register_model_type(self, name: str, model_type: "typing.Type[model.base.BaseModel]"):
|
||||
self.model_types[name] = model_type
|
||||
|
||||
def reload(self):
|
||||
# reset the model list
|
||||
for model in self.models.values():
|
||||
model.unload()
|
||||
self.models.clear()
|
||||
|
||||
# load all the models in the library
|
||||
for model_path in self.model_library.iterdir():
|
||||
model_name: str = model_path.name
|
||||
model_configuration_path: Path = model_path / "config.json"
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
import json
|
||||
|
||||
from source.model import base
|
||||
|
||||
|
||||
class DummyModel(base.BaseModel):
|
||||
"""
|
||||
A dummy model, mainly used to test the API and the manager.
|
||||
simply send back the request made to it.
|
||||
"""
|
||||
|
||||
def _load(self) -> None:
|
||||
pass
|
||||
|
||||
def _unload(self) -> None:
|
||||
pass
|
||||
|
||||
def _infer(self, payload: dict) -> str | bytes:
|
||||
return json.dumps(payload)
|
|
@ -1,9 +1,14 @@
|
|||
import importlib.util
|
||||
import subprocess
|
||||
import sys
|
||||
import typing
|
||||
import uuid
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
|
||||
import fastapi
|
||||
|
||||
from source import utils
|
||||
from source.manager import ModelManager
|
||||
from source.model import base
|
||||
|
||||
|
@ -16,6 +21,8 @@ class PythonModel(base.BaseModel):
|
|||
def __init__(self, manager: ModelManager, configuration: dict, path: Path):
|
||||
super().__init__(manager, configuration, path)
|
||||
|
||||
## Configuration
|
||||
|
||||
# get the name of the file containing the model code
|
||||
file = configuration.get("file")
|
||||
if file is None:
|
||||
|
@ -36,11 +43,28 @@ class PythonModel(base.BaseModel):
|
|||
# load the module
|
||||
module_spec.loader.exec_module(self.module)
|
||||
|
||||
## Api
|
||||
|
||||
# load the inputs data into the inference function signature (used by FastAPI)
|
||||
parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||
|
||||
# create an endpoint wrapping the inference inside a fastapi call
|
||||
async def infer_api(*args, **kwargs):
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=self.infer(*args, **kwargs),
|
||||
media_type=self.output_type,
|
||||
)
|
||||
|
||||
infer_api.__signature__ = inspect.Signature(parameters=parameters)
|
||||
|
||||
# add the inference endpoint on the API
|
||||
self.manager.application.add_api_route(f"/models/{self.name}/infer", infer_api, methods=["POST"])
|
||||
|
||||
def _load(self) -> None:
|
||||
return self.module.load(self)
|
||||
|
||||
def _unload(self) -> None:
|
||||
return self.module.unload(self)
|
||||
|
||||
def _infer(self, payload: dict) -> str | bytes:
|
||||
return self.module.infer(self, payload)
|
||||
def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
|
||||
return self.module.infer(self, *args, **kwargs)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from . import base
|
||||
|
||||
from .DummyModel import DummyModel
|
||||
from .PythonModel import PythonModel
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import abc
|
||||
import gc
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
from source import api
|
||||
from source.manager import ModelManager
|
||||
|
||||
|
||||
|
@ -10,13 +12,13 @@ class BaseModel(abc.ABC):
|
|||
Represent a model.
|
||||
"""
|
||||
|
||||
def __init__(self, manager: ModelManager, configuration: dict, path: Path):
|
||||
def __init__(self, manager: ModelManager, configuration: dict[str, typing.Any], path: Path):
|
||||
# the environment directory of the model
|
||||
self.path = path
|
||||
# the model manager
|
||||
self.manager = manager
|
||||
# the mimetype of the model responses
|
||||
self.response_mimetype: str = configuration.get("response_mimetype", "application/json")
|
||||
self.output_type: str = configuration.get("output_type", "application/json")
|
||||
|
||||
self._loaded = False
|
||||
|
||||
|
@ -101,13 +103,11 @@ class BaseModel(abc.ABC):
|
|||
"""
|
||||
Unload the model
|
||||
Do not call manually, use `unload` instead.
|
||||
:return:
|
||||
"""
|
||||
|
||||
def infer(self, payload: dict) -> str | bytes:
|
||||
def infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
Infer our payload through the model within the model manager
|
||||
:param payload: the payload to give to the model
|
||||
:return: the response of the model
|
||||
"""
|
||||
|
||||
|
@ -115,12 +115,11 @@ class BaseModel(abc.ABC):
|
|||
self.load()
|
||||
|
||||
# model specific inference part
|
||||
return self._infer(payload)
|
||||
return self._infer(*args, **kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _infer(self, payload: dict) -> str | bytes:
|
||||
def _infer(self, *args, **kwargs) -> typing.Iterator[bytes]:
|
||||
"""
|
||||
Infer our payload through the model
|
||||
:param payload: the payload to give to the model
|
||||
:return: the response of the model
|
||||
"""
|
||||
|
|
1
source/utils/__init__.py
Normal file
1
source/utils/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from . import parameters
|
54
source/utils/parameters.py
Normal file
54
source/utils/parameters.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
import inspect
|
||||
from datetime import datetime
|
||||
|
||||
import fastapi
|
||||
|
||||
|
||||
# the list of types and their name that can be used by the API
|
||||
types: dict[str, type] = {
|
||||
"bool": bool,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"str": str,
|
||||
"bytes": bytes,
|
||||
"list": list,
|
||||
"tuple": tuple,
|
||||
"set": set,
|
||||
"dict": dict,
|
||||
"datetime": datetime,
|
||||
"file": fastapi.UploadFile,
|
||||
}
|
||||
|
||||
|
||||
def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
||||
"""
|
||||
Load a list python function parameters from their definitions.
|
||||
:param parameters_definition: the definitions of the parameters
|
||||
:return: the python function parameters
|
||||
|
||||
Examples:
|
||||
>>> parameters_definition = {
|
||||
... "boolean": {"type": "bool", "default": False},
|
||||
... "list": {"type": "list", "default": [1, 2, 3]},
|
||||
... "datetime": {"type": "datetime"},
|
||||
... "file": {"type": "file"},
|
||||
... }
|
||||
>>> parameters = load_parameters(parameters_definition)
|
||||
"""
|
||||
|
||||
parameters: list[inspect.Parameter] = []
|
||||
|
||||
for name, definition in parameters_definition.items():
|
||||
# deserialize the parameter
|
||||
parameter = inspect.Parameter(
|
||||
name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=definition.get("default", inspect.Parameter.empty),
|
||||
annotation=types[definition["type"]],
|
||||
)
|
||||
parameters.append(parameter)
|
||||
|
||||
# sort the parameters so that non-default arguments always end up before default ones
|
||||
parameters.sort(key=lambda parameter: parameter.default is inspect.Parameter.empty, reverse=True)
|
||||
|
||||
return parameters
|
Loading…
Reference in a new issue