mirror of
https://git.isriupjv.fr/ISRI/ai-server
synced 2025-04-24 01:58:12 +02:00
131 lines
4.4 KiB
Python
131 lines
4.4 KiB
Python
import asyncio
|
|
import json
|
|
import os
|
|
import typing
|
|
import warnings
|
|
from pathlib import Path
|
|
|
|
import fastapi
|
|
|
|
from source.model.base import BaseModel
|
|
from source.registry import InterfaceRegistry
|
|
|
|
|
|
class ModelRegistry:
|
|
"""
|
|
The model registry
|
|
Load the list of models available, ensure that only one model is loaded at the same time.
|
|
"""
|
|
|
|
def __init__(self, model_library: os.PathLike | str, api_base: str, interface_registry: InterfaceRegistry):
|
|
self.model_library: Path = Path(model_library)
|
|
self.interface_registry = interface_registry
|
|
self._api_base = api_base
|
|
|
|
# the model types
|
|
self.model_types: dict[str, typing.Type[BaseModel]] = {}
|
|
# the models
|
|
self.models: dict[str, BaseModel] = {}
|
|
|
|
# the currently loaded model
|
|
# TODO(Faraphel): load more than one model at a time ?
|
|
# would require a way more complex manager to handle memory issue
|
|
# having two calculations at the same time might not be worth it either
|
|
self.current_loaded_model: typing.Optional[BaseModel] = None
|
|
|
|
# lock to control access to model inference
|
|
self.infer_lock = asyncio.Lock()
|
|
|
|
@property
|
|
def api_base(self) -> str:
|
|
"""
|
|
Base for the api routes
|
|
:return: the base for the api routes
|
|
"""
|
|
|
|
return self._api_base
|
|
|
|
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
|
|
self.model_types[name] = model_type
|
|
|
|
def reload_models(self) -> None:
|
|
"""
|
|
Reload the list of models available
|
|
"""
|
|
|
|
# reset the model list
|
|
for model in self.models.values():
|
|
model.unload()
|
|
self.models.clear()
|
|
|
|
# load all the models in the library
|
|
for model_path in self.model_library.iterdir():
|
|
# get the model name
|
|
model_name: str = model_path.name
|
|
if model_name.startswith("."):
|
|
# ignore model starting with a dot
|
|
continue
|
|
|
|
model_configuration_path: Path = model_path / "config.json"
|
|
|
|
# check if the configuration file exists
|
|
if not model_configuration_path.exists():
|
|
warnings.warn(f"Model {model_name!r} is missing a config.json file.")
|
|
continue
|
|
|
|
try:
|
|
# load the configuration file
|
|
model_configuration = json.loads(model_configuration_path.read_text())
|
|
except json.decoder.JSONDecodeError:
|
|
raise Exception(f"Model {model_name!r}'s configuration is invalid. See above.")
|
|
|
|
# get the model type for this model
|
|
model_type_name: str = model_configuration.get("type")
|
|
if model_type_name not in self.model_types:
|
|
warnings.warn("Field 'type' missing from the model configuration file.")
|
|
continue
|
|
|
|
# get the class of this model type
|
|
model_type = self.model_types.get(model_type_name)
|
|
if model_type is None:
|
|
warnings.warn(f"Model type {model_type_name!r} does not exists. Has it been registered ?")
|
|
continue
|
|
|
|
# load the model
|
|
self.models[model_name] = model_type(self, model_configuration, model_path)
|
|
|
|
def mount(self, application: fastapi.FastAPI) -> None:
|
|
"""
|
|
Mount the models endpoints into a fastapi application
|
|
:param application: the fastapi application
|
|
"""
|
|
|
|
@application.get(self.api_base)
|
|
async def get_models() -> list[str]:
|
|
"""
|
|
Get the list of models available
|
|
:return: the list of models available
|
|
"""
|
|
|
|
# list the models found
|
|
return list(self.models.keys())
|
|
|
|
@application.get(f"{self.api_base}/{{model_name}}")
|
|
async def get_model(model_name: str) -> dict:
|
|
"""
|
|
Get information about a specific model
|
|
:param model_name: the name of the model
|
|
:return: the information about the corresponding model
|
|
"""
|
|
|
|
# get the corresponding model
|
|
model = self.models.get(model_name)
|
|
if model is None:
|
|
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
|
|
|
# return the model information
|
|
return model.get_information()
|
|
|
|
# mount all the models in the registry
|
|
for model_name, model in self.models.items():
|
|
model.mount(application)
|