mirror of
https://git.isriupjv.fr/ISRI/ai-server
synced 2025-04-24 18:18:11 +02:00
added support for additional more user-friendly interfaces, improved some part of the application loading process to make it a bit simpler
This commit is contained in:
parent
1a49aa3779
commit
f647c960dd
20 changed files with 353 additions and 107 deletions
|
@ -2,6 +2,7 @@
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
pydantic
|
pydantic
|
||||||
|
gradio
|
||||||
python-multipart
|
python-multipart
|
||||||
|
|
||||||
# AI
|
# AI
|
||||||
|
|
|
@ -2,10 +2,14 @@
|
||||||
"type": "python",
|
"type": "python",
|
||||||
"tags": ["dummy"],
|
"tags": ["dummy"],
|
||||||
"file": "model.py",
|
"file": "model.py",
|
||||||
|
"interface": "chat",
|
||||||
|
|
||||||
"output_type": "video/mp4",
|
"summary": "Echo model",
|
||||||
|
"description": "The most basic example model, simply echo the input",
|
||||||
|
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"file": {"type": "file"}
|
"messages": {"type": "list[dict]", "default": "[{\"role\": \"user\", \"content\": \"who are you ?\"}]"}
|
||||||
}
|
},
|
||||||
|
|
||||||
|
"output_type": "text/markdown"
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,5 +7,5 @@ def load(model) -> None:
|
||||||
def unload(model) -> None:
|
def unload(model) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def infer(model, file) -> typing.AsyncIterator[bytes]:
|
async def infer(model, messages: list[dict]) -> typing.AsyncIterator[bytes]:
|
||||||
yield await file.read()
|
yield messages[-1]["content"].encode("utf-8")
|
||||||
|
|
|
@ -16,7 +16,7 @@ def unload(model) -> None:
|
||||||
model.model = None
|
model.model = None
|
||||||
model.tokenizer = None
|
model.tokenizer = None
|
||||||
|
|
||||||
def infer(model, prompt: str) -> typing.Iterator[bytes]:
|
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
inputs = model.tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -16,7 +16,7 @@ def unload(model) -> None:
|
||||||
model.model = None
|
model.model = None
|
||||||
model.tokenizer = None
|
model.tokenizer = None
|
||||||
|
|
||||||
def infer(model, prompt: str) -> typing.Iterator[bytes]:
|
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
inputs = model.tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
from . import api
|
from . import api
|
||||||
from . import model
|
from . import model
|
||||||
from . import manager
|
from . import registry
|
||||||
|
|
|
@ -1,15 +1,25 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from source import manager, model, api
|
from source import registry, model, api
|
||||||
|
from source.api import interface
|
||||||
|
|
||||||
# create a fastapi application
|
# create a fastapi application
|
||||||
application = api.Application()
|
application = api.Application()
|
||||||
|
|
||||||
|
|
||||||
# create the model controller
|
# create the interface registry
|
||||||
model_controller = manager.ModelManager(application, os.environ["MODEL_LIBRARY"])
|
interface_registry = registry.InterfaceRegistry()
|
||||||
model_controller.register_model_type("python", model.PythonModel)
|
interface_registry.register_type("chat", interface.ChatInterface)
|
||||||
model_controller.reload()
|
|
||||||
|
|
||||||
|
# create the model registry
|
||||||
|
model_registry = registry.ModelRegistry(os.environ["MODEL_LIBRARY"], "/models", interface_registry)
|
||||||
|
model_registry.register_type("python", model.PythonModel)
|
||||||
|
model_registry.reload_models()
|
||||||
|
|
||||||
|
# add the model registry routes to the fastapi
|
||||||
|
model_registry.mount(application)
|
||||||
|
|
||||||
|
|
||||||
# serve the application
|
# serve the application
|
||||||
application.serve("0.0.0.0", 8000)
|
application.serve("0.0.0.0", 8000)
|
||||||
|
|
|
@ -8,7 +8,8 @@ class Application(fastapi.FastAPI):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
title=meta.name,
|
title=meta.name,
|
||||||
description=meta.description
|
description=meta.description,
|
||||||
|
redoc_url=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def serve(self, host: str = "0.0.0.0", port: int = 8080):
|
def serve(self, host: str = "0.0.0.0", port: int = 8080):
|
||||||
|
|
|
@ -1 +1,3 @@
|
||||||
|
from . import interface
|
||||||
|
|
||||||
from .Application import Application
|
from .Application import Application
|
||||||
|
|
75
source/api/interface/ChatInterface.py
Normal file
75
source/api/interface/ChatInterface.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
import gradio
|
||||||
|
|
||||||
|
from source import meta
|
||||||
|
from source.api.interface import base
|
||||||
|
from source.model.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ChatInterface(base.BaseInterface):
|
||||||
|
"""
|
||||||
|
An interface for Chat-like models.
|
||||||
|
Use the OpenAI convention (list of dict with roles and content)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: "BaseModel"):
|
||||||
|
# Function to send and receive chat messages
|
||||||
|
super().__init__(model)
|
||||||
|
|
||||||
|
async def send_message(self, user_message, old_messages: list[dict], system_message: str):
|
||||||
|
# normalize the user message (the type can be wrong, especially when "edited")
|
||||||
|
if isinstance(user_message, str):
|
||||||
|
user_message: dict = {"files": [], "text": user_message}
|
||||||
|
|
||||||
|
# copy the history to avoid modifying it
|
||||||
|
messages: list[dict] = old_messages.copy()
|
||||||
|
|
||||||
|
# add the system instruction
|
||||||
|
if system_message:
|
||||||
|
messages.insert(0, {"role": "system", "content": system_message})
|
||||||
|
|
||||||
|
# add the user message
|
||||||
|
# NOTE: gradio.ChatInterface add our message and the assistant message
|
||||||
|
# TODO(Faraphel): add support for files
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": user_message["text"],
|
||||||
|
})
|
||||||
|
|
||||||
|
# infer the message through the model
|
||||||
|
chunks = [chunk async for chunk in await self.model.infer(messages=messages)]
|
||||||
|
assistant_message: str = b"".join(chunks).decode("utf-8")
|
||||||
|
|
||||||
|
# send back the messages, clear the user prompt, disable the system prompt
|
||||||
|
return assistant_message
|
||||||
|
|
||||||
|
def get_gradio_application(self):
|
||||||
|
# create a gradio interface
|
||||||
|
with gradio.Blocks(analytics_enabled=False) as application:
|
||||||
|
# header
|
||||||
|
gradio.Markdown(textwrap.dedent(f"""
|
||||||
|
# {meta.name}
|
||||||
|
## {self.model.name}
|
||||||
|
"""))
|
||||||
|
|
||||||
|
# additional settings
|
||||||
|
with gradio.Accordion("Advanced Settings") as advanced_settings:
|
||||||
|
system_prompt = gradio.Textbox(
|
||||||
|
label="System prompt",
|
||||||
|
placeholder="You are an expert in C++...",
|
||||||
|
lines=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# chat interface
|
||||||
|
gradio.ChatInterface(
|
||||||
|
fn=self.send_message,
|
||||||
|
type="messages",
|
||||||
|
multimodal=True,
|
||||||
|
editable=True,
|
||||||
|
save_history=True,
|
||||||
|
additional_inputs=[system_prompt],
|
||||||
|
additional_inputs_accordion=advanced_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
return application
|
3
source/api/interface/__init__.py
Normal file
3
source/api/interface/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from . import base
|
||||||
|
|
||||||
|
from .ChatInterface import ChatInterface
|
40
source/api/interface/base/BaseInterface.py
Normal file
40
source/api/interface/base/BaseInterface.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
import abc
|
||||||
|
|
||||||
|
import fastapi
|
||||||
|
import gradio
|
||||||
|
|
||||||
|
import source
|
||||||
|
|
||||||
|
|
||||||
|
class BaseInterface(abc.ABC):
|
||||||
|
def __init__(self, model: "source.model.base.BaseModel"):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def route(self) -> str:
|
||||||
|
"""
|
||||||
|
The route to the interface
|
||||||
|
:return: the route to the interface
|
||||||
|
"""
|
||||||
|
|
||||||
|
return f"{self.model.api_base}/interface"
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_gradio_application(self) -> gradio.Blocks:
|
||||||
|
"""
|
||||||
|
Get a gradio application
|
||||||
|
:return: a gradio application
|
||||||
|
"""
|
||||||
|
|
||||||
|
def mount(self, application: fastapi.FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
Mount the interface on an application
|
||||||
|
:param application: the application to mount the interface on
|
||||||
|
:param path: the path where to mount the application
|
||||||
|
"""
|
||||||
|
|
||||||
|
gradio.mount_gradio_app(
|
||||||
|
application,
|
||||||
|
self.get_gradio_application(),
|
||||||
|
self.route
|
||||||
|
)
|
1
source/api/interface/base/__init__.py
Normal file
1
source/api/interface/base/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .BaseInterface import BaseInterface
|
|
@ -1 +0,0 @@
|
||||||
from .ModelManager import ModelManager
|
|
|
@ -9,8 +9,8 @@ from pathlib import Path
|
||||||
import fastapi
|
import fastapi
|
||||||
|
|
||||||
from source import utils
|
from source import utils
|
||||||
from source.manager import ModelManager
|
|
||||||
from source.model import base
|
from source.model import base
|
||||||
|
from source.registry import ModelRegistry
|
||||||
from source.utils.fastapi import UploadFileFix
|
from source.utils.fastapi import UploadFileFix
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,21 +19,22 @@ class PythonModel(base.BaseModel):
|
||||||
A model running a custom python model.
|
A model running a custom python model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, manager: ModelManager, configuration: dict, path: Path):
|
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
|
||||||
super().__init__(manager, configuration, path)
|
super().__init__(registry, configuration, path)
|
||||||
|
|
||||||
## Configuration
|
# get the parameters of the model
|
||||||
|
self.parameters = utils.parameters.load(configuration.get("inputs", {}))
|
||||||
# get the name of the file containing the model code
|
|
||||||
file = configuration.get("file")
|
|
||||||
if file is None:
|
|
||||||
raise ValueError("Field 'file' is missing from the configuration")
|
|
||||||
|
|
||||||
# install custom requirements
|
# install custom requirements
|
||||||
requirements = configuration.get("requirements", [])
|
requirements = configuration.get("requirements", [])
|
||||||
if len(requirements) > 0:
|
if len(requirements) > 0:
|
||||||
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
|
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
|
||||||
|
|
||||||
|
# get the name of the file containing the model code
|
||||||
|
file = configuration.get("file")
|
||||||
|
if file is None:
|
||||||
|
raise ValueError("Field 'file' is missing from the configuration")
|
||||||
|
|
||||||
# create the module specification
|
# create the module specification
|
||||||
module_spec = importlib.util.spec_from_file_location(
|
module_spec = importlib.util.spec_from_file_location(
|
||||||
f"model-{uuid.uuid4()}",
|
f"model-{uuid.uuid4()}",
|
||||||
|
@ -44,10 +45,17 @@ class PythonModel(base.BaseModel):
|
||||||
# load the module
|
# load the module
|
||||||
module_spec.loader.exec_module(self.module)
|
module_spec.loader.exec_module(self.module)
|
||||||
|
|
||||||
## Api
|
def _load(self) -> None:
|
||||||
|
return self.module.load(self)
|
||||||
|
|
||||||
# load the inputs data into the inference function signature (used by FastAPI)
|
def _unload(self) -> None:
|
||||||
parameters = utils.parameters.load(configuration.get("inputs", {}))
|
return self.module.unload(self)
|
||||||
|
|
||||||
|
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||||
|
return self.module.infer(self, **kwargs)
|
||||||
|
|
||||||
|
def _mount(self, application: fastapi.FastAPI):
|
||||||
|
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
|
||||||
|
|
||||||
# create an endpoint wrapping the inference inside a fastapi call
|
# create an endpoint wrapping the inference inside a fastapi call
|
||||||
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
|
||||||
|
@ -61,7 +69,7 @@ class PythonModel(base.BaseModel):
|
||||||
}
|
}
|
||||||
|
|
||||||
return fastapi.responses.StreamingResponse(
|
return fastapi.responses.StreamingResponse(
|
||||||
content=await self.infer(**kwargs),
|
content=await self.registry.infer_model(self, **kwargs),
|
||||||
media_type=self.output_type,
|
media_type=self.output_type,
|
||||||
headers={
|
headers={
|
||||||
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||||
|
@ -69,27 +77,25 @@ class PythonModel(base.BaseModel):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
infer_api.__signature__ = inspect.Signature(parameters=parameters)
|
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
|
||||||
|
|
||||||
|
# format the description
|
||||||
|
description_sections: list[str] = []
|
||||||
|
if self.description is not None:
|
||||||
|
description_sections.append(self.description)
|
||||||
|
if self.interface is not None:
|
||||||
|
description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**")
|
||||||
|
|
||||||
# add the inference endpoint on the API
|
# add the inference endpoint on the API
|
||||||
self.manager.application.add_api_route(
|
application.add_api_route(
|
||||||
f"/models/{self.name}/infer",
|
f"{self.api_base}/infer",
|
||||||
infer_api,
|
infer_api,
|
||||||
methods=["POST"],
|
methods=["POST"],
|
||||||
tags=self.tags,
|
tags=self.tags,
|
||||||
# summary=...,
|
summary=self.summary,
|
||||||
# description=...,
|
description="<br>".join(description_sections),
|
||||||
response_class=fastapi.responses.StreamingResponse,
|
response_class=fastapi.responses.StreamingResponse,
|
||||||
responses={
|
responses={
|
||||||
200: {"content": {self.output_type: {}}}
|
200: {"content": {self.output_type: {}}}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load(self) -> None:
|
|
||||||
return self.module.load(self)
|
|
||||||
|
|
||||||
def _unload(self) -> None:
|
|
||||||
return self.module.unload(self)
|
|
||||||
|
|
||||||
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]:
|
|
||||||
return self.module.infer(self, **kwargs)
|
|
||||||
|
|
|
@ -3,7 +3,9 @@ import gc
|
||||||
import typing
|
import typing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from source.manager import ModelManager
|
import fastapi
|
||||||
|
|
||||||
|
from source.registry import ModelRegistry
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(abc.ABC):
|
class BaseModel(abc.ABC):
|
||||||
|
@ -11,21 +13,43 @@ class BaseModel(abc.ABC):
|
||||||
Represent a model.
|
Represent a model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, manager: ModelManager, configuration: dict[str, typing.Any], path: Path):
|
def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path):
|
||||||
|
# the model registry
|
||||||
|
self.registry = registry
|
||||||
|
|
||||||
|
# get the documentation of the model
|
||||||
|
self.summary = configuration.get("summary")
|
||||||
|
self.description = configuration.get("description")
|
||||||
|
|
||||||
# the environment directory of the model
|
# the environment directory of the model
|
||||||
self.path = path
|
self.path = path
|
||||||
# the model manager
|
|
||||||
self.manager = manager
|
|
||||||
# the mimetype of the model responses
|
# the mimetype of the model responses
|
||||||
self.output_type: str = configuration.get("output_type", "application/json")
|
self.output_type: str = configuration.get("output_type", "application/json")
|
||||||
# get the tags of the model
|
# get the tags of the model
|
||||||
self.tags = configuration.get("tags", [])
|
self.tags = configuration.get("tags", [])
|
||||||
|
|
||||||
|
# get the selected interface of the model
|
||||||
|
interface_name: typing.Optional[str] = configuration.get("interface", None)
|
||||||
|
self.interface = (
|
||||||
|
self.registry.interface_registry.interface_types[interface_name](self)
|
||||||
|
if interface_name is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# is the model currently loaded
|
||||||
self._loaded = False
|
self._loaded = False
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<{self.__class__.__name__}: {self.name}>"
|
return f"<{self.__class__.__name__}: {self.name}>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_base(self) -> str:
|
||||||
|
"""
|
||||||
|
Base for the API routes
|
||||||
|
:return: the base for the API routes
|
||||||
|
"""
|
||||||
|
|
||||||
|
return f"{self.registry.api_base}/{self.name}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
"""
|
"""
|
||||||
|
@ -44,6 +68,7 @@ class BaseModel(abc.ABC):
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"output_type": self.output_type,
|
"output_type": self.output_type,
|
||||||
|
"tags": self.tags
|
||||||
}
|
}
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
|
@ -51,22 +76,13 @@ class BaseModel(abc.ABC):
|
||||||
Load the model within the model manager
|
Load the model within the model manager
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# if we are already loaded, stop
|
# if the model is already loaded, skip
|
||||||
if self._loaded:
|
if self._loaded:
|
||||||
return
|
return
|
||||||
|
|
||||||
# check if we are the current loaded model
|
# load the model depending on the implementation
|
||||||
if self.manager.current_loaded_model is not self:
|
|
||||||
# unload the previous model
|
|
||||||
if self.manager.current_loaded_model is not None:
|
|
||||||
self.manager.current_loaded_model.unload()
|
|
||||||
|
|
||||||
# model specific loading
|
|
||||||
self._load()
|
self._load()
|
||||||
|
|
||||||
# declare ourselves as the currently loaded model
|
|
||||||
self.manager.current_loaded_model = self
|
|
||||||
|
|
||||||
# mark the model as loaded
|
# mark the model as loaded
|
||||||
self._loaded = True
|
self._loaded = True
|
||||||
|
|
||||||
|
@ -86,11 +102,7 @@ class BaseModel(abc.ABC):
|
||||||
if not self._loaded:
|
if not self._loaded:
|
||||||
return
|
return
|
||||||
|
|
||||||
# if we were the currently loaded model of the manager, demote ourselves
|
# unload the model depending on the implementation
|
||||||
if self.manager.current_loaded_model is self:
|
|
||||||
self.manager.current_loaded_model = None
|
|
||||||
|
|
||||||
# model specific unloading part
|
|
||||||
self._unload()
|
self._unload()
|
||||||
|
|
||||||
# force the garbage collector to clean the memory
|
# force the garbage collector to clean the memory
|
||||||
|
@ -106,22 +118,42 @@ class BaseModel(abc.ABC):
|
||||||
Do not call manually, use `unload` instead.
|
Do not call manually, use `unload` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
|
async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||||
"""
|
"""
|
||||||
Infer our payload through the model within the model manager
|
Infer our payload through the model within the model manager
|
||||||
:return: the response of the model
|
:return: the response of the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async with self.manager.inference_lock:
|
# make sure we are loaded before an inference
|
||||||
# make sure we are loaded before an inference
|
self.load()
|
||||||
self.load()
|
|
||||||
|
|
||||||
# model specific inference part
|
# model specific inference part
|
||||||
return self._infer(**kwargs)
|
return await self._infer(**kwargs)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
|
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||||
"""
|
"""
|
||||||
Infer our payload through the model
|
Infer our payload through the model
|
||||||
:return: the response of the model
|
:return: the response of the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def mount(self, application: fastapi.FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
Add the model to the api
|
||||||
|
:param application: the fastapi application
|
||||||
|
"""
|
||||||
|
|
||||||
|
# mount the interface if selected
|
||||||
|
if self.interface is not None:
|
||||||
|
self.interface.mount(application)
|
||||||
|
|
||||||
|
# implementation specific mount
|
||||||
|
self._mount(application)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _mount(self, application: fastapi.FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
Add the model to the api
|
||||||
|
Do not call manually, use `unload` instead.
|
||||||
|
:param application: the fastapi application
|
||||||
|
"""
|
||||||
|
|
16
source/registry/InterfaceRegistry.py
Normal file
16
source/registry/InterfaceRegistry.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from source.api.interface import base
|
||||||
|
|
||||||
|
|
||||||
|
class InterfaceRegistry:
|
||||||
|
"""
|
||||||
|
The interface registry
|
||||||
|
Store the list of other interface available
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.interface_types: dict[str, typing.Type[base.BaseInterface]] = {}
|
||||||
|
|
||||||
|
def register_type(self, name: str, interface_type: typing.Type[base.BaseInterface]):
|
||||||
|
self.interface_types[name] = interface_type
|
|
@ -7,64 +7,80 @@ from pathlib import Path
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
|
||||||
from source import model, api
|
from source.model.base import BaseModel
|
||||||
|
from source.registry import InterfaceRegistry
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelRegistry:
|
||||||
"""
|
"""
|
||||||
The model manager
|
The model registry
|
||||||
Load the list of models available, ensure that only one model is loaded at the same time.
|
Load the list of models available, ensure that only one model is loaded at the same time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, application: api.Application, model_library: os.PathLike | str):
|
def __init__(self, model_library: os.PathLike | str, api_base: str, interface_registry: InterfaceRegistry):
|
||||||
self.application: api.Application = application
|
|
||||||
self.model_library: Path = Path(model_library)
|
self.model_library: Path = Path(model_library)
|
||||||
|
self.interface_registry = interface_registry
|
||||||
|
self._api_base = api_base
|
||||||
|
|
||||||
# the model types
|
# the model types
|
||||||
self.model_types: dict[str, typing.Type[model.base.BaseModel]] = {}
|
self.model_types: dict[str, typing.Type[BaseModel]] = {}
|
||||||
# the models
|
# the models
|
||||||
self.models: dict[str, model.base.BaseModel] = {}
|
self.models: dict[str, BaseModel] = {}
|
||||||
|
|
||||||
# the currently loaded model
|
# the currently loaded model
|
||||||
# TODO(Faraphel): load more than one model at a time ?
|
# TODO(Faraphel): load more than one model at a time ?
|
||||||
# would require a way more complex manager to handle memory issue
|
# would require a way more complex manager to handle memory issue
|
||||||
# having two calculations at the same time might not be worth it either
|
# having two calculations at the same time might not be worth it either
|
||||||
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
|
self.current_loaded_model: typing.Optional[BaseModel] = None
|
||||||
|
|
||||||
# lock to avoid concurrent inference and concurrent model loading and unloading
|
# lock to avoid concurrent inference and concurrent model loading and unloading
|
||||||
self.inference_lock = asyncio.Lock()
|
self.inference_lock = asyncio.Lock()
|
||||||
|
|
||||||
@self.application.get("/models")
|
@property
|
||||||
async def get_models() -> list[str]:
|
def api_base(self) -> str:
|
||||||
"""
|
"""
|
||||||
Get the list of models available
|
Base for the api routes
|
||||||
:return: the list of models available
|
:return: the base for the api routes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# list the models found
|
return self._api_base
|
||||||
return list(self.models.keys())
|
|
||||||
|
|
||||||
@self.application.get("/models/{model_name}")
|
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
|
||||||
async def get_model(model_name: str) -> dict:
|
|
||||||
"""
|
|
||||||
Get information about a specific model
|
|
||||||
:param model_name: the name of the model
|
|
||||||
:return: the information about the corresponding model
|
|
||||||
"""
|
|
||||||
|
|
||||||
# get the corresponding model
|
|
||||||
model = self.models.get(model_name)
|
|
||||||
if model is None:
|
|
||||||
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
|
||||||
|
|
||||||
# return the model information
|
|
||||||
return model.get_information()
|
|
||||||
|
|
||||||
|
|
||||||
def register_model_type(self, name: str, model_type: "typing.Type[model.base.BaseModel]"):
|
|
||||||
self.model_types[name] = model_type
|
self.model_types[name] = model_type
|
||||||
|
|
||||||
def reload(self):
|
async def load_model(self, model: "BaseModel"):
|
||||||
|
# lock to avoid concurrent loading
|
||||||
|
async with self.inference_lock:
|
||||||
|
# if there is another currently loaded model, unload it
|
||||||
|
if self.current_loaded_model is not None and self.current_loaded_model is not model:
|
||||||
|
await self.unload_model(self.current_loaded_model)
|
||||||
|
|
||||||
|
# load the model
|
||||||
|
model.load()
|
||||||
|
|
||||||
|
# mark the model as the currently loaded model of the manager
|
||||||
|
self.current_loaded_model = model
|
||||||
|
|
||||||
|
async def unload_model(self, model: "BaseModel"):
|
||||||
|
# lock to avoid concurrent unloading
|
||||||
|
async with self.inference_lock:
|
||||||
|
# if we were the currently loaded model of the manager, demote ourselves
|
||||||
|
if self.current_loaded_model is model:
|
||||||
|
self.current_loaded_model = None
|
||||||
|
|
||||||
|
# model specific unloading part
|
||||||
|
model.unload()
|
||||||
|
|
||||||
|
async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]:
|
||||||
|
# lock to avoid concurrent inference
|
||||||
|
async with self.inference_lock:
|
||||||
|
return await model.infer(**kwargs)
|
||||||
|
|
||||||
|
def reload_models(self) -> None:
|
||||||
|
"""
|
||||||
|
Reload the list of models available
|
||||||
|
"""
|
||||||
|
|
||||||
# reset the model list
|
# reset the model list
|
||||||
for model in self.models.values():
|
for model in self.models.values():
|
||||||
model.unload()
|
model.unload()
|
||||||
|
@ -97,3 +113,39 @@ class ModelManager:
|
||||||
|
|
||||||
# load the model
|
# load the model
|
||||||
self.models[model_name] = model_type(self, model_configuration, model_path)
|
self.models[model_name] = model_type(self, model_configuration, model_path)
|
||||||
|
|
||||||
|
def mount(self, application: fastapi.FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
Mount the models endpoints into a fastapi application
|
||||||
|
:param application: the fastapi application
|
||||||
|
"""
|
||||||
|
|
||||||
|
@application.get(self.api_base)
|
||||||
|
async def get_models() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get the list of models available
|
||||||
|
:return: the list of models available
|
||||||
|
"""
|
||||||
|
|
||||||
|
# list the models found
|
||||||
|
return list(self.models.keys())
|
||||||
|
|
||||||
|
@application.get(f"{self.api_base}/{{model_name}}")
|
||||||
|
async def get_model(model_name: str) -> dict:
|
||||||
|
"""
|
||||||
|
Get information about a specific model
|
||||||
|
:param model_name: the name of the model
|
||||||
|
:return: the information about the corresponding model
|
||||||
|
"""
|
||||||
|
|
||||||
|
# get the corresponding model
|
||||||
|
model = self.models.get(model_name)
|
||||||
|
if model is None:
|
||||||
|
raise fastapi.HTTPException(status_code=404, detail="Model not found")
|
||||||
|
|
||||||
|
# return the model information
|
||||||
|
return model.get_information()
|
||||||
|
|
||||||
|
# mount all the models in the registry
|
||||||
|
for model_name, model in self.models.items():
|
||||||
|
model.mount(application)
|
2
source/registry/__init__.py
Normal file
2
source/registry/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .ModelRegistry import ModelRegistry
|
||||||
|
from .InterfaceRegistry import InterfaceRegistry
|
|
@ -10,12 +10,14 @@ types: dict[str, type] = {
|
||||||
"float": float,
|
"float": float,
|
||||||
"str": str,
|
"str": str,
|
||||||
"bytes": bytes,
|
"bytes": bytes,
|
||||||
"list": list,
|
|
||||||
"tuple": tuple,
|
|
||||||
"set": set,
|
|
||||||
"dict": dict,
|
"dict": dict,
|
||||||
"datetime": datetime,
|
"datetime": datetime,
|
||||||
"file": UploadFile,
|
"file": UploadFile,
|
||||||
|
|
||||||
|
# TODO(Faraphel): use a "ParameterRegistry" or other functions to handle complex type ?
|
||||||
|
"list[dict]": list[dict],
|
||||||
|
# "tuple": tuple,
|
||||||
|
# "set": set,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue