

refactor: add `ModelManager`
@2c4f21c8d628cf945e51f51a347b774fffc20845
--- src/faster_whisper_server/main.py
+++ src/faster_whisper_server/main.py
... | ... | @@ -1,11 +1,9 @@ |
1 | 1 |
from __future__ import annotations |
2 | 2 |
|
3 | 3 |
import asyncio |
4 |
-from collections import OrderedDict |
|
5 | 4 |
from contextlib import asynccontextmanager |
6 | 5 |
import gc |
7 | 6 |
from io import BytesIO |
8 |
-import time |
|
9 | 7 |
from typing import TYPE_CHECKING, Annotated, Literal |
10 | 8 |
|
11 | 9 |
from fastapi import ( |
... | ... | @@ -22,7 +20,6 @@ |
22 | 20 |
from fastapi.middleware.cors import CORSMiddleware |
23 | 21 |
from fastapi.responses import StreamingResponse |
24 | 22 |
from fastapi.websockets import WebSocketState |
25 |
-from faster_whisper import WhisperModel |
|
26 | 23 |
from faster_whisper.vad import VadOptions, get_speech_timestamps |
27 | 24 |
import huggingface_hub |
28 | 25 |
from huggingface_hub.hf_api import RepositoryNotFoundError |
... | ... | @@ -40,6 +37,7 @@ |
40 | 37 |
) |
41 | 38 |
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt |
42 | 39 |
from faster_whisper_server.logger import logger |
40 |
+from faster_whisper_server.model_manager import ModelManager |
|
43 | 41 |
from faster_whisper_server.server_models import ( |
44 | 42 |
ModelListResponse, |
45 | 43 |
ModelObject, |
... | ... | @@ -54,42 +52,16 @@ |
54 | 52 |
from faster_whisper.transcribe import TranscriptionInfo |
55 | 53 |
from huggingface_hub.hf_api import ModelInfo |
56 | 54 |
|
57 |
-loaded_models: OrderedDict[str, WhisperModel] = OrderedDict() |
|
58 |
- |
|
59 |
- |
|
60 |
-def load_model(model_name: str) -> WhisperModel: |
|
61 |
- if model_name in loaded_models: |
|
62 |
- logger.debug(f"{model_name} model already loaded") |
|
63 |
- return loaded_models[model_name] |
|
64 |
- if len(loaded_models) >= config.max_models: |
|
65 |
- oldest_model_name = next(iter(loaded_models)) |
|
66 |
- logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}") |
|
67 |
- del loaded_models[oldest_model_name] |
|
68 |
- logger.debug(f"Loading {model_name}...") |
|
69 |
- start = time.perf_counter() |
|
70 |
- # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check? |
|
71 |
- whisper = WhisperModel( |
|
72 |
- model_name, |
|
73 |
- device=config.whisper.inference_device, |
|
74 |
- device_index=config.whisper.device_index, |
|
75 |
- compute_type=config.whisper.compute_type, |
|
76 |
- cpu_threads=config.whisper.cpu_threads, |
|
77 |
- num_workers=config.whisper.num_workers, |
|
78 |
- ) |
|
79 |
- logger.info( |
|
80 |
- f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501 |
|
81 |
- ) |
|
82 |
- loaded_models[model_name] = whisper |
|
83 |
- return whisper |
|
84 |
- |
|
85 | 55 |
|
86 | 56 |
logger.debug(f"Config: {config}") |
57 |
+ |
|
58 |
+model_manager = ModelManager() |
|
87 | 59 |
|
88 | 60 |
|
89 | 61 |
@asynccontextmanager |
90 | 62 |
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: |
91 | 63 |
for model_name in config.preload_models: |
92 |
- load_model(model_name) |
|
64 |
+ model_manager.load_model(model_name) |
|
93 | 65 |
yield |
94 | 66 |
|
95 | 67 |
|
... | ... | @@ -123,22 +95,22 @@ |
123 | 95 |
|
124 | 96 |
@app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.") |
125 | 97 |
def get_running_models() -> dict[str, list[str]]: |
126 |
- return {"models": list(loaded_models.keys())} |
|
98 |
+ return {"models": list(model_manager.loaded_models.keys())} |
|
127 | 99 |
|
128 | 100 |
|
129 | 101 |
@app.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.") |
130 | 102 |
def load_model_route(model_name: str) -> Response: |
131 |
- if model_name in loaded_models: |
|
103 |
+ if model_name in model_manager.loaded_models: |
|
132 | 104 |
return Response(status_code=409, content="Model already loaded") |
133 |
- load_model(model_name) |
|
105 |
+ model_manager.load_model(model_name) |
|
134 | 106 |
return Response(status_code=201) |
135 | 107 |
|
136 | 108 |
|
137 | 109 |
@app.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.") |
138 | 110 |
def stop_running_model(model_name: str) -> Response: |
139 |
- model = loaded_models.get(model_name) |
|
111 |
+ model = model_manager.loaded_models.get(model_name) |
|
140 | 112 |
if model is not None: |
141 |
- del loaded_models[model_name] |
|
113 |
+ del model_manager.loaded_models[model_name] |
|
142 | 114 |
gc.collect() |
143 | 115 |
return Response(status_code=204) |
144 | 116 |
return Response(status_code=404) |
... | ... | @@ -291,7 +263,7 @@ |
291 | 263 |
temperature: Annotated[float, Form()] = 0.0, |
292 | 264 |
stream: Annotated[bool, Form()] = False, |
293 | 265 |
) -> Response | StreamingResponse: |
294 |
- whisper = load_model(model) |
|
266 |
+ whisper = model_manager.load_model(model) |
|
295 | 267 |
segments, transcription_info = whisper.transcribe( |
296 | 268 |
file.file, |
297 | 269 |
task=Task.TRANSLATE, |
... | ... | @@ -327,7 +299,7 @@ |
327 | 299 |
stream: Annotated[bool, Form()] = False, |
328 | 300 |
hotwords: Annotated[str | None, Form()] = None, |
329 | 301 |
) -> Response | StreamingResponse: |
330 |
- whisper = load_model(model) |
|
302 |
+ whisper = model_manager.load_model(model) |
|
331 | 303 |
segments, transcription_info = whisper.transcribe( |
332 | 304 |
file.file, |
333 | 305 |
task=Task.TRANSCRIBE, |
... | ... | @@ -391,7 +363,7 @@ |
391 | 363 |
"vad_filter": True, |
392 | 364 |
"condition_on_previous_text": False, |
393 | 365 |
} |
394 |
- whisper = load_model(model) |
|
366 |
+ whisper = model_manager.load_model(model) |
|
395 | 367 |
asr = FasterWhisperASR(whisper, **transcribe_opts) |
396 | 368 |
audio_stream = AudioStream() |
397 | 369 |
async with asyncio.TaskGroup() as tg: |
+++ src/faster_whisper_server/model_manager.py
... | ... | @@ -0,0 +1,43 @@ |
1 | +from __future__ import annotations | |
2 | + | |
3 | +from collections import OrderedDict | |
4 | +import gc | |
5 | +import time | |
6 | + | |
7 | +from faster_whisper import WhisperModel | |
8 | + | |
9 | +from faster_whisper_server.config import ( | |
10 | + config, | |
11 | +) | |
12 | +from faster_whisper_server.logger import logger | |
13 | + | |
14 | + | |
15 | +class ModelManager: | |
16 | + def __init__(self) -> None: | |
17 | + self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict() | |
18 | + | |
19 | + def load_model(self, model_name: str) -> WhisperModel: | |
20 | + if model_name in self.loaded_models: | |
21 | + logger.debug(f"{model_name} model already loaded") | |
22 | + return self.loaded_models[model_name] | |
23 | + if len(self.loaded_models) >= config.max_models: | |
24 | + oldest_model_name = next(iter(self.loaded_models)) | |
25 | + logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}") | |
26 | + del self.loaded_models[oldest_model_name] | |
27 | + gc.collect() | |
28 | + logger.debug(f"Loading {model_name}...") | |
29 | + start = time.perf_counter() | |
30 | + # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check? | |
31 | + whisper = WhisperModel( | |
32 | + model_name, | |
33 | + device=config.whisper.inference_device, | |
34 | + device_index=config.whisper.device_index, | |
35 | + compute_type=config.whisper.compute_type, | |
36 | + cpu_threads=config.whisper.cpu_threads, | |
37 | + num_workers=config.whisper.num_workers, | |
38 | + ) | |
39 | + logger.info( | |
40 | + f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501 | |
41 | + ) | |
42 | + self.loaded_models[model_name] = whisper | |
43 | + return whisper |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?