Fedir Zadniprovskyi 2024-09-20
refactor: add `ModelManager`
@2c4f21c8d628cf945e51f51a347b774fffc20845
src/faster_whisper_server/main.py
--- src/faster_whisper_server/main.py
+++ src/faster_whisper_server/main.py
@@ -1,11 +1,9 @@
 from __future__ import annotations
 
 import asyncio
-from collections import OrderedDict
 from contextlib import asynccontextmanager
 import gc
 from io import BytesIO
-import time
 from typing import TYPE_CHECKING, Annotated, Literal
 
 from fastapi import (
@@ -22,7 +20,6 @@
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import StreamingResponse
 from fastapi.websockets import WebSocketState
-from faster_whisper import WhisperModel
 from faster_whisper.vad import VadOptions, get_speech_timestamps
 import huggingface_hub
 from huggingface_hub.hf_api import RepositoryNotFoundError
@@ -40,6 +37,7 @@
 )
 from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
 from faster_whisper_server.logger import logger
+from faster_whisper_server.model_manager import ModelManager
 from faster_whisper_server.server_models import (
     ModelListResponse,
     ModelObject,
@@ -54,42 +52,16 @@
     from faster_whisper.transcribe import TranscriptionInfo
     from huggingface_hub.hf_api import ModelInfo
 
-loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
-
-
-def load_model(model_name: str) -> WhisperModel:
-    if model_name in loaded_models:
-        logger.debug(f"{model_name} model already loaded")
-        return loaded_models[model_name]
-    if len(loaded_models) >= config.max_models:
-        oldest_model_name = next(iter(loaded_models))
-        logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
-        del loaded_models[oldest_model_name]
-    logger.debug(f"Loading {model_name}...")
-    start = time.perf_counter()
-    # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
-    whisper = WhisperModel(
-        model_name,
-        device=config.whisper.inference_device,
-        device_index=config.whisper.device_index,
-        compute_type=config.whisper.compute_type,
-        cpu_threads=config.whisper.cpu_threads,
-        num_workers=config.whisper.num_workers,
-    )
-    logger.info(
-        f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."  # noqa: E501
-    )
-    loaded_models[model_name] = whisper
-    return whisper
-
 
 logger.debug(f"Config: {config}")
+
+model_manager = ModelManager()
 
 
 @asynccontextmanager
 async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
     for model_name in config.preload_models:
-        load_model(model_name)
+        model_manager.load_model(model_name)
     yield
 
 
@@ -123,22 +95,22 @@
 
 @app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
 def get_running_models() -> dict[str, list[str]]:
-    return {"models": list(loaded_models.keys())}
+    return {"models": list(model_manager.loaded_models.keys())}
 
 
 @app.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
 def load_model_route(model_name: str) -> Response:
-    if model_name in loaded_models:
+    if model_name in model_manager.loaded_models:
         return Response(status_code=409, content="Model already loaded")
-    load_model(model_name)
+    model_manager.load_model(model_name)
     return Response(status_code=201)
 
 
 @app.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
 def stop_running_model(model_name: str) -> Response:
-    model = loaded_models.get(model_name)
+    model = model_manager.loaded_models.get(model_name)
     if model is not None:
-        del loaded_models[model_name]
+        del model_manager.loaded_models[model_name]
         gc.collect()
         return Response(status_code=204)
     return Response(status_code=404)
@@ -291,7 +263,7 @@
     temperature: Annotated[float, Form()] = 0.0,
     stream: Annotated[bool, Form()] = False,
 ) -> Response | StreamingResponse:
-    whisper = load_model(model)
+    whisper = model_manager.load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
         task=Task.TRANSLATE,
@@ -327,7 +299,7 @@
     stream: Annotated[bool, Form()] = False,
     hotwords: Annotated[str | None, Form()] = None,
 ) -> Response | StreamingResponse:
-    whisper = load_model(model)
+    whisper = model_manager.load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
         task=Task.TRANSCRIBE,
@@ -391,7 +363,7 @@
         "vad_filter": True,
         "condition_on_previous_text": False,
     }
-    whisper = load_model(model)
+    whisper = model_manager.load_model(model)
     asr = FasterWhisperASR(whisper, **transcribe_opts)
     audio_stream = AudioStream()
     async with asyncio.TaskGroup() as tg:
 
src/faster_whisper_server/model_manager.py (added)
+++ src/faster_whisper_server/model_manager.py
@@ -0,0 +1,43 @@
+from __future__ import annotations
+
+from collections import OrderedDict
+import gc
+import time
+
+from faster_whisper import WhisperModel
+
+from faster_whisper_server.config import (
+    config,
+)
+from faster_whisper_server.logger import logger
+
+
+class ModelManager:
+    def __init__(self) -> None:
+        self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
+
+    def load_model(self, model_name: str) -> WhisperModel:
+        if model_name in self.loaded_models:
+            logger.debug(f"{model_name} model already loaded")
+            return self.loaded_models[model_name]
+        if len(self.loaded_models) >= config.max_models:
+            oldest_model_name = next(iter(self.loaded_models))
+            logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
+            del self.loaded_models[oldest_model_name]
+            gc.collect()
+        logger.debug(f"Loading {model_name}...")
+        start = time.perf_counter()
+        # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
+        whisper = WhisperModel(
+            model_name,
+            device=config.whisper.inference_device,
+            device_index=config.whisper.device_index,
+            compute_type=config.whisper.compute_type,
+            cpu_threads=config.whisper.cpu_threads,
+            num_workers=config.whisper.num_workers,
+        )
+        logger.info(
+            f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."  # noqa: E501
+        )
+        self.loaded_models[model_name] = whisper
+        return whisper
Add a comment
List