mirror of
https://git.isriupjv.fr/ISRI/ai-server
synced 2025-04-24 01:58:12 +02:00
replaced the previous venv system by a conda one, allowing for better dependencies management
This commit is contained in:
parent
8bf28e4c48
commit
0034c7b31a
17 changed files with 313 additions and 230 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -2,4 +2,4 @@
|
|||
.idea/
|
||||
|
||||
# Environment
|
||||
venv/
|
||||
env/
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
FROM python:3.12
|
||||
FROM continuumio/miniconda3
|
||||
|
||||
# copy the application
|
||||
WORKDIR /app
|
||||
COPY ./ ./
|
||||
|
||||
# install the dependencies
|
||||
RUN pip3 install -r ./requirements.txt
|
||||
RUN conda env create -f environment.yml
|
||||
|
||||
# expose the API port
|
||||
EXPOSE 8000
|
||||
|
||||
# environment variables
|
||||
ENV MODEL_DIRECTORY=/models/
|
||||
ENV MODEL_LIBRARY=/models/
|
||||
|
||||
# run the server
|
||||
CMD ["python3", "-m", "source"]
|
||||
|
|
100
README.md
100
README.md
|
@ -2,14 +2,108 @@
|
|||
|
||||
A server that can serve AI models with an API and an authentication system
|
||||
|
||||
# Usage
|
||||
## Usage
|
||||
|
||||
## Docker
|
||||
The ai-server project require a conda environement. You can use Docker to deploy it easily.
|
||||
|
||||
### Docker
|
||||
|
||||
You can easily use docker-compose to run the project.
|
||||
Simply go into the project directory and run :
|
||||
|
||||
# Environment Variables
|
||||
```bash
|
||||
docker compose run -d
|
||||
```
|
||||
|
||||
#### Environment Variables
|
||||
|
||||
The project use special environement variables :
|
||||
|
||||
| Name | Description |
|
||||
|-----------------|-------------------------------------------|
|
||||
| MODEL_DIRECTORY | the directory where the models are stored |
|
||||
|
||||
#### Volumes
|
||||
|
||||
The project might store data inside multiples volumes :
|
||||
|
||||
| Path | Type | Required | Description |
|
||||
|------|------|----------|-------------|
|
||||
| /models/ | Volume | True | The place where the models are stored |
|
||||
| /root/.huggingface/hub | Bind | False | The place where the internal models are cached. Avoid redownloading huge amount of data at every inference |
|
||||
|
||||
## Models
|
||||
|
||||
A model is an object that can be loaded and do inference.
|
||||
|
||||
It is stored inside a directory and must always contain a `config.json` file.
|
||||
|
||||
### Configuration
|
||||
|
||||
This is a json-structured file with basic information about the model.
|
||||
|
||||
It describe :
|
||||
- its type (see below)
|
||||
- its tags
|
||||
- its interface
|
||||
- its inputs
|
||||
- its output mimetype
|
||||
|
||||
And other properties depending on the model type.
|
||||
|
||||
#### Types
|
||||
|
||||
There is for now only a single type of Model : the Python model
|
||||
|
||||
##### Python Model
|
||||
|
||||
A python model is isolated in a `conda` environement.
|
||||
|
||||
To be considered a Python model, you need theses three files :
|
||||
|
||||
| File | Description |
|
||||
|-------------|-------------------------------|
|
||||
| config.json | The configuration file |
|
||||
| env | The conda virtual environment |
|
||||
| model.py | The model file |
|
||||
|
||||
###### Configuration
|
||||
|
||||
Additionnal fields might be found in the configuration :
|
||||
|
||||
...
|
||||
|
||||
###### Virtual Environment
|
||||
|
||||
You can create a conda virtual environement it with :
|
||||
|
||||
`conda create --prefix ./env/ python=<version>`
|
||||
|
||||
You can install your requirements inside with `pip3` or `conda`.
|
||||
|
||||
###### Internal Model
|
||||
|
||||
You need to create a `model.py` file containing a class named `Model`.
|
||||
This class must be exposed with the Pyro4 library for inter-operability with our main environement.
|
||||
|
||||
Here is an example of an internal model :
|
||||
|
||||
```python
|
||||
import Pyro4
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
@Pyro4.expose
|
||||
class Model:
|
||||
def __init__(self):
|
||||
# load the model
|
||||
self.pipe = transformers.pipeline(...)
|
||||
|
||||
def infer(self, messages: list[dict]) -> bytes:
|
||||
with torch.no_grad():
|
||||
outputs = self.pipe(messages)
|
||||
|
||||
return outputs[0]["generated_text"][-1]
|
||||
```
|
||||
|
|
14
environment.yml
Normal file
14
environment.yml
Normal file
|
@ -0,0 +1,14 @@
|
|||
channels:
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.12
|
||||
- numpy
|
||||
- pillow
|
||||
- pip
|
||||
- pip:
|
||||
- fastapi
|
||||
- uvicorn
|
||||
- pydantic
|
||||
- gradio
|
||||
- python-multipart
|
||||
- Pyro5
|
|
@ -1,13 +0,0 @@
|
|||
# web
|
||||
fastapi
|
||||
uvicorn
|
||||
pydantic
|
||||
gradio
|
||||
python-multipart
|
||||
|
||||
# data manipulation
|
||||
pillow
|
||||
numpy
|
||||
|
||||
# AI
|
||||
accelerate
|
|
@ -1,15 +0,0 @@
|
|||
{
|
||||
"type": "python",
|
||||
"tags": ["dummy"],
|
||||
"file": "model.py",
|
||||
"interface": "chat",
|
||||
|
||||
"summary": "Echo model",
|
||||
"description": "The most basic example model, simply echo the input",
|
||||
|
||||
"inputs": {
|
||||
"messages": {"type": "list[dict]", "default": "[{\"role\": \"user\", \"content\": \"who are you ?\"}]"}
|
||||
},
|
||||
|
||||
"output_type": "text/markdown"
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
import typing
|
||||
|
||||
|
||||
class Model:
|
||||
async def infer(self, messages: list[dict]) -> typing.AsyncIterator[bytes]:
|
||||
yield messages[-1]["content"].encode("utf-8")
|
|
@ -1,15 +0,0 @@
|
|||
{
|
||||
"type": "python",
|
||||
"file": "model.py",
|
||||
|
||||
"inputs": {
|
||||
"prompt": {"type": "str"}
|
||||
},
|
||||
|
||||
"requirements": [
|
||||
"transformers",
|
||||
"torch",
|
||||
"torchvision",
|
||||
"torchaudio"
|
||||
]
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
import json
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
class Model:
|
||||
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model = transformers.AutoModel.from_pretrained(self.NAME)
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
|
||||
|
||||
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
embeddings = outputs.last_hidden_state
|
||||
|
||||
yield json.dumps({
|
||||
"data": embeddings.tolist()
|
||||
}).encode("utf-8")
|
|
@ -1,4 +0,0 @@
|
|||
{
|
||||
"type": "python",
|
||||
"file": "model.py"
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
import json
|
||||
import typing
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
class Model:
|
||||
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.model = transformers.AutoModel.from_pretrained(self.NAME)
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
|
||||
|
||||
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
embeddings = outputs.last_hidden_state
|
||||
|
||||
yield json.dumps({
|
||||
"data": embeddings.tolist()
|
||||
}).encode("utf-8")
|
|
@ -3,6 +3,7 @@ import os
|
|||
from source import registry, model, api
|
||||
from source.api import interface
|
||||
|
||||
|
||||
# create a fastapi application
|
||||
application = api.Application()
|
||||
|
||||
|
|
|
@ -30,19 +30,19 @@ class ChatInterface(base.BaseInterface):
|
|||
messages.insert(0, {"role": "system", "content": system_message})
|
||||
|
||||
# add the user message
|
||||
# NOTE: gradio.ChatInterface add our message and the assistant message
|
||||
# TODO(Faraphel): add support for files
|
||||
# NOTE: gradio.ChatInterface add our message and the assistant message automatically
|
||||
# TODO(Faraphel): add support for files - directory use user_message ? apparently, field "image" is supported.
|
||||
# check "https://huggingface.co/docs/transformers/main_classes/pipelines" at "ImageTextToTextPipeline"
|
||||
|
||||
# TODO(Faraphel): add a "MultimodalChatInterface" to support images
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": user_message["text"],
|
||||
})
|
||||
|
||||
# infer the message through the model
|
||||
chunks = [chunk async for chunk in await self.model.infer(messages=messages)]
|
||||
assistant_message: str = b"".join(chunks).decode("utf-8")
|
||||
|
||||
# send back the messages, clear the user prompt, disable the system prompt
|
||||
return assistant_message
|
||||
async for chunk in self.model.infer(messages=messages):
|
||||
yield chunk.decode("utf-8")
|
||||
|
||||
def get_application(self):
|
||||
# create a gradio interface
|
||||
|
@ -65,7 +65,7 @@ class ChatInterface(base.BaseInterface):
|
|||
gradio.ChatInterface(
|
||||
fn=self.send_message,
|
||||
type="messages",
|
||||
multimodal=True,
|
||||
multimodal=False, # TODO(Faraphel): should handle at least image and text files
|
||||
editable=True,
|
||||
save_history=True,
|
||||
additional_inputs=[system_prompt],
|
||||
|
|
|
@ -1,18 +1,22 @@
|
|||
import importlib.util
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import typing
|
||||
import uuid
|
||||
import inspect
|
||||
import textwrap
|
||||
import os
|
||||
import signal
|
||||
from pathlib import Path
|
||||
|
||||
import fastapi
|
||||
import Pyro5
|
||||
import Pyro5.api
|
||||
|
||||
from source import utils
|
||||
from source.model import base
|
||||
from source.registry import ModelRegistry
|
||||
from source.utils.fastapi import UploadFileFix
|
||||
|
||||
# enable serpent to represent bytes directly
|
||||
Pyro5.config.SERPENT_BYTES_REPR = True
|
||||
|
||||
|
||||
class PythonModel(base.BaseModel):
|
||||
"""
|
||||
|
@ -22,85 +26,114 @@ class PythonModel(base.BaseModel):
|
|||
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
|
||||
super().__init__(registry, configuration, path)
|
||||
|
||||
# get the parameters of the model
|
||||
self.parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||
# get the environment
|
||||
self.environment = self.path / "env"
|
||||
if not self.environment.exists():
|
||||
raise Exception("The model is missing an environment")
|
||||
|
||||
# install custom requirements
|
||||
requirements = configuration.get("requirements", [])
|
||||
if len(requirements) > 0:
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
|
||||
|
||||
# get the name of the file containing the model code
|
||||
file = configuration.get("file")
|
||||
if file is None:
|
||||
raise ValueError("Field 'file' is missing from the configuration")
|
||||
|
||||
# create the module specification
|
||||
module_spec = importlib.util.spec_from_file_location(
|
||||
f"model-{uuid.uuid4()}",
|
||||
self.path / file
|
||||
)
|
||||
# get the module
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
# load the module
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
# create the internal model from the class defined in the module
|
||||
self._model_type = module.Model
|
||||
self._model: typing.Optional[module.Model] = None
|
||||
# prepare the process that will hold the environment python interpreter
|
||||
self._storage: typing.Optional[tempfile.TemporaryDirectory]
|
||||
self._process: typing.Optional[subprocess.Popen]
|
||||
self._model: typing.Optional[Pyro5.api.Proxy]
|
||||
|
||||
async def _load(self) -> None:
|
||||
self._model = self._model_type()
|
||||
# create a temporary space for the unix socket
|
||||
self._storage: tempfile.TemporaryDirectory = tempfile.TemporaryDirectory()
|
||||
socket_file = Path(self._storage.name) / 'socket.unix'
|
||||
|
||||
# create a process inside the conda environment
|
||||
self._process = subprocess.Popen(
|
||||
[
|
||||
"conda", "run", # run a command within conda
|
||||
"--prefix", self.environment.relative_to(self.path), # use the model environment
|
||||
"python3", "-c", # run a python command
|
||||
|
||||
textwrap.dedent(f"""
|
||||
# make sure that Pyro5 is installed for communication
|
||||
import sys
|
||||
import subprocess
|
||||
subprocess.run(["python3", "-m", "pip", "install", "Pyro5"])
|
||||
|
||||
import os
|
||||
import Pyro5
|
||||
import Pyro5.api
|
||||
import model
|
||||
|
||||
# allow Pyro5 to return bytes objects directly
|
||||
Pyro5.config.SERPENT_BYTES_REPR = True
|
||||
|
||||
# helper to check if a process is still alive
|
||||
def is_pid_alive(pid: int) -> bool:
|
||||
try:
|
||||
# do nothing if the process is alive, raise an exception if it does not exists
|
||||
os.kill(pid, 0)
|
||||
except OSError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
# create a pyro daemon
|
||||
daemon = Pyro5.api.Daemon(unixsocket={str(socket_file)!r})
|
||||
# export our model through it
|
||||
daemon.register(Pyro5.api.expose(model.Model), objectId="model")
|
||||
# handle requests
|
||||
# stop the process if the manager is no longer alive
|
||||
daemon.requestLoop(lambda: is_pid_alive({os.getpid()}))
|
||||
""")
|
||||
],
|
||||
|
||||
cwd=self.path, # use the model directory as the working directory
|
||||
start_new_session=True, # put the process in a new group to avoid killing ourselves when we unload the process
|
||||
)
|
||||
|
||||
# wait for the process to be initialized properly
|
||||
while True:
|
||||
# check if the process is still alive
|
||||
if self._process.poll() is not None:
|
||||
# if the process stopped, raise an error (it shall stay alive until the unloading)
|
||||
raise Exception("Could not load the model.")
|
||||
|
||||
# if the socket file have been created, the program is running successfully
|
||||
if socket_file.exists():
|
||||
break
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
# get the proxy model object from the environment
|
||||
self._model = Pyro5.api.Proxy(f"PYRO:model@./u:{socket_file}")
|
||||
|
||||
async def _unload(self) -> None:
|
||||
self._model = None
|
||||
# clear the proxy object
|
||||
self._model._pyroRelease() # NOQA
|
||||
del self._model
|
||||
# stop the environment process
|
||||
os.killpg(os.getpgid(self._process.pid), signal.SIGTERM)
|
||||
self._process.wait()
|
||||
del self._process
|
||||
# clear the storage
|
||||
self._storage.cleanup()
|
||||
del self._storage
|
||||
|
||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
return self._model.infer(**kwargs)
|
||||
# Pyro5 is not capable of receiving an "UploadFile" object, so save it to a file and send the path instead
|
||||
with tempfile.TemporaryDirectory() as working_directory:
|
||||
for key, value in kwargs.items():
|
||||
# check if this the argument is a file
|
||||
if not isinstance(value, UploadFileFix):
|
||||
continue
|
||||
|
||||
def _mount(self, application: fastapi.FastAPI):
|
||||
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
|
||||
# copy the uploaded file to our working directory
|
||||
path = Path(working_directory) / value.filename
|
||||
with open(path, "wb") as file:
|
||||
while content := await value.read(1024*1024):
|
||||
file.write(content)
|
||||
|
||||
# create an endpoint wrapping the inference inside a fastapi call
|
||||
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
||||
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
||||
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
||||
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
||||
# curiosity that may change in the future
|
||||
kwargs = {
|
||||
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
||||
for key, value in kwargs.items()
|
||||
}
|
||||
# replace the argument
|
||||
kwargs[key] = str(path)
|
||||
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=await self.infer(**kwargs),
|
||||
media_type=self.output_type,
|
||||
headers={
|
||||
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
|
||||
}
|
||||
)
|
||||
# run the inference
|
||||
for chunk in self._model.infer(**kwargs):
|
||||
yield chunk
|
||||
|
||||
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
||||
|
||||
# format the description
|
||||
description_sections: list[str] = []
|
||||
if self.description is not None:
|
||||
description_sections.append(f"# Description\n{self.description}")
|
||||
if self.interface is not None:
|
||||
description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
|
||||
|
||||
description: str = "\n".join(description_sections)
|
||||
|
||||
# add the inference endpoint on the API
|
||||
application.add_api_route(
|
||||
f"{self.api_base}/infer",
|
||||
infer_api,
|
||||
methods=["POST"],
|
||||
tags=self.tags,
|
||||
summary=self.summary,
|
||||
description=description,
|
||||
response_class=fastapi.responses.StreamingResponse,
|
||||
responses={
|
||||
200: {"content": {self.output_type: {}}}
|
||||
},
|
||||
)
|
||||
# TODO(Faraphel): if the FastAPI close, it seem like it wait for conda to finish (or the async tasks ?)
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
import abc
|
||||
import asyncio
|
||||
import gc
|
||||
import tempfile
|
||||
import typing
|
||||
from pathlib import Path
|
||||
|
||||
import fastapi
|
||||
import inspect
|
||||
|
||||
from source import utils
|
||||
from source.registry import ModelRegistry
|
||||
from source.utils.fastapi import UploadFileFix
|
||||
|
||||
|
||||
class BaseModel(abc.ABC):
|
||||
|
@ -24,6 +28,9 @@ class BaseModel(abc.ABC):
|
|||
|
||||
# the environment directory of the model
|
||||
self.path = path
|
||||
# get the parameters of the model
|
||||
self.inputs = configuration.get("inputs", {})
|
||||
self.parameters = utils.parameters.load(self.inputs)
|
||||
# the mimetype of the model responses
|
||||
self.output_type: str = configuration.get("output_type", "application/json")
|
||||
# get the tags of the model
|
||||
|
@ -71,8 +78,12 @@ class BaseModel(abc.ABC):
|
|||
|
||||
return {
|
||||
"name": self.name,
|
||||
"summary": self.summary,
|
||||
"description": self.description,
|
||||
"inputs": self.inputs,
|
||||
"output_type": self.output_type,
|
||||
"tags": self.tags
|
||||
"tags": self.tags,
|
||||
"interface": self.interface,
|
||||
}
|
||||
|
||||
async def load(self) -> None:
|
||||
|
@ -145,7 +156,8 @@ class BaseModel(abc.ABC):
|
|||
await self.load()
|
||||
|
||||
# model specific inference part
|
||||
return await self._infer(**kwargs)
|
||||
async for chunk in self._infer(**kwargs):
|
||||
yield chunk
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||
|
@ -164,13 +176,50 @@ class BaseModel(abc.ABC):
|
|||
if self.interface is not None:
|
||||
self.interface.mount(application)
|
||||
|
||||
# implementation specific mount
|
||||
self._mount(application)
|
||||
# create an endpoint wrapping the inference inside a fastapi call
|
||||
# the arguments will be loaded from the configuration files. Use kwargs for the definition
|
||||
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
||||
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
|
||||
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
|
||||
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
|
||||
# curiosity that may change in the future
|
||||
kwargs = {
|
||||
key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value
|
||||
for key, value in kwargs.items()
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def _mount(self, application: fastapi.FastAPI) -> None:
|
||||
"""
|
||||
Add the model to the api
|
||||
Do not call manually, use `unload` instead.
|
||||
:param application: the fastapi application
|
||||
"""
|
||||
# return a streaming response around the inference call
|
||||
return fastapi.responses.StreamingResponse(
|
||||
content=self.infer(**kwargs),
|
||||
media_type=self.output_type,
|
||||
headers={
|
||||
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
|
||||
}
|
||||
)
|
||||
|
||||
# update the signature of the function to use the configuration parameters
|
||||
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
||||
|
||||
# format the description
|
||||
description_sections: list[str] = []
|
||||
if self.description is not None:
|
||||
description_sections.append(f"# Description\n{self.description}")
|
||||
if self.interface is not None:
|
||||
description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
|
||||
|
||||
description: str = "\n".join(description_sections)
|
||||
|
||||
# add the inference endpoint on the API
|
||||
application.add_api_route(
|
||||
f"{self.api_base}/infer",
|
||||
infer_api,
|
||||
methods=["POST"],
|
||||
tags=self.tags,
|
||||
summary=self.summary,
|
||||
description=description,
|
||||
response_class=fastapi.responses.StreamingResponse,
|
||||
responses={
|
||||
200: {"content": {self.output_type: {}}}
|
||||
},
|
||||
)
|
||||
|
|
|
@ -60,7 +60,12 @@ class ModelRegistry:
|
|||
|
||||
# load all the models in the library
|
||||
for model_path in self.model_library.iterdir():
|
||||
# get the model name
|
||||
model_name: str = model_path.name
|
||||
if model_name.startswith("."):
|
||||
# ignore model starting with a dot
|
||||
continue
|
||||
|
||||
model_configuration_path: Path = model_path / "config.json"
|
||||
|
||||
# check if the configuration file exists
|
||||
|
@ -68,8 +73,11 @@ class ModelRegistry:
|
|||
warnings.warn(f"Model {model_name!r} is missing a config.json file.")
|
||||
continue
|
||||
|
||||
try:
|
||||
# load the configuration file
|
||||
model_configuration = json.loads(model_configuration_path.read_text())
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise Exception(f"Model {model_name!r}'s configuration is invalid. See above.")
|
||||
|
||||
# get the model type for this model
|
||||
model_type_name: str = model_configuration.get("type")
|
||||
|
|
|
@ -1,24 +1,5 @@
|
|||
import inspect
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import UploadFile
|
||||
|
||||
# the list of types and their name that can be used by the API
|
||||
types: dict[str, type] = {
|
||||
"bool": bool,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"str": str,
|
||||
"bytes": bytes,
|
||||
"dict": dict,
|
||||
"datetime": datetime,
|
||||
"file": UploadFile,
|
||||
|
||||
# TODO(Faraphel): use a "ParameterRegistry" or other functions to handle complex type ?
|
||||
"list[dict]": list[dict],
|
||||
# "tuple": tuple,
|
||||
# "set": set,
|
||||
}
|
||||
import fastapi
|
||||
|
||||
|
||||
def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
||||
|
@ -31,7 +12,6 @@ def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
|||
>>> parameters_definition = {
|
||||
... "boolean": {"type": "bool", "default": False},
|
||||
... "list": {"type": "list", "default": [1, 2, 3]},
|
||||
... "datetime": {"type": "datetime"},
|
||||
... "file": {"type": "file"},
|
||||
... }
|
||||
>>> parameters = load_parameters(parameters_definition)
|
||||
|
@ -40,12 +20,19 @@ def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]:
|
|||
parameters: list[inspect.Parameter] = []
|
||||
|
||||
for name, definition in parameters_definition.items():
|
||||
# preprocess the type
|
||||
match definition["type"]:
|
||||
case "file":
|
||||
# shortcut for uploading a file
|
||||
definition["type"] = fastapi.UploadFile
|
||||
|
||||
|
||||
# deserialize the parameter
|
||||
parameter = inspect.Parameter(
|
||||
name,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
default=definition.get("default", inspect.Parameter.empty),
|
||||
annotation=types[definition["type"]],
|
||||
annotation=definition["type"],
|
||||
)
|
||||
parameters.append(parameter)
|
||||
|
||||
|
|
Loading…
Reference in a new issue