added support for additional more user-friendly interfaces, improved some part of the application loading process to make it a bit simpler

This commit is contained in:
faraphel 2025-01-12 12:52:19 +01:00
parent 1a49aa3779
commit f647c960dd
20 changed files with 353 additions and 107 deletions

View file

@ -2,6 +2,7 @@
fastapi
uvicorn
pydantic
gradio
python-multipart
# AI

View file

@ -2,10 +2,14 @@
"type": "python",
"tags": ["dummy"],
"file": "model.py",
"interface": "chat",
"output_type": "video/mp4",
"summary": "Echo model",
"description": "The most basic example model, simply echo the input",
"inputs": {
"file": {"type": "file"}
}
"messages": {"type": "list[dict]", "default": "[{\"role\": \"user\", \"content\": \"who are you ?\"}]"}
},
"output_type": "text/markdown"
}

View file

@ -7,5 +7,5 @@ def load(model) -> None:
def unload(model) -> None:
pass
async def infer(model, file) -> typing.AsyncIterator[bytes]:
yield await file.read()
async def infer(model, messages: list[dict]) -> typing.AsyncIterator[bytes]:
yield messages[-1]["content"].encode("utf-8")

View file

@ -16,7 +16,7 @@ def unload(model) -> None:
model.model = None
model.tokenizer = None
def infer(model, prompt: str) -> typing.Iterator[bytes]:
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
inputs = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():

View file

@ -16,7 +16,7 @@ def unload(model) -> None:
model.model = None
model.tokenizer = None
def infer(model, prompt: str) -> typing.Iterator[bytes]:
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
inputs = model.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():

View file

@ -1,3 +1,3 @@
from . import api
from . import model
from . import manager
from . import registry

View file

@ -1,15 +1,25 @@
import os
from source import manager, model, api
from source import registry, model, api
from source.api import interface
# create a fastapi application
application = api.Application()
# create the model controller
model_controller = manager.ModelManager(application, os.environ["MODEL_LIBRARY"])
model_controller.register_model_type("python", model.PythonModel)
model_controller.reload()
# create the interface registry
interface_registry = registry.InterfaceRegistry()
interface_registry.register_type("chat", interface.ChatInterface)
# create the model registry
model_registry = registry.ModelRegistry(os.environ["MODEL_LIBRARY"], "/models", interface_registry)
model_registry.register_type("python", model.PythonModel)
model_registry.reload_models()
# add the model registry routes to the fastapi
model_registry.mount(application)
# serve the application
application.serve("0.0.0.0", 8000)

View file

@ -8,7 +8,8 @@ class Application(fastapi.FastAPI):
def __init__(self):
super().__init__(
title=meta.name,
description=meta.description
description=meta.description,
redoc_url=None,
)
def serve(self, host: str = "0.0.0.0", port: int = 8080):

View file

@ -1 +1,3 @@
from . import interface
from .Application import Application

View file

@ -0,0 +1,75 @@
import textwrap
import gradio
from source import meta
from source.api.interface import base
from source.model.base import BaseModel
class ChatInterface(base.BaseInterface):
"""
An interface for Chat-like models.
Use the OpenAI convention (list of dict with roles and content)
"""
def __init__(self, model: "BaseModel"):
# Function to send and receive chat messages
super().__init__(model)
async def send_message(self, user_message, old_messages: list[dict], system_message: str):
# normalize the user message (the type can be wrong, especially when "edited")
if isinstance(user_message, str):
user_message: dict = {"files": [], "text": user_message}
# copy the history to avoid modifying it
messages: list[dict] = old_messages.copy()
# add the system instruction
if system_message:
messages.insert(0, {"role": "system", "content": system_message})
# add the user message
# NOTE: gradio.ChatInterface add our message and the assistant message
# TODO(Faraphel): add support for files
messages.append({
"role": "user",
"content": user_message["text"],
})
# infer the message through the model
chunks = [chunk async for chunk in await self.model.infer(messages=messages)]
assistant_message: str = b"".join(chunks).decode("utf-8")
# send back the messages, clear the user prompt, disable the system prompt
return assistant_message
def get_gradio_application(self):
# create a gradio interface
with gradio.Blocks(analytics_enabled=False) as application:
# header
gradio.Markdown(textwrap.dedent(f"""
# {meta.name}
## {self.model.name}
"""))
# additional settings
with gradio.Accordion("Advanced Settings") as advanced_settings:
system_prompt = gradio.Textbox(
label="System prompt",
placeholder="You are an expert in C++...",
lines=2,
)
# chat interface
gradio.ChatInterface(
fn=self.send_message,
type="messages",
multimodal=True,
editable=True,
save_history=True,
additional_inputs=[system_prompt],
additional_inputs_accordion=advanced_settings,
)
return application

