added support for additional more user-friendly interfaces, improved some part of the application loading process to make it a bit simpler

This commit is contained in:
faraphel 2025-01-12 12:52:19 +01:00
parent 1a49aa3779
commit f647c960dd
20 changed files with 353 additions and 107 deletions

View file

@ -2,6 +2,7 @@
fastapi fastapi
uvicorn uvicorn
pydantic pydantic
gradio
python-multipart python-multipart
# AI # AI

View file

@ -2,10 +2,14 @@
"type": "python", "type": "python",
"tags": ["dummy"], "tags": ["dummy"],
"file": "model.py", "file": "model.py",
"interface": "chat",
"output_type": "video/mp4", "summary": "Echo model",
"description": "The most basic example model, simply echo the input",
"inputs": { "inputs": {
"file": {"type": "file"} "messages": {"type": "list[dict]", "default": "[{\"role\": \"user\", \"content\": \"who are you ?\"}]"}
} },
"output_type": "text/markdown"
} }

View file

@ -7,5 +7,5 @@ def load(model) -> None:
def unload(model) -> None: def unload(model) -> None:
pass pass
async def infer(model, file) -> typing.AsyncIterator[bytes]: async def infer(model, messages: list[dict]) -> typing.AsyncIterator[bytes]:
yield await file.read() yield messages[-1]["content"].encode("utf-8")

View file

@ -16,7 +16,7 @@ def unload(model) -> None:
model.model = None model.model = None
model.tokenizer = None model.tokenizer = None
def infer(model, prompt: str) -> typing.Iterator[bytes]: async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
inputs = model.tokenizer(prompt, return_tensors="pt") inputs = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():

View file

@ -16,7 +16,7 @@ def unload(model) -> None:
model.model = None model.model = None
model.tokenizer = None model.tokenizer = None
def infer(model, prompt: str) -> typing.Iterator[bytes]: async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
inputs = model.tokenizer(prompt, return_tensors="pt") inputs = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():

View file

@ -1,3 +1,3 @@
from . import api from . import api
from . import model from . import model
from . import manager from . import registry

View file

@ -1,15 +1,25 @@
import os import os
from source import manager, model, api from source import registry, model, api
from source.api import interface
# create a fastapi application # create a fastapi application
application = api.Application() application = api.Application()
# create the model controller # create the interface registry
model_controller = manager.ModelManager(application, os.environ["MODEL_LIBRARY"]) interface_registry = registry.InterfaceRegistry()
model_controller.register_model_type("python", model.PythonModel) interface_registry.register_type("chat", interface.ChatInterface)
model_controller.reload()
# create the model registry
model_registry = registry.ModelRegistry(os.environ["MODEL_LIBRARY"], "/models", interface_registry)
model_registry.register_type("python", model.PythonModel)
model_registry.reload_models()
# add the model registry routes to the fastapi
model_registry.mount(application)
# serve the application # serve the application
application.serve("0.0.0.0", 8000) application.serve("0.0.0.0", 8000)

View file

@ -8,7 +8,8 @@ class Application(fastapi.FastAPI):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
title=meta.name, title=meta.name,
description=meta.description description=meta.description,
redoc_url=None,
) )
def serve(self, host: str = "0.0.0.0", port: int = 8080): def serve(self, host: str = "0.0.0.0", port: int = 8080):

View file

@ -1 +1,3 @@
from . import interface
from .Application import Application from .Application import Application

View file

@ -0,0 +1,75 @@
import textwrap
import gradio
from source import meta
from source.api.interface import base
from source.model.base import BaseModel
class ChatInterface(base.BaseInterface):
"""
An interface for Chat-like models.
Use the OpenAI convention (list of dict with roles and content)
"""
def __init__(self, model: "BaseModel"):
# Function to send and receive chat messages
super().__init__(model)
async def send_message(self, user_message, old_messages: list[dict], system_message: str):
# normalize the user message (the type can be wrong, especially when "edited")
if isinstance(user_message, str):
user_message: dict = {"files": [], "text": user_message}
# copy the history to avoid modifying it
messages: list[dict] = old_messages.copy()
# add the system instruction
if system_message:
messages.insert(0, {"role": "system", "content": system_message})
# add the user message
# NOTE: gradio.ChatInterface add our message and the assistant message
# TODO(Faraphel): add support for files
messages.append({
"role": "user",
"content": user_message["text"],
})
# infer the message through the model
chunks = [chunk async for chunk in await self.model.infer(messages=messages)]
assistant_message: str = b"".join(chunks).decode("utf-8")
# send back the messages, clear the user prompt, disable the system prompt
return assistant_message
def get_gradio_application(self):
# create a gradio interface
with gradio.Blocks(analytics_enabled=False) as application:
# header
gradio.Markdown(textwrap.dedent(f"""
# {meta.name}
## {self.model.name}
"""))
# additional settings
with gradio.Accordion("Advanced Settings") as advanced_settings:
system_prompt = gradio.Textbox(
label="System prompt",
placeholder="You are an expert in C++...",
lines=2,
)
# chat interface
gradio.ChatInterface(
fn=self.send_message,
type="messages",
multimodal=True,
editable=True,
save_history=True,
additional_inputs=[system_prompt],
additional_inputs_accordion=advanced_settings,
)
return application

