

feat: support model preloading (#66)
@8b11304e0f7e7394cbdfb753a4c3fcfe432703f1
--- faster_whisper_server/config.py
+++ faster_whisper_server/config.py
... | ... | @@ -1,6 +1,7 @@ |
1 | 1 |
import enum |
2 |
+from typing import Self |
|
2 | 3 |
|
3 |
-from pydantic import BaseModel, Field |
|
4 |
+from pydantic import BaseModel, Field, model_validator |
|
4 | 5 |
from pydantic_settings import BaseSettings, SettingsConfigDict |
5 | 6 |
|
6 | 7 |
SAMPLES_PER_SECOND = 16000 |
... | ... | @@ -151,7 +152,9 @@ |
151 | 152 |
|
152 | 153 |
model: str = Field(default="Systran/faster-whisper-medium.en") |
153 | 154 |
""" |
154 |
- Huggingface model to use for transcription. Note, the model must support being ran using CTranslate2. |
|
155 |
+ Default Huggingface model to use for transcription. Note, the model must support being ran using CTranslate2. |
|
156 |
+ This model will be used if no model is specified in the request. |
|
157 |
+ |
|
155 | 158 |
Models created by authors of `faster-whisper` can be found at https://huggingface.co/Systran |
156 | 159 |
You can find other supported models at https://huggingface.co/models?p=2&sort=trending&search=ctranslate2 and https://huggingface.co/models?sort=trending&search=ct2 |
157 | 160 |
""" |
... | ... | @@ -199,6 +202,16 @@ |
199 | 202 |
""" |
200 | 203 |
Maximum number of models that can be loaded at a time. |
201 | 204 |
""" |
205 |
+ preload_models: list[str] = Field( |
|
206 |
+ default_factory=list, |
|
207 |
+ examples=[ |
|
208 |
+ ["Systran/faster-whisper-medium.en"], |
|
209 |
+ ["Systran/faster-whisper-medium.en", "Systran/faster-whisper-small.en"], |
|
210 |
+ ], |
|
211 |
+ ) |
|
212 |
+ """ |
|
213 |
+ List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request. |
|
214 |
+ """ # noqa: E501 |
|
202 | 215 |
max_no_data_seconds: float = 1.0 |
203 | 216 |
""" |
204 | 217 |
Max duration to wait for the next audio chunk before transcription is finilized and connection is closed. |
... | ... | @@ -218,5 +231,13 @@ |
218 | 231 |
Should be greater than `max_inactivity_seconds` |
219 | 232 |
""" |
220 | 233 |
|
234 |
+ @model_validator(mode="after") |
|
235 |
+ def ensure_preloaded_models_is_lte_max_models(self) -> Self: |
|
236 |
+ if len(self.preload_models) > self.max_models: |
|
237 |
+ raise ValueError( |
|
238 |
+ f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501 |
|
239 |
+ ) |
|
240 |
+ return self |
|
241 |
+ |
|
221 | 242 |
|
222 | 243 |
config = Config() |
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
... | ... | @@ -2,6 +2,7 @@ |
2 | 2 |
|
3 | 3 |
import asyncio |
4 | 4 |
from collections import OrderedDict |
5 |
+from contextlib import asynccontextmanager |
|
5 | 6 |
from io import BytesIO |
6 | 7 |
import time |
7 | 8 |
from typing import TYPE_CHECKING, Annotated, Literal |
... | ... | @@ -45,7 +46,7 @@ |
45 | 46 |
from faster_whisper_server.transcriber import audio_transcriber |
46 | 47 |
|
47 | 48 |
if TYPE_CHECKING: |
48 |
- from collections.abc import Generator, Iterable |
|
49 |
+ from collections.abc import AsyncGenerator, Generator, Iterable |
|
49 | 50 |
|
50 | 51 |
from faster_whisper.transcribe import TranscriptionInfo |
51 | 52 |
from huggingface_hub.hf_api import ModelInfo |
... | ... | @@ -63,7 +64,7 @@ |
63 | 64 |
del loaded_models[oldest_model_name] |
64 | 65 |
logger.debug(f"Loading {model_name}...") |
65 | 66 |
start = time.perf_counter() |
66 |
- # NOTE: will raise an exception if the model name isn't valid |
|
67 |
+ # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check? |
|
67 | 68 |
whisper = WhisperModel( |
68 | 69 |
model_name, |
69 | 70 |
device=config.whisper.inference_device, |
... | ... | @@ -81,7 +82,15 @@ |
81 | 82 |
|
82 | 83 |
logger.debug(f"Config: {config}") |
83 | 84 |
|
84 |
-app = FastAPI() |
|
85 |
+ |
|
86 |
+@asynccontextmanager |
|
87 |
+async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: |
|
88 |
+ for model_name in config.preload_models: |
|
89 |
+ load_model(model_name) |
|
90 |
+ yield |
|
91 |
+ |
|
92 |
+ |
|
93 |
+app = FastAPI(lifespan=lifespan) |
|
85 | 94 |
|
86 | 95 |
if config.allow_origins is not None: |
87 | 96 |
app.add_middleware( |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?