Fedir Zadniprovskyi 2024-09-20
feat: dependency injection
The main purpose of this change is to allow modifying the configuration
for testing. This change does lead to some ugly code where `get_config`
function gets called in random places.
@7cfaf5b1f698ed9900f3c7349a3603ac73d16873
Taskfile.yaml
--- Taskfile.yaml
+++ Taskfile.yaml
@@ -1,6 +1,6 @@
 version: "3"
 tasks:
-  server: uvicorn --host 0.0.0.0 faster_whisper_server.main:app {{.CLI_ARGS}}
+  server: uvicorn --factory --host 0.0.0.0 faster_whisper_server.main:create_app {{.CLI_ARGS}}
   test:
     cmds:
       - pytest -o log_cli=true -o log_cli_level=DEBUG {{.CLI_ARGS}}
pyproject.toml
--- pyproject.toml
+++ pyproject.toml
@@ -75,6 +75,7 @@
     "ISC001", # recommended to disable for formatting
     "INP001",
     "PT018",
+    "G004", # logging f string
 ]
 
 [tool.ruff.lint.isort]
src/faster_whisper_server/asr.py
--- src/faster_whisper_server/asr.py
+++ src/faster_whisper_server/asr.py
@@ -1,11 +1,13 @@
 import asyncio
+import logging
 import time
 
 from faster_whisper import transcribe
 
 from faster_whisper_server.audio import Audio
 from faster_whisper_server.core import Segment, Transcription, Word
-from faster_whisper_server.logger import logger
+
+logger = logging.getLogger(__name__)
 
 
 class FasterWhisperASR:
src/faster_whisper_server/audio.py
--- src/faster_whisper_server/audio.py
+++ src/faster_whisper_server/audio.py
@@ -1,13 +1,13 @@
 from __future__ import annotations
 
 import asyncio
+import logging
 from typing import TYPE_CHECKING, BinaryIO
 
 import numpy as np
 import soundfile as sf
 
 from faster_whisper_server.config import SAMPLES_PER_SECOND
-from faster_whisper_server.logger import logger
 
 if TYPE_CHECKING:
     from collections.abc import AsyncGenerator
@@ -15,6 +15,9 @@
     from numpy.typing import NDArray
 
 
+logger = logging.getLogger(__name__)
+
+
 def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
     audio_and_sample_rate = sf.read(
         file,
src/faster_whisper_server/config.py
--- src/faster_whisper_server/config.py
+++ src/faster_whisper_server/config.py
@@ -238,6 +238,3 @@
                 f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})"  # noqa: E501
             )
         return self
-
-
-config = Config()
src/faster_whisper_server/core.py
--- src/faster_whisper_server/core.py
+++ src/faster_whisper_server/core.py
@@ -5,7 +5,7 @@
 
 from pydantic import BaseModel
 
-from faster_whisper_server.config import config
+from faster_whisper_server.dependencies import get_config
 
 if TYPE_CHECKING:
     from collections.abc import Iterable
@@ -113,6 +113,7 @@
         self.words.extend(words)
 
     def _ensure_no_word_overlap(self, words: list[Word]) -> None:
