mirror of
https://git.isriupjv.fr/ISRI/ai-server
synced 2025-04-24 18:18:11 +02:00
fixed the life cycle of the models (they couldn't unload anymore) and simplified the implementation of the Python models
This commit is contained in:
parent
f647c960dd
commit
8bf28e4c48
9 changed files with 96 additions and 111 deletions
|
@ -5,5 +5,9 @@ pydantic
|
||||||
gradio
|
gradio
|
||||||
python-multipart
|
python-multipart
|
||||||
|
|
||||||
|
# data manipulation
|
||||||
|
pillow
|
||||||
|
numpy
|
||||||
|
|
||||||
# AI
|
# AI
|
||||||
accelerate
|
accelerate
|
||||||
|
|
|
@ -1,11 +1,6 @@
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
|
||||||
def load(model) -> None:
|
class Model:
|
||||||
pass
|
async def infer(self, messages: list[dict]) -> typing.AsyncIterator[bytes]:
|
||||||
|
yield messages[-1]["content"].encode("utf-8")
|
||||||
def unload(model) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def infer(model, messages: list[dict]) -> typing.AsyncIterator[bytes]:
|
|
||||||
yield messages[-1]["content"].encode("utf-8")
|
|
||||||
|
|
|
@ -5,25 +5,21 @@ import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
class Model:
|
||||||
|
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.model = transformers.AutoModel.from_pretrained(self.NAME)
|
||||||
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
|
||||||
|
|
||||||
def load(model) -> None:
|
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||||
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
|
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||||
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
||||||
|
|
||||||
def unload(model) -> None:
|
with torch.no_grad():
|
||||||
model.model = None
|
outputs = self.model(**inputs)
|
||||||
model.tokenizer = None
|
|
||||||
|
|
||||||
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
|
embeddings = outputs.last_hidden_state
|
||||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
yield json.dumps({
|
||||||
outputs = model.model(**inputs)
|
"data": embeddings.tolist()
|
||||||
|
}).encode("utf-8")
|
||||||
embeddings = outputs.last_hidden_state
|
|
||||||
|
|
||||||
yield json.dumps({
|
|
||||||
"data": embeddings.tolist()
|
|
||||||
}).encode("utf-8")
|
|
||||||
|
|
|
@ -5,25 +5,21 @@ import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
MODEL_NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
class Model:
|
||||||
|
NAME: str = "huawei-noah/TinyBERT_General_4L_312D"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.model = transformers.AutoModel.from_pretrained(self.NAME)
|
||||||
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.NAME)
|
||||||
|
|
||||||
def load(model) -> None:
|
async def infer(self, prompt: str) -> typing.AsyncIterator[bytes]:
|
||||||
model.model = transformers.AutoModel.from_pretrained(MODEL_NAME)
|
inputs = self.tokenizer(prompt, return_tensors="pt")
|
||||||
model.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
||||||
|
|
||||||
def unload(model) -> None:
|
with torch.no_grad():
|
||||||
model.model = None
|
outputs = self.model(**inputs)
|
||||||
model.tokenizer = None
|
|
||||||
|
|
||||||
async def infer(model, prompt: str) -> typing.AsyncIterator[bytes]:
|
embeddings = outputs.last_hidden_state
|
||||||
inputs = model.tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
yield json.dumps({
|
||||||
outputs = model.model(**inputs)
|
"data": embeddings.tolist()
|
||||||
|
}).encode("utf-8")
|
||||||
embeddings = outputs.last_hidden_state
|
|
||||||
|
|
||||||
yield json.dumps({
|
|
||||||
"data": embeddings.tolist()
|
|
||||||
}).encode("utf-8")
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ class ChatInterface(base.BaseInterface):
|
||||||
# send back the messages, clear the user prompt, disable the system prompt
|
# send back the messages, clear the user prompt, disable the system prompt
|
||||||
return assistant_message
|
return assistant_message
|
||||||
|
|
||||||
def get_gradio_application(self):
|
def get_application(self):
|
||||||
# create a gradio interface
|
# create a gradio interface
|
||||||
with gradio.Blocks(analytics_enabled=False) as application:
|
with gradio.Blocks(analytics_enabled=False) as application:
|
||||||
# header
|
# header
|
||||||
|
|
|
@ -7,7 +7,7 @@ import source
|
||||||
|
|
||||||
|
|
||||||
class BaseInterface(abc.ABC):
|
class BaseInterface(abc.ABC):
|
||||||
def __init__(self, model: "source.model.base.BaseModel"):
|
def __init__(self, model: "source._model.base.BaseModel"):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -20,7 +20,7 @@ class BaseInterface(abc.ABC):
|
||||||
return f"{self.model.api_base}/interface"
|
return f"{self.model.api_base}/interface"
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_gradio_application(self) -> gradio.Blocks:
|
def get_application(self) -> gradio.Blocks:
|
||||||
"""
|
"""
|
||||||
Get a gradio application
|
Get a gradio application
|
||||||
:return: a gradio application
|
:return: a gradio application
|
||||||
|
@ -35,6 +35,6 @@ class BaseInterface(abc.ABC):
|
||||||
|
|
||||||
gradio.mount_gradio_app(
|
gradio.mount_gradio_app(
|
||||||
application,
|
application,
|
||||||
self.get_gradio_application(),
|
self.get_application(),
|
||||||
self.route
|
self.route
|
||||||
)
|
)
|
||||||
|
|
|
@ -41,18 +41,21 @@ class PythonModel(base.BaseModel):
|
||||||
self.path / file
|
self.path / file
|
||||||
)
|
)
|
||||||
# get the module
|
# get the module
|
||||||
self.module = importlib.util.module_from_spec(module_spec)
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
# load the module
|
# load the module
|
||||||
module_spec.loader.exec_module(self.module)
|
module_spec.loader.exec_module(module)
|
||||||
|
|
||||||
def _load(self) -> None:
|
# create the internal model from the class defined in the module
|
||||||
return self.module.load(self)
|
self._model_type = module.Model
|
||||||
|
self._model: typing.Optional[module.Model] = None
|
||||||
|
|
||||||
def _unload(self) -> None:
|
async def _load(self) -> None:
|
||||||
return self.module.unload(self)
|
self._model = self._model_type()
|
||||||
|
async def _unload(self) -> None:
|
||||||
|
self._model = None
|
||||||
|
|
||||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||||
return self.module.infer(self, **kwargs)
|
return self._model.infer(**kwargs)
|
||||||
|
|
||||||
def _mount(self, application: fastapi.FastAPI):
|
def _mount(self, application: fastapi.FastAPI):
|
||||||
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
|
# TODO(Faraphel): should this be done directly in the BaseModel ? How to handle the inputs then ?
|
||||||
|
@ -69,7 +72,7 @@ class PythonModel(base.BaseModel):
|
||||||
}
|
}
|
||||||
|
|
||||||
return fastapi.responses.StreamingResponse(
|
return fastapi.responses.StreamingResponse(
|
||||||
content=await self.registry.infer_model(self, **kwargs),
|
content=await self.infer(**kwargs),
|
||||||
media_type=self.output_type,
|
media_type=self.output_type,
|
||||||
headers={
|
headers={
|
||||||
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
|
||||||
|
@ -82,9 +85,11 @@ class PythonModel(base.BaseModel):
|
||||||
# format the description
|
# format the description
|
||||||
description_sections: list[str] = []
|
description_sections: list[str] = []
|
||||||
if self.description is not None:
|
if self.description is not None:
|
||||||
description_sections.append(self.description)
|
description_sections.append(f"# Description\n{self.description}")
|
||||||
if self.interface is not None:
|
if self.interface is not None:
|
||||||
description_sections.append(f"**[Open Dedicated Interface]({self.interface.route})**")
|
description_sections.append(f"# Interface\n**[Open Dedicated Interface]({self.interface.route})**")
|
||||||
|
|
||||||
|
description: str = "\n".join(description_sections)
|
||||||
|
|
||||||
# add the inference endpoint on the API
|
# add the inference endpoint on the API
|
||||||
application.add_api_route(
|
application.add_api_route(
|
||||||
|
@ -93,7 +98,7 @@ class PythonModel(base.BaseModel):
|
||||||
methods=["POST"],
|
methods=["POST"],
|
||||||
tags=self.tags,
|
tags=self.tags,
|
||||||
summary=self.summary,
|
summary=self.summary,
|
||||||
description="<br>".join(description_sections),
|
description=description,
|
||||||
response_class=fastapi.responses.StreamingResponse,
|
response_class=fastapi.responses.StreamingResponse,
|
||||||
responses={
|
responses={
|
||||||
200: {"content": {self.output_type: {}}}
|
200: {"content": {self.output_type: {}}}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import abc
|
import abc
|
||||||
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import typing
|
import typing
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -38,6 +39,9 @@ class BaseModel(abc.ABC):
|
||||||
# is the model currently loaded
|
# is the model currently loaded
|
||||||
self._loaded = False
|
self._loaded = False
|
||||||
|
|
||||||
|
# lock to avoid loading and unloading at the same time
|
||||||
|
self.load_lock = asyncio.Lock()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<{self.__class__.__name__}: {self.name}>"
|
return f"<{self.__class__.__name__}: {self.name}>"
|
||||||
|
|
||||||
|
@ -71,48 +75,60 @@ class BaseModel(abc.ABC):
|
||||||
"tags": self.tags
|
"tags": self.tags
|
||||||
}
|
}
|
||||||
|
|
||||||
def load(self) -> None:
|
async def load(self) -> None:
|
||||||
"""
|
"""
|
||||||
Load the model within the model manager
|
Load the model within the model manager
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# if the model is already loaded, skip
|
async with self.load_lock:
|
||||||
if self._loaded:
|
# if the model is already loaded, skip
|
||||||
return
|
if self._loaded:
|
||||||
|
return
|
||||||
|
|
||||||
# load the model depending on the implementation
|
# unload the currently loaded model if any
|
||||||
self._load()
|
if self.registry.current_loaded_model is not None:
|
||||||
|
await self.registry.current_loaded_model.unload()
|
||||||
|
|
||||||
# mark the model as loaded
|
# load the model depending on the implementation
|
||||||
self._loaded = True
|
await self._load()
|
||||||
|
|
||||||
|
# mark the model as loaded
|
||||||
|
self._loaded = True
|
||||||
|
# mark the model as the registry loaded model
|
||||||
|
self.registry.current_loaded_model = self
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _load(self):
|
async def _load(self):
|
||||||
"""
|
"""
|
||||||
Load the model
|
Load the model
|
||||||
Do not call manually, use `load` instead.
|
Do not call manually, use `load` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def unload(self) -> None:
|
async def unload(self) -> None:
|
||||||
"""
|
"""
|
||||||
Unload the model within the model manager
|
Unload the model within the model manager
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# if we are not already loaded, stop
|
async with self.load_lock:
|
||||||
if not self._loaded:
|
# if we are not already loaded, stop
|
||||||
return
|
if not self._loaded:
|
||||||
|
return
|
||||||
|
|
||||||
# unload the model depending on the implementation
|
# unload the model depending on the implementation
|
||||||
self._unload()
|
await self._unload()
|
||||||
|
|
||||||
# force the garbage collector to clean the memory
|
# force the garbage collector to clean the memory
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# mark the model as unloaded
|
# mark the model as unloaded
|
||||||
self._loaded = False
|
self._loaded = False
|
||||||
|
|
||||||
|
# if we are the registry current loaded model, remove this status
|
||||||
|
if self.registry.current_loaded_model is self:
|
||||||
|
self.registry.current_loaded_model = None
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _unload(self):
|
async def _unload(self):
|
||||||
"""
|
"""
|
||||||
Unload the model
|
Unload the model
|
||||||
Do not call manually, use `unload` instead.
|
Do not call manually, use `unload` instead.
|
||||||
|
@ -124,11 +140,12 @@ class BaseModel(abc.ABC):
|
||||||
:return: the response of the model
|
:return: the response of the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# make sure we are loaded before an inference
|
async with self.registry.infer_lock:
|
||||||
self.load()
|
# ensure that the model is loaded
|
||||||
|
await self.load()
|
||||||
|
|
||||||
# model specific inference part
|
# model specific inference part
|
||||||
return await self._infer(**kwargs)
|
return await self._infer(**kwargs)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
async def _infer(self, **kwargs) -> typing.AsyncIterator[bytes]:
|
||||||
|
|
|
@ -33,8 +33,8 @@ class ModelRegistry:
|
||||||
# having two calculations at the same time might not be worth it either
|
# having two calculations at the same time might not be worth it either
|
||||||
self.current_loaded_model: typing.Optional[BaseModel] = None
|
self.current_loaded_model: typing.Optional[BaseModel] = None
|
||||||
|
|
||||||
# lock to avoid concurrent inference and concurrent model loading and unloading
|
# lock to control access to model inference
|
||||||
self.inference_lock = asyncio.Lock()
|
self.infer_lock = asyncio.Lock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def api_base(self) -> str:
|
def api_base(self) -> str:
|
||||||
|
@ -48,34 +48,6 @@ class ModelRegistry:
|
||||||
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
|
def register_type(self, name: str, model_type: "typing.Type[BaseModel]"):
|
||||||
self.model_types[name] = model_type
|
self.model_types[name] = model_type
|
||||||
|
|
||||||
async def load_model(self, model: "BaseModel"):
|
|
||||||
# lock to avoid concurrent loading
|
|
||||||
async with self.inference_lock:
|
|
||||||
# if there is another currently loaded model, unload it
|
|
||||||
if self.current_loaded_model is not None and self.current_loaded_model is not model:
|
|
||||||
await self.unload_model(self.current_loaded_model)
|
|
||||||
|
|
||||||
# load the model
|
|
||||||
model.load()
|
|
||||||
|
|
||||||
# mark the model as the currently loaded model of the manager
|
|
||||||
self.current_loaded_model = model
|
|
||||||
|
|
||||||
async def unload_model(self, model: "BaseModel"):
|
|
||||||
# lock to avoid concurrent unloading
|
|
||||||
async with self.inference_lock:
|
|
||||||
# if we were the currently loaded model of the manager, demote ourselves
|
|
||||||
if self.current_loaded_model is model:
|
|
||||||
self.current_loaded_model = None
|
|
||||||
|
|
||||||
# model specific unloading part
|
|
||||||
model.unload()
|
|
||||||
|
|
||||||
async def infer_model(self, model: "BaseModel", **kwargs) -> typing.AsyncIterator[bytes]:
|
|
||||||
# lock to avoid concurrent inference
|
|
||||||
async with self.inference_lock:
|
|
||||||
return await model.infer(**kwargs)
|
|
||||||
|
|
||||||
def reload_models(self) -> None:
|
def reload_models(self) -> None:
|
||||||
"""
|
"""
|
||||||
Reload the list of models available
|
Reload the list of models available
|
||||||
|
|
Loading…
Reference in a new issue