Fedir Zadniprovskyi 2024-05-27
feat: support loading multiple models
@aada575dd8c8d2054f53e744abcee704e2944632
speaches/config.py
--- speaches/config.py
+++ speaches/config.py
@@ -163,39 +163,41 @@
 
 
 class WhisperConfig(BaseModel):
-    model: Model = Field(default=Model.DISTIL_MEDIUM_EN)  # ENV: WHISPER_MODEL
-    inference_device: Device = Field(
-        default=Device.AUTO
-    )  # ENV: WHISPER_INFERENCE_DEVICE
-    compute_type: Quantization = Field(
-        default=Quantization.DEFAULT
-    )  # ENV: WHISPER_COMPUTE_TYPE
+    model: Model = Field(default=Model.DISTIL_MEDIUM_EN)
+    inference_device: Device = Field(default=Device.AUTO)
+    compute_type: Quantization = Field(default=Quantization.DEFAULT)
 
 
 class Config(BaseSettings):
+    """
+    Configuration for the application. Values can be set via environment variables.
+    Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields.
+    To populate nested, the environment should be prefixed with the nested field name and an underscore. For example,
+    the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER_MODEL` to `whisper.model`, etc.
+    """
+
     model_config = SettingsConfigDict(env_nested_delimiter="_")
 
-    log_level: str = "info"  # ENV: LOG_LEVEL
-    default_language: Language | None = None  # ENV: DEFAULT_LANGUAGE
-    default_response_format: ResponseFormat = (
-        ResponseFormat.JSON
-    )  # ENV: DEFAULT_RESPONSE_FORMAT
-    whisper: WhisperConfig = WhisperConfig()  # ENV: WHISPER_*
+    log_level: str = "info"
+    default_language: Language | None = None
+    default_response_format: ResponseFormat = ResponseFormat.JSON
+    whisper: WhisperConfig = WhisperConfig()
+    max_models: int = 1
     """
     Max duration to for the next audio chunk before transcription is finilized and connection is closed.
     """
-    max_no_data_seconds: float = 1.0  # ENV: MAX_NO_DATA_SECONDS
-    min_duration: float = 1.0  # ENV: MIN_DURATION
-    word_timestamp_error_margin: float = 0.2  # ENV: WORD_TIMESTAMP_ERROR_MARGIN
+    max_no_data_seconds: float = 1.0
+    min_duration: float = 1.0
+    word_timestamp_error_margin: float = 0.2
     """
     Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
     """
-    max_inactivity_seconds: float = 2.0  # ENV: MAX_INACTIVITY_SECONDS
+    max_inactivity_seconds: float = 2.0
     """
     Controls how many latest seconds of audio are being passed through VAD.
     Should be greater than `max_inactivity_seconds`
     """
-    inactivity_window_seconds: float = 3.0  # ENV: INACTIVITY_WINDOW_SECONDS
+    inactivity_window_seconds: float = 3.0
 
 
 config = Config()
speaches/main.py
--- speaches/main.py
+++ speaches/main.py
@@ -1,11 +1,10 @@
 from __future__ import annotations
 
 import asyncio
-import logging
 import time
 from contextlib import asynccontextmanager
 from io import BytesIO
-from typing import Annotated, Literal
+from typing import Annotated, Literal, OrderedDict
 
 from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket,
                      WebSocketDisconnect)
@@ -19,29 +18,45 @@
 from speaches.audio import AudioStream, audio_samples_from_file
 from speaches.config import (SAMPLES_PER_SECOND, Language, Model,
                              ResponseFormat, config)
-from speaches.core import Transcription
 from speaches.logger import logger
 from speaches.server_models import (TranscriptionJsonResponse,
                                     TranscriptionVerboseJsonResponse)
 from speaches.transcriber import audio_transcriber
 
-whisper: WhisperModel = None  # type: ignore
+models: OrderedDict[Model, WhisperModel] = OrderedDict()
+
+
+def load_model(model_name: Model) -> WhisperModel:
+    if model_name in models:
+        logger.debug(f"{model_name} model already loaded")
+        return models[model_name]
+    if len(models) >= config.max_models:
+        oldest_model_name = next(iter(models))
+        logger.info(
+            f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
+        )
+        del models[oldest_model_name]
+    logger.debug(f"Loading {model_name}")
+    start = time.perf_counter()
+    whisper = WhisperModel(
+        model_name,
+        device=config.whisper.inference_device,
+        compute_type=config.whisper.compute_type,
+    )
+    logger.info(
+        f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds"
+    )
+    models[model_name] = whisper
+    return whisper
 
 
 @asynccontextmanager
 async def lifespan(_: FastAPI):
-    global whisper
-    logging.debug(f"Loading {config.whisper.model}")
-    start = time.perf_counter()
-    whisper = WhisperModel(
-        config.whisper.model,
-        device=config.whisper.inference_device,
-        compute_type=config.whisper.compute_type,
-    )
-    logger.debug(
-        f"Loaded {config.whisper.model} loaded in {time.perf_counter() - start:.2f} seconds"
-    )
+    load_model(config.whisper.model)
     yield
+    for model in models.keys():
+        logger.info(f"Unloading {model}")
+        del models[model]
 
 
 app = FastAPI(lifespan=lifespan)
@@ -53,7 +68,7 @@
 
 
 @app.post("/v1/audio/translations")
-async def translate_file(
+def translate_file(
     file: Annotated[UploadFile, Form()],
     model: Annotated[Model, Form()] = config.whisper.model,
     prompt: Annotated[str | None, Form()] = None,
@@ -61,11 +76,8 @@
     temperature: Annotated[float, Form()] = 0.0,
     stream: Annotated[bool, Form()] = False,
 ):
-    if model != config.whisper.model:
-        logger.warning(
-            f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
-        )
     start = time.perf_counter()
+    whisper = load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
         task="translate",
@@ -107,7 +119,7 @@
 # https://platform.openai.com/docs/api-reference/audio/createTranscription
 # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
 @app.post("/v1/audio/transcriptions")
-async def transcribe_file(
+def transcribe_file(
     file: Annotated[UploadFile, Form()],
     model: Annotated[Model, Form()] = config.whisper.model,
     language: Annotated[Language | None, Form()] = config.default_language,
@@ -120,11 +132,8 @@
     ] = ["segments"],
     stream: Annotated[bool, Form()] = False,
 ):
-    if model != config.whisper.model:
-        logger.warning(
-            f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
-        )
     start = time.perf_counter()
+    whisper = load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
         task="transcribe",
@@ -209,21 +218,6 @@
     audio_stream.close()
 
 
-def format_transcription(
-    transcription: Transcription, response_format: ResponseFormat
-) -> str:
-    if response_format == ResponseFormat.TEXT:
-        return transcription.text
-    elif response_format == ResponseFormat.JSON:
-        return TranscriptionJsonResponse.from_transcription(
-            transcription
-        ).model_dump_json()
-    elif response_format == ResponseFormat.VERBOSE_JSON:
-        return TranscriptionVerboseJsonResponse.from_transcription(
-            transcription
-        ).model_dump_json()
-
-
 @app.websocket("/v1/audio/transcriptions")
 async def transcribe_stream(
     ws: WebSocket,
@@ -234,18 +228,7 @@
         ResponseFormat, Query()
     ] = config.default_response_format,
     temperature: Annotated[float, Query()] = 0.0,
-    timestamp_granularities: Annotated[
-        list[Literal["segments"] | Literal["words"]],
-        Query(
-            alias="timestamp_granularities[]",
-            description="No-op. Ignored. Only for compatibility.",
-        ),
-    ] = ["segments", "words"],
 ) -> None:
-    if model != config.whisper.model:
-        logger.warning(
-            f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
-        )
     await ws.accept()
     transcribe_opts = {
         "language": language,
@@ -254,6 +237,7 @@
         "vad_filter": True,
         "condition_on_previous_text": False,
     }
+    whisper = load_model(model)
     asr = FasterWhisperASR(whisper, **transcribe_opts)
     audio_stream = AudioStream()
     async with asyncio.TaskGroup() as tg:
@@ -262,7 +246,21 @@
             logger.debug(f"Sending transcription: {transcription.text}")
             if ws.client_state == WebSocketState.DISCONNECTED:
                 break
-            await ws.send_text(format_transcription(transcription, response_format))
+
+            if response_format == ResponseFormat.TEXT:
+                await ws.send_text(transcription.text)
+            elif response_format == ResponseFormat.JSON:
+                await ws.send_json(
+                    TranscriptionJsonResponse.from_transcription(
+                        transcription
+                    ).model_dump()
+                )
+            elif response_format == ResponseFormat.VERBOSE_JSON:
+                await ws.send_json(
+                    TranscriptionVerboseJsonResponse.from_transcription(
+                        transcription
+                    ).model_dump()
+                )
 
     if not ws.client_state == WebSocketState.DISCONNECTED:
         logger.info("Closing the connection.")
Add a comment
List