View file

@ -0,0 +1,3 @@
from . import base
from .ChatInterface import ChatInterface

View file

@ -0,0 +1,40 @@
import abc
import fastapi
import gradio
import source
class BaseInterface(abc.ABC):
def __init__(self, model: "source.model.base.BaseModel"):
self.model = model
@property
def route(self) -> str:
"""
The route to the interface
:return: the route to the interface
"""
return f"{self.model.api_base}/interface"
@abc.abstractmethod
def get_gradio_application(self) -> gradio.Blocks:
"""
Get a gradio application
:return: a gradio application
"""
def mount(self, application: fastapi.FastAPI) -> None:
"""
Mount the interface on an application
:param application: the application to mount the interface on
:param path: the path where to mount the application
"""
gradio.mount_gradio_app(
application,
self.get_gradio_application(),
self.route
)

View file

@ -0,0 +1 @@
from .BaseInterface import BaseInterface

View file

@ -1 +0,0 @@
from .ModelManager import ModelManager

View file

@ -9,8 +9,8 @@ from pathlib import Path
import fastapi import fastapi
from source import utils from source import utils
from source.manager import ModelManager
from source.model import base from source.model import base
from source.registry import ModelRegistry
from source.utils.fastapi import UploadFileFix from source.utils.fastapi import UploadFileFix
@ -19,21 +19,22 @@ class PythonModel(base.BaseModel):
A model running a custom python model. A model running a custom python model.
""" """
def __init__(self, manager: ModelManager, configuration: dict, path: Path): def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
super().__init__(manager, configuration, path) super().__init__(registry, configuration, path)
## Configuration # get the parameters of the model
self.parameters = utils.parameters.load(configuration.get("inputs", {}))
# get the name of the file containing the model code
file = configuration.get("file")
if file is None:
raise ValueError("Field 'file' is missing from the configuration")
# install custom requirements # install custom requirements
requirements = configuration.get("requirements", []) requirements = configuration.get("requirements", [])
if len(requirements) > 0: if len(requirements) > 0:
subprocess.run([sys.executable, "-m", "pip", "install", *requirements]) subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
# get the name of the file containing the model code
file = configuration.get("file")
if file is None:
raise ValueError("Field 'file' is missing from the configuration")
# create the module specification # create the module specification
module_spec = importlib.util.spec_from_file_location( module_spec = importlib.util.spec_from_file_location(
f"model-{uuid.uuid4()}", f"model-{uuid.uuid4()}",
@ -44,10 +45,17 @@ class PythonModel(base.BaseModel):
# load the module # load the module
module_spec.loader.exec_module(self.module) module_spec.loader.exec_module(self.module)
## Api def _load(self) -> None:
return self.module.load(self)
# load the inputs data into the inference function signature (used by FastAPI) def _unload(self) -> None:
parameters = utils.parameters.load(configuration.get("inputs", {})) return self.module.unload(self)
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
return self.module.infer(self, **kwargs)
def _mount(self, application: fastapi.FastAPI):
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
# create an endpoint wrapping the inference inside a fastapi call # create an endpoint wrapping the inference inside a fastapi call
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse: async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
@ -61,7 +69,7 @@ class PythonModel(base.BaseModel):
} }
return fastapi.responses.StreamingResponse( return fastapi.responses.StreamingResponse(
content=await self.infer(**kwargs), content=await self.registry.infer_model(self, **kwargs),
media_type=self.output_type, media_type=self.output_type,
headers={ headers={
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI # if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
@ -69,27 +77,25 @@ class PythonModel(base.BaseModel):
} }
) )
infer_api.__signature__ = inspect.Signature(parameters=parameters) infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
# format the description
description_sections: list[str] = []
if self.description is not None:
description_sections.append(self.description)
if self.interface is not None:
description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**")
# add the inference endpoint on the API # add the inference endpoint on the API
self.manager.application.add_api_route( application.add_api_route(
f"/models/{self.name}/infer", f"{self.api_base}/infer",
infer_api, infer_api,
methods=["POST"], methods=["POST"],
tags=self.tags, tags=self.tags,
# summary=..., summary=self.summary,
# description=..., description="<br>".join(description_sections),
response_class=fastapi.responses.StreamingResponse, response_class=fastapi.responses.StreamingResponse,
responses={ responses={
200: {"content": {self.output_type: {}}} 200: {"content": {self.output_type: {}}}
}, },
) )
def _load(self) -> None:
return self.module.load(self)
def _unload(self) -> None:
return self.module.unload(self)
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]:
return self.module.infer(self, **kwargs)

