mirror of
https://git.isriupjv.fr/ISRI/ai-server
synced 2025-04-24 10:08:11 +02:00
replaced the previous venv system by a conda one, allowing for better dependencies management
This commit is contained in:
parent
8bf28e4c48
commit
0034c7b31a
17 changed files with 313 additions and 230 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -2,4 +2,4 @@
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
# Environment
|
# Environment
|
||||||
venv/
|
env/
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
FROM python:3.12
|
FROM continuumio/miniconda3
|
||||||
|
|
||||||
# copy the application
|
# copy the application
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY ./ ./
|
COPY ./ ./
|
||||||
|
|
||||||
# install the dependencies
|
# install the dependencies
|
||||||
RUN pip3 install -r ./requirements.txt
|
RUN conda env create -f environment.yml
|
||||||
|
|
||||||
# expose the API port
|
# expose the API port
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
# environment variables
|
# environment variables
|
||||||
ENV MODEL_DIRECTORY=/models/
|
ENV MODEL_LIBRARY=/models/
|
||||||
|
|
||||||
# run the server
|
# run the server
|
||||||
CMD ["python3", "-m", "source"]
|
CMD ["python3", "-m", "source"]
|
||||||
|
|
100
README.md
100
README.md
|
@ -2,14 +2,108 @@
|
||||||
|
|
||||||
A server that can serve AI models with an API and an authentication system
|
A server that can serve AI models with an API and an authentication system
|
||||||
|
|
||||||
# Usage
|
## Usage
|
||||||
|
|
||||||
## Docker
|
The ai-server project require a conda environement. You can use Docker to deploy it easily.
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
|
||||||
|
You can easily use docker-compose to run the project.
|
||||||
|
Simply go into the project directory and run :
|
||||||
|
|
||||||
# Environment Variables
|
```bash
|
||||||
|
docker compose run -d
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Environment Variables
|
||||||
|
|
||||||
|
The project use special environement variables :
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
|-----------------|-------------------------------------------|
|
|-----------------|-------------------------------------------|
|
||||||
| MODEL_DIRECTORY | the directory where the models are stored |
|
| MODEL_DIRECTORY | the directory where the models are stored |
|
||||||
|
|
||||||
|
#### Volumes
|
||||||
|
|
||||||
|
The project might store data inside multiples volumes :
|
||||||
|
|
||||||
|
| Path | Type | Required | Description |
|
||||||
|
|------|------|----------|-------------|
|
||||||
|
| /models/ | Volume | True | The place where the models are stored |
|
||||||
|
| /root/.huggingface/hub | Bind | False | The place where the internal models are cached. Avoid redownloading huge amount of data at every inference |
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
A model is an object that can be loaded and do inference.
|
||||||
|
|
||||||
|
It is stored inside a directory and must always contain a `config.json` file.
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
This is a json-structured file with basic information about the model.
|
||||||
|
|
||||||
|
It describe :
|
||||||
|
- its type (see below)
|
||||||
|
- its tags
|
||||||
|
- its interface
|
||||||
|
- its inputs
|
||||||
|
- its output mimetype
|
||||||
|
|
||||||
|
And other properties depending on the model type.
|
||||||
|
|
||||||
|
#### Types
|
||||||
|
|
||||||
|
There is for now only a single type of Model : the Python model
|
||||||
|
|
||||||
|
##### Python Model
|
||||||
|
|
||||||
|
A python model is isolated in a `conda` environement.
|
||||||
|
|
||||||
|
To be considered a Python model, you need theses three files :
|
||||||
|
|
||||||
|
| File | Description |
|
||||||
|
|-------------|-------------------------------|
|
||||||
|
| config.json | The configuration file |
|
||||||
|
| env | The conda virtual environment |
|
||||||
|
| model.py | The model file |
|
||||||
|
|
||||||
|
###### Configuration
|
||||||
|
|
||||||
|
Additionnal fields might be found in the configuration :
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
###### Virtual Environment
|
||||||
|
|
||||||
|
You can create a conda virtual environement it with :
|
||||||
|
|
||||||
|
`conda create --prefix ./env/ python=<version>`
|
||||||
|
|
||||||
|
You can install your requirements inside with `pip3` or `conda`.
|
||||||
|
|
||||||
|
###### Internal Model
|
||||||
|
|
||||||
|
You need to create a `model.py` file containing a class named `Model`.
|
||||||
|
This class must be exposed with the Pyro4 library for inter-operability with our main environement.
|
||||||
|
|
||||||
|
Here is an example of an internal model :
|
||||||
|
|
||||||
|
```python
|
||||||
|
import Pyro4
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
|
@Pyro4.expose
|
||||||
|
class Model:
|
||||||
|
def __init__(self):
|
||||||
|
# load the model
|
||||||
|
self.pipe = transformers.pipeline(...)
|
||||||
|
|
||||||
|
def infer(self, messages: list[dict]) -> bytes:
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.pipe(messages)
|
||||||
|
|
||||||
|
return outputs[0]["generated_text"][-1]
|
||||||
|
```
|
||||||
|
|
14
environment.yml
Normal file
14
environment.yml
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.12
|
||||||
|
- numpy
|
||||||
|
- pillow
|
||||||
|
- pip
|
||||||
|
- pip:
|
||||||
|
- fastapi
|
||||||
|
- uvicorn
|
||||||
|
- pydantic
|
||||||
|
- gradio
|
||||||
|
- python-multipart
|
||||||
|
- Pyro5
|
|
@ -1,13 +0,0 @@
|
||||||
# web
|
|
||||||
fastapi
|
|
||||||
uvicorn
|
|
||||||
pydantic
|
|
||||||
gradio
|
|
||||||
python-multipart
|
|
||||||
|
|
||||||
# data manipulation
|
|
||||||
pillow
|
|
||||||
numpy
|
|
||||||
|
|
||||||
# AI
|
|
||||||
accelerate
|
|
|
@ -1,15 +0,0 @@
|
||||||
{
|
|
||||||
"type": "python",
|
|
||||||
"tags": ["dummy"],
|
|
||||||
"file": "model.py",
|
|
||||||
"interface": "chat",
|
|
||||||
|
|
||||||
"summary": "Echo model",
|
|
||||||
"description": "The most basic example model, simply echo the input",
|
|
||||||
|
|
||||||
"inputs": {
|
|
||||||
"messages": {"type": "list[dict]", "default": "[{\"role\": \"user\", \"content\": \"who are you ?\"}]"}
|
|
||||||
},
|
|
||||||
|
|
||||||
"output_type": "text/markdown"
|
|
||||||
}
|
|
|
@ -1,6 +0,0 @@
|
||||||
import typing
|
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
|
||||||
async def infer(self, messages: list[dict]) -> typing.AsyncIterator[bytes]:
|
|
||||||
yield messages[-1]["content"].encode("utf-8")
|
|
|
@ -1,15 +0,0 @@
|
||||||
{
|
|
||||||
"type": "python",
|
|
||||||
"file": "model.py",
|
|
||||||
|
|
||||||
"inputs": {
|
|
||||||
"prompt": {"type": "str"}
|
|
||||||
},
|
|
||||||
|
|
||||||
"requirements": [
|
|
||||||
"transformers",
|
|
||||||
"torch",
|
|
||||||
"torchvision",
|
|
||||||
"torchaudio"
|
|
||||||
]
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
import json
|
|
||||||
import typing
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
|
||||||
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.model = transformers.AutoModel.from_pretrained(self.NAME)
|
|
||||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
|
|
||||||
|
|
||||||
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
|
|
||||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
|
|
||||||
embeddings = outputs.last_hidden_state
|
|
||||||
|
|
||||||
yield json.dumps({
|
|
||||||
"data": embeddings.tolist()
|
|
||||||
}).encode("utf-8")
|
|
|
@ -1,4 +0,0 @@
|
||||||
{
|
|
||||||
"type": "python",
|
|
||||||
"file": "model.py"
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
import json
|
|
||||||
import typing
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
|
|
||||||
class Model:
|
|
||||||
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.model = transformers.AutoModel.from_pretrained(self.NAME)
|
|
||||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
|
|
||||||
|
|
||||||
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
|
|
||||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
|
|
||||||
embeddings = outputs.last_hidden_state
|
|
||||||
|
|
||||||
yield json.dumps({
|
|
||||||
"data": embeddings.tolist()
|
|
||||||
}).encode("utf-8")
|
|
|
@ -3,6 +3,7 @@ import os
|
||||||
from source import registry, model, api
|
from source import registry, model, api
|
||||||
from source.api import interface
|
from source.api import interface
|
||||||
|
|
||||||
|
|
||||||
# create a fastapi application
|
# create a fastapi application
|
||||||
application = api.Application()
|
application = api.Application()
|
||||||
|
|
||||||
|
|
|
@ -30,19 +30,19 @@ class ChatInterface(base.BaseInterface):
|
||||||
messages.insert(0, {"role": "system", "content": system_message})
|
messages.insert(0, {"role": "system", "content": system_message})
|
||||||
|
|
||||||
# add the user message
|
# add the user message
|
||||||
# NOTE: gradio.ChatInterface add our message and the assistant message
|
# NOTE: gradio.ChatInterface add our message and the assistant message automatically
|
||||||
# TODO(Faraphel): add support for files
|
# TODO(Faraphel): add support for files - directory use user_message ? apparently, field "image" is supported.
|
||||||
|
# check "https://huggingface.co/docs/transformers/main_classes/pipelines" at "ImageTextToTextPipeline"
|
||||||
|
|
||||||
|
# TODO(Faraphel): add a "MultimodalChatInterface" to support images
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": user_message["text"],
|
"content": user_message["text"],
|
||||||
})
|
})
|
||||||
|
|
||||||
# infer the message through the model
|
# infer the message through the model
|
||||||
chunks = [chunk async for chunk in await self.model.infer(messages=messages)]
|
async for chunk in self.model.infer(messages=messages):
|
||||||
assistant_message: str = b"".join(chunks).decode("utf-8")
|
yield chunk.decode("utf-8")
|
||||||
|
|
||||||
# send back the messages, clear the user prompt, disable the system prompt
|
|
||||||
return assistant_message
|
|
||||||
|
|
||||||
def get_application(self):
|
def get_application(self):
|
||||||
# create a gradio interface
|
# create a gradio interface
|
||||||
|
@ -65,7 +65,7 @@ class ChatInterface(base.BaseInterface):
|
||||||
gradio.ChatInterface(
|
gradio.ChatInterface(
|
||||||
fn=self.send_message,
|
fn=self.send_message,
|
||||||
type="messages",
|
type="messages",
|
||||||
multimodal=True,
|
multimodal=False, # TODO(Faraphel): should handle at least image and text files
|
||||||
editable=True,
|
editable=True,
|
||||||
save_history=True,
|
save_history=True,
|
||||||
additional_inputs=[system_prompt],
|
additional_inputs=[system_prompt],
|
||||||
|
|
|
@ -1,18 +1,22 @@
|
||||||
import importlib.util
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import tempfile
|
||||||
|
import time
|
||||||
import typing
|
import typing
|
||||||
import uuid
|
import textwrap
|
||||||
import inspect
|
import os
|
||||||
|
import signal
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import fastapi
|
import Pyro5
|
||||||
|
import Pyro5.api
|
||||||
|
|
||||||
from source import utils
|
|
||||||
from source.model import base
|
from source.model import base
|
||||||
from source.registry import ModelRegistry
|
from source.registry import ModelRegistry
|
||||||
from source.utils.fastapi import UploadFileFix
|
from source.utils.fastapi import UploadFileFix
|
||||||
|
|
||||||
|
# enable serpent to represent bytes directly
|
||||||
|
Pyro5.config.SERPENT_BYTES_REPR = True
|
||||||
|
|
||||||
|
|
||||||
class PythonModel(base.BaseModel):
|
class PythonModel(base.BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -22,85 +26,114 @@ class PythonModel(base.BaseModel):
|
||||||
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
|
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
|
||||||
super().__init__(registry, configuration, path)
|
super().__init__(registry, configuration, path)
|
||||||
|
|
||||||
# get the parameters of the model
|
# get the environment
|
||||||
self.parameters = utils.parameters.load(configuration.get("inputs", {}))
|
self.environment = self.path / "env"
|
||||||
|
if not self.environment.exists():
|
||||||
|
raise Exception("The model is missing an environment")
|
||||||
|
|
||||||
# install custom requirements
|
# prepare the process that will hold the environment python interpreter
|
||||||
requirements = configuration.get("requirements", [])
|
self._storage: typing.Optional[tempfile.TemporaryDirectory]
|
||||||
if len(requirements) > 0:
|
self._process: typing.Optional[subprocess.Popen]
|
||||||
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
|
self._model: typing.Optional[Pyro5.api.Proxy]
|
||||||
|
|
||||||
# get the name of the file containing the model code
|
|
||||||
file = configuration.get("file")
|
|
||||||
if file is None:
|
|
||||||
raise ValueError("Field 'file' is missing from the configuration")
|
|
||||||
|
|
||||||
# create the module specification
|
|
||||||
module_spec = importlib.util.spec_from_file_location(
|
|
||||||
f"model-{uuid.uuid4()}",
|
|
||||||
self.path / file
|
|
||||||
)
|
|
||||||
# get the module
|
|
||||||
module = importlib.util.module_from_spec(module_spec)
|
|
||||||
# load the module
|
|
||||||
module_spec.loader.exec_module(module)
|
|
||||||
|
|
||||||
# create the internal model from the class defined in the module
|
|
||||||
self._model_type = module.Model
|
|
||||||
self._model: typing.Optional[module.Model] = None
|
|
||||||
|
|
||||||
async def _load(self) -> None:
|
async def _load(self) -> None:
|
||||||
self._model = self._model_type()
|
# create a temporary space for the unix socket
|
||||||
|
self._storage: tempfile.TemporaryDirectory = tempfile.TemporaryDirectory()
|
||||||
|
socket_file = Path(self._storage.name) / 'socket.unix'
|
||||||
|
|
||||||
|
# create a process inside the conda environment
|
||||||
|
self._process = subprocess.Popen(
|
||||||
|
[
|
||||||
|
"conda", "run", # run a command within conda
|
||||||
|
"--prefix", self.environment.relative_to(self.path), # use the model environment
|
||||||
|
"python3", "-c", # run a python command
|
||||||
|
|
||||||
|
textwrap.dedent(f"""
|
||||||
|
# make sure that Pyro5 is installed for communication
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
subprocess.run(["python3", "-m", "pip", "install", "Pyro5"])
|
||||||
|
|
||||||
|
import os
|
||||||
|
import Pyro5
|
||||||
|
import Pyro5.api
|
||||||
|
import model
|
||||||
|
|
||||||
|
# allow Pyro5 to return bytes objects directly
|
||||||
|
Pyro5.config.SERPENT_BYTES_REPR = True
|
||||||
|
|
||||||
|
# helper to check if a process is still alive
|
||||||
|
def is_pid_alive(pid: int) -> bool:
|
||||||
|
try:
|
||||||
|
# do nothing if the process is alive, raise an exception if it does not exists
|
||||||
|
os.kill(pid, 0)
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# create a pyro daemon
|
||||||
|
daemon = Pyro5.api.Daemon(unixsocket={str(socket_file)!r})
|
||||||
|
# export our model through it
|
||||||
|
daemon.register(Pyro5.api.expose(model.Model), objectId="model")
|
||||||
|
# handle requests
|
||||||
|
# stop the process if the manager is no longer alive
|
||||||
|
daemon.requestLoop(lambda: is_pid_alive({os.getpid()}))
|
||||||
|
""")
|
||||||
|
],
|
||||||
|
|
||||||
|
cwd=self.path, # use the model directory as the working directory
|
||||||
|
start_new_session=True, # put the process in a new group to avoid killing ourselves when we unload the process
|
||||||
|
)
|
||||||
|
|
||||||
|
# wait for the process to be initialized properly
|
||||||
|
while True:
|
||||||
|
# check if the process is still alive
|
||||||
|
if self._process.poll() is not None:
|
||||||
|
# if the process stopped, raise an error (it shall stay alive until the unloading)
|
||||||
|
raise Exception("Could not load the model.")
|
||||||
|
|
||||||
|
# if the socket file have been created, the program is running successfully
|
||||||
|
if socket_file.exists():
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
# get the proxy model object from the environment
|
||||||
|
self._model = Pyro5.api.Proxy(f"PYRO:model@./u:{socket_file}")
|
||||||
|
|
||||||
async def _unload(self) -> None:
|
async def _unload(self) -> None:
|
||||||
self._model = None
|
# clear the proxy object
|
||||||
|
self._model._pyroRelease() # NOQA
|
||||||
|
del self._model
|
||||||
|
# stop the environment process
|
||||||
|
os.killpg(os.getpgid(self._process.pid), signal.SIGTERM)
|
||||||
|
self._process.wait()
|
||||||
|
del self._process
|
||||||
|
# clear the storage
|
||||||
|
self._storage.cleanup()
|
||||||
|
del self._storage
|
||||||
|
|
||||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||||
return self._model.infer(**kwargs)
|
# Pyro5 is not capable of receiving an "UploadFile" object, so save it to a file and send the path instead
|
||||||
|
with tempfile.TemporaryDirectory() as working_directory:
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
# check if this the argument is a file
|
||||||
|
if not isinstance(value, UploadFileFix):
|
||||||
|
continue
|
||||||
|
|
||||||
def _mount(self, application: fastapi.FastAPI):
|
# copy the uploaded file to our working directory
|
||||||
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
|
path = Path(working_directory) / value.filename
|
||||||
|
with open(path, "wb") as file:
|
||||||
|
while content := await value.read(1024*1024):
|
||||||
|
file.write(content)
|
||||||
|
|
||||||
# create an endpoint wrapping the inference inside a fastapi call
|
# replace the argument
|
||||||
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
kwargs[key] = str(path)
|
||||||
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
|
||||||
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
|
||||||
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
|
||||||
# curiosity that may change in the future
|
|
||||||
kwargs = {
|
|
||||||
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
|
||||||
for key, value in kwargs.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return fastapi.responses.StreamingResponse(
|
# run the inference
|
||||||
content=await self.infer(**kwargs),
|
for chunk in self._model.infer(**kwargs):
|
||||||
media_type=self.output_type,
|
yield chunk
|
||||||
headers={
|
|
||||||
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
|
||||||
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
|
||||||
|
|
||||||
# format the description
|
# TODO(Faraphel): if the FastAPI close, it seem like it wait for conda to finish (or the async tasks ?)
|
||||||
description_sections: list[str] = []
|
|
||||||
if self.description is not None:
|
|
||||||
description_sections.append(f"# Description\n{self.description}")
|
|
||||||
if self.interface is not None:
|
|
||||||
description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
|
|
||||||
|
|
||||||
description: str = "\n".join(description_sections)
|
|
||||||
|
|
||||||
# add the inference endpoint on the API
|
|
||||||
application.add_api_route(
|
|
||||||
f"{self.api_base}/infer",
|
|
||||||
infer_api,
|
|
||||||
methods=["POST"],
|
|
||||||
tags=self.tags,
|
|
||||||
summary=self.summary,
|
|
||||||
description=description,
|
|
||||||
response_class=fastapi.responses.StreamingResponse,
|
|
||||||
responses={
|
|
||||||
200: {"content": {self.output_type: {}}}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
|
import tempfile
|
||||||
import typing
|
import typing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from source import utils
|
||||||
from source.registry import ModelRegistry
|
from source.registry import ModelRegistry
|
||||||
|
from source.utils.fastapi import UploadFileFix
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(abc.ABC):
|
class BaseModel(abc.ABC):
|
||||||
|
@ -24,6 +28,9 @@ class BaseModel(abc.ABC):
|
||||||
|
|
||||||
# the environment directory of the model
|
# the environment directory of the model
|
||||||
self.path = path
|
self.path = path
|
||||||
|
# get the parameters of the model
|
||||||
|
self.inputs = configuration.get("inputs", {})
|
||||||
|
self.parameters = utils.parameters.load(self.inputs)
|
||||||
# the mimetype of the model responses
|
# the mimetype of the model responses
|
||||||
self.output_type: str = configuration.get("output_type", "application/json")
|
self.output_type: str = configuration.get("output_type", "application/json")
|
||||||
# get the tags of the model
|
# get the tags of the model
|
||||||
|
@ -71,8 +78,12 @@ class BaseModel(abc.ABC):
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
|
"summary": self.summary,
|
||||||
|
"description": self.description,
|
||||||
|
"inputs": self.inputs,
|
||||||
"output_type": self.output_type,
|
"output_type": self.output_type,
|
||||||
"tags": self.tags
|
"tags": self.tags,
|
||||||
|
"interface": self.interface,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def load(self) -> None:
|
async def load(self) -> None:
|
||||||
|
@ -145,7 +156,8 @@ class BaseModel(abc.ABC):
|
||||||
await self.load()
|
await self.load()
|
||||||
|
|
||||||
# model specific inference part
|
# model specific inference part
|
||||||
return await self._infer(**kwargs)
|
async for chunk in self._infer(**kwargs):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||||
|
@ -164,13 +176,50 @@ class BaseModel(abc.ABC):
|
||||||
if self.interface is not None:
|
if self.interface is not None:
|
||||||
self.interface.mount(application)
|
self.interface.mount(application)
|
||||||
|
|
||||||
# implementation specific mount
|
# create an endpoint wrapping the inference inside a fastapi call
|
||||||
self._mount(application)
|
# the arguments will be loaded from the configuration files. Use kwargs for the definition
|
||||||
|
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
||||||
|
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
||||||
|
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
||||||
|
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
||||||
|
# curiosity that may change in the future
|
||||||
|
kwargs = {
|
||||||
|
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
||||||
|
for key, value in kwargs.items()
|
||||||
|
}
|
||||||
|
|
||||||
@abc.abstractmethod
|
# return a streaming response around the inference call
|
||||||
def _mount(self, application: fastapi.FastAPI) -> None:
|
return fastapi.responses.StreamingResponse(
|
||||||
"""
|
content=self.infer(**kwargs),
|
||||||
Add the model to the api
|
media_type=self.output_type,
|
||||||
Do not call manually, use `unload` instead.
|
headers={
|
||||||
:param application: the fastapi application
|
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||||
"""
|
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# update the signature of the function to use the configuration parameters
|
||||||
|
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
||||||
|
|
||||||
|
# format the description
|
||||||
|
description_sections: list[str] = []
|
||||||
|
if self.description is not None:
|
||||||
|
description_sections.append(f"# Description\n{self.description}")
|
||||||
|
if self.interface is not None:
|
||||||
|
description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
|
||||||
|
|
||||||
|
description: str = "\n".join(description_sections)
|
||||||
|
|
||||||
|
# add the inference endpoint on the API
|
||||||
|
application.add_api_route(
|
||||||
|
f"{self.api_base}/infer",
|
||||||
|
infer_api,
|
||||||
|
methods=["POST"],
|
||||||
|
tags=self.tags,
|
||||||
|
summary=self.summary,
|
||||||
|
description=description,
|
||||||
|
response_class=fastapi.responses.StreamingResponse,
|
||||||
|
responses={
|
||||||
|
200: {"content": {self.output_type: {}}}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -60,7 +60,12 @@ class ModelRegistry:
|
||||||
|
|
||||||
# load all the models in the library
|
# load all the models in the library
|
||||||
for model_path in self.model_library.iterdir():
|
for model_path in self.model_library.iterdir():
|
||||||
|
# get the model name
|
||||||
model_name: str = model_path.name
|
model_name: str = model_path.name
|
||||||
|
if model_name.startswith("."):
|
||||||
|
# ignore model starting with a dot
|
||||||
|
continue
|
||||||
|
|
||||||
model_configuration_path: Path = model_path / "config.json"
|
model_configuration_path: Path = model_path / "config.json"
|
||||||
|
|
||||||
# check if the configuration file exists
|
# check if the configuration file exists
|
||||||
|
@ -68,8 +73,11 @@ class ModelRegistry:
|
||||||
warnings.warn(f"Model {model_name!r} is missing a config.json file.")
|
warnings.warn(f"Model {model_name!r} is missing a config.json file.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# load the configuration file
|
try:
|
||||||
model_configuration = json.loads(model_configuration_path.read_text())
|
# load the configuration file
|
||||||
|
model_configuration = json.loads(model_configuration_path.read_text())
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
raise Exception(f"Model {model_name!r}'s configuration is invalid. See above.")
|
||||||
|
|
||||||
# get the model type for this model
|
# get the model type for this model
|
||||||
model_type_name: str = model_configuration.get("type")
|
model_type_name: str = model_configuration.get("type")
|
||||||
|
|
|
@ -1,24 +1,5 @@
|
||||||
import inspect
|
import inspect
|
||||||
from datetime import datetime
|
import fastapi
|
||||||
|
|
||||||
from fastapi import UploadFile
|
|
||||||
|
|
||||||
# the list of types and their name that can be used by the API
|
|
||||||
types: dict[str, type] = {
|
|
||||||
"bool": bool,
|
|
||||||
"int": int,
|
|
||||||
"float": float,
|
|
||||||
"str": str,
|
|
||||||
"bytes": bytes,
|
|
||||||
"dict": dict,
|
|
||||||
"datetime": datetime,
|
|
||||||
"file": UploadFile,
|
|
||||||
|
|
||||||
# TODO(Faraphel): use a "ParameterRegistry" or other functions to handle complex type ?
|
|
||||||
"list[dict]": list[dict],
|
|
||||||
# "tuple": tuple,
|
|
||||||
# "set": set,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
||||||
|
@ -31,7 +12,6 @@ def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
||||||
>>> parameters_definition = {
|
>>> parameters_definition = {
|
||||||
... "boolean": {"type": "bool", "default": False},
|
... "boolean": {"type": "bool", "default": False},
|
||||||
... "list": {"type": "list", "default": [1, 2, 3]},
|
... "list": {"type": "list", "default": [1, 2, 3]},
|
||||||
... "datetime": {"type": "datetime"},
|
|
||||||
... "file": {"type": "file"},
|
... "file": {"type": "file"},
|
||||||
... }
|
... }
|
||||||
>>> parameters = load_parameters(parameters_definition)
|
>>> parameters = load_parameters(parameters_definition)
|
||||||
|
@ -40,12 +20,19 @@ def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
||||||
parameters: list[inspect.Parameter] = []
|
parameters: list[inspect.Parameter] = []
|
||||||
|
|
||||||
for name, definition in parameters_definition.items():
|
for name, definition in parameters_definition.items():
|
||||||
|
# preprocess the type
|
||||||
|
match definition["type"]:
|
||||||
|
case "file":
|
||||||
|
# shortcut for uploading a file
|
||||||
|
definition["type"] = fastapi.UploadFile
|
||||||
|
|
||||||
|
|
||||||
# deserialize the parameter
|
# deserialize the parameter
|
||||||
parameter = inspect.Parameter(
|
parameter = inspect.Parameter(
|
||||||
name,
|
name,
|
||||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
default=definition.get("default", inspect.Parameter.empty),
|
default=definition.get("default", inspect.Parameter.empty),
|
||||||
annotation=types[definition["type"]],
|
annotation=definition["type"],
|
||||||
)
|
)
|
||||||
parameters.append(parameter)
|
parameters.append(parameter)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue