mirror of
https://git.isriupjv.fr/ISRI/ai-server
synced 2025-04-24 10:08:11 +02:00
176 lines
4.9 KiB
Python
176 lines
4.9 KiB
Python
import abc
|
|
import asyncio
|
|
import gc
|
|
import typing
|
|
from pathlib import Path
|
|
|
|
import fastapi
|
|
|
|
from source.registry import ModelRegistry
|
|
|
|
|
|
class BaseModel(abc.ABC):
|
|
"""
|
|
Represent a model.
|
|
"""
|
|
|
|
def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path):
|
|
# the model registry
|
|
self.registry = registry
|
|
|
|
# get the documentation of the model
|
|
self.summary = configuration.get("summary")
|
|
self.description = configuration.get("description")
|
|
|
|
# the environment directory of the model
|
|
self.path = path
|
|
# the mimetype of the model responses
|
|
self.output_type: str = configuration.get("output_type", "application/json")
|
|
# get the tags of the model
|
|
self.tags = configuration.get("tags", [])
|
|
|
|
# get the selected interface of the model
|
|
interface_name: typing.Optional[str] = configuration.get("interface", None)
|
|
self.interface = (
|
|
self.registry.interface_registry.interface_types[interface_name](self)
|
|
if interface_name is not None else None
|
|
)
|
|
|
|
# is the model currently loaded
|
|
self._loaded = False
|
|
|
|
# lock to avoid loading and unloading at the same time
|
|
self.load_lock = asyncio.Lock()
|
|
|
|
def __repr__(self):
|
|
return f"<{self.__class__.__name__}: {self.name}>"
|
|
|
|
@property
|
|
def api_base(self) -> str:
|
|
"""
|
|
Base for the API routes
|
|
:return: the base for the API routes
|
|
"""
|
|
|
|
return f"{self.registry.api_base}/{self.name}"
|
|
|
|
@property
|
|
def name(self):
|
|
"""
|
|
Get the name of the model
|
|
:return: the name of the model
|
|
"""
|
|
|
|
return self.path.name
|
|
|
|
def get_information(self):
|
|
"""
|
|
Get information about the model
|
|
:return: information about the model
|
|
"""
|
|
|
|
return {
|
|
"name": self.name,
|
|
"output_type": self.output_type,
|
|
"tags": self.tags
|
|
}
|
|
|
|
async def load(self) -> None:
|
|
"""
|
|
Load the model within the model manager
|
|
"""
|
|
|
|
async with self.load_lock:
|
|
# if the model is already loaded, skip
|
|
if self._loaded:
|
|
return
|
|
|
|
# unload the currently loaded model if any
|
|
if self.registry.current_loaded_model is not None:
|
|
await self.registry.current_loaded_model.unload()
|
|
|
|
# load the model depending on the implementation
|
|
await self._load()
|
|
|
|
# mark the model as loaded
|
|
self._loaded = True
|
|
# mark the model as the registry loaded model
|
|
self.registry.current_loaded_model = self
|
|
|
|
@abc.abstractmethod
|
|
async def _load(self):
|
|
"""
|
|
Load the model
|
|
Do not call manually, use `load` instead.
|
|
"""
|
|
|
|
async def unload(self) -> None:
|
|
"""
|
|
Unload the model within the model manager
|
|
"""
|
|
|
|
async with self.load_lock:
|
|
# if we are not already loaded, stop
|
|
if not self._loaded:
|
|
return
|
|
|
|
# unload the model depending on the implementation
|
|
await self._unload()
|
|
|
|
# force the garbage collector to clean the memory
|
|
gc.collect()
|
|
|
|
# mark the model as unloaded
|
|
self._loaded = False
|
|
|
|
# if we are the registry current loaded model, remove this status
|
|
if self.registry.current_loaded_model is self:
|
|
self.registry.current_loaded_model = None
|
|
|
|
@abc.abstractmethod
|
|
async def _unload(self):
|
|
"""
|
|
Unload the model
|
|
Do not call manually, use `unload` instead.
|
|
"""
|
|
|
|
async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
|
"""
|
|
Infer our payload through the model within the model manager
|
|
:return: the response of the model
|
|
"""
|
|
|
|
async with self.registry.infer_lock:
|
|
# ensure that the model is loaded
|
|
await self.load()
|
|
|
|
# model specific inference part
|
|
return await self._infer(**kwargs)
|
|
|
|
@abc.abstractmethod
|
|
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
|
"""
|
|
Infer our payload through the model
|
|
:return: the response of the model
|
|
"""
|
|
|
|
def mount(self, application: fastapi.FastAPI) -> None:
|
|
"""
|
|
Add the model to the api
|
|
:param application: the fastapi application
|
|
"""
|
|
|
|
# mount the interface if selected
|
|
if self.interface is not None:
|
|
self.interface.mount(application)
|
|
|
|
# implementation specific mount
|
|
self._mount(application)
|
|
|
|
@abc.abstractmethod
|
|
def _mount(self, application: fastapi.FastAPI) -> None:
|
|
"""
|
|
Add the model to the api
|
|
Do not call manually, use `unload` instead.
|
|
:param application: the fastapi application
|
|
"""
|