View file

@ -3,7 +3,9 @@ import gc
import typing import typing
from pathlib import Path from pathlib import Path
from source.manager import ModelManager import fastapi
from source.registry import ModelRegistry
class BaseModel(abc.ABC): class BaseModel(abc.ABC):
@ -11,21 +13,43 @@ class BaseModel(abc.ABC):
Represent a model. Represent a model.
""" """
def __init__(self, manager: ModelManager, configuration: dict[str, typing.Any], path: Path): def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path):
# the model registry
self.registry = registry
# get the documentation of the model
self.summary = configuration.get("summary")
self.description = configuration.get("description")
# the environment directory of the model # the environment directory of the model
self.path = path self.path = path
# the model manager
self.manager = manager
# the mimetype of the model responses # the mimetype of the model responses
self.output_type: str = configuration.get("output_type", "application/json") self.output_type: str = configuration.get("output_type", "application/json")
# get the tags of the model # get the tags of the model
self.tags = configuration.get("tags", []) self.tags = configuration.get("tags", [])
# get the selected interface of the model
interface_name: typing.Optional[str] = configuration.get("interface", None)
self.interface = (
self.registry.interface_registry.interface_types[interface_name](self)
if interface_name is not None else None
)
# is the model currently loaded
self._loaded = False self._loaded = False
def __repr__(self): def __repr__(self):
return f"<{self.__class__.__name__}: {self.name}>" return f"<{self.__class__.__name__}: {self.name}>"
@property
def api_base(self) -> str:
"""
Base for the API routes
:return: the base for the API routes
"""
return f"{self.registry.api_base}/{self.name}"
@property @property
def name(self): def name(self):
""" """
@ -44,6 +68,7 @@ class BaseModel(abc.ABC):
return { return {
"name": self.name, "name": self.name,
"output_type": self.output_type, "output_type": self.output_type,
"tags": self.tags
} }
def load(self) -> None: def load(self) -> None:
@ -51,22 +76,13 @@ class BaseModel(abc.ABC):
Load the model within the model manager Load the model within the model manager
""" """
# if we are already loaded, stop # if the model is already loaded, skip
if self._loaded: if self._loaded:
return return
# check if we are the current loaded model # load the model depending on the implementation
if self.manager.current_loaded_model is not self:
# unload the previous model
if self.manager.current_loaded_model is not None:
self.manager.current_loaded_model.unload()
# model specific loading
self._load() self._load()
# declare ourselves as the currently loaded model
self.manager.current_loaded_model = self
# mark the model as loaded # mark the model as loaded
self._loaded = True self._loaded = True
@ -86,11 +102,7 @@ class BaseModel(abc.ABC):
if not self._loaded: if not self._loaded:
return return
# if we were the currently loaded model of the manager, demote ourselves # unload the model depending on the implementation
if self.manager.current_loaded_model is self:
self.manager.current_loaded_model = None
# model specific unloading part
self._unload() self._unload()
# force the garbage collector to clean the memory # force the garbage collector to clean the memory
@ -106,22 +118,42 @@ class BaseModel(abc.ABC):
Do not call manually, use `unload` instead. Do not call manually, use `unload` instead.
""" """
async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]: async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
""" """
Infer our payload through the model within the model manager Infer our payload through the model within the model manager
:return: the response of the model :return: the response of the model
""" """
async with self.manager.inference_lock:
# make sure we are loaded before an inference # make sure we are loaded before an inference
self.load() self.load()
# model specific inference part # model specific inference part
return self._infer(**kwargs) return await self._infer(**kwargs)
@abc.abstractmethod @abc.abstractmethod
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]: async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
""" """
Infer our payload through the model Infer our payload through the model
:return: the response of the model :return: the response of the model
""" """
def mount(self, application: fastapi.FastAPI) -> None:
"""
Add the model to the api
:param application: the fastapi application
"""
# mount the interface if selected
if self.interface is not None:
self.interface.mount(application)
# implementation specific mount
self._mount(application)
@abc.abstractmethod
def _mount(self, application: fastapi.FastAPI) -> None:
"""
Add the model to the api
Do not call manually, use `unload` instead.
:param application: the fastapi application
"""

View file

@ -0,0 +1,16 @@
import typing
from source.api.interface import base
class InterfaceRegistry:
"""
The interface registry
Store the list of other interface available
"""
def __init__(self):
self.interface_types: dict[str, typing.Type[base.BaseInterface]] = {}
def register_type(self, name: str, interface_type: typing.Type[base.BaseInterface]):
self.interface_types[name] = interface_type

View file

