

feat: support loading multiple models
@aada575dd8c8d2054f53e744abcee704e2944632
--- speaches/config.py
+++ speaches/config.py
... | ... | @@ -163,39 +163,41 @@ |
163 | 163 |
|
164 | 164 |
|
165 | 165 |
class WhisperConfig(BaseModel): |
166 |
- model: Model = Field(default=Model.DISTIL_MEDIUM_EN) # ENV: WHISPER_MODEL |
|
167 |
- inference_device: Device = Field( |
|
168 |
- default=Device.AUTO |
|
169 |
- ) # ENV: WHISPER_INFERENCE_DEVICE |
|
170 |
- compute_type: Quantization = Field( |
|
171 |
- default=Quantization.DEFAULT |
|
172 |
- ) # ENV: WHISPER_COMPUTE_TYPE |
|
166 |
+ model: Model = Field(default=Model.DISTIL_MEDIUM_EN) |
|
167 |
+ inference_device: Device = Field(default=Device.AUTO) |
|
168 |
+ compute_type: Quantization = Field(default=Quantization.DEFAULT) |
|
173 | 169 |
|
174 | 170 |
|
175 | 171 |
class Config(BaseSettings): |
172 |
+ """ |
|
173 |
+ Configuration for the application. Values can be set via environment variables. |
|
174 |
+ Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields. |
|
175 |
+ To populate nested, the environment should be prefixed with the nested field name and an underscore. For example, |
|
176 |
+ the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER_MODEL` to `whisper.model`, etc. |
|
177 |
+ """ |
|
178 |
+ |
|
176 | 179 |
model_config = SettingsConfigDict(env_nested_delimiter="_") |
177 | 180 |
|
178 |
- log_level: str = "info" # ENV: LOG_LEVEL |
|
179 |
- default_language: Language | None = None # ENV: DEFAULT_LANGUAGE |
|
180 |
- default_response_format: ResponseFormat = ( |
|
181 |
- ResponseFormat.JSON |
|
182 |
- ) # ENV: DEFAULT_RESPONSE_FORMAT |
|
183 |
- whisper: WhisperConfig = WhisperConfig() # ENV: WHISPER_* |
|
181 |
+ log_level: str = "info" |
|
182 |
+ default_language: Language | None = None |
|
183 |
+ default_response_format: ResponseFormat = ResponseFormat.JSON |
|
184 |
+ whisper: WhisperConfig = WhisperConfig() |
|
185 |
+ max_models: int = 1 |
|
184 | 186 |
""" |
185 | 187 |
Max duration to for the next audio chunk before transcription is finilized and connection is closed. |
186 | 188 |
""" |
187 |
- max_no_data_seconds: float = 1.0 # ENV: MAX_NO_DATA_SECONDS |
|
188 |
- min_duration: float = 1.0 # ENV: MIN_DURATION |
|
189 |
- word_timestamp_error_margin: float = 0.2 # ENV: WORD_TIMESTAMP_ERROR_MARGIN |
|
189 |
+ max_no_data_seconds: float = 1.0 |
|
190 |
+ min_duration: float = 1.0 |
|
191 |
+ word_timestamp_error_margin: float = 0.2 |
|
190 | 192 |
""" |
191 | 193 |
Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed. |
192 | 194 |
""" |
193 |
- max_inactivity_seconds: float = 2.0 # ENV: MAX_INACTIVITY_SECONDS |
|
195 |
+ max_inactivity_seconds: float = 2.0 |
|
194 | 196 |
""" |
195 | 197 |
Controls how many latest seconds of audio are being passed through VAD. |
196 | 198 |
Should be greater than `max_inactivity_seconds` |
197 | 199 |
""" |
198 |
- inactivity_window_seconds: float = 3.0 # ENV: INACTIVITY_WINDOW_SECONDS |
|
200 |
+ inactivity_window_seconds: float = 3.0 |
|
199 | 201 |
|
200 | 202 |
|
201 | 203 |
config = Config() |
--- speaches/main.py
+++ speaches/main.py
... | ... | @@ -1,11 +1,10 @@ |
1 | 1 |
from __future__ import annotations |
2 | 2 |
|
3 | 3 |
import asyncio |
4 |
-import logging |
|
5 | 4 |
import time |
6 | 5 |
from contextlib import asynccontextmanager |
7 | 6 |
from io import BytesIO |
8 |
-from typing import Annotated, Literal |
|
7 |
+from typing import Annotated, Literal, OrderedDict |
|
9 | 8 |
|
10 | 9 |
from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket, |
11 | 10 |
WebSocketDisconnect) |
... | ... | @@ -19,29 +18,45 @@ |
19 | 18 |
from speaches.audio import AudioStream, audio_samples_from_file |
20 | 19 |
from speaches.config import (SAMPLES_PER_SECOND, Language, Model, |
21 | 20 |
ResponseFormat, config) |
22 |
-from speaches.core import Transcription |
|
23 | 21 |
from speaches.logger import logger |
24 | 22 |
from speaches.server_models import (TranscriptionJsonResponse, |
25 | 23 |
TranscriptionVerboseJsonResponse) |
26 | 24 |
from speaches.transcriber import audio_transcriber |
27 | 25 |
|
28 |
-whisper: WhisperModel = None # type: ignore |
|
26 |
+models: OrderedDict[Model, WhisperModel] = OrderedDict() |
|
27 |
+ |
|
28 |
+ |
|
29 |
+def load_model(model_name: Model) -> WhisperModel: |
|
30 |
+ if model_name in models: |
|
31 |
+ logger.debug(f"{model_name} model already loaded") |
|
32 |
+ return models[model_name] |
|
33 |
+ if len(models) >= config.max_models: |
|
34 |
+ oldest_model_name = next(iter(models)) |
|
35 |
+ logger.info( |
|
36 |
+ f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}" |
|
37 |
+ ) |
|
38 |
+ del models[oldest_model_name] |
|
39 |
+ logger.debug(f"Loading {model_name}") |
|
40 |
+ start = time.perf_counter() |
|
41 |
+ whisper = WhisperModel( |
|
42 |
+ model_name, |
|
43 |
+ device=config.whisper.inference_device, |
|
44 |
+ compute_type=config.whisper.compute_type, |
|
45 |
+ ) |
|
46 |
+ logger.info( |
|
47 |
+ f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds" |
|
48 |
+ ) |
|
49 |
+ models[model_name] = whisper |
|
50 |
+ return whisper |
|
29 | 51 |
|
30 | 52 |
|
31 | 53 |
@asynccontextmanager |
32 | 54 |
async def lifespan(_: FastAPI): |
33 |
- global whisper |
|
34 |
- logging.debug(f"Loading {config.whisper.model}") |
|
35 |
- start = time.perf_counter() |
|
36 |
- whisper = WhisperModel( |
|
37 |
- config.whisper.model, |
|
38 |
- device=config.whisper.inference_device, |
|
39 |
- compute_type=config.whisper.compute_type, |
|
40 |
- ) |
|
41 |
- logger.debug( |
|
42 |
- f"Loaded {config.whisper.model} loaded in {time.perf_counter() - start:.2f} seconds" |
|
43 |
- ) |
|
55 |
+ load_model(config.whisper.model) |
|
44 | 56 |
yield |
57 |
+ for model in models.keys(): |
|
58 |
+ logger.info(f"Unloading {model}") |
|
59 |
+ del models[model] |
|
45 | 60 |
|
46 | 61 |
|
47 | 62 |
app = FastAPI(lifespan=lifespan) |
... | ... | @@ -53,7 +68,7 @@ |
53 | 68 |
|
54 | 69 |
|
55 | 70 |
@app.post("/v1/audio/translations") |
56 |
-async def translate_file( |
|
71 |
+def translate_file( |
|
57 | 72 |
file: Annotated[UploadFile, Form()], |
58 | 73 |
model: Annotated[Model, Form()] = config.whisper.model, |
59 | 74 |
prompt: Annotated[str | None, Form()] = None, |
... | ... | @@ -61,11 +76,8 @@ |
61 | 76 |
temperature: Annotated[float, Form()] = 0.0, |
62 | 77 |
stream: Annotated[bool, Form()] = False, |
63 | 78 |
): |
64 |
- if model != config.whisper.model: |
|
65 |
- logger.warning( |
|
66 |
- f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}." |
|
67 |
- ) |
|
68 | 79 |
start = time.perf_counter() |
80 |
+ whisper = load_model(model) |
|
69 | 81 |
segments, transcription_info = whisper.transcribe( |
70 | 82 |
file.file, |
71 | 83 |
task="translate", |
... | ... | @@ -107,7 +119,7 @@ |
107 | 119 |
# https://platform.openai.com/docs/api-reference/audio/createTranscription |
108 | 120 |
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 |
109 | 121 |
@app.post("/v1/audio/transcriptions") |
110 |
-async def transcribe_file( |
|
122 |
+def transcribe_file( |
|
111 | 123 |
file: Annotated[UploadFile, Form()], |
112 | 124 |
model: Annotated[Model, Form()] = config.whisper.model, |
113 | 125 |
language: Annotated[Language | None, Form()] = config.default_language, |
... | ... | @@ -120,11 +132,8 @@ |
120 | 132 |
] = ["segments"], |
121 | 133 |
stream: Annotated[bool, Form()] = False, |
122 | 134 |
): |
123 |
- if model != config.whisper.model: |
|
124 |
- logger.warning( |
|
125 |
- f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}." |
|
126 |
- ) |
|
127 | 135 |
start = time.perf_counter() |
136 |
+ whisper = load_model(model) |
|
128 | 137 |
segments, transcription_info = whisper.transcribe( |
129 | 138 |
file.file, |
130 | 139 |
task="transcribe", |
... | ... | @@ -209,21 +218,6 @@ |
209 | 218 |
audio_stream.close() |
210 | 219 |
|
211 | 220 |
|
212 |
-def format_transcription( |
|
213 |
- transcription: Transcription, response_format: ResponseFormat |
|
214 |
-) -> str: |
|
215 |
- if response_format == ResponseFormat.TEXT: |
|
216 |
- return transcription.text |
|
217 |
- elif response_format == ResponseFormat.JSON: |
|
218 |
- return TranscriptionJsonResponse.from_transcription( |
|
219 |
- transcription |
|
220 |
- ).model_dump_json() |
|
221 |
- elif response_format == ResponseFormat.VERBOSE_JSON: |
|
222 |
- return TranscriptionVerboseJsonResponse.from_transcription( |
|
223 |
- transcription |
|
224 |
- ).model_dump_json() |
|
225 |
- |
|
226 |
- |
|
227 | 221 |
@app.websocket("/v1/audio/transcriptions") |
228 | 222 |
async def transcribe_stream( |
229 | 223 |
ws: WebSocket, |
... | ... | @@ -234,18 +228,7 @@ |
234 | 228 |
ResponseFormat, Query() |
235 | 229 |
] = config.default_response_format, |
236 | 230 |
temperature: Annotated[float, Query()] = 0.0, |
237 |
- timestamp_granularities: Annotated[ |
|
238 |
- list[Literal["segments"] | Literal["words"]], |
|
239 |
- Query( |
|
240 |
- alias="timestamp_granularities[]", |
|
241 |
- description="No-op. Ignored. Only for compatibility.", |
|
242 |
- ), |
|
243 |
- ] = ["segments", "words"], |
|
244 | 231 |
) -> None: |
245 |
- if model != config.whisper.model: |
|
246 |
- logger.warning( |
|
247 |
- f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}." |
|
248 |
- ) |
|
249 | 232 |
await ws.accept() |
250 | 233 |
transcribe_opts = { |
251 | 234 |
"language": language, |
... | ... | @@ -254,6 +237,7 @@ |
254 | 237 |
"vad_filter": True, |
255 | 238 |
"condition_on_previous_text": False, |
256 | 239 |
} |
240 |
+ whisper = load_model(model) |
|
257 | 241 |
asr = FasterWhisperASR(whisper, **transcribe_opts) |
258 | 242 |
audio_stream = AudioStream() |
259 | 243 |
async with asyncio.TaskGroup() as tg: |
... | ... | @@ -262,7 +246,21 @@ |
262 | 246 |
logger.debug(f"Sending transcription: {transcription.text}") |
263 | 247 |
if ws.client_state == WebSocketState.DISCONNECTED: |
264 | 248 |
break |
265 |
- await ws.send_text(format_transcription(transcription, response_format)) |
|
249 |
+ |
|
250 |
+ if response_format == ResponseFormat.TEXT: |
|
251 |
+ await ws.send_text(transcription.text) |
|
252 |
+ elif response_format == ResponseFormat.JSON: |
|
253 |
+ await ws.send_json( |
|
254 |
+ TranscriptionJsonResponse.from_transcription( |
|
255 |
+ transcription |
|
256 |
+ ).model_dump() |
|
257 |
+ ) |
|
258 |
+ elif response_format == ResponseFormat.VERBOSE_JSON: |
|
259 |
+ await ws.send_json( |
|
260 |
+ TranscriptionVerboseJsonResponse.from_transcription( |
|
261 |
+ transcription |
|
262 |
+ ).model_dump() |
|
263 |
+ ) |
|
266 | 264 |
|
267 | 265 |
if not ws.client_state == WebSocketState.DISCONNECTED: |
268 | 266 |
logger.info("Closing the connection.") |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?