added a lock to avoid two inference at the same time and added consequent support for asynchronous generator based model

This commit is contained in:
faraphel 2025-01-10 19:11:48 +01:00
parent c6d779f591
commit 775c78c6cb
4 changed files with 19 additions and 12 deletions

View file

@ -1,4 +1,3 @@
import json
import typing
@ -8,5 +7,5 @@ def load(model) -> None:
def unload(model) -> None:
pass
def infer(model, file) -> typing.Iterator[bytes]:
yield json.dumps({"hello": "world!"}).encode("utf-8")
async def infer(model, file) -> typing.AsyncIterator[bytes]:
yield await file.read()

View file

@ -23,6 +23,9 @@ class ModelManager:
# TODO(Faraphel): load more than one model at a time ? require a way more complex manager to handle memory issue
self.current_loaded_model: typing.Optional[model.base.BaseModel] = None
# lock to avoid concurrent inference and concurrent model loading and unloading
self.inference_lock = asyncio.Lock()
@self.application.get("/models")
async def get_models() -> list[str]:
"""

View file

@ -50,7 +50,7 @@ class PythonModel(base.BaseModel):
parameters = utils.parameters.load(configuration.get("inputs", {}))
# create an endpoint wrapping the inference inside a fastapi call
async def infer_api(**kwargs):
async def infer_api(**kwargs) -> fastapi.responses.StreamingResponse:
# NOTE: fix an issue where it is not possible to give an UploadFile to a StreamingResponse
# NOTE: perform a naive type(value).__name__ == "type_name" because fastapi do not use it own
# fastapi.UploadFile class, but instead the starlette UploadFile class that is more of an implementation
@ -61,8 +61,12 @@ class PythonModel(base.BaseModel):
}
return fastapi.responses.StreamingResponse(
content=self.infer(**kwargs),
content=await self.infer(**kwargs),
media_type=self.output_type,
headers={
# if the data is not text-like, mark it as an attachment to avoid display issue with Swagger UI
"content-disposition": "inline" if utils.mimetypes.is_textlike(self.output_type) else "attachment"
}
)
infer_api.__signature__ = inspect.Signature(parameters=parameters)
@ -81,5 +85,5 @@ class PythonModel(base.BaseModel):
def _unload(self) -> None:
return self.module.unload(self)
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.Iterator[bytes]:
return self.module.infer(self, **kwargs)

View file

@ -106,20 +106,21 @@ class BaseModel(abc.ABC):
Do not call manually, use `unload` instead.
"""
def infer(self, **kwargs) -> typing.Iterator[bytes]:
async def infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
"""
Infer our payload through the model within the model manager
:return: the response of the model
"""
# make sure we are loaded before an inference
self.load()
async with self.manager.inference_lock:
# make sure we are loaded before an inference
self.load()
# model specific inference part
return self._infer(**kwargs)
# model specific inference part
return self._infer(**kwargs)
@abc.abstractmethod
def _infer(self, **kwargs) -> typing.Iterator[bytes]:
def _infer(self, **kwargs) -> typing.Iterator[bytes] | typing.AsyncIterator[bytes]:
"""
Infer our payload through the model
:return: the response of the model