

tests: proper `get_config` dependency override
@32a4072fc789fa66c3601e1074bf6c1f9ad29b5f
--- src/faster_whisper_server/dependencies.py
+++ src/faster_whisper_server/dependencies.py
... | ... | @@ -1,4 +1,5 @@ |
1 | 1 |
from functools import lru_cache |
2 |
+import logging |
|
2 | 3 |
from typing import Annotated |
3 | 4 |
|
4 | 5 |
from fastapi import Depends, HTTPException, status |
... | ... | @@ -11,7 +12,13 @@ |
11 | 12 |
from faster_whisper_server.config import Config |
12 | 13 |
from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager |
13 | 14 |
|
15 |
+logger = logging.getLogger(__name__) |
|
14 | 16 |
|
17 |
+# NOTE: `get_config` is called directly instead of using sub-dependencies so that these functions could be used outside of `FastAPI` # noqa: E501 |
|
18 |
+ |
|
19 |
+ |
|
20 |
+# https://fastapi.tiangolo.com/advanced/settings/?h=setti#creating-the-settings-only-once-with-lru_cache |
|
21 |
+# WARN: Any new module that ends up calling this function directly (not through `FastAPI` dependency injection) should be patched in `tests/conftest.py` # noqa: E501 |
|
15 | 22 |
@lru_cache |
16 | 23 |
def get_config() -> Config: |
17 | 24 |
return Config() |
... | ... | @@ -22,7 +29,7 @@ |
22 | 29 |
|
23 | 30 |
@lru_cache |
24 | 31 |
def get_model_manager() -> WhisperModelManager: |
25 |
- config = get_config() # HACK |
|
32 |
+ config = get_config() |
|
26 | 33 |
return WhisperModelManager(config.whisper) |
27 | 34 |
|
28 | 35 |
|
... | ... | @@ -31,8 +38,8 @@ |
31 | 38 |
|
32 | 39 |
@lru_cache |
33 | 40 |
def get_piper_model_manager() -> PiperModelManager: |
34 |
- config = get_config() # HACK |
|
35 |
- return PiperModelManager(config.whisper.ttl) # HACK |
|
41 |
+ config = get_config() |
|
42 |
+ return PiperModelManager(config.whisper.ttl) # HACK: should have its own config |
|
36 | 43 |
|
37 | 44 |
|
38 | 45 |
PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)] |
... | ... | @@ -53,7 +60,7 @@ |
53 | 60 |
|
54 | 61 |
@lru_cache |
55 | 62 |
def get_completion_client() -> AsyncCompletions: |
56 |
- config = get_config() # HACK |
|
63 |
+ config = get_config() |
|
57 | 64 |
oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key) |
58 | 65 |
return oai_client.chat.completions |
59 | 66 |
|
... | ... | @@ -63,9 +70,9 @@ |
63 | 70 |
|
64 | 71 |
@lru_cache |
65 | 72 |
def get_speech_client() -> AsyncSpeech: |
66 |
- config = get_config() # HACK |
|
73 |
+ config = get_config() |
|
67 | 74 |
if config.speech_base_url is None: |
68 |
- # this might not work as expected if the `speech_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501 |
|
75 |
+ # this might not work as expected if `speech_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501 |
|
69 | 76 |
from faster_whisper_server.routers.speech import ( |
70 | 77 |
router as speech_router, |
71 | 78 |
) |
... | ... | @@ -86,7 +93,7 @@ |
86 | 93 |
def get_transcription_client() -> AsyncTranscriptions: |
87 | 94 |
config = get_config() |
88 | 95 |
if config.transcription_base_url is None: |
89 |
- # this might not work as expected if the `transcription_router` won't have shared state with the main FastAPI `app`. TODO: verify # noqa: E501 |
|
96 |
+ # this might not work as expected if `transcription_router` won't have shared state (access to the same `model_manager`) with the main FastAPI `app`. TODO: verify # noqa: E501 |
|
90 | 97 |
from faster_whisper_server.routers.stt import ( |
91 | 98 |
router as stt_router, |
92 | 99 |
) |
--- src/faster_whisper_server/logger.py
+++ src/faster_whisper_server/logger.py
... | ... | @@ -1,11 +1,8 @@ |
1 | 1 |
import logging |
2 | 2 |
|
3 |
-from faster_whisper_server.dependencies import get_config |
|
4 | 3 |
|
5 |
- |
|
6 |
-def setup_logger() -> None: |
|
7 |
- config = get_config() # HACK |
|
4 |
+def setup_logger(log_level: str) -> None: |
|
8 | 5 |
logging.getLogger().setLevel(logging.INFO) |
9 | 6 |
logger = logging.getLogger(__name__) |
10 |
- logger.setLevel(config.log_level.upper()) |
|
7 |
+ logger.setLevel(log_level.upper()) |
|
11 | 8 |
logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(lineno)d:%(message)s") |
--- src/faster_whisper_server/main.py
+++ src/faster_whisper_server/main.py
... | ... | @@ -27,9 +27,11 @@ |
27 | 27 |
|
28 | 28 |
|
29 | 29 |
def create_app() -> FastAPI: |
30 |
- setup_logger() |
|
31 |
- |
|
30 |
+ config = get_config() # HACK |
|
31 |
+ setup_logger(config.log_level) |
|
32 | 32 |
logger = logging.getLogger(__name__) |
33 |
+ |
|
34 |
+ logger.debug(f"Config: {config}") |
|
33 | 35 |
|
34 | 36 |
if platform.machine() == "x86_64": |
35 | 37 |
from faster_whisper_server.routers.speech import ( |
... | ... | @@ -38,9 +40,6 @@ |
38 | 40 |
else: |
39 | 41 |
logger.warning("`/v1/audio/speech` is only supported on x86_64 machines") |
40 | 42 |
speech_router = None |
41 |
- |
|
42 |
- config = get_config() # HACK |
|
43 |
- logger.debug(f"Config: {config}") |
|
44 | 43 |
|
45 | 44 |
model_manager = get_model_manager() # HACK |
46 | 45 |
|
--- tests/conftest.py
+++ tests/conftest.py
... | ... | @@ -1,6 +1,8 @@ |
1 | 1 |
from collections.abc import AsyncGenerator, Generator |
2 |
+from contextlib import AbstractAsyncContextManager, asynccontextmanager |
|
2 | 3 |
import logging |
3 | 4 |
import os |
5 |
+from typing import Protocol |
|
4 | 6 |
|
5 | 7 |
from fastapi.testclient import TestClient |
6 | 8 |
from httpx import ASGITransport, AsyncClient |
... | ... | @@ -8,19 +10,31 @@ |
8 | 10 |
from openai import AsyncOpenAI |
9 | 11 |
import pytest |
10 | 12 |
import pytest_asyncio |
13 |
+from pytest_mock import MockerFixture |
|
11 | 14 |
|
15 |
+from faster_whisper_server.config import Config, WhisperConfig |
|
16 |
+from faster_whisper_server.dependencies import get_config |
|
12 | 17 |
from faster_whisper_server.main import create_app |
13 | 18 |
|
14 |
-disable_loggers = ["multipart.multipart", "faster_whisper"] |
|
19 |
+DISABLE_LOGGERS = ["multipart.multipart", "faster_whisper"] |
|
20 |
+OPENAI_BASE_URL = "https://api.openai.com/v1" |
|
21 |
+DEFAULT_WHISPER_MODEL = "Systran/faster-whisper-tiny.en" |
|
22 |
+# TODO: figure out a way to initialize the config without parsing environment variables, as those may interfere with the tests # noqa: E501 |
|
23 |
+DEFAULT_WHISPER_CONFIG = WhisperConfig(model=DEFAULT_WHISPER_MODEL, ttl=0) |
|
24 |
+DEFAULT_CONFIG = Config( |
|
25 |
+ whisper=DEFAULT_WHISPER_CONFIG, |
|
26 |
+ # disable the UI as it slightly increases the app startup time due to the imports it's doing |
|
27 |
+ enable_ui=False, |
|
28 |
+) |
|
15 | 29 |
|
16 | 30 |
|
17 | 31 |
def pytest_configure() -> None: |
18 |
- for logger_name in disable_loggers: |
|
32 |
+ for logger_name in DISABLE_LOGGERS: |
|
19 | 33 |
logger = logging.getLogger(logger_name) |
20 | 34 |
logger.disabled = True |
21 | 35 |
|
22 | 36 |
|
23 |
-# NOTE: not being used. Keeping just in case |
|
37 |
+# NOTE: not being used. Keeping just in case. Needs to be modified to work similarly to `aclient_factory` |
|
24 | 38 |
@pytest.fixture |
25 | 39 |
def client() -> Generator[TestClient, None, None]: |
26 | 40 |
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" |
... | ... | @@ -28,10 +42,37 @@ |
28 | 42 |
yield client |
29 | 43 |
|
30 | 44 |
|
45 |
+# https://stackoverflow.com/questions/74890214/type-hint-callback-function-with-optional-parameters-aka-callable-with-optional |
|
46 |
+class AclientFactory(Protocol): |
|
47 |
+ def __call__(self, config: Config = DEFAULT_CONFIG) -> AbstractAsyncContextManager[AsyncClient]: ... |
|
48 |
+ |
|
49 |
+ |
|
31 | 50 |
@pytest_asyncio.fixture() |
32 |
-async def aclient() -> AsyncGenerator[AsyncClient, None]: |
|
33 |
- os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" |
|
34 |
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: |
|
51 |
+async def aclient_factory(mocker: MockerFixture) -> AclientFactory: |
|
52 |
+ """Returns a context manager that provides an `AsyncClient` instance with `app` using the provided configuration.""" |
|
53 |
+ |
|
54 |
+ @asynccontextmanager |
|
55 |
+ async def inner(config: Config = DEFAULT_CONFIG) -> AsyncGenerator[AsyncClient, None]: |
|
56 |
+ # NOTE: all calls to `get_config` should be patched. One way to test that this works is to update the original `get_config` to raise an exception and see if the tests fail # noqa: E501 |
|
57 |
+ mocker.patch("faster_whisper_server.dependencies.get_config", return_value=config) |
|
58 |
+ mocker.patch("faster_whisper_server.main.get_config", return_value=config) |
|
59 |
+ # NOTE: I couldn't get the following to work but it shouldn't matter |
|
60 |
+ # mocker.patch( |
|
61 |
+ # "faster_whisper_server.text_utils.Transcription._ensure_no_word_overlap.get_config", return_value=config |
|
62 |
+ # ) |
|
63 |
+ |
|
64 |
+ app = create_app() |
|
65 |
+ # https://fastapi.tiangolo.com/advanced/testing-dependencies/ |
|
66 |
+ app.dependency_overrides[get_config] = lambda: config |
|
67 |
+ async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient: |
|
68 |
+ yield aclient |
|
69 |
+ |
|
70 |
+ return inner |
|
71 |
+ |
|
72 |
+ |
|
73 |
+@pytest_asyncio.fixture() |
|
74 |
+async def aclient(aclient_factory: AclientFactory) -> AsyncGenerator[AsyncClient, None]: |
|
75 |
+ async with aclient_factory() as aclient: |
|
35 | 76 |
yield aclient |
36 | 77 |
|
37 | 78 |
|
... | ... | @@ -43,11 +84,13 @@ |
43 | 84 |
@pytest.fixture |
44 | 85 |
def actual_openai_client() -> AsyncOpenAI: |
45 | 86 |
return AsyncOpenAI( |
46 |
- base_url="https://api.openai.com/v1" |
|
47 |
- ) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value |
|
87 |
+ # `base_url` is provided in case `OPENAI_BASE_URL` is set to a different value |
|
88 |
+ base_url=OPENAI_BASE_URL |
|
89 |
+ ) |
|
48 | 90 |
|
49 | 91 |
|
50 | 92 |
# TODO: remove the download after running the tests |
93 |
+# TODO: do not download when not needed |
|
51 | 94 |
@pytest.fixture(scope="session", autouse=True) |
52 | 95 |
def download_piper_voices() -> None: |
53 | 96 |
# Only download `voices.json` and the default voice |
--- tests/model_manager_test.py
+++ tests/model_manager_test.py
... | ... | @@ -1,23 +1,22 @@ |
1 | 1 |
import asyncio |
2 |
-import os |
|
3 | 2 |
|
4 | 3 |
import anyio |
5 |
-from httpx import ASGITransport, AsyncClient |
|
6 | 4 |
import pytest |
7 | 5 |
|
8 |
-from faster_whisper_server.main import create_app |
|
6 |
+from faster_whisper_server.config import Config, WhisperConfig |
|
7 |
+from tests.conftest import DEFAULT_WHISPER_MODEL, AclientFactory |
|
8 |
+ |
|
9 |
+MODEL = DEFAULT_WHISPER_MODEL # just to make the test more readable |
|
9 | 10 |
|
10 | 11 |
|
11 | 12 |
@pytest.mark.asyncio |
12 |
-async def test_model_unloaded_after_ttl() -> None: |
|
13 |
+async def test_model_unloaded_after_ttl(aclient_factory: AclientFactory) -> None: |
|
13 | 14 |
ttl = 5 |
14 |
- model = "Systran/faster-whisper-tiny.en" |
|
15 |
- os.environ["WHISPER__TTL"] = str(ttl) |
|
16 |
- os.environ["ENABLE_UI"] = "false" |
|
17 |
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: |
|
15 |
+ config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) |
|
16 |
+ async with aclient_factory(config) as aclient: |
|
18 | 17 |
res = (await aclient.get("/api/ps")).json() |
19 | 18 |
assert len(res["models"]) == 0 |
20 |
- await aclient.post(f"/api/ps/{model}") |
|
19 |
+ await aclient.post(f"/api/ps/{MODEL}") |
|
21 | 20 |
res = (await aclient.get("/api/ps")).json() |
22 | 21 |
assert len(res["models"]) == 1 |
23 | 22 |
await asyncio.sleep(ttl + 1) # wait for the model to be unloaded |
... | ... | @@ -26,13 +25,11 @@ |
26 | 25 |
|
27 | 26 |
|
28 | 27 |
@pytest.mark.asyncio |
29 |
-async def test_ttl_resets_after_usage() -> None: |
|
28 |
+async def test_ttl_resets_after_usage(aclient_factory: AclientFactory) -> None: |
|
30 | 29 |
ttl = 5 |
31 |
- model = "Systran/faster-whisper-tiny.en" |
|
32 |
- os.environ["WHISPER__TTL"] = str(ttl) |
|
33 |
- os.environ["ENABLE_UI"] = "false" |
|
34 |
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: |
|
35 |
- await aclient.post(f"/api/ps/{model}") |
|
30 |
+ config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) |
|
31 |
+ async with aclient_factory(config) as aclient: |
|
32 |
+ await aclient.post(f"/api/ps/{MODEL}") |
|
36 | 33 |
res = (await aclient.get("/api/ps")).json() |
37 | 34 |
assert len(res["models"]) == 1 |
38 | 35 |
await asyncio.sleep(ttl - 2) # sleep for less than the ttl. The model should not be unloaded |
... | ... | @@ -43,7 +40,9 @@ |
43 | 40 |
data = await f.read() |
44 | 41 |
res = ( |
45 | 42 |
await aclient.post( |
46 |
- "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} |
|
43 |
+ "/v1/audio/transcriptions", |
|
44 |
+ files={"file": ("audio.wav", data, "audio/wav")}, |
|
45 |
+ data={"model": MODEL}, |
|
47 | 46 |
) |
48 | 47 |
).json() |
49 | 48 |
res = (await aclient.get("/api/ps")).json() |
... | ... | @@ -60,28 +59,28 @@ |
60 | 59 |
# this just ensures the model can be loaded again after being unloaded |
61 | 60 |
res = ( |
62 | 61 |
await aclient.post( |
63 |
- "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} |
|
62 |
+ "/v1/audio/transcriptions", |
|
63 |
+ files={"file": ("audio.wav", data, "audio/wav")}, |
|
64 |
+ data={"model": MODEL}, |
|
64 | 65 |
) |
65 | 66 |
).json() |
66 | 67 |
|
67 | 68 |
|
68 | 69 |
@pytest.mark.asyncio |
69 |
-async def test_model_cant_be_unloaded_when_used() -> None: |
|
70 |
+async def test_model_cant_be_unloaded_when_used(aclient_factory: AclientFactory) -> None: |
|
70 | 71 |
ttl = 0 |
71 |
- model = "Systran/faster-whisper-tiny.en" |
|
72 |
- os.environ["WHISPER__TTL"] = str(ttl) |
|
73 |
- os.environ["ENABLE_UI"] = "false" |
|
74 |
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: |
|
72 |
+ config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) |
|
73 |
+ async with aclient_factory(config) as aclient: |
|
75 | 74 |
async with await anyio.open_file("audio.wav", "rb") as f: |
76 | 75 |
data = await f.read() |
77 | 76 |
|
78 | 77 |
task = asyncio.create_task( |
79 | 78 |
aclient.post( |
80 |
- "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} |
|
79 |
+ "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": MODEL} |
|
81 | 80 |
) |
82 | 81 |
) |
83 | 82 |
await asyncio.sleep(0.1) # wait for the server to start processing the request |
84 |
- res = await aclient.delete(f"/api/ps/{model}") |
|
83 |
+ res = await aclient.delete(f"/api/ps/{MODEL}") |
|
85 | 84 |
assert res.status_code == 409 |
86 | 85 |
|
87 | 86 |
await task |
... | ... | @@ -90,27 +89,23 @@ |
90 | 89 |
|
91 | 90 |
|
92 | 91 |
@pytest.mark.asyncio |
93 |
-async def test_model_cant_be_loaded_twice() -> None: |
|
92 |
+async def test_model_cant_be_loaded_twice(aclient_factory: AclientFactory) -> None: |
|
94 | 93 |
ttl = -1 |
95 |
- model = "Systran/faster-whisper-tiny.en" |
|
96 |
- os.environ["ENABLE_UI"] = "false" |
|
97 |
- os.environ["WHISPER__TTL"] = str(ttl) |
|
98 |
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: |
|
99 |
- res = await aclient.post(f"/api/ps/{model}") |
|
94 |
+ config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) |
|
95 |
+ async with aclient_factory(config) as aclient: |
|
96 |
+ res = await aclient.post(f"/api/ps/{MODEL}") |
|
100 | 97 |
assert res.status_code == 201 |
101 |
- res = await aclient.post(f"/api/ps/{model}") |
|
98 |
+ res = await aclient.post(f"/api/ps/{MODEL}") |
|
102 | 99 |
assert res.status_code == 409 |
103 | 100 |
res = (await aclient.get("/api/ps")).json() |
104 | 101 |
assert len(res["models"]) == 1 |
105 | 102 |
|
106 | 103 |
|
107 | 104 |
@pytest.mark.asyncio |
108 |
-async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None: |
|
105 |
+async def test_model_is_unloaded_after_request_when_ttl_is_zero(aclient_factory: AclientFactory) -> None: |
|
109 | 106 |
ttl = 0 |
110 |
- os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" |
|
111 |
- os.environ["WHISPER__TTL"] = str(ttl) |
|
112 |
- os.environ["ENABLE_UI"] = "false" |
|
113 |
- async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: |
|
107 |
+ config = Config(whisper=WhisperConfig(model=MODEL, ttl=ttl), enable_ui=False) |
|
108 |
+ async with aclient_factory(config) as aclient: |
|
114 | 109 |
async with await anyio.open_file("audio.wav", "rb") as f: |
115 | 110 |
data = await f.read() |
116 | 111 |
res = await aclient.post( |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?