View file

@ -0,0 +1,3 @@
from . import base
from .ChatInterface import ChatInterface

View file

@ -0,0 +1,40 @@
import abc
import fastapi
import gradio
import source
class BaseInterface(abc.ABC):
def __init__(self, model: "source.model.base.BaseModel"):
self.model = model
@property
def route(self) -> str:
"""
The route to the interface
:return: the route to the interface
"""
return f"{self.model.api_base}/interface"
@abc.abstractmethod
def get_gradio_application(self) -> gradio.Blocks:
"""
Get a gradio application
:return: a gradio application
"""
def mount(self, application: fastapi.FastAPI) -> None:
"""
Mount the interface on an application
:param application: the application to mount the interface on
:param path: the path where to mount the application
"""
gradio.mount_gradio_app(
application,
self.get_gradio_application(),
self.route
)

View file

@ -0,0 +1 @@
from .BaseInterface import BaseInterface

View file

@ -1 +0,0 @@
from .ModelManager import ModelManager

View file

@ -9,8 +9,8 @@ from pathlib import Path
import fastapi
from source import utils
from source.manager import ModelManager
from source.model import base
from source.registry import ModelRegistry
from source.utils.fastapi import UploadFileFix
@ -19,21 +19,22 @@ class PythonModel(base.BaseModel):
A model running a custom python model.
"""
def __init__(self, manager: ModelManager, configuration: dict, path: Path):
super().__init__(manager, configuration, path)
def __init__(self, registry: ModelRegistry, configuration: dict, path: Path):
super().__init__(registry, configuration, path)
## Configuration
# get the name of the file containing the model code
file = configuration.get("file")
if file is None:
raise ValueError("Field 'file' is missing from the configuration")
# get the parameters of the model
self.parameters = utils.parameters.load(configuration.get("inputs", {}))
# install custom requirements
requirements = configuration.get("requirements", [])
if len(requirements) > 0:
subprocess.run([sys.executable, "-m", "pip", "install", *requirements])
# get the name of the file containing the model code
file = configuration.get("file")
if file is None:
raise ValueError("Field 'file' is missing from the configuration")
# create the module specification
module_spec = importlib.util.spec_from_file_location(
f"model-{uuid.uuid4()}",
@ -44,10 +45,17 @@ class PythonModel(base.BaseModel):
# load the module
module_spec.loader.exec_module(self.module)
## Api
def _load(self) -> None:
return self.module.load(self)
# load the inputs data into the inference function signature (used by FastAPI)
parameters = utils.parameters.load(configuration.get("inputs", {}))
def _unload(self) -> None:
return self.module.unload(self)
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
return self.module.infer(self, **kwargs)
def _mount(self, application: fastapi.FastAPI):
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
# create an endpoint wrapping the inference inside a fastapi call
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
@ -61,7 +69,7 @@ class PythonModel(base.BaseModel):
}
return fastapi.responses.StreamingResponse(
content=await self.infer(**kwargs),
content=await self.registry.infer_model(self, **kwargs),
media_type=self.output_type,
headers={
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
@ -69,27 +77,25 @@ class PythonModel(base.BaseModel):
}
)
infer_api.__signature__ = inspect.Signature(parameters=parameters)
infer_api.__signature__ = inspect.Signature(parameters=self.parameters)
# format the description
description_sections: list[str] = []
if self.description is not None:
description_sections.append(self.description)
if self.interface is not None:
description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**")
# add the inference endpoint on the API
self.manager.application.add_api_route(
f"/models/{self.name}/infer",
application.add_api_route(
f"{self.api_base}/infer",
infer_api,
methods=["POST"],
tags=self.tags,
# summary=...,
# description=...,
summary=self.summary,
description="<br>".join(description_sections),
response_class=fastapi.responses.StreamingResponse,
responses={
200: {"content": {self.output_type: {}}}
},
)
def _load(self) -> None:
return self.module.load(self)
def _unload(self) -> None:
return self.module.unload(self)
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]:
return self.module.infer(self, **kwargs)

View file

@ -3,7 +3,9 @@ import gc
import typing
from pathlib import Path
from source.manager import ModelManager
import fastapi
from source.registry import ModelRegistry
class BaseModel(abc.ABC):
@ -11,21 +13,43 @@ class BaseModel(abc.ABC):
Represent a model.
"""
def __init__(self, manager: ModelManager, configuration: dict[str, typing.Any], path: Path):
def __init__(self, registry: ModelRegistry, configuration: dict[str, typing.Any], path: Path):
# the model registry
self.registry = registry
# get the documentation of the model
self.summary = configuration.get("summary")
self.description = configuration.get("description")
# the environment directory of the model
self.path = path
# the model manager
self.manager = manager
# the mimetype of the model responses
self.output_type: str = configuration.get("output_type", "application/json")
# get the tags of the model
self.tags = configuration.get("tags", [])
# get the selected interface of the model
interface_name: typing.Optional[str] = configuration.get("interface", None)
self.interface = (
self.registry.interface_registry.interface_types[interface_name](self)
if interface_name is not None else None
)
# is the model currently loaded
self._loaded = False
def __repr__(self):
return f"<{self.__class__.__name__}: {self.name}>"
@property
def api_base(self) -> str:
"""
Base for the API routes
:return: the base for the API routes
"""
return f"{self.registry.api_base}/{self.name}"
@property
def name(self):
"""
@ -44,6 +68,7 @@ class BaseModel(abc.ABC):
return {
"name": self.name,
"output_type": self.output_type,
"tags": self.tags
}
def load(self) -> None:
@ -51,22 +76,13 @@ class BaseModel(abc.ABC):
Load the model within the model manager
"""
# if we are already loaded, stop
# if the model is already loaded, skip
if self._loaded:
return
# check if we are the current loaded model
if self.manager.current_loaded_model is not self:
# unload the previous model
if self.manager.current_loaded_model is not None:
self.manager.current_loaded_model.unload()
# model specific loading
# load the model depending on the implementation
self._load()
# declare ourselves as the currently loaded model
self.manager.current_loaded_model = self
# mark the model as loaded
self._loaded = True
@ -86,11 +102,7 @@ class BaseModel(abc.ABC):
if not self._loaded:
return
# if we were the currently loaded model of the manager, demote ourselves
if self.manager.current_loaded_model is self:
self.manager.current_loaded_model = None
# model specific unloading part
# unload the model depending on the implementation
self._unload()
# force the garbage collector to clean the memory
@ -106,22 +118,42 @@ class BaseModel(abc.ABC):
Do not call manually, use `unload` instead.
"""
async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
async def infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
"""
Infer our payload through the model within the model manager
:return: the response of the model
"""
async with self.manager.inference_lock:
# make sure we are loaded before an inference
self.load()
# model specific inference part
return self._infer(**kwargs)
return await self._infer(**kwargs)
@abc.abstractmethod
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
"""
Infer our payload through the model
:return: the response of the model
"""
def mount(self, application: fastapi.FastAPI) -> None:
"""
Add the model to the api
:param application: the fastapi application
"""
# mount the interface if selected
if self.interface is not None:
self.interface.mount(application)
# implementation specific mount
self._mount(application)
@abc.abstractmethod
def _mount(self, application: fastapi.FastAPI) -> None:
"""
Add the model to the api
Do not call manually, use `unload` instead.
:param application: the fastapi application
"""

View file

@ -0,0 +1,16 @@
import typing
from source.api.interface import base
class InterfaceRegistry:
"""
The interface registry
Store the list of other interface available
"""
def __init__(self):
self.interface_types: dict[str, typing.Type[base.BaseInterface]] = {}
def register_type(self, name: str, interface_type: typing.Type[base.BaseInterface]):
self.interface_types[name] = interface_type

View file

@ -7,64 +7,80 @@ from pathlib import Path
import fastapi
from source import model, api
from source.model.base import BaseModel
from source.registry import InterfaceRegistry
class ModelManager:
class ModelRegistry:
"""
The model manager
The model registry
Load the list of models available, ensure that only one model is loaded at the same time.
"""
def __init__(self, application: api.Application, model_library: os.PathLike | str):
self.application: api.Application = application
def __init__(self, model_library: os.PathLike | str, api_base: str, interface_registry: InterfaceRegistry):
self.model_library: Path = Path(model_library)
self.interface_registry = interface_registry
self._api_base = api_base
# the model types
self.model_types: dict[str, typing.Type[model.base.BaseModel]] = {}
self.model_types: dict[str, typing.Type[BaseModel]] = {}
# the models
self.models: dict[str, model.base.BaseModel] = {}
self.models: dict[str, BaseModel] = {}
# the currently loaded model
# TODO(Faraphel): load more than one model at a time ?
# would require a way more complex manager to handle memory issue
# having two calculations at the same time might not be worth it either
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
self.current_loaded_model: typing.Optional[BaseModel] = None
# lock to avoid concurrent inference and concurrent model loading and unloading
self.inference_lock = asyncio.Lock()
@self.application.get("/models")
async def get_models() -> list[str]:
@property
def api_base(self) -> str:
"""
Get the list of models available
:return: the list of models available
Base for the api routes
:return: the base for the api routes
"""
# list the models found
return list(self.models.keys())
return self._api_base
@self.application.get("/models/{model_name}")
async def get_model(model_name: str) -> dict:
"""
Get information about a specific model
:param model_name: the name of the model
:return: the information about the corresponding model
"""
# get the corresponding model
model = self.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# return the model information
return model.get_information()
def register_model_type(self, name: str, model_type: "typing.Type[model.base.BaseModel]"):
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
self.model_types[name] = model_type
def reload(self):
async def load_model(self, model: "BaseModel"):
# lock to avoid concurrent loading
async with self.inference_lock:
# if there is another currently loaded model, unload it
if self.current_loaded_model is not None and self.current_loaded_model is not model:
await self.unload_model(self.current_loaded_model)
# load the model
model.load()
# mark the model as the currently loaded model of the manager
self.current_loaded_model = model
async def unload_model(self, model: "BaseModel"):
# lock to avoid concurrent unloading
async with self.inference_lock:
# if we were the currently loaded model of the manager, demote ourselves
if self.current_loaded_model is model:
self.current_loaded_model = None
# model specific unloading part
model.unload()
async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]:
# lock to avoid concurrent inference
async with self.inference_lock:
return await model.infer(**kwargs)
def reload_models(self) -> None:
"""
Reload the list of models available
"""
# reset the model list
for model in self.models.values():
model.unload()
@ -97,3 +113,39 @@ class ModelManager:
# load the model
self.models[model_name] = model_type(self, model_configuration, model_path)
def mount(self, application: fastapi.FastAPI) -> None:
"""
Mount the models endpoints into a fastapi application
:param application: the fastapi application
"""
@application.get(self.api_base)
async def get_models() -> list[str]:
"""
Get the list of models available
:return: the list of models available
"""
# list the models found
return list(self.models.keys())
@application.get(f"{self.api_base}/{{model_name}}")
async def get_model(model_name: str) -> dict:
"""
Get information about a specific model
:param model_name: the name of the model
:return: the information about the corresponding model
"""
# get the corresponding model
model = self.models.get(model_name)
if model is None:
raise fastapi.HTTPException(status_code=404, detail="Model not found")
# return the model information
return model.get_information()
# mount all the models in the registry
for model_name, model in self.models.items():
model.mount(application)

View file

@ -0,0 +1,2 @@
from .ModelRegistry import ModelRegistry
from .InterfaceRegistry import InterfaceRegistry

View file

@ -10,12 +10,14 @@ types: dict[str, type] = {
"float": float,
"str": str,
"bytes": bytes,
"list": list,
"tuple": tuple,
"set": set,
"dict": dict,
"datetime": datetime,
"file": UploadFile,
# TODO(Faraphel): use a "ParameterRegistry" or other functions to handle complex type ?
"list[dict]": list[dict],
# "tuple": tuple,
# "set": set,
}