

feat: dependency injection
The main purpose of this change is to allow modifying the configuration for testing. This change does lead to some ugly code where `get_config` function gets called in random places.
@7cfaf5b1f698ed9900f3c7349a3603ac73d16873
--- Taskfile.yaml
+++ Taskfile.yaml
... | ... | @@ -1,6 +1,6 @@ |
1 | 1 |
version: "3" |
2 | 2 |
tasks: |
3 |
- server: uvicorn --host 0.0.0.0 faster_whisper_server.main:app {{.CLI_ARGS}} |
|
3 |
+ server: uvicorn --factory --host 0.0.0.0 faster_whisper_server.main:create_app {{.CLI_ARGS}} |
|
4 | 4 |
test: |
5 | 5 |
cmds: |
6 | 6 |
- pytest -o log_cli=true -o log_cli_level=DEBUG {{.CLI_ARGS}} |
--- pyproject.toml
+++ pyproject.toml
... | ... | @@ -75,6 +75,7 @@ |
75 | 75 |
"ISC001", # recommended to disable for formatting |
76 | 76 |
"INP001", |
77 | 77 |
"PT018", |
78 |
+ "G004", # logging f string |
|
78 | 79 |
] |
79 | 80 |
|
80 | 81 |
[tool.ruff.lint.isort] |
--- src/faster_whisper_server/asr.py
+++ src/faster_whisper_server/asr.py
... | ... | @@ -1,11 +1,13 @@ |
1 | 1 |
import asyncio |
2 |
+import logging |
|
2 | 3 |
import time |
3 | 4 |
|
4 | 5 |
from faster_whisper import transcribe |
5 | 6 |
|
6 | 7 |
from faster_whisper_server.audio import Audio |
7 | 8 |
from faster_whisper_server.core import Segment, Transcription, Word |
8 |
-from faster_whisper_server.logger import logger |
|
9 |
+ |
|
10 |
+logger = logging.getLogger(__name__) |
|
9 | 11 |
|
10 | 12 |
|
11 | 13 |
class FasterWhisperASR: |
--- src/faster_whisper_server/audio.py
+++ src/faster_whisper_server/audio.py
... | ... | @@ -1,13 +1,13 @@ |
1 | 1 |
from __future__ import annotations |
2 | 2 |
|
3 | 3 |
import asyncio |
4 |
+import logging |
|
4 | 5 |
from typing import TYPE_CHECKING, BinaryIO |
5 | 6 |
|
6 | 7 |
import numpy as np |
7 | 8 |
import soundfile as sf |
8 | 9 |
|
9 | 10 |
from faster_whisper_server.config import SAMPLES_PER_SECOND |
10 |
-from faster_whisper_server.logger import logger |
|
11 | 11 |
|
12 | 12 |
if TYPE_CHECKING: |
13 | 13 |
from collections.abc import AsyncGenerator |
... | ... | @@ -15,6 +15,9 @@ |
15 | 15 |
from numpy.typing import NDArray |
16 | 16 |
|
17 | 17 |
|
18 |
+logger = logging.getLogger(__name__) |
|
19 |
+ |
|
20 |
+ |
|
18 | 21 |
def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]: |
19 | 22 |
audio_and_sample_rate = sf.read( |
20 | 23 |
file, |
--- src/faster_whisper_server/config.py
+++ src/faster_whisper_server/config.py
... | ... | @@ -238,6 +238,3 @@ |
238 | 238 |
f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501 |
239 | 239 |
) |
240 | 240 |
return self |
241 |
- |
|
242 |
- |
|
243 |
-config = Config() |
--- src/faster_whisper_server/core.py
+++ src/faster_whisper_server/core.py
... | ... | @@ -5,7 +5,7 @@ |
5 | 5 |
|
6 | 6 |
from pydantic import BaseModel |
7 | 7 |
|
8 |
-from faster_whisper_server.config import config |
|
8 |
+from faster_whisper_server.dependencies import get_config |
|
9 | 9 |
|
10 | 10 |
if TYPE_CHECKING: |
11 | 11 |
from collections.abc import Iterable |
... | ... | @@ -113,6 +113,7 @@ |
113 | 113 |
self.words.extend(words) |
114 | 114 |
|
115 | 115 |
def _ensure_no_word_overlap(self, words: list[Word]) -> None: |
116 |
+ config = get_config() # HACK |
|
116 | 117 |
if len(self.words) > 0 and len(words) > 0: |
117 | 118 |
if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end: |
118 | 119 |
raise ValueError( |
+++ src/faster_whisper_server/dependencies.py
... | ... | @@ -0,0 +1,24 @@ |
1 | +from functools import lru_cache | |
2 | +from typing import Annotated | |
3 | + | |
4 | +from fastapi import Depends | |
5 | + | |
6 | +from faster_whisper_server.config import Config | |
7 | +from faster_whisper_server.model_manager import ModelManager | |
8 | + | |
9 | + | |
10 | +@lru_cache | |
11 | +def get_config() -> Config: | |
12 | + return Config() | |
13 | + | |
14 | + | |
15 | +ConfigDependency = Annotated[Config, Depends(get_config)] | |
16 | + | |
17 | + | |
18 | +@lru_cache | |
19 | +def get_model_manager() -> ModelManager: | |
20 | + config = get_config() # HACK | |
21 | + return ModelManager(config) | |
22 | + | |
23 | + | |
24 | +ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)] |
--- src/faster_whisper_server/hf_utils.py
+++ src/faster_whisper_server/hf_utils.py
... | ... | @@ -1,10 +1,11 @@ |
1 | 1 |
from collections.abc import Generator |
2 |
+import logging |
|
2 | 3 |
from pathlib import Path |
3 | 4 |
import typing |
4 | 5 |
|
5 | 6 |
import huggingface_hub |
6 | 7 |
|
7 |
-from faster_whisper_server.logger import logger |
|
8 |
+logger = logging.getLogger(__name__) |
|
8 | 9 |
|
9 | 10 |
LIBRARY_NAME = "ctranslate2" |
10 | 11 |
TASK_NAME = "automatic-speech-recognition" |
--- src/faster_whisper_server/logger.py
+++ src/faster_whisper_server/logger.py
... | ... | @@ -1,8 +1,11 @@ |
1 | 1 |
import logging |
2 | 2 |
|
3 |
-from faster_whisper_server.config import config |
|
3 |
+from faster_whisper_server.dependencies import get_config |
|
4 | 4 |
|
5 |
-logging.getLogger().setLevel(logging.INFO) |
|
6 |
-logger = logging.getLogger(__name__) |
|
7 |
-logger.setLevel(config.log_level.upper()) |
|
8 |
-logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s") |
|
5 |
+ |
|
6 |
+def setup_logger() -> None: |
|
7 |
+ config = get_config() # HACK |
|
8 |
+ logging.getLogger().setLevel(logging.INFO) |
|
9 |
+ logger = logging.getLogger(__name__) |
|
10 |
+ logger.setLevel(config.log_level.upper()) |
|
11 |
+ logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s") |
--- src/faster_whisper_server/main.py
+++ src/faster_whisper_server/main.py
... | ... | @@ -1,6 +1,7 @@ |
1 | 1 |
from __future__ import annotations |
2 | 2 |
|
3 | 3 |
from contextlib import asynccontextmanager |
4 |
+import logging |
|
4 | 5 |
from typing import TYPE_CHECKING |
5 | 6 |
|
6 | 7 |
from fastapi import ( |
... | ... | @@ -8,11 +9,8 @@ |
8 | 9 |
) |
9 | 10 |
from fastapi.middleware.cors import CORSMiddleware |
10 | 11 |
|
11 |
-from faster_whisper_server.config import ( |
|
12 |
- config, |
|
13 |
-) |
|
14 |
-from faster_whisper_server.logger import logger |
|
15 |
-from faster_whisper_server.model_manager import model_manager |
|
12 |
+from faster_whisper_server.dependencies import get_config, get_model_manager |
|
13 |
+from faster_whisper_server.logger import setup_logger |
|
16 | 14 |
from faster_whisper_server.routers.list_models import ( |
17 | 15 |
router as list_models_router, |
18 | 16 |
) |
... | ... | @@ -27,34 +25,42 @@ |
27 | 25 |
from collections.abc import AsyncGenerator |
28 | 26 |
|
29 | 27 |
|
30 |
-logger.debug(f"Config: {config}") |
|
28 |
+def create_app() -> FastAPI: |
|
29 |
+ setup_logger() |
|
31 | 30 |
|
31 |
+ logger = logging.getLogger(__name__) |
|
32 | 32 |
|
33 |
-@asynccontextmanager |
|
34 |
-async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: |
|
35 |
- for model_name in config.preload_models: |
|
36 |
- model_manager.load_model(model_name) |
|
37 |
- yield |
|
33 |
+ config = get_config() # HACK |
|
34 |
+ logger.debug(f"Config: {config}") |
|
38 | 35 |
|
36 |
+ model_manager = get_model_manager() # HACK |
|
39 | 37 |
|
40 |
-app = FastAPI(lifespan=lifespan) |
|
38 |
+ @asynccontextmanager |
|
39 |
+ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: |
|
40 |
+ for model_name in config.preload_models: |
|
41 |
+ model_manager.load_model(model_name) |
|
42 |
+ yield |
|
41 | 43 |
|
42 |
-app.include_router(stt_router) |
|
43 |
-app.include_router(list_models_router) |
|
44 |
-app.include_router(misc_router) |
|
44 |
+ app = FastAPI(lifespan=lifespan) |
|
45 | 45 |
|
46 |
-if config.allow_origins is not None: |
|
47 |
- app.add_middleware( |
|
48 |
- CORSMiddleware, |
|
49 |
- allow_origins=config.allow_origins, |
|
50 |
- allow_credentials=True, |
|
51 |
- allow_methods=["*"], |
|
52 |
- allow_headers=["*"], |
|
53 |
- ) |
|
46 |
+ app.include_router(stt_router) |
|
47 |
+ app.include_router(list_models_router) |
|
48 |
+ app.include_router(misc_router) |
|
54 | 49 |
|
55 |
-if config.enable_ui: |
|
56 |
- import gradio as gr |
|
50 |
+ if config.allow_origins is not None: |
|
51 |
+ app.add_middleware( |
|
52 |
+ CORSMiddleware, |
|
53 |
+ allow_origins=config.allow_origins, |
|
54 |
+ allow_credentials=True, |
|
55 |
+ allow_methods=["*"], |
|
56 |
+ allow_headers=["*"], |
|
57 |
+ ) |
|
57 | 58 |
|
58 |
- from faster_whisper_server.gradio_app import create_gradio_demo |
|
59 |
+ if config.enable_ui: |
|
60 |
+ import gradio as gr |
|
59 | 61 |
|
60 |
- app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/") |
|
62 |
+ from faster_whisper_server.gradio_app import create_gradio_demo |
|
63 |
+ |
|
64 |
+ app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/") |
|
65 |
+ |
|
66 |
+ return app |
--- src/faster_whisper_server/model_manager.py
+++ src/faster_whisper_server/model_manager.py
... | ... | @@ -2,27 +2,34 @@ |
2 | 2 |
|
3 | 3 |
from collections import OrderedDict |
4 | 4 |
import gc |
5 |
+import logging |
|
5 | 6 |
import time |
7 |
+from typing import TYPE_CHECKING |
|
6 | 8 |
|
7 | 9 |
from faster_whisper import WhisperModel |
8 | 10 |
|
9 |
-from faster_whisper_server.config import ( |
|
10 |
- config, |
|
11 |
-) |
|
12 |
-from faster_whisper_server.logger import logger |
|
11 |
+if TYPE_CHECKING: |
|
12 |
+ from faster_whisper_server.config import ( |
|
13 |
+ Config, |
|
14 |
+ ) |
|
15 |
+ |
|
16 |
+logger = logging.getLogger(__name__) |
|
13 | 17 |
|
14 | 18 |
|
15 | 19 |
class ModelManager: |
16 |
- def __init__(self) -> None: |
|
20 |
+ def __init__(self, config: Config) -> None: |
|
21 |
+ self.config = config |
|
17 | 22 |
self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict() |
18 | 23 |
|
19 | 24 |
def load_model(self, model_name: str) -> WhisperModel: |
20 | 25 |
if model_name in self.loaded_models: |
21 | 26 |
logger.debug(f"{model_name} model already loaded") |
22 | 27 |
return self.loaded_models[model_name] |
23 |
- if len(self.loaded_models) >= config.max_models: |
|
28 |
+ if len(self.loaded_models) >= self.config.max_models: |
|
24 | 29 |
oldest_model_name = next(iter(self.loaded_models)) |
25 |
- logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}") |
|
30 |
+ logger.info( |
|
31 |
+ f"Max models ({self.config.max_models}) reached. Unloading the oldest model: {oldest_model_name}" |
|
32 |
+ ) |
|
26 | 33 |
del self.loaded_models[oldest_model_name] |
27 | 34 |
gc.collect() |
28 | 35 |
logger.debug(f"Loading {model_name}...") |
... | ... | @@ -30,17 +37,14 @@ |
30 | 37 |
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check? |
31 | 38 |
whisper = WhisperModel( |
32 | 39 |
model_name, |
33 |
- device=config.whisper.inference_device, |
|
34 |
- device_index=config.whisper.device_index, |
|
35 |
- compute_type=config.whisper.compute_type, |
|
36 |
- cpu_threads=config.whisper.cpu_threads, |
|
37 |
- num_workers=config.whisper.num_workers, |
|
40 |
+ device=self.config.whisper.inference_device, |
|
41 |
+ device_index=self.config.whisper.device_index, |
|
42 |
+ compute_type=self.config.whisper.compute_type, |
|
43 |
+ cpu_threads=self.config.whisper.cpu_threads, |
|
44 |
+ num_workers=self.config.whisper.num_workers, |
|
38 | 45 |
) |
39 | 46 |
logger.info( |
40 |
- f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501 |
|
47 |
+ f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference." # noqa: E501 |
|
41 | 48 |
) |
42 | 49 |
self.loaded_models[model_name] = whisper |
43 | 50 |
return whisper |
44 |
- |
|
45 |
- |
|
46 |
-model_manager = ModelManager() |
--- src/faster_whisper_server/routers/misc.py
+++ src/faster_whisper_server/routers/misc.py
... | ... | @@ -6,10 +6,11 @@ |
6 | 6 |
APIRouter, |
7 | 7 |
Response, |
8 | 8 |
) |
9 |
-from faster_whisper_server import hf_utils |
|
10 |
-from faster_whisper_server.model_manager import model_manager |
|
11 | 9 |
import huggingface_hub |
12 | 10 |
from huggingface_hub.hf_api import RepositoryNotFoundError |
11 |
+ |
|
12 |
+from faster_whisper_server import hf_utils |
|
13 |
+from faster_whisper_server.dependencies import ModelManagerDependency # noqa: TCH001 |
|
13 | 14 |
|
14 | 15 |
router = APIRouter() |
15 | 16 |
|
... | ... | @@ -31,12 +32,14 @@ |
31 | 32 |
|
32 | 33 |
|
33 | 34 |
@router.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.") |
34 |
-def get_running_models() -> dict[str, list[str]]: |
|
35 |
+def get_running_models( |
|
36 |
+ model_manager: ModelManagerDependency, |
|
37 |
+) -> dict[str, list[str]]: |
|
35 | 38 |
return {"models": list(model_manager.loaded_models.keys())} |
36 | 39 |
|
37 | 40 |
|
38 | 41 |
@router.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.") |
39 |
-def load_model_route(model_name: str) -> Response: |
|
42 |
+def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response: |
|
40 | 43 |
if model_name in model_manager.loaded_models: |
41 | 44 |
return Response(status_code=409, content="Model already loaded") |
42 | 45 |
model_manager.load_model(model_name) |
... | ... | @@ -44,7 +47,7 @@ |
44 | 47 |
|
45 | 48 |
|
46 | 49 |
@router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.") |
47 |
-def stop_running_model(model_name: str) -> Response: |
|
50 |
+def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response: |
|
48 | 51 |
model = model_manager.loaded_models.get(model_name) |
49 | 52 |
if model is not None: |
50 | 53 |
del model_manager.loaded_models[model_name] |
--- src/faster_whisper_server/routers/stt.py
+++ src/faster_whisper_server/routers/stt.py
... | ... | @@ -2,6 +2,7 @@ |
2 | 2 |
|
3 | 3 |
import asyncio |
4 | 4 |
from io import BytesIO |
5 |
+import logging |
|
5 | 6 |
from typing import TYPE_CHECKING, Annotated, Literal |
6 | 7 |
|
7 | 8 |
from fastapi import ( |
... | ... | @@ -16,6 +17,8 @@ |
16 | 17 |
from fastapi.responses import StreamingResponse |
17 | 18 |
from fastapi.websockets import WebSocketState |
18 | 19 |
from faster_whisper.vad import VadOptions, get_speech_timestamps |
20 |
+from pydantic import AfterValidator |
|
21 |
+ |
|
19 | 22 |
from faster_whisper_server.asr import FasterWhisperASR |
20 | 23 |
from faster_whisper_server.audio import AudioStream, audio_samples_from_file |
21 | 24 |
from faster_whisper_server.config import ( |
... | ... | @@ -23,23 +26,22 @@ |
23 | 26 |
Language, |
24 | 27 |
ResponseFormat, |
25 | 28 |
Task, |
26 |
- config, |
|
27 | 29 |
) |
28 | 30 |
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt |
29 |
-from faster_whisper_server.logger import logger |
|
30 |
-from faster_whisper_server.model_manager import model_manager |
|
31 |
+from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config |
|
31 | 32 |
from faster_whisper_server.server_models import ( |
32 | 33 |
TranscriptionJsonResponse, |
33 | 34 |
TranscriptionVerboseJsonResponse, |
34 | 35 |
) |
35 | 36 |
from faster_whisper_server.transcriber import audio_transcriber |
36 |
-from pydantic import AfterValidator |
|
37 | 37 |
|
38 | 38 |
if TYPE_CHECKING: |
39 | 39 |
from collections.abc import Generator, Iterable |
40 | 40 |
|
41 | 41 |
from faster_whisper.transcribe import TranscriptionInfo |
42 | 42 |
|
43 |
+ |
|
44 |
+logger = logging.getLogger(__name__) |
|
43 | 45 |
|
44 | 46 |
router = APIRouter() |
45 | 47 |
|
... | ... | @@ -103,6 +105,7 @@ |
103 | 105 |
|
104 | 106 |
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623. |
105 | 107 |
""" |
108 |
+ config = get_config() # HACK |
|
106 | 109 |
if model_name == "whisper-1": |
107 | 110 |
logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.") |
108 | 111 |
return config.whisper.model |
... | ... | @@ -117,13 +120,19 @@ |
117 | 120 |
response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse, |
118 | 121 |
) |
119 | 122 |
def translate_file( |
123 |
+ config: ConfigDependency, |
|
124 |
+ model_manager: ModelManagerDependency, |
|
120 | 125 |
file: Annotated[UploadFile, Form()], |
121 |
- model: Annotated[ModelName, Form()] = config.whisper.model, |
|
126 |
+ model: Annotated[ModelName | None, Form()] = None, |
|
122 | 127 |
prompt: Annotated[str | None, Form()] = None, |
123 |
- response_format: Annotated[ResponseFormat, Form()] = config.default_response_format, |
|
128 |
+ response_format: Annotated[ResponseFormat | None, Form()] = None, |
|
124 | 129 |
temperature: Annotated[float, Form()] = 0.0, |
125 | 130 |
stream: Annotated[bool, Form()] = False, |
126 | 131 |
) -> Response | StreamingResponse: |
132 |
+ if model is None: |
|
133 |
+ model = config.whisper.model |
|
134 |
+ if response_format is None: |
|
135 |
+ response_format = config.default_response_format |
|
127 | 136 |
whisper = model_manager.load_model(model) |
128 | 137 |
segments, transcription_info = whisper.transcribe( |
129 | 138 |
file.file, |
... | ... | @@ -147,11 +156,13 @@ |
147 | 156 |
response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse, |
148 | 157 |
) |
149 | 158 |
def transcribe_file( |
159 |
+ config: ConfigDependency, |
|
160 |
+ model_manager: ModelManagerDependency, |
|
150 | 161 |
file: Annotated[UploadFile, Form()], |
151 |
- model: Annotated[ModelName, Form()] = config.whisper.model, |
|
152 |
- language: Annotated[Language | None, Form()] = config.default_language, |
|
162 |
+ model: Annotated[ModelName | None, Form()] = None, |
|
163 |
+ language: Annotated[Language | None, Form()] = None, |
|
153 | 164 |
prompt: Annotated[str | None, Form()] = None, |
154 |
- response_format: Annotated[ResponseFormat, Form()] = config.default_response_format, |
|
165 |
+ response_format: Annotated[ResponseFormat | None, Form()] = None, |
|
155 | 166 |
temperature: Annotated[float, Form()] = 0.0, |
156 | 167 |
timestamp_granularities: Annotated[ |
157 | 168 |
list[Literal["segment", "word"]], |
... | ... | @@ -160,6 +171,12 @@ |
160 | 171 |
stream: Annotated[bool, Form()] = False, |
161 | 172 |
hotwords: Annotated[str | None, Form()] = None, |
162 | 173 |
) -> Response | StreamingResponse: |
174 |
+ if model is None: |
|
175 |
+ model = config.whisper.model |
|
176 |
+ if language is None: |
|
177 |
+ language = config.default_language |
|
178 |
+ if response_format is None: |
|
179 |
+ response_format = config.default_response_format |
|
163 | 180 |
whisper = model_manager.load_model(model) |
164 | 181 |
segments, transcription_info = whisper.transcribe( |
165 | 182 |
file.file, |
... | ... | @@ -180,6 +197,7 @@ |
180 | 197 |
|
181 | 198 |
|
182 | 199 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: |
200 |
+ config = get_config() # HACK |
|
183 | 201 |
try: |
184 | 202 |
while True: |
185 | 203 |
bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds) |
... | ... | @@ -211,12 +229,20 @@ |
211 | 229 |
|
212 | 230 |
@router.websocket("/v1/audio/transcriptions") |
213 | 231 |
async def transcribe_stream( |
232 |
+ config: ConfigDependency, |
|
233 |
+ model_manager: ModelManagerDependency, |
|
214 | 234 |
ws: WebSocket, |
215 |
- model: Annotated[ModelName, Query()] = config.whisper.model, |
|
216 |
- language: Annotated[Language | None, Query()] = config.default_language, |
|
217 |
- response_format: Annotated[ResponseFormat, Query()] = config.default_response_format, |
|
235 |
+ model: Annotated[ModelName | None, Query()] = None, |
|
236 |
+ language: Annotated[Language | None, Query()] = None, |
|
237 |
+ response_format: Annotated[ResponseFormat | None, Query()] = None, |
|
218 | 238 |
temperature: Annotated[float, Query()] = 0.0, |
219 | 239 |
) -> None: |
240 |
+ if model is None: |
|
241 |
+ model = config.whisper.model |
|
242 |
+ if language is None: |
|
243 |
+ language = config.default_language |
|
244 |
+ if response_format is None: |
|
245 |
+ response_format = config.default_response_format |
|
220 | 246 |
await ws.accept() |
221 | 247 |
transcribe_opts = { |
222 | 248 |
"language": language, |
... | ... | @@ -229,7 +255,7 @@ |
229 | 255 |
audio_stream = AudioStream() |
230 | 256 |
async with asyncio.TaskGroup() as tg: |
231 | 257 |
tg.create_task(audio_receiver(ws, audio_stream)) |
232 |
- async for transcription in audio_transcriber(asr, audio_stream): |
|
258 |
+ async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration): |
|
233 | 259 |
logger.debug(f"Sending transcription: {transcription.text}") |
234 | 260 |
if ws.client_state == WebSocketState.DISCONNECTED: |
235 | 261 |
break |
--- src/faster_whisper_server/transcriber.py
+++ src/faster_whisper_server/transcriber.py
... | ... | @@ -1,16 +1,17 @@ |
1 | 1 |
from __future__ import annotations |
2 | 2 |
|
3 |
+import logging |
|
3 | 4 |
from typing import TYPE_CHECKING |
4 | 5 |
|
5 | 6 |
from faster_whisper_server.audio import Audio, AudioStream |
6 |
-from faster_whisper_server.config import config |
|
7 | 7 |
from faster_whisper_server.core import Transcription, Word, common_prefix, to_full_sentences, word_to_text |
8 |
-from faster_whisper_server.logger import logger |
|
9 | 8 |
|
10 | 9 |
if TYPE_CHECKING: |
11 | 10 |
from collections.abc import AsyncGenerator |
12 | 11 |
|
13 | 12 |
from faster_whisper_server.asr import FasterWhisperASR |
13 |
+ |
|
14 |
+logger = logging.getLogger(__name__) |
|
14 | 15 |
|
15 | 16 |
|
16 | 17 |
class LocalAgreement: |
... | ... | @@ -47,11 +48,12 @@ |
47 | 48 |
async def audio_transcriber( |
48 | 49 |
asr: FasterWhisperASR, |
49 | 50 |
audio_stream: AudioStream, |
51 |
+ min_duration: float, |
|
50 | 52 |
) -> AsyncGenerator[Transcription, None]: |
51 | 53 |
local_agreement = LocalAgreement() |
52 | 54 |
full_audio = Audio() |
53 | 55 |
confirmed = Transcription() |
54 |
- async for chunk in audio_stream.chunks(config.min_duration): |
|
56 |
+ async for chunk in audio_stream.chunks(min_duration): |
|
55 | 57 |
full_audio.extend(chunk) |
56 | 58 |
audio = full_audio.after(needs_audio_after(confirmed)) |
57 | 59 |
transcription, _ = await asr.transcribe(audio, prompt(confirmed)) |
--- tests/conftest.py
+++ tests/conftest.py
... | ... | @@ -1,7 +1,9 @@ |
1 | 1 |
from collections.abc import AsyncGenerator, Generator |
2 | 2 |
import logging |
3 |
+import os |
|
3 | 4 |
|
4 | 5 |
from fastapi.testclient import TestClient |
6 |
+from faster_whisper_server.main import create_app |
|
5 | 7 |
from httpx import ASGITransport, AsyncClient |
6 | 8 |
from openai import OpenAI |
7 | 9 |
import pytest |
... | ... | @@ -18,17 +20,15 @@ |
18 | 20 |
|
19 | 21 |
@pytest.fixture() |
20 | 22 |
def client() -> Generator[TestClient, None, None]: |
21 |
- from faster_whisper_server.main import app |
|
22 |
- |
|
23 |
- with TestClient(app) as client: |
|
23 |
+ os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" |
|
24 |
+ with TestClient(create_app()) as client: |
|
24 | 25 |
yield client |
25 | 26 |
|
26 | 27 |
|
27 | 28 |
@pytest_asyncio.fixture() |
28 | 29 |
async def aclient() -> AsyncGenerator[AsyncClient, None]: |
29 |
- from faster_whisper_server.main import app |
|
30 |
- |
|
31 |
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient: |
|
30 |
+ os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" |
|
31 |
+ async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: |
|
32 | 32 |
yield aclient |
33 | 33 |
|
34 | 34 |
|
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?