diff --git a/.gitignore b/.gitignore index 6863c6e..e4f66b0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ .idea/ # Environment -venv/ +env/ diff --git a/Dockerfile b/Dockerfile index 582bda9..8d55a01 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,17 +1,17 @@ -FROM python:3.12 +FROM continuumio/miniconda3 # copy the application WORKDIR /app COPY ./ ./ # install the dependencies -RUN pip3 install -r ./requirements.txt +RUN conda env create -f environment.yml # expose the API port EXPOSE 8000 # environment variables -ENV MODEL_DIRECTORY=/models/ +ENV MODEL_LIBRARY=/models/ # run the server CMD ["python3", "-m", "source"] diff --git a/README.md b/README.md index 4ad7096..ba515be 100644 --- a/README.md +++ b/README.md @@ -2,14 +2,108 @@ A server that can serve AI models with an API and an authentication system -# Usage +## Usage -## Docker +The ai-server project require a conda environement. You can use Docker to deploy it easily. +### Docker +You can easily use docker-compose to run the project. +Simply go into the project directory and run : -# Environment Variables +```bash +docker compose run -d +``` + +#### Environment Variables + +The project use special environement variables : | Name | Description | |-----------------|-------------------------------------------| | MODEL_DIRECTORY | the directory where the models are stored | + +#### Volumes + +The project might store data inside multiples volumes : + +| Path | Type | Required | Description | +|------|------|----------|-------------| +| /models/ | Volume | True | The place where the models are stored | +| /root/.huggingface/hub | Bind | False | The place where the internal models are cached. Avoid redownloading huge amount of data at every inference | + +## Models + +A model is an object that can be loaded and do inference. + +It is stored inside a directory and must always contain a `config.json` file. + +### Configuration + +This is a json-structured file with basic information about the model. + +It describe : +- its type (see below) +- its tags +- its interface +- its inputs +- its output mimetype + +And other properties depending on the model type. + +#### Types + +There is for now only a single type of Model : the Python model + +##### Python Model + +A python model is isolated in a `conda` environement. + +To be considered a Python model, you need theses three files : + +| File | Description | +|-------------|-------------------------------| +| config.json | The configuration file | +| env | The conda virtual environment | +| model.py | The model file | + +###### Configuration + +Additionnal fields might be found in the configuration : + +... + +###### Virtual Environment + +You can create a conda virtual environement it with : + +`conda create --prefix ./env/ python=` + +You can install your requirements inside with `pip3` or `conda`. + +###### Internal Model + +You need to create a `model.py` file containing a class named `Model`. +This class must be exposed with the Pyro4 library for inter-operability with our main environement. + +Here is an example of an internal model : + +```python +import Pyro4 + +import torch +import transformers + + +@Pyro4.expose +class Model: + def __init__(self): + # load the model + self.pipe = transformers.pipeline(...) + + def infer(self, messages: list[dict]) -> bytes: + with torch.no_grad(): + outputs = self.pipe(messages) + + return outputs[0]["generated_text"][-1] +``` diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..7aa0823 --- /dev/null +++ b/environment.yml @@ -0,0 +1,14 @@ +channels: + - defaults +dependencies: + - python=3.12 + - numpy + - pillow + - pip + - pip: + - fastapi + - uvicorn + - pydantic + - gradio + - python-multipart + - Pyro5 diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 00110d9..0000000 --- a/requirements.txt +++ /dev/null @@ -1,13 +0,0 @@ -# web -fastapi -uvicorn -pydantic -gradio -python-multipart - -# data manipulation -pillow -numpy - -# AI -accelerate diff --git a/samples/models/dummy/config.json b/samples/models/dummy/config.json deleted file mode 100644 index bba7ae4..0000000 --- a/samples/models/dummy/config.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "type": "python", - "tags": ["dummy"], - "file": "model.py", - "interface": "chat", - - "summary": "Echo model", - "description": "The most basic example model, simply echo the input", - - "inputs": { - "messages": {"type": "list[dict]", "default": "[{\"role\": \"user\", \"content\": \"who are you ?\"}]"} - }, - - "output_type": "text/markdown" -} diff --git a/samples/models/dummy/model.py b/samples/models/dummy/model.py deleted file mode 100644 index f41489c..0000000 --- a/samples/models/dummy/model.py +++ /dev/null @@ -1,6 +0,0 @@ -import typing - - -class Model: - async def infer(self, messages: list[dict]) -> typing.AsyncIterator[bytes]: - yield messages[-1]["content"].encode("utf-8") diff --git a/samples/models/python-bert-1/config.json b/samples/models/python-bert-1/config.json deleted file mode 100644 index 5177187..0000000 --- a/samples/models/python-bert-1/config.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "type": "python", - "file": "model.py", - - "inputs": { - "prompt": {"type": "str"} - }, - - "requirements": [ - "transformers", - "torch", - "torchvision", - "torchaudio" - ] -} diff --git a/samples/models/python-bert-1/model.py b/samples/models/python-bert-1/model.py deleted file mode 100644 index 5d5d225..0000000 --- a/samples/models/python-bert-1/model.py +++ /dev/null @@ -1,25 +0,0 @@ -import json -import typing - -import torch -import transformers - - -class Model: - NAME: str = "huawei-noah/TinyBERT_General_4L_312D" - - def __init__(self) -> None: - self.model = transformers.AutoModel.from_pretrained(self.NAME) - self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME) - - async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]: - inputs = self.tokenizer(prompt, return_tensors="pt") - - with torch.no_grad(): - outputs = self.model(**inputs) - - embeddings = outputs.last_hidden_state - - yield json.dumps({ - "data": embeddings.tolist() - }).encode("utf-8") diff --git a/samples/models/python-bert-2/config.json b/samples/models/python-bert-2/config.json deleted file mode 100644 index af6ab11..0000000 --- a/samples/models/python-bert-2/config.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "type": "python", - "file": "model.py" -} diff --git a/samples/models/python-bert-2/model.py b/samples/models/python-bert-2/model.py deleted file mode 100644 index 5d5d225..0000000 --- a/samples/models/python-bert-2/model.py +++ /dev/null @@ -1,25 +0,0 @@ -import json -import typing - -import torch -import transformers - - -class Model: - NAME: str = "huawei-noah/TinyBERT_General_4L_312D" - - def __init__(self) -> None: - self.model = transformers.AutoModel.from_pretrained(self.NAME) - self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME) - - async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]: - inputs = self.tokenizer(prompt, return_tensors="pt") - - with torch.no_grad(): - outputs = self.model(**inputs) - - embeddings = outputs.last_hidden_state - - yield json.dumps({ - "data": embeddings.tolist() - }).encode("utf-8") diff --git a/source/__main__.py b/source/__main__.py index 9253d56..7a0d881 100644 --- a/source/__main__.py +++ b/source/__main__.py @@ -3,6 +3,7 @@ import os from source import registry, model, api from source.api import interface + # create a fastapi application application = api.Application() diff --git a/source/api/interface/ChatInterface.py b/source/api/interface/ChatInterface.py index a608bc5..e3b30ee 100644 --- a/source/api/interface/ChatInterface.py +++ b/source/api/interface/ChatInterface.py @@ -30,19 +30,19 @@ class ChatInterface(base.BaseInterface): messages.insert(0, {"role": "system", "content": system_message}) # add the user message - # NOTE: gradio.ChatInterface add our message and the assistant message - # TODO(Faraphel): add support for files + # NOTE: gradio.ChatInterface add our message and the assistant message automatically + # TODO(Faraphel): add support for files - directory use user_message ? apparently, field "image" is supported. + # check "https://huggingface.co/docs/transformers/main_classes/pipelines" at "ImageTextToTextPipeline" + + # TODO(Faraphel): add a "MultimodalChatInterface" to support images messages.append({ "role": "user", "content": user_message["text"], }) # infer the message through the model - chunks = [chunk async for chunk in await self.model.infer(messages=messages)] - assistant_message: str = b"".join(chunks).decode("utf-8") - - # send back the messages, clear the user prompt, disable the system prompt - return assistant_message + async for chunk in self.model.infer(messages=messages): + yield chunk.decode("utf-8") def get_application(self): # create a gradio interface @@ -65,7 +65,7 @@ class ChatInterface(base.BaseInterface): gradio.ChatInterface( fn=self.send_message, type="messages", - multimodal=True, + multimodal=False, # TODO(Faraphel): should handle at least image and text files editable=True, save_history=True, additional_inputs=[system_prompt], diff --git a/source/model/PythonModel.py b/source/model/PythonModel.py index d5d8853..c80095b 100644 --- a/source/model/PythonModel.py +++ b/source/model/PythonModel.py @@ -1,18 +1,22 @@ -import importlib.util import subprocess -import sys +import tempfile +import time import typing -import uuid -import inspect +import textwrap +import os +import signal from pathlib import Path -import fastapi +import Pyro5 +import Pyro5.api -from source import utils from source.model import base from source.registry import ModelRegistry from source.utils.fastapi import UploadFileFix +# enable serpent to represent bytes directly +Pyro5.config.SERPENT_BYTES_REPR = True + class PythonModel(base.BaseModel): """ @@ -22,85 +26,114 @@ class PythonModel(base.BaseModel): def __init__(self, registry: ModelRegistry, configuration: dict, path: Path): super().__init__(registry, configuration, path) - # get the parameters of the model - self.parameters = utils.parameters.load(configuration.get("inputs", {})) + # get the environment + self.environment = self.path / "env" + if not self.environment.exists(): + raise Exception("The model is missing an environment") - # install custom requirements - requirements = configuration.get("requirements", []) - if len(requirements) > 0: - subprocess.run([sys.executable, "-m", "pip", "install", *requirements]) - - # get the name of the file containing the model code - file = configuration.get("file") - if file is None: - raise ValueError("Field 'file' is missing from the configuration") - - # create the module specification - module_spec = importlib.util.spec_from_file_location( - f"model-{uuid.uuid4()}", - self.path / file - ) - # get the module - module = importlib.util.module_from_spec(module_spec) - # load the module - module_spec.loader.exec_module(module) - - # create the internal model from the class defined in the module - self._model_type = module.Model - self._model: typing.Optional[module.Model] = None + # prepare the process that will hold the environment python interpreter + self._storage: typing.Optional[tempfile.TemporaryDirectory] + self._process: typing.Optional[subprocess.Popen] + self._model: typing.Optional[Pyro5.api.Proxy] async def _load(self) -> None: - self._model = self._model_type() + # create a temporary space for the unix socket + self._storage: tempfile.TemporaryDirectory = tempfile.TemporaryDirectory() + socket_file = Path(self._storage.name) / 'socket.unix' + + # create a process inside the conda environment + self._process = subprocess.Popen( + [ + "conda", "run", # run a command within conda + "--prefix", self.environment.relative_to(self.path), # use the model environment + "python3", "-c", # run a python command + + textwrap.dedent(f""" + # make sure that Pyro5 is installed for communication + import sys + import subprocess + subprocess.run(["python3", "-m", "pip", "install", "Pyro5"]) + + import os + import Pyro5 + import Pyro5.api + import model + + # allow Pyro5 to return bytes objects directly + Pyro5.config.SERPENT_BYTES_REPR = True + + # helper to check if a process is still alive + def is_pid_alive(pid: int) -> bool: + try: + # do nothing if the process is alive, raise an exception if it does not exists + os.kill(pid, 0) + except OSError: + return False + else: + return True + + # create a pyro daemon + daemon = Pyro5.api.Daemon(unixsocket={str(socket_file)!r}) + # export our model through it + daemon.register(Pyro5.api.expose(model.Model), objectId="model") + # handle requests + # stop the process if the manager is no longer alive + daemon.requestLoop(lambda: is_pid_alive({os.getpid()})) + """) + ], + + cwd=self.path, # use the model directory as the working directory + start_new_session=True, # put the process in a new group to avoid killing ourselves when we unload the process + ) + + # wait for the process to be initialized properly + while True: + # check if the process is still alive + if self._process.poll() is not None: + # if the process stopped, raise an error (it shall stay alive until the unloading) + raise Exception("Could not load the model.") + + # if the socket file have been created, the program is running successfully + if socket_file.exists(): + break + + time.sleep(0.5) + + # get the proxy model object from the environment + self._model = Pyro5.api.Proxy(f"PYRO:model@./u:{socket_file}") + async def _unload(self) -> None: - self._model = None + # clear the proxy object + self._model._pyroRelease() # NOQA + del self._model + # stop the environment process + os.killpg(os.getpgid(self._process.pid), signal.SIGTERM) + self._process.wait() + del self._process + # clear the storage + self._storage.cleanup() + del self._storage async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: - return self._model.infer(**kwargs) + # Pyro5 is not capable of receiving an "UploadFile" object, so save it to a file and send the path instead + with tempfile.TemporaryDirectory() as working_directory: + for key, value in kwargs.items(): + # check if this the argument is a file + if not isinstance(value, UploadFileFix): + continue - def _mount(self, application: fastapi.FastAPI): - # TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ? + # copy the uploaded file to our working directory + path = Path(working_directory) / value.filename + with open(path, "wb") as file: + while content := await value.read(1024*1024): + file.write(content) - # create an endpoint wrapping the inference inside a fastapi call - async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse: - # NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse - # NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own - # fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation - # curiosity that may change in the future - kwargs = { - key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value - for key, value in kwargs.items() - } + # replace the argument + kwargs[key] = str(path) - return fastapi.responses.StreamingResponse( - content=await self.infer(**kwargs), - media_type=self.output_type, - headers={ - # if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI - "content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment" - } - ) + # run the inference + for chunk in self._model.infer(**kwargs): + yield chunk - infer_api.__signature__ = inspect.Signature(parameters=self.parameters) - # format the description - description_sections: list[str] = [] - if self.description is not None: - description_sections.append(f"# Description\n{self.description}") - if self.interface is not None: - description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**") - - description: str = "\n".join(description_sections) - - # add the inference endpoint on the API - application.add_api_route( - f"{self.api_base}/infer", - infer_api, - methods=["POST"], - tags=self.tags, - summary=self.summary, - description=description, - response_class=fastapi.responses.StreamingResponse, - responses={ - 200: {"content": {self.output_type: {}}} - }, - ) +# TODO(Faraphel): if the FastAPI close, it seem like it wait for conda to finish (or the async tasks ?) diff --git a/source/model/base/BaseModel.py b/source/model/base/BaseModel.py index 3fc347e..6ff8ca2 100644 --- a/source/model/base/BaseModel.py +++ b/source/model/base/BaseModel.py @@ -1,12 +1,16 @@ import abc import asyncio import gc +import tempfile import typing from pathlib import Path import fastapi +import inspect +from source import utils from source.registry import ModelRegistry +from source.utils.fastapi import UploadFileFix class BaseModel(abc.ABC): @@ -24,6 +28,9 @@ class BaseModel(abc.ABC): # the environment directory of the model self.path = path + # get the parameters of the model + self.inputs = configuration.get("inputs", {}) + self.parameters = utils.parameters.load(self.inputs) # the mimetype of the model responses self.output_type: str = configuration.get("output_type", "application/json") # get the tags of the model @@ -71,8 +78,12 @@ class BaseModel(abc.ABC): return { "name": self.name, + "summary": self.summary, + "description": self.description, + "inputs": self.inputs, "output_type": self.output_type, - "tags": self.tags + "tags": self.tags, + "interface": self.interface, } async def load(self) -> None: @@ -145,7 +156,8 @@ class BaseModel(abc.ABC): await self.load() # model specific inference part - return await self._infer(**kwargs) + async for chunk in self._infer(**kwargs): + yield chunk @abc.abstractmethod async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]: @@ -164,13 +176,50 @@ class BaseModel(abc.ABC): if self.interface is not None: self.interface.mount(application) - # implementation specific mount - self._mount(application) + # create an endpoint wrapping the inference inside a fastapi call + # the arguments will be loaded from the configuration files. Use kwargs for the definition + async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse: + # NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse + # NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own + # fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation + # curiosity that may change in the future + kwargs = { + key: UploadFileFix(value) if type(value).__name__ == "UploadFile" else value + for key, value in kwargs.items() + } - @abc.abstractmethod - def _mount(self, application: fastapi.FastAPI) -> None: - """ - Add the model to the api - Do not call manually, use `unload` instead. - :param application: the fastapi application - """ + # return a streaming response around the inference call + return fastapi.responses.StreamingResponse( + content=self.infer(**kwargs), + media_type=self.output_type, + headers={ + # if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI + "content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment" + } + ) + + # update the signature of the function to use the configuration parameters + infer_api.__signature__ = inspect.Signature(parameters=self.parameters) + + # format the description + description_sections: list[str] = [] + if self.description is not None: + description_sections.append(f"# Description\n{self.description}") + if self.interface is not None: + description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**") + + description: str = "\n".join(description_sections) + + # add the inference endpoint on the API + application.add_api_route( + f"{self.api_base}/infer", + infer_api, + methods=["POST"], + tags=self.tags, + summary=self.summary, + description=description, + response_class=fastapi.responses.StreamingResponse, + responses={ + 200: {"content": {self.output_type: {}}} + }, + ) diff --git a/source/registry/ModelRegistry.py b/source/registry/ModelRegistry.py index ecba609..ff4cf6e 100644 --- a/source/registry/ModelRegistry.py +++ b/source/registry/ModelRegistry.py @@ -60,7 +60,12 @@ class ModelRegistry: # load all the models in the library for model_path in self.model_library.iterdir(): + # get the model name model_name: str = model_path.name + if model_name.startswith("."): + # ignore model starting with a dot + continue + model_configuration_path: Path = model_path / "config.json" # check if the configuration file exists @@ -68,8 +73,11 @@ class ModelRegistry: warnings.warn(f"Model {model_name!r} is missing a config.json file.") continue - # load the configuration file - model_configuration = json.loads(model_configuration_path.read_text()) + try: + # load the configuration file + model_configuration = json.loads(model_configuration_path.read_text()) + except json.decoder.JSONDecodeError: + raise Exception(f"Model {model_name!r}'s configuration is invalid. See above.") # get the model type for this model model_type_name: str = model_configuration.get("type") diff --git a/source/utils/parameters.py b/source/utils/parameters.py index cc8c41f..69d8fff 100644 --- a/source/utils/parameters.py +++ b/source/utils/parameters.py @@ -1,24 +1,5 @@ import inspect -from datetime import datetime - -from fastapi import UploadFile - -# the list of types and their name that can be used by the API -types: dict[str, type] = { - "bool": bool, - "int": int, - "float": float, - "str": str, - "bytes": bytes, - "dict": dict, - "datetime": datetime, - "file": UploadFile, - - # TODO(Faraphel): use a "ParameterRegistry" or other functions to handle complex type ? - "list[dict]": list[dict], - # "tuple": tuple, - # "set": set, -} +import fastapi def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]: @@ -31,7 +12,6 @@ def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]: >>> parameters_definition = { ... "boolean": {"type": "bool", "default": False}, ... "list": {"type": "list", "default": [1, 2, 3]}, - ... "datetime": {"type": "datetime"}, ... "file": {"type": "file"}, ... } >>> parameters = load_parameters(parameters_definition) @@ -40,12 +20,19 @@ def load(parameters_definition: dict[str, dict]) -> list[inspect.Parameter]: parameters: list[inspect.Parameter] = [] for name, definition in parameters_definition.items(): + # preprocess the type + match definition["type"]: + case "file": + # shortcut for uploading a file + definition["type"] = fastapi.UploadFile + + # deserialize the parameter parameter = inspect.Parameter( name, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=definition.get("default", inspect.Parameter.empty), - annotation=types[definition["type"]], + annotation=definition["type"], ) parameters.append(parameter)