Fedir Zadniprovskyi 2024-10-01
feat: model unloading
@cb7375c9ca40bf121ee40e294acadf25466c7d1f
pyproject.toml
--- pyproject.toml
+++ pyproject.toml
@@ -19,10 +19,10 @@
 client = [
     "keyboard>=0.13.5",
 ]
-# NOTE: when installing `dev` group, all other groups should also be installed
 dev = [
     "anyio>=4.4.0",
     "basedpyright>=1.18.0",
+    "pytest-antilru>=2.0.0",
     "pytest-asyncio>=0.24.0",
     "pytest-xdist>=3.6.1",
     "pytest>=8.3.3",
src/faster_whisper_server/config.py
--- src/faster_whisper_server/config.py
+++ src/faster_whisper_server/config.py
@@ -1,7 +1,6 @@
 import enum
-from typing import Self
 
-from pydantic import BaseModel, Field, model_validator
+from pydantic import BaseModel, Field
 from pydantic_settings import BaseSettings, SettingsConfigDict
 
 SAMPLES_PER_SECOND = 16000
@@ -163,6 +162,12 @@
     compute_type: Quantization = Field(default=Quantization.DEFAULT)
     cpu_threads: int = 0
     num_workers: int = 1
+    ttl: int = Field(default=300, ge=-1)
+    """
+    Time in seconds until the model is unloaded if it is not being used.
+    -1: Never unload the model.
+    0: Unload the model immediately after usage.
+    """
 
 
 class Config(BaseSettings):
@@ -198,10 +203,6 @@
     """
     default_response_format: ResponseFormat = ResponseFormat.JSON
     whisper: WhisperConfig = WhisperConfig()
-    max_models: int = 1
-    """
-    Maximum number of models that can be loaded at a time.
-    """
     preload_models: list[str] = Field(
         default_factory=list,
         examples=[
@@ -210,8 +211,8 @@
         ],
     )
     """
-    List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request.
-    """  # noqa: E501
+    List of models to preload on startup. By default, the model is first loaded on first request.
+    """
     max_no_data_seconds: float = 1.0
     """
     Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
@@ -230,11 +231,3 @@
     Controls how many latest seconds of audio are being passed through VAD.
     Should be greater than `max_inactivity_seconds`
     """
-
-    @model_validator(mode="after")
-    def ensure_preloaded_models_is_lte_max_models(self) -> Self:
-        if len(self.preload_models) > self.max_models:
-            raise ValueError(
-                f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})"  # noqa: E501
-            )
-        return self
src/faster_whisper_server/dependencies.py
--- src/faster_whisper_server/dependencies.py
+++ src/faster_whisper_server/dependencies.py
@@ -18,7 +18,7 @@
 @lru_cache
 def get_model_manager() -> ModelManager:
     config = get_config()  # HACK
-    return ModelManager(config)
+    return ModelManager(config.whisper)
 
 
 ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
src/faster_whisper_server/model_manager.py
--- src/faster_whisper_server/model_manager.py
+++ src/faster_whisper_server/model_manager.py
@@ -3,48 +3,132 @@
 from collections import OrderedDict
 import gc
 import logging
+import threading
 import time
 from typing import TYPE_CHECKING
 
 from faster_whisper import WhisperModel
 
 if TYPE_CHECKING:
+    from collections.abc import Callable
+
     from faster_whisper_server.config import (
-        Config,
+        WhisperConfig,
     )
 
 logger = logging.getLogger(__name__)
 
+# TODO: enable concurrent model downloads
+
+
+class SelfDisposingWhisperModel:
+    def __init__(
+        self,
+        model_id: str,
+        whisper_config: WhisperConfig,
+        *,
+        on_unload: Callable[[str], None] | None = None,
+    ) -> None:
+        self.model_id = model_id
+        self.whisper_config = whisper_config
+        self.on_unload = on_unload
+
+        self.ref_count: int = 0
+        self.rlock = threading.RLock()
+        self.expire_timer: threading.Timer | None = None
+        self.whisper: WhisperModel | None = None
+
+    def unload(self) -> None:
+        with self.rlock:
+            if self.whisper is None:
+                raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
+            if self.ref_count > 0:
+                raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
+            if self.expire_timer:
+                self.expire_timer.cancel()
+            self.whisper = None
+            # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
+            gc.collect()
+            logger.info(f"Model {self.model_id} unloaded")
+            if self.on_unload is not None:
+                self.on_unload(self.model_id)
+
+    def _load(self) -> None:
+        with self.rlock:
+            assert self.whisper is None
+            logger.debug(f"Loading model {self.model_id}")
+            start = time.perf_counter()
+            self.whisper = WhisperModel(
+                self.model_id,
+                device=self.whisper_config.inference_device,
+                device_index=self.whisper_config.device_index,
+                compute_type=self.whisper_config.compute_type,
+                cpu_threads=self.whisper_config.cpu_threads,
+                num_workers=self.whisper_config.num_workers,
+            )
+            logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
+
+    def _increment_ref(self) -> None:
+        with self.rlock:
+            self.ref_count += 1
+            if self.expire_timer:
+                logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
+                self.expire_timer.cancel()
+            logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")
+
+    def _decrement_ref(self) -> None:
+        with self.rlock:
+            self.ref_count -= 1
+            logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
+            if self.ref_count <= 0:
+                if self.whisper_config.ttl > 0:
+                    logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
+                    self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
+                    self.expire_timer.start()
+                elif self.whisper_config.ttl == 0:
+                    logger.info(f"Model {self.model_id} is idle, unloading immediately")
+                    self.unload()
+                else:
+                    logger.info(f"Model {self.model_id} is idle, not unloading")
+
+    def __enter__(self) -> WhisperModel:
+        with self.rlock:
+            if self.whisper is None:
+                self._load()
+            self._increment_ref()
+            assert self.whisper is not None
+            return self.whisper
+
+    def __exit__(self, *_args) -> None:  # noqa: ANN002
+        self._decrement_ref()
+
 
 class ModelManager:
-    def __init__(self, config: Config) -> None:
-        self.config = config
-        self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
+    def __init__(self, whisper_config: WhisperConfig) -> None:
+        self.whisper_config = whisper_config
+        self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
+        self._lock = threading.Lock()
 
-    def load_model(self, model_name: str) -> WhisperModel:
-        if model_name in self.loaded_models:
-            logger.debug(f"{model_name} model already loaded")
-            return self.loaded_models[model_name]
-        if len(self.loaded_models) >= self.config.max_models:
-            oldest_model_name = next(iter(self.loaded_models))
-            logger.info(
-                f"Max models ({self.config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
+    def _handle_model_unload(self, model_name: str) -> None:
+        with self._lock:
+            if model_name in self.loaded_models:
+                del self.loaded_models[model_name]
+
+    def unload_model(self, model_name: str) -> None:
+        with self._lock:
+            model = self.loaded_models.get(model_name)
+            if model is None:
+                raise KeyError(f"Model {model_name} not found")
+            self.loaded_models[model_name].unload()
+
+    def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
+        with self._lock:
+            if model_name in self.loaded_models:
+                logger.debug(f"{model_name} model already loaded")
+                return self.loaded_models[model_name]
+            self.loaded_models[model_name] = SelfDisposingWhisperModel(
+                model_name,
+                self.whisper_config,
+                on_unload=self._handle_model_unload,
             )
-            del self.loaded_models[oldest_model_name]
-            gc.collect()
-        logger.debug(f"Loading {model_name}...")
-        start = time.perf_counter()
-        # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
-        whisper = WhisperModel(
-            model_name,
-            device=self.config.whisper.inference_device,
-            device_index=self.config.whisper.device_index,
-            compute_type=self.config.whisper.compute_type,
-            cpu_threads=self.config.whisper.cpu_threads,
-            num_workers=self.config.whisper.num_workers,
-        )
-        logger.info(
-            f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference."  # noqa: E501
-        )
-        self.loaded_models[model_name] = whisper
-        return whisper
+            return self.loaded_models[model_name]
src/faster_whisper_server/routers/misc.py
--- src/faster_whisper_server/routers/misc.py
+++ src/faster_whisper_server/routers/misc.py
@@ -1,7 +1,5 @@
 from __future__ import annotations
 
-import gc
-
 from fastapi import (
     APIRouter,
     Response,
@@ -42,15 +40,19 @@
 def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
     if model_name in model_manager.loaded_models:
         return Response(status_code=409, content="Model already loaded")
-    model_manager.load_model(model_name)
+    with model_manager.load_model(model_name):
+        pass
     return Response(status_code=201)
 
 
 @router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
 def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
-    model = model_manager.loaded_models.get(model_name)
-    if model is not None:
-        del model_manager.loaded_models[model_name]
-        gc.collect()
+    try:
+        model_manager.unload_model(model_name)
         return Response(status_code=204)
-    return Response(status_code=404)
+    except (KeyError, ValueError) as e:
+        match e:
+            case KeyError():
+                return Response(status_code=404, content="Model not found")
+            case ValueError():
+                return Response(status_code=409, content=str(e))
src/faster_whisper_server/routers/stt.py
--- src/faster_whisper_server/routers/stt.py
+++ src/faster_whisper_server/routers/stt.py
@@ -142,20 +142,20 @@
         model = config.whisper.model
     if response_format is None:
         response_format = config.default_response_format
-    whisper = model_manager.load_model(model)
-    segments, transcription_info = whisper.transcribe(
-        file.file,
-        task=Task.TRANSLATE,
-        initial_prompt=prompt,
-        temperature=temperature,
-        vad_filter=vad_filter,
-    )
-    segments = TranscriptionSegment.from_faster_whisper_segments(segments)
+    with model_manager.load_model(model) as whisper:
+        segments, transcription_info = whisper.transcribe(
+            file.file,
+            task=Task.TRANSLATE,
+            initial_prompt=prompt,
+            temperature=temperature,
+            vad_filter=vad_filter,
+        )
+        segments = TranscriptionSegment.from_faster_whisper_segments(segments)
 
-    if stream:
-        return segments_to_streaming_response(segments, transcription_info, response_format)
-    else:
-        return segments_to_response(segments, transcription_info, response_format)
+        if stream:
+            return segments_to_streaming_response(segments, transcription_info, response_format)
+        else:
+            return segments_to_response(segments, transcription_info, response_format)
 
 
 # HACK: Since Form() doesn't support `alias`, we need to use a workaround.
@@ -206,23 +206,23 @@
         logger.warning(
             "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities."  # noqa: E501
         )
-    whisper = model_manager.load_model(model)
-    segments, transcription_info = whisper.transcribe(
-        file.file,
-        task=Task.TRANSCRIBE,
-        language=language,
-        initial_prompt=prompt,
-        word_timestamps="word" in timestamp_granularities,
-        temperature=temperature,
-        vad_filter=vad_filter,
-        hotwords=hotwords,
-    )
-    segments = TranscriptionSegment.from_faster_whisper_segments(segments)
+    with model_manager.load_model(model) as whisper:
+        segments, transcription_info = whisper.transcribe(
+            file.file,
+            task=Task.TRANSCRIBE,
+            language=language,
+            initial_prompt=prompt,
+            word_timestamps="word" in timestamp_granularities,
+            temperature=temperature,
+            vad_filter=vad_filter,
+            hotwords=hotwords,
+        )
+        segments = TranscriptionSegment.from_faster_whisper_segments(segments)
 
-    if stream:
-        return segments_to_streaming_response(segments, transcription_info, response_format)
-    else:
-        return segments_to_response(segments, transcription_info, response_format)
+        if stream:
+            return segments_to_streaming_response(segments, transcription_info, response_format)
+        else:
+            return segments_to_response(segments, transcription_info, response_format)
 
 
 async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
@@ -280,24 +280,24 @@
         "vad_filter": vad_filter,
         "condition_on_previous_text": False,
     }
-    whisper = model_manager.load_model(model)
-    asr = FasterWhisperASR(whisper, **transcribe_opts)
-    audio_stream = AudioStream()
-    async with asyncio.TaskGroup() as tg:
-        tg.create_task(audio_receiver(ws, audio_stream))
-        async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
-            logger.debug(f"Sending transcription: {transcription.text}")
-            if ws.client_state == WebSocketState.DISCONNECTED:
-                break
+    with model_manager.load_model(model) as whisper:
+        asr = FasterWhisperASR(whisper, **transcribe_opts)
+        audio_stream = AudioStream()
+        async with asyncio.TaskGroup() as tg:
+            tg.create_task(audio_receiver(ws, audio_stream))
+            async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
+                logger.debug(f"Sending transcription: {transcription.text}")
+                if ws.client_state == WebSocketState.DISCONNECTED:
+                    break
 
-            if response_format == ResponseFormat.TEXT:
-                await ws.send_text(transcription.text)
-            elif response_format == ResponseFormat.JSON:
-                await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
-            elif response_format == ResponseFormat.VERBOSE_JSON:
-                await ws.send_json(
-                    CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
-                )
+                if response_format == ResponseFormat.TEXT:
+                    await ws.send_text(transcription.text)
+                elif response_format == ResponseFormat.JSON:
+                    await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
+                elif response_format == ResponseFormat.VERBOSE_JSON:
+                    await ws.send_json(
+                        CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
+                    )
 
     if ws.client_state != WebSocketState.DISCONNECTED:
         logger.info("Closing the connection.")
 
tests/model_manager_test.py (added)
+++ tests/model_manager_test.py
@@ -0,0 +1,122 @@
+import asyncio
+import os
+
+import anyio
+from httpx import ASGITransport, AsyncClient
+import pytest
+
+from faster_whisper_server.main import create_app
+
+
+@pytest.mark.asyncio
+async def test_model_unloaded_after_ttl() -> None:
+    ttl = 5
+    model = "Systran/faster-whisper-tiny.en"
+    os.environ["WHISPER__TTL"] = str(ttl)
+    os.environ["ENABLE_UI"] = "false"
+    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 0
+        await aclient.post(f"/api/ps/{model}")
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 1
+        await asyncio.sleep(ttl + 1)
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 0
+
+
+@pytest.mark.asyncio
+async def test_ttl_resets_after_usage() -> None:
+    ttl = 5
+    model = "Systran/faster-whisper-tiny.en"
+    os.environ["WHISPER__TTL"] = str(ttl)
+    os.environ["ENABLE_UI"] = "false"
+    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
+        await aclient.post(f"/api/ps/{model}")
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 1
+        await asyncio.sleep(ttl - 2)
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 1
+
+        async with await anyio.open_file("audio.wav", "rb") as f:
+            data = await f.read()
+        res = (
+            await aclient.post(
+                "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
+            )
+        ).json()
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 1
+        await asyncio.sleep(ttl - 2)
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 1
+
+        await asyncio.sleep(3)
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 0
+
+        # test the model can be used again after being unloaded
+        # this just ensures the model can be loaded again after being unloaded
+        res = (
+            await aclient.post(
+                "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
+            )
+        ).json()
+
+
+@pytest.mark.asyncio
+async def test_model_cant_be_unloaded_when_used() -> None:
+    ttl = 0
+    model = "Systran/faster-whisper-tiny.en"
+    os.environ["WHISPER__TTL"] = str(ttl)
+    os.environ["ENABLE_UI"] = "false"
+    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
+        async with await anyio.open_file("audio.wav", "rb") as f:
+            data = await f.read()
+
+        task = asyncio.create_task(
+            aclient.post(
+                "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
+            )
+        )
+        await asyncio.sleep(0.01)
+        res = await aclient.delete(f"/api/ps/{model}")
+        assert res.status_code == 409
+
+        await task
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 0
+
+
+@pytest.mark.asyncio
+async def test_model_cant_be_loaded_twice() -> None:
+    ttl = -1
+    model = "Systran/faster-whisper-tiny.en"
+    os.environ["ENABLE_UI"] = "false"
+    os.environ["WHISPER__TTL"] = str(ttl)
+    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
+        res = await aclient.post(f"/api/ps/{model}")
+        assert res.status_code == 201
+        res = await aclient.post(f"/api/ps/{model}")
+        assert res.status_code == 409
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 1
+
+
+@pytest.mark.asyncio
+async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None:
+    ttl = 0
+    os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
+    os.environ["WHISPER__TTL"] = str(ttl)
+    os.environ["ENABLE_UI"] = "false"
+    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
+        async with await anyio.open_file("audio.wav", "rb") as f:
+            data = await f.read()
+        res = await aclient.post(
+            "/v1/audio/transcriptions",
+            files={"file": ("audio.wav", data, "audio/wav")},
+            data={"model": "Systran/faster-whisper-tiny.en"},
+        )
+        res = (await aclient.get("/api/ps")).json()
+        assert len(res["models"]) == 0
uv.lock
--- uv.lock
+++ uv.lock
@@ -293,6 +293,7 @@
     { name = "anyio" },
     { name = "basedpyright" },
     { name = "pytest" },
+    { name = "pytest-antilru" },
     { name = "pytest-asyncio" },
     { name = "pytest-xdist" },
     { name = "ruff" },
@@ -322,6 +323,7 @@
     { name = "pydantic", specifier = ">=2.9.0" },
     { name = "pydantic-settings", specifier = ">=2.5.2" },
     { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" },
+    { name = "pytest-antilru", marker = "extra == 'dev'", specifier = ">=2.0.0" },
     { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
     { name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.6.1" },
     { name = "python-multipart", specifier = ">=0.0.10" },
@@ -3483,6 +3485,18 @@
 ]
 
 [[package]]
+name = "pytest-antilru"
+version = "2.0.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+    { name = "pytest" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/c6/01/0b5ef3f143f335b5cb1c1e8e6497769dfb48aed5a791b5dfd119151e2b15/pytest_antilru-2.0.0.tar.gz", hash = "sha256:48cff342648b6a1ce4e5398cf203966905d546b3f2bee7bb55d7cb3ec87a85fb", size = 5569 }
+wheels = [
+    { url = "https://files.pythonhosted.org/packages/23/f0/fc9f5aaaf2818a7d7f795e99fcf59719dd6ec5f98005e642e1efd63ad2a4/pytest_antilru-2.0.0-py3-none-any.whl", hash = "sha256:cf1d97db0e7b17ef568c1f0bf4c89b8748053fe07546f4eb2558bebf64c1ad33", size = 6301 },
+]
+
+[[package]]
 name = "pytest-asyncio"
 version = "0.24.0"
 source = { registry = "https://pypi.org/simple" }
Add a comment
List