@ -7,64 +7,80 @@ from pathlib import Path
import fastapi import fastapi
from source import model, api from source.model.base import BaseModel
from source.registry import InterfaceRegistry
class ModelManager: class ModelRegistry:
""" """
The model manager The model registry
Load the list of models available, ensure that only one model is loaded at the same time. Load the list of models available, ensure that only one model is loaded at the same time.
""" """
def __init__(self, application: api.Application, model_library: os.PathLike | str): def __init__(self, model_library: os.PathLike | str, api_base: str, interface_registry: InterfaceRegistry):
self.application: api.Application = application
self.model_library: Path = Path(model_library) self.model_library: Path = Path(model_library)
self.interface_registry = interface_registry
self._api_base = api_base
# the model types # the model types
self.model_types: dict[str, typing.Type[model.base.BaseModel]] = {} self.model_types: dict[str, typing.Type[BaseModel]] = {}
# the models # the models
self.models: dict[str, model.base.BaseModel] = {} self.models: dict[str, BaseModel] = {}
# the currently loaded model # the currently loaded model
# TODO(Faraphel): load more than one model at a time ? # TODO(Faraphel): load more than one model at a time ?
# would require a way more complex manager to handle memory issue # would require a way more complex manager to handle memory issue
# having two calculations at the same time might not be worth it either # having two calculations at the same time might not be worth it either
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None self.current_loaded_model: typing.Optional[BaseModel] = None
# lock to avoid concurrent inference and concurrent model loading and unloading # lock to avoid concurrent inference and concurrent model loading and unloading
self.inference_lock = asyncio.Lock() self.inference_lock = asyncio.Lock()
@self.application.get("/models") @property
async def get_models() -> list[str]: def api_base(self) -> str:
""" """
Get the list of models available Base for the api routes
:return: the list of models available :return: the base for the api routes
""" """
# list the models found return self._api_base
return list(self.models.keys())
@self.application.get("/models/{model_name}") def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
async def get_model(model_name: str) -> dict:
"""
Get information about a specific model
:param model_name: the name of the model
:return: the information about the corresponding model
"""
# get the corresponding model
model = self.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# return the model information
return model.get_information()
def register_model_type(self, name: str, model_type: "typing.Type[model.base.BaseModel]"):
self.model_types[name] = model_type self.model_types[name] = model_type
def reload(self): async def load_model(self, model: "BaseModel"):
# lock to avoid concurrent loading
async with self.inference_lock:
# if there is another currently loaded model, unload it
if self.current_loaded_model is not None and self.current_loaded_model is not model:
await self.unload_model(self.current_loaded_model)
# load the model
model.load()
# mark the model as the currently loaded model of the manager
self.current_loaded_model = model
async def unload_model(self, model: "BaseModel"):
# lock to avoid concurrent unloading
async with self.inference_lock:
# if we were the currently loaded model of the manager, demote ourselves
if self.current_loaded_model is model:
self.current_loaded_model = None
# model specific unloading part
model.unload()
async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]:
# lock to avoid concurrent inference
async with self.inference_lock:
return await model.infer(**kwargs)
def reload_models(self) -> None:
"""
Reload the list of models available
"""
# reset the model list # reset the model list
for model in self.models.values(): for model in self.models.values():
model.unload() model.unload()
@ -97,3 +113,39 @@ class ModelManager:
# load the model # load the model
self.models[model_name] = model_type(self, model_configuration, model_path) self.models[model_name] = model_type(self, model_configuration, model_path)
def mount(self, application: fastapi.FastAPI) -> None:
"""
Mount the models endpoints into a fastapi application
:param application: the fastapi application
"""
@application.get(self.api_base)
async def get_models() -> list[str]:
"""
Get the list of models available
:return: the list of models available
"""
# list the models found
return list(self.models.keys())
@application.get(f"{self.api_base}/{{model_name}}")
async def get_model(model_name: str) -> dict:
"""
Get information about a specific model
:param model_name: the name of the model
:return: the information about the corresponding model
"""
# get the corresponding model
model = self.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# return the model information
return model.get_information()
# mount all the models in the registry
for model_name, model in self.models.items():
model.mount(application)

View file

@ -0,0 +1,2 @@
from .ModelRegistry import ModelRegistry
from .InterfaceRegistry import InterfaceRegistry

View file

@ -10,12 +10,14 @@ types: dict[str, type] = {
"float": float, "float": float,
"str": str, "str": str,
"bytes": bytes, "bytes": bytes,
"list": list,
"tuple": tuple,
"set": set,
"dict": dict, "dict": dict,
"datetime": datetime, "datetime": datetime,
"file": UploadFile, "file": UploadFile,
# TODO(Faraphel): use a "ParameterRegistry" or other functions to handle complex type ?
"list[dict]": list[dict],
# "tuple": tuple,
# "set": set,
} }