Fedir Zadniprovskyi 01-03
tests: proper `get_config` dependency override
@32a4072fc789fa66c3601e1074bf6c1f9ad29b5f
src/faster_whisper_server/dependencies.py
--- src/faster_whisper_server/dependencies.py
+++ src/faster_whisper_server/dependencies.py
@@ -1,4 +1,5 @@
 from functools import lru_cache
+import logging
 from typing import Annotated
 
 from fastapi import Depends, HTTPException, status
@@ -11,7 +12,13 @@
 from faster_whisper_server.config import Config
 from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
 
+logger = logging.getLogger(__name__)
 
+# NOTE: `get_config` is called directly instead of using sub-dependencies so that these functions could be used outside of `FastAPI`  # noqa: E501
+
+
+# https://fastapi.tiangolo.com/advanced/settings/?h=setti#creating-the-settings-only-once-with-lru_cache
+# WARN: Any new module that ends up calling this function directly (not through `FastAPI` dependency injection) should be patched in `tests/conftest.py`  # noqa: E501
 @lru_cache
 def get_config() -> Config:
     return Config()
@@ -22,7 +29,7 @@
 
 @lru_cache
 def get_model_manager() -> WhisperModelManager:
-    config = get_config()  # HACK
+    config = get_config()
     return WhisperModelManager(config.whisper)
 
 
@@ -31,8 +38,8 @@
 
 @lru_cache
 def get_piper_model_manager() -> PiperModelManager:
-    config = get_config()  # HACK
-    return PiperModelManager(config.whisper.ttl)  # HACK
+    config = get_config()
+    return PiperModelManager(config.whisper.ttl)  # HACK: should have its own config
 
 
 PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
@@ -53,7 +60,7 @@
 
 @lru_cache
 def get_completion_client() -> AsyncCompletions:
-    config = get_config()  # HACK
+    config = get_config()
     oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
     return oai_client.chat.completions
 
@@ -63,9 +70,9 @@
 
 @lru_cache
 def get_speech_client() -> AsyncSpeech:
-    config = get_config()  # HACK
+    config = get_config()
     if config.speech_base_url is None:
-        # this might not work as expected if the `speech_router` won't have shared state with the main FastAPI `app`. TODO: verify  # noqa: E501
+        # this might not work as expected if `speech_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify  # noqa: E501
         from faster_whisper_server.routers.speech import (
             router as speech_router,
         )
@@ -86,7 +93,7 @@
 def get_transcription_client() -> AsyncTranscriptions:
     config = get_config()
     if config.transcription_base_url is None:
-        # this might not work as expected if the `transcription_router` won't have shared state with the main FastAPI `app`. TODO: verify  # noqa: E501
+        # this might not work as expected if `transcription_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify  # noqa: E501
         from faster_whisper_server.routers.stt import (
             router as stt_router,
         )
src/faster_whisper_server/logger.py
--- src/faster_whisper_server/logger.py
+++ src/faster_whisper_server/logger.py
@@ -1,11 +1,8 @@
 import logging
 
-from faster_whisper_server.dependencies import get_config
 
-
-def setup_logger() -> None:
-    config = get_config()  # HACK
+def setup_logger(log_level: str) -> None:
     logging.getLogger().setLevel(logging.INFO)
     logger = logging.getLogger(__name__)
-    logger.setLevel(config.log_level.upper())
+    logger.setLevel(log_level.upper())
     logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s")
src/faster_whisper_server/main.py
--- src/faster_whisper_server/main.py
+++ src/faster_whisper_server/main.py
@@ -27,9 +27,11 @@
 
 
 def create_app() -> FastAPI:
-    setup_logger()
-
+    config = get_config()  # HACK
+    setup_logger(config.log_level)
     logger = logging.getLogger(__name__)
+
+    logger.debug(f"Config: {config}")
 
     if platform.machine() == "x86_64":
         from faster_whisper_server.routers.speech import (
@@ -38,9 +40,6 @@
     else:
         logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
         speech_router = None
-
-    config = get_config()  # HACK
-    logger.debug(f"Config: {config}")
 
     model_manager = get_model_manager()  # HACK
 
tests/conftest.py
--- tests/conftest.py
+++ tests/conftest.py
@@ -1,6 +1,8 @@
 from collections.abc import AsyncGenerator, Generator
+from contextlib import AbstractAsyncContextManager, asynccontextmanager
 import logging
 import os
+from typing import Protocol
 
 from fastapi.testclient import TestClient
 from httpx import ASGITransport, AsyncClient
@@ -8,19 +10,31 @@
 from openai import AsyncOpenAI
 import pytest
 import pytest_asyncio
+from pytest_mock import MockerFixture
 
+from faster_whisper_server.config import Config, WhisperConfig
+from faster_whisper_server.dependencies import get_config
 from faster_whisper_server.main import create_app
 
-disable_loggers = ["multipart.multipart", "faster_whisper"]
+DISABLE_LOGGERS = ["multipart.multipart", "faster_whisper"]
+OPENAI_BASE_URL = "https://api.openai.com/v1"
+DEFAULT_WHISPER_MODEL = "Systran/faster-whisper-tiny.en"
+# TODO: figure out a way to initialize the config without parsing environment variables, as those may interfere with the tests  # noqa: E501
+DEFAULT_WHISPER_CONFIG = WhisperConfig(model=DEFAULT_WHISPER_MODEL, ttl=0)
+DEFAULT_CONFIG = Config(
+    whisper=DEFAULT_WHISPER_CONFIG,
+    # disable the UI as it slightly increases the app startup time due to the imports it's doing
+    enable_ui=False,
+)
 
 
 def pytest_configure() -> None:
-    for logger_name in disable_loggers:
+    for logger_name in DISABLE_LOGGERS:
         logger = logging.getLogger(logger_name)
         logger.disabled = True
 
 
-# NOTE: not being used. Keeping just in case
+# NOTE: not being used. Keeping just in case. Needs to be modified to work similarly to `aclient_factory`
 @pytest.fixture
 def client() -> Generator[TestClient, None, None]:
     os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
@@ -28,10 +42,37 @@
         yield client
 
 
+# https://stackoverflow.com/questions/74890214/type-hint-callback-function-with-optional-parameters-aka-callable-with-optional
+class AclientFactory(Protocol):
+    def __call__(self, config: Config = DEFAULT_CONFIG) -> AbstractAsyncContextManager[AsyncClient]: ...
+
+
 @pytest_asyncio.fixture()
-async def aclient() -> AsyncGenerator[AsyncClient, None]:
-    os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
-    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
+async def aclient_factory(mocker: MockerFixture) -> AclientFactory:
+    """Returns a context manager that provides an `AsyncClient` instance with `app` using the provided configuration."""
+
+    @asynccontextmanager
+    async def inner(config: Config = DEFAULT_CONFIG) -> AsyncGenerator[AsyncClient, None]:
+        # NOTE: all calls to `get_config` should be patched. One way to test that this works is to update the original `get_config` to raise an exception and see if the tests fail  # noqa: E501
+        mocker.patch("faster_whisper_server.dependencies.get_config", return_value=config)
+        mocker.patch("faster_whisper_server.main.get_config", return_value=config)
+        # NOTE: I couldn't get the following to work but it shouldn't matter
+        # mocker.patch(
+        #     "faster_whisper_server.text_utils.Transcription._ensure_no_word_overlap.get_config", return_value=config
+        # )
+
+        app = create_app()
+        # https://fastapi.tiangolo.com/advanced/testing-dependencies/
+        app.dependency_overrides[get_config] = lambda: config
+        async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient:
+            yield aclient
+
+    return inner
+
+
+@pytest_asyncio.fixture()
+async def aclient(aclient_factory: AclientFactory) -> AsyncGenerator[AsyncClient, None]:
+    async with aclient_factory() as aclient:
         yield aclient
 
 
@@ -43,11 +84,13 @@
 @pytest.fixture
 def actual_openai_client() -> AsyncOpenAI:
     return AsyncOpenAI(
-        base_url="https://api.openai.com/v1"
-    )  # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
+        # `base_url` is provided in case `OPENAI_BASE_URL` is set to a different value
+        base_url=OPENAI_BASE_URL
+    )
 
 
 # TODO: remove the download after running the tests
+# TODO: do not download when not needed
 @pytest.fixture(scope="session", autouse=True)
 def download_piper_voices() -> None:
     # Only download `voices.json` and the default voice
tests/model_manager_test.py
--- tests/model_manager_test.py
+++ tests/model_manager_test.py
@@ -1,23 +1,22 @@
 import asyncio
-import os
 
 import anyio
-from httpx import ASGITransport, AsyncClient
 import pytest
 