+        config = get_config()  # HACK
         if len(self.words) > 0 and len(words) > 0:
             if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
                 raise ValueError(
 
src/faster_whisper_server/dependencies.py (added)
+++ src/faster_whisper_server/dependencies.py
@@ -0,0 +1,24 @@
+from functools import lru_cache
+from typing import Annotated
+
+from fastapi import Depends
+
+from faster_whisper_server.config import Config
+from faster_whisper_server.model_manager import ModelManager
+
+
+@lru_cache
+def get_config() -> Config:
+    return Config()
+
+
+ConfigDependency = Annotated[Config, Depends(get_config)]
+
+
+@lru_cache
+def get_model_manager() -> ModelManager:
+    config = get_config()  # HACK
+    return ModelManager(config)
+
+
+ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
src/faster_whisper_server/hf_utils.py
--- src/faster_whisper_server/hf_utils.py
+++ src/faster_whisper_server/hf_utils.py
@@ -1,10 +1,11 @@
 from collections.abc import Generator
+import logging
 from pathlib import Path
 import typing
 
 import huggingface_hub
 
-from faster_whisper_server.logger import logger
+logger = logging.getLogger(__name__)
 
 LIBRARY_NAME = "ctranslate2"
 TASK_NAME = "automatic-speech-recognition"
src/faster_whisper_server/logger.py
--- src/faster_whisper_server/logger.py
+++ src/faster_whisper_server/logger.py
@@ -1,8 +1,11 @@
 import logging
 
-from faster_whisper_server.config import config
+from faster_whisper_server.dependencies import get_config
 
-logging.getLogger().setLevel(logging.INFO)
-logger = logging.getLogger(__name__)
-logger.setLevel(config.log_level.upper())
-logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s")
+
+def setup_logger() -> None:
+    config = get_config()  # HACK
+    logging.getLogger().setLevel(logging.INFO)
+    logger = logging.getLogger(__name__)
+    logger.setLevel(config.log_level.upper())
+    logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s")
src/faster_whisper_server/main.py
--- src/faster_whisper_server/main.py
+++ src/faster_whisper_server/main.py
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 from contextlib import asynccontextmanager
+import logging
 from typing import TYPE_CHECKING
 
 from fastapi import (
@@ -8,11 +9,8 @@
 )
 from fastapi.middleware.cors import CORSMiddleware
 
-from faster_whisper_server.config import (
-    config,
-)
-from faster_whisper_server.logger import logger
-from faster_whisper_server.model_manager import model_manager
+from faster_whisper_server.dependencies import get_config, get_model_manager
+from faster_whisper_server.logger import setup_logger
 from faster_whisper_server.routers.list_models import (
     router as list_models_router,
 )
@@ -27,34 +25,42 @@
     from collections.abc import AsyncGenerator
 
 
-logger.debug(f"Config: {config}")
+def create_app() -> FastAPI:
+    setup_logger()
 
+    logger = logging.getLogger(__name__)
 
-@asynccontextmanager
-async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
-    for model_name in config.preload_models:
-        model_manager.load_model(model_name)
-    yield
+    config = get_config()  # HACK
+    logger.debug(f"Config: {config}")
 
+    model_manager = get_model_manager()  # HACK
 
-app = FastAPI(lifespan=lifespan)
+    @asynccontextmanager
+    async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
+        for model_name in config.preload_models:
+            model_manager.load_model(model_name)
+        yield
 
-app.include_router(stt_router)
-app.include_router(list_models_router)
-app.include_router(misc_router)
+    app = FastAPI(lifespan=lifespan)
 
-if config.allow_origins is not None:
-    app.add_middleware(
-        CORSMiddleware,
-        allow_origins=config.allow_origins,
-        allow_credentials=True,
-        allow_methods=["*"],
-        allow_headers=["*"],
-    )
+    app.include_router(stt_router)
+    app.include_router(list_models_router)
+    app.include_router(misc_router)
 
-if config.enable_ui:
-    import gradio as gr
+    if config.allow_origins is not None:
+        app.add_middleware(
+            CORSMiddleware,
+            allow_origins=config.allow_origins,
+            allow_credentials=True,
+            allow_methods=["*"],
+            allow_headers=["*"],
+        )
 
-    from faster_whisper_server.gradio_app import create_gradio_demo
+    if config.enable_ui:
+        import gradio as gr
 
-    app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/")
+        from faster_whisper_server.gradio_app import create_gradio_demo
+
+        app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/")
+
+    return app
src/faster_whisper_server/model_manager.py
--- src/faster_whisper_server/model_manager.py
+++ src/faster_whisper_server/model_manager.py
@@ -2,27 +2,34 @@
 
 from collections import OrderedDict
 import gc
+import logging
 import time
+from typing import TYPE_CHECKING
 
 from faster_whisper import WhisperModel
 
-from faster_whisper_server.config import (
-    config,
-)
-from faster_whisper_server.logger import logger
+if TYPE_CHECKING:
+    from faster_whisper_server.config import (
+        Config,
+    )
+
+logger = logging.getLogger(__name__)
 
 
 class ModelManager:
-    def __init__(self) -> None:
+    def __init__(self, config: Config) -> None:
+        self.config = config
         self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
 
     def load_model(self, model_name: str) -> WhisperModel:
         if model_name in self.loaded_models:
             logger.debug(f"{model_name} model already loaded")
             return self.loaded_models[model_name]
-        if len(self.loaded_models) >= config.max_models:
+        if len(self.loaded_models) >= self.config.max_models:
             oldest_model_name = next(iter(self.loaded_models))
-            logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
+            logger.info(
+                f"Max models ({self.config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
+            )
             del self.loaded_models[oldest_model_name]
             gc.collect()
         logger.debug(f"Loading {model_name}...")
@@ -30,17 +37,14 @@
         # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
         whisper = WhisperModel(
             model_name,
-            device=config.whisper.inference_device,
-            device_index=config.whisper.device_index,
-            compute_type=config.whisper.compute_type,
-            cpu_threads=config.whisper.cpu_threads,
-            num_workers=config.whisper.num_workers,
+            device=self.config.whisper.inference_device,
+            device_index=self.config.whisper.device_index,
+            compute_type=self.config.whisper.compute_type,
+            cpu_threads=self.config.whisper.cpu_threads,
+            num_workers=self.config.whisper.num_workers,
         )
         logger.info(
-            f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."  # noqa: E501
+            f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference."  # noqa: E501
         )
         self.loaded_models[model_name] = whisper
         return whisper
-
-
-model_manager = ModelManager()
src/faster_whisper_server/routers/misc.py
--- src/faster_whisper_server/routers/misc.py
+++ src/faster_whisper_server/routers/misc.py
@@ -6,10 +6,11 @@
     APIRouter,
     Response,
 )
-from faster_whisper_server import hf_utils
-from faster_whisper_server.model_manager import model_manager
 import huggingface_hub
 from huggingface_hub.hf_api import RepositoryNotFoundError
+
+from faster_whisper_server import hf_utils
+from faster_whisper_server.dependencies import ModelManagerDependency  # noqa: TCH001
 
 router = APIRouter()
 
@@ -31,12 +32,14 @@
 
 
 @router.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
-def get_running_models() -> dict[str, list[str]]:
+def get_running_models(
+    model_manager: ModelManagerDependency,
+) -> dict[str, list[str]]:
     return {"models": list(model_manager.loaded_models.keys())}
 
 
 @router.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
-def load_model_route(model_name: str) -> Response:
+def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
     if model_name in model_manager.loaded_models:
         return Response(status_code=409, content="Model already loaded")
     model_manager.load_model(model_name)
@@ -44,7 +47,7 @@
 
 
 @router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
-def stop_running_model(model_name: str) -> Response:
+def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
     model = model_manager.loaded_models.get(model_name)
     if model is not None:
         del model_manager.loaded_models[model_name]
src/faster_whisper_server/routers/stt.py
--- src/faster_whisper_server/routers/stt.py
+++ src/faster_whisper_server/routers/stt.py
@@ -2,6 +2,7 @@
 
 import asyncio
 from io import BytesIO
+import logging
 from typing import TYPE_CHECKING, Annotated, Literal
 
 from fastapi import (
@@ -16,6 +17,8 @@
 from fastapi.responses import StreamingResponse
 from fastapi.websockets import WebSocketState
 from faster_whisper.vad import VadOptions, get_speech_timestamps
+from pydantic import AfterValidator
+
 from faster_whisper_server.asr import FasterWhisperASR
 from faster_whisper_server.audio import AudioStream, audio_samples_from_file
 from faster_whisper_server.config import (
@@ -23,23 +26,22 @@
     Language,
     ResponseFormat,
     Task,
-    config,
 )
 from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
-from faster_whisper_server.logger import logger
-from faster_whisper_server.model_manager import model_manager
+from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
 from faster_whisper_server.server_models import (
     TranscriptionJsonResponse,
     TranscriptionVerboseJsonResponse,
 )
 from faster_whisper_server.transcriber import audio_transcriber
-from pydantic import AfterValidator
 
 if TYPE_CHECKING:
     from collections.abc import Generator, Iterable
 
     from faster_whisper.transcribe import TranscriptionInfo
 
+
+logger = logging.getLogger(__name__)
 
 router = APIRouter()
 
@@ -103,6 +105,7 @@
 
     For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
     """
+    config = get_config()  # HACK
     if model_name == "whisper-1":
         logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
         return config.whisper.model
@@ -117,13 +120,19 @@
     response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
 )
 def translate_file(
+    config: ConfigDependency,
+    model_manager: ModelManagerDependency,
     file: Annotated[UploadFile, Form()],
-    model: Annotated[ModelName, Form()] = config.whisper.model,
+    model: Annotated[ModelName | None, Form()] = None,
     prompt: Annotated[str | None, Form()] = None,
-    response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
+    response_format: Annotated[ResponseFormat | None, Form()] = None,
     temperature: Annotated[float, Form()] = 0.0,
     stream: Annotated[bool, Form()] = False,
 ) -> Response | StreamingResponse:
+    if model is None:
+        model = config.whisper.model
+    if response_format is None:
+        response_format = config.default_response_format
     whisper = model_manager.load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
@@ -147,11 +156,13 @@
     response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
 )
 def transcribe_file(
+    config: ConfigDependency,
+    model_manager: ModelManagerDependency,
     file: Annotated[UploadFile, Form()],
-    model: Annotated[ModelName, Form()] = config.whisper.model,
-    language: Annotated[Language | None, Form()] = config.default_language,
+    model: Annotated[ModelName | None, Form()] = None,
+    language: Annotated[Language | None, Form()] = None,
     prompt: Annotated[str | None, Form()] = None,
-    response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
+    response_format: Annotated[ResponseFormat | None, Form()] = None,
     temperature: Annotated[float, Form()] = 0.0,
     timestamp_granularities: Annotated[
         list[Literal["segment", "word"]],
@@ -160,6 +171,12 @@
     stream: Annotated[bool, Form()] = False,
     hotwords: Annotated[str | None, Form()] = None,
 ) -> Response | StreamingResponse:
+    if model is None:
+        model = config.whisper.model
+    if language is None:
+        language = config.default_language
+    if response_format is None:
+        response_format = config.default_response_format
     whisper = model_manager.load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
@@ -180,6 +197,7 @@
 
 
 async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
+    config = get_config()  # HACK
     try:
         while True:
             bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
@@ -211,12 +229,20 @@
 
 @router.websocket("/v1/audio/transcriptions")
 async def transcribe_stream(
+    config: ConfigDependency,
+    model_manager: ModelManagerDependency,
     ws: WebSocket,
-    model: Annotated[ModelName, Query()] = config.whisper.model,
-    language: Annotated[Language | None, Query()] = config.default_language,
-    response_format: Annotated[ResponseFormat, Query()] = config.default_response_format,
+    model: Annotated[ModelName | None, Query()] = None,
+    language: Annotated[Language | None, Query()] = None,
+    response_format: Annotated[ResponseFormat | None, Query()] = None,
     temperature: Annotated[float, Query()] = 0.0,
 ) -> None:
+    if model is None:
+        model = config.whisper.model
+    if language is None:
+        language = config.default_language
+    if response_format is None:
+        response_format = config.default_response_format
     await ws.accept()
     transcribe_opts = {
         "language": language,
@@ -229,7 +255,7 @@
     audio_stream = AudioStream()
     async with asyncio.TaskGroup() as tg:
         tg.create_task(audio_receiver(ws, audio_stream))
-        async for transcription in audio_transcriber(asr, audio_stream):
+        async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
             logger.debug(f"Sending transcription: {transcription.text}")
             if ws.client_state == WebSocketState.DISCONNECTED:
                 break
src/faster_whisper_server/transcriber.py
--- src/faster_whisper_server/transcriber.py
+++ src/faster_whisper_server/transcriber.py
@@ -1,16 +1,17 @@
 from __future__ import annotations
 
+import logging
 from typing import TYPE_CHECKING
 
 from faster_whisper_server.audio import Audio, AudioStream
-from faster_whisper_server.config import config
 from faster_whisper_server.core import Transcription, Word, common_prefix, to_full_sentences, word_to_text
-from faster_whisper_server.logger import logger
 
 if TYPE_CHECKING:
     from collections.abc import AsyncGenerator
 
     from faster_whisper_server.asr import FasterWhisperASR
+
+logger = logging.getLogger(__name__)
 
 
 class LocalAgreement:
@@ -47,11 +48,12 @@
 async def audio_transcriber(
     asr: FasterWhisperASR,
     audio_stream: AudioStream,
+    min_duration: float,
 ) -> AsyncGenerator[Transcription, None]:
     local_agreement = LocalAgreement()
     full_audio = Audio()
     confirmed = Transcription()
-    async for chunk in audio_stream.chunks(config.min_duration):
+    async for chunk in audio_stream.chunks(min_duration):
         full_audio.extend(chunk)
         audio = full_audio.after(needs_audio_after(confirmed))
         transcription, _ = await asr.transcribe(audio, prompt(confirmed))
tests/conftest.py
--- tests/conftest.py
+++ tests/conftest.py
@@ -1,7 +1,9 @@
 from collections.abc import AsyncGenerator, Generator
 import logging
+import os
 
 from fastapi.testclient import TestClient
+from faster_whisper_server.main import create_app
 from httpx import ASGITransport, AsyncClient
 from openai import OpenAI
 import pytest
@@ -18,17 +20,15 @@
 
 @pytest.fixture()
 def client() -> Generator[TestClient, None, None]:
-    from faster_whisper_server.main import app
-
-    with TestClient(app) as client:
+    os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
+    with TestClient(create_app()) as client:
         yield client
 
 
 @pytest_asyncio.fixture()
 async def aclient() -> AsyncGenerator[AsyncClient, None]:
-    from faster_whisper_server.main import app
-
-    async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient:
+    os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
+    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
         yield aclient
 
 
Add a comment
List