Fedir Zadniprovskyi 2024-09-05
feat: support model preloading (#66)
@8b11304e0f7e7394cbdfb753a4c3fcfe432703f1
faster_whisper_server/config.py
--- faster_whisper_server/config.py
+++ faster_whisper_server/config.py
@@ -1,6 +1,7 @@
 import enum
+from typing import Self
 
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, model_validator
 from pydantic_settings import BaseSettings, SettingsConfigDict
 
 SAMPLES_PER_SECOND = 16000
@@ -151,7 +152,9 @@
 
     model: str = Field(default="Systran/faster-whisper-medium.en")
     """
-    Huggingface model to use for transcription. Note, the model must support being ran using CTranslate2.
+    Default Huggingface model to use for transcription. Note, the model must support being ran using CTranslate2.
+    This model will be used if no model is specified in the request.
+
     Models created by authors of `faster-whisper` can be found at https://huggingface.co/Systran
     You can find other supported models at https://huggingface.co/models?p=2&sort=trending&search=ctranslate2 and https://huggingface.co/models?sort=trending&search=ct2
     """
@@ -199,6 +202,16 @@
     """
     Maximum number of models that can be loaded at a time.
     """
+    preload_models: list[str] = Field(
+        default_factory=list,
+        examples=[
+            ["Systran/faster-whisper-medium.en"],
+            ["Systran/faster-whisper-medium.en", "Systran/faster-whisper-small.en"],
+        ],
+    )
+    """
+    List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request.
+    """  # noqa: E501
     max_no_data_seconds: float = 1.0
     """
     Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
@@ -218,5 +231,13 @@
     Should be greater than `max_inactivity_seconds`
     """
 
+    @model_validator(mode="after")
+    def ensure_preloaded_models_is_lte_max_models(self) -> Self:
+        if len(self.preload_models) > self.max_models:
+            raise ValueError(
+                f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})"  # noqa: E501
+            )
+        return self
+
 
 config = Config()
faster_whisper_server/main.py
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
@@ -2,6 +2,7 @@
 
 import asyncio
 from collections import OrderedDict
+from contextlib import asynccontextmanager
 from io import BytesIO
 import time
 from typing import TYPE_CHECKING, Annotated, Literal
@@ -45,7 +46,7 @@
 from faster_whisper_server.transcriber import audio_transcriber
 
 if TYPE_CHECKING:
-    from collections.abc import Generator, Iterable
+    from collections.abc import AsyncGenerator, Generator, Iterable
 
     from faster_whisper.transcribe import TranscriptionInfo
     from huggingface_hub.hf_api import ModelInfo
@@ -63,7 +64,7 @@
         del loaded_models[oldest_model_name]
     logger.debug(f"Loading {model_name}...")
     start = time.perf_counter()
-    # NOTE: will raise an exception if the model name isn't valid
+    # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
     whisper = WhisperModel(
         model_name,
         device=config.whisper.inference_device,
@@ -81,7 +82,15 @@
 
 logger.debug(f"Config: {config}")
 
-app = FastAPI()
+
+@asynccontextmanager
+async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
+    for model_name in config.preload_models:
+        load_model(model_name)
+    yield
+
+
+app = FastAPI(lifespan=lifespan)
 
 if config.allow_origins is not None:
     app.add_middleware(
Add a comment
List