-from faster_whisper_server.main import create_app
+from faster_whisper_server.config import Config, WhisperConfig
+from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory
+
+MODEL = DEFAULT_WHISPER_MODEL  # just to make the test more readable
 
 
 @pytest.mark.asyncio
-async def test_model_unloaded_after_ttl() -> None:
+async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None:
     ttl = 5
-    model = "Systran/faster-whisper-tiny.en"
-    os.environ["WHISPER__TTL"] = str(ttl)
-    os.environ["ENABLE_UI"] = "false"
-    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
+    config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
+    async with aclient_factory(config) as aclient:
         res = (await aclient.get("/api/ps")).json()
         assert len(res["models"]) == 0
-        await aclient.post(f"/api/ps/{model}")
+        await aclient.post(f"/api/ps/{MODEL}")
         res = (await aclient.get("/api/ps")).json()
         assert len(res["models"]) == 1
         await asyncio.sleep(ttl + 1)  # wait for the model to be unloaded
@@ -26,13 +25,11 @@
 
 
 @pytest.mark.asyncio
-async def test_ttl_resets_after_usage() -> None:
+async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None:
     ttl = 5
-    model = "Systran/faster-whisper-tiny.en"
-    os.environ["WHISPER__TTL"] = str(ttl)
-    os.environ["ENABLE_UI"] = "false"
-    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
-        await aclient.post(f"/api/ps/{model}")
+    config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
+    async with aclient_factory(config) as aclient:
+        await aclient.post(f"/api/ps/{MODEL}")
         res = (await aclient.get("/api/ps")).json()
         assert len(res["models"]) == 1
         await asyncio.sleep(ttl - 2)  # sleep for less than the ttl. The model should not be unloaded
@@ -43,7 +40,9 @@
             data = await f.read()
         res = (
             await aclient.post(
-                "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
+                "/v1/audio/transcriptions",
+                files={"file": ("audio.wav", data, "audio/wav")},
+                data={"model": MODEL},
             )
         ).json()
         res = (await aclient.get("/api/ps")).json()
@@ -60,28 +59,28 @@
         # this just ensures the model can be loaded again after being unloaded
         res = (
             await aclient.post(
-                "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
+                "/v1/audio/transcriptions",
+                files={"file": ("audio.wav", data, "audio/wav")},
+                data={"model": MODEL},
             )
         ).json()
 
 
 @pytest.mark.asyncio
-async def test_model_cant_be_unloaded_when_used() -> None:
+async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None:
     ttl = 0
-    model = "Systran/faster-whisper-tiny.en"
-    os.environ["WHISPER__TTL"] = str(ttl)
-    os.environ["ENABLE_UI"] = "false"
-    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
+    config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
+    async with aclient_factory(config) as aclient:
         async with await anyio.open_file("audio.wav", "rb") as f:
             data = await f.read()
 
         task = asyncio.create_task(
             aclient.post(
-                "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
+                "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": MODEL}
             )
         )
         await asyncio.sleep(0.1)  # wait for the server to start processing the request
-        res = await aclient.delete(f"/api/ps/{model}")
+        res = await aclient.delete(f"/api/ps/{MODEL}")
         assert res.status_code == 409
 
         await task
@@ -90,27 +89,23 @@
 
 
 @pytest.mark.asyncio
-async def test_model_cant_be_loaded_twice() -> None:
+async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None:
     ttl = -1
-    model = "Systran/faster-whisper-tiny.en"
-    os.environ["ENABLE_UI"] = "false"
-    os.environ["WHISPER__TTL"] = str(ttl)
-    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
-        res = await aclient.post(f"/api/ps/{model}")
+    config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
+    async with aclient_factory(config) as aclient:
+        res = await aclient.post(f"/api/ps/{MODEL}")
         assert res.status_code == 201
-        res = await aclient.post(f"/api/ps/{model}")
+        res = await aclient.post(f"/api/ps/{MODEL}")
         assert res.status_code == 409
         res = (await aclient.get("/api/ps")).json()
         assert len(res["models"]) == 1
 
 
 @pytest.mark.asyncio
-async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None:
+async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None:
     ttl = 0
-    os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
-    os.environ["WHISPER__TTL"] = str(ttl)
-    os.environ["ENABLE_UI"] = "false"
-    async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
+    config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False)
+    async with aclient_factory(config) as aclient:
         async with await anyio.open_file("audio.wav", "rb") as f:
             data = await f.read()
         res = await aclient.post(
Add a comment
List