

refactor: split out app into multiple router modules
@75fe8a441b0e12ba39bc6f824c266c4420072a66
--- src/faster_whisper_server/main.py
+++ src/faster_whisper_server/main.py
... | ... | @@ -1,61 +1,33 @@ |
1 | 1 |
from __future__ import annotations |
2 | 2 |
|
3 |
-import asyncio |
|
4 | 3 |
from contextlib import asynccontextmanager |
5 |
-import gc |
|
6 |
-from io import BytesIO |
|
7 |
-from typing import TYPE_CHECKING, Annotated, Literal |
|
4 |
+from typing import TYPE_CHECKING |
|
8 | 5 |
|
9 | 6 |
from fastapi import ( |
10 | 7 |
FastAPI, |
11 |
- Form, |
|
12 |
- HTTPException, |
|
13 |
- Path, |
|
14 |
- Query, |
|
15 |
- Response, |
|
16 |
- UploadFile, |
|
17 |
- WebSocket, |
|
18 |
- WebSocketDisconnect, |
|
19 | 8 |
) |
20 | 9 |
from fastapi.middleware.cors import CORSMiddleware |
21 |
-from fastapi.responses import StreamingResponse |
|
22 |
-from fastapi.websockets import WebSocketState |
|
23 |
-from faster_whisper.vad import VadOptions, get_speech_timestamps |
|
24 |
-import huggingface_hub |
|
25 |
-from huggingface_hub.hf_api import RepositoryNotFoundError |
|
26 |
-from pydantic import AfterValidator |
|
27 | 10 |
|
28 |
-from faster_whisper_server import hf_utils |
|
29 |
-from faster_whisper_server.asr import FasterWhisperASR |
|
30 |
-from faster_whisper_server.audio import AudioStream, audio_samples_from_file |
|
31 | 11 |
from faster_whisper_server.config import ( |
32 |
- SAMPLES_PER_SECOND, |
|
33 |
- Language, |
|
34 |
- ResponseFormat, |
|
35 |
- Task, |
|
36 | 12 |
config, |
37 | 13 |
) |
38 |
-from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt |
|
39 | 14 |
from faster_whisper_server.logger import logger |
40 |
-from faster_whisper_server.model_manager import ModelManager |
|
41 |
-from faster_whisper_server.server_models import ( |
|
42 |
- ModelListResponse, |
|
43 |
- ModelObject, |
|
44 |
- TranscriptionJsonResponse, |
|
45 |
- TranscriptionVerboseJsonResponse, |
|
15 |
+from faster_whisper_server.model_manager import model_manager |
|
16 |
+from faster_whisper_server.routers.list_models import ( |
|
17 |
+ router as list_models_router, |
|
46 | 18 |
) |
47 |
-from faster_whisper_server.transcriber import audio_transcriber |
|
19 |
+from faster_whisper_server.routers.misc import ( |
|
20 |
+ router as misc_router, |
|
21 |
+) |
|
22 |
+from faster_whisper_server.routers.stt import ( |
|
23 |
+ router as stt_router, |
|
24 |
+) |
|
48 | 25 |
|
49 | 26 |
if TYPE_CHECKING: |
50 |
- from collections.abc import AsyncGenerator, Generator, Iterable |
|
51 |
- |
|
52 |
- from faster_whisper.transcribe import TranscriptionInfo |
|
53 |
- from huggingface_hub.hf_api import ModelInfo |
|
27 |
+ from collections.abc import AsyncGenerator |
|
54 | 28 |
|
55 | 29 |
|
56 | 30 |
logger.debug(f"Config: {config}") |
57 |
- |
|
58 |
-model_manager = ModelManager() |
|
59 | 31 |
|
60 | 32 |
|
61 | 33 |
@asynccontextmanager |
... | ... | @@ -67,6 +39,10 @@ |
67 | 39 |
|
68 | 40 |
app = FastAPI(lifespan=lifespan) |
69 | 41 |
|
42 |
+app.include_router(stt_router) |
|
43 |
+app.include_router(list_models_router) |
|
44 |
+app.include_router(misc_router) |
|
45 |
+ |
|
70 | 46 |
if config.allow_origins is not None: |
71 | 47 |
app.add_middleware( |
72 | 48 |
CORSMiddleware, |
... | ... | @@ -75,315 +51,6 @@ |
75 | 51 |
allow_methods=["*"], |
76 | 52 |
allow_headers=["*"], |
77 | 53 |
) |
78 |
- |
|
79 |
- |
|
80 |
-@app.get("/health") |
|
81 |
-def health() -> Response: |
|
82 |
- return Response(status_code=200, content="OK") |
|
83 |
- |
|
84 |
- |
|
85 |
-@app.post("/api/pull/{model_name:path}", tags=["experimental"], summary="Download a model from Hugging Face.") |
|
86 |
-def pull_model(model_name: str) -> Response: |
|
87 |
- if hf_utils.does_local_model_exist(model_name): |
|
88 |
- return Response(status_code=200, content="Model already exists") |
|
89 |
- try: |
|
90 |
- huggingface_hub.snapshot_download(model_name, repo_type="model") |
|
91 |
- except RepositoryNotFoundError as e: |
|
92 |
- return Response(status_code=404, content=str(e)) |
|
93 |
- return Response(status_code=201, content="Model downloaded") |
|
94 |
- |
|
95 |
- |
|
96 |
-@app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.") |
|
97 |
-def get_running_models() -> dict[str, list[str]]: |
|
98 |
- return {"models": list(model_manager.loaded_models.keys())} |
|
99 |
- |
|
100 |
- |
|
101 |
-@app.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.") |
|
102 |
-def load_model_route(model_name: str) -> Response: |
|
103 |
- if model_name in model_manager.loaded_models: |
|
104 |
- return Response(status_code=409, content="Model already loaded") |
|
105 |
- model_manager.load_model(model_name) |
|
106 |
- return Response(status_code=201) |
|
107 |
- |
|
108 |
- |
|
109 |
-@app.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.") |
|
110 |
-def stop_running_model(model_name: str) -> Response: |
|
111 |
- model = model_manager.loaded_models.get(model_name) |
|
112 |
- if model is not None: |
|
113 |
- del model_manager.loaded_models[model_name] |
|
114 |
- gc.collect() |
|
115 |
- return Response(status_code=204) |
|
116 |
- return Response(status_code=404) |
|
117 |
- |
|
118 |
- |
|
119 |
-@app.get("/v1/models") |
|
120 |
-def get_models() -> ModelListResponse: |
|
121 |
- models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True) |
|
122 |
- models = list(models) |
|
123 |
- models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003 |
|
124 |
- transformed_models: list[ModelObject] = [] |
|
125 |
- for model in models: |
|
126 |
- assert model.created_at is not None |
|
127 |
- assert model.card_data is not None |
|
128 |
- assert model.card_data.language is None or isinstance(model.card_data.language, str | list) |
|
129 |
- if model.card_data.language is None: |
|
130 |
- language = [] |
|
131 |
- elif isinstance(model.card_data.language, str): |
|
132 |
- language = [model.card_data.language] |
|
133 |
- else: |
|
134 |
- language = model.card_data.language |
|
135 |
- transformed_model = ModelObject( |
|
136 |
- id=model.id, |
|
137 |
- created=int(model.created_at.timestamp()), |
|
138 |
- object_="model", |
|
139 |
- owned_by=model.id.split("/")[0], |
|
140 |
- language=language, |
|
141 |
- ) |
|
142 |
- transformed_models.append(transformed_model) |
|
143 |
- return ModelListResponse(data=transformed_models) |
|
144 |
- |
|
145 |
- |
|
146 |
-@app.get("/v1/models/{model_name:path}") |
|
147 |
-# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537 |
|
148 |
-def get_model( |
|
149 |
- model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")], |
|
150 |
-) -> ModelObject: |
|
151 |
- models = huggingface_hub.list_models( |
|
152 |
- model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True |
|
153 |
- ) |
|
154 |
- models = list(models) |
|
155 |
- models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003 |
|
156 |
- if len(models) == 0: |
|
157 |
- raise HTTPException(status_code=404, detail="Model doesn't exists") |
|
158 |
- exact_match: ModelInfo | None = None |
|
159 |
- for model in models: |
|
160 |
- if model.id == model_name: |
|
161 |
- exact_match = model |
|
162 |
- break |
|
163 |
- if exact_match is None: |
|
164 |
- raise HTTPException( |
|
165 |
- status_code=404, |
|
166 |
- detail=f"Model doesn't exists. Possible matches: {', '.join([model.id for model in models])}", |
|
167 |
- ) |
|
168 |
- assert exact_match.created_at is not None |
|
169 |
- assert exact_match.card_data is not None |
|
170 |
- assert exact_match.card_data.language is None or isinstance(exact_match.card_data.language, str | list) |
|
171 |
- if exact_match.card_data.language is None: |
|
172 |
- language = [] |
|
173 |
- elif isinstance(exact_match.card_data.language, str): |
|
174 |
- language = [exact_match.card_data.language] |
|
175 |
- else: |
|
176 |
- language = exact_match.card_data.language |
|
177 |
- return ModelObject( |
|
178 |
- id=exact_match.id, |
|
179 |
- created=int(exact_match.created_at.timestamp()), |
|
180 |
- object_="model", |
|
181 |
- owned_by=exact_match.id.split("/")[0], |
|
182 |
- language=language, |
|
183 |
- ) |
|
184 |
- |
|
185 |
- |
|
186 |
-def segments_to_response( |
|
187 |
- segments: Iterable[Segment], |
|
188 |
- transcription_info: TranscriptionInfo, |
|
189 |
- response_format: ResponseFormat, |
|
190 |
-) -> Response: |
|
191 |
- segments = list(segments) |
|
192 |
- if response_format == ResponseFormat.TEXT: # noqa: RET503 |
|
193 |
- return Response(segments_to_text(segments), media_type="text/plain") |
|
194 |
- elif response_format == ResponseFormat.JSON: |
|
195 |
- return Response( |
|
196 |
- TranscriptionJsonResponse.from_segments(segments).model_dump_json(), |
|
197 |
- media_type="application/json", |
|
198 |
- ) |
|
199 |
- elif response_format == ResponseFormat.VERBOSE_JSON: |
|
200 |
- return Response( |
|
201 |
- TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info).model_dump_json(), |
|
202 |
- media_type="application/json", |
|
203 |
- ) |
|
204 |
- elif response_format == ResponseFormat.VTT: |
|
205 |
- return Response( |
|
206 |
- "".join(segments_to_vtt(segment, i) for i, segment in enumerate(segments)), media_type="text/vtt" |
|
207 |
- ) |
|
208 |
- elif response_format == ResponseFormat.SRT: |
|
209 |
- return Response( |
|
210 |
- "".join(segments_to_srt(segment, i) for i, segment in enumerate(segments)), media_type="text/plain" |
|
211 |
- ) |
|
212 |
- |
|
213 |
- |
|
214 |
-def format_as_sse(data: str) -> str: |
|
215 |
- return f"data: {data}\n\n" |
|
216 |
- |
|
217 |
- |
|
218 |
-def segments_to_streaming_response( |
|
219 |
- segments: Iterable[Segment], |
|
220 |
- transcription_info: TranscriptionInfo, |
|
221 |
- response_format: ResponseFormat, |
|
222 |
-) -> StreamingResponse: |
|
223 |
- def segment_responses() -> Generator[str, None, None]: |
|
224 |
- for i, segment in enumerate(segments): |
|
225 |
- if response_format == ResponseFormat.TEXT: |
|
226 |
- data = segment.text |
|
227 |
- elif response_format == ResponseFormat.JSON: |
|
228 |
- data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json() |
|
229 |
- elif response_format == ResponseFormat.VERBOSE_JSON: |
|
230 |
- data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json() |
|
231 |
- elif response_format == ResponseFormat.VTT: |
|
232 |
- data = segments_to_vtt(segment, i) |
|
233 |
- elif response_format == ResponseFormat.SRT: |
|
234 |
- data = segments_to_srt(segment, i) |
|
235 |
- yield format_as_sse(data) |
|
236 |
- |
|
237 |
- return StreamingResponse(segment_responses(), media_type="text/event-stream") |
|
238 |
- |
|
239 |
- |
|
240 |
-def handle_default_openai_model(model_name: str) -> str: |
|
241 |
- """Exists because some callers may not be able override the default("whisper-1") model name. |
|
242 |
- |
|
243 |
- For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623. |
|
244 |
- """ |
|
245 |
- if model_name == "whisper-1": |
|
246 |
- logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.") |
|
247 |
- return config.whisper.model |
|
248 |
- return model_name |
|
249 |
- |
|
250 |
- |
|
251 |
-ModelName = Annotated[str, AfterValidator(handle_default_openai_model)] |
|
252 |
- |
|
253 |
- |
|
254 |
-@app.post( |
|
255 |
- "/v1/audio/translations", |
|
256 |
- response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse, |
|
257 |
-) |
|
258 |
-def translate_file( |
|
259 |
- file: Annotated[UploadFile, Form()], |
|
260 |
- model: Annotated[ModelName, Form()] = config.whisper.model, |
|
261 |
- prompt: Annotated[str | None, Form()] = None, |
|
262 |
- response_format: Annotated[ResponseFormat, Form()] = config.default_response_format, |
|
263 |
- temperature: Annotated[float, Form()] = 0.0, |
|
264 |
- stream: Annotated[bool, Form()] = False, |
|
265 |
-) -> Response | StreamingResponse: |
|
266 |
- whisper = model_manager.load_model(model) |
|
267 |
- segments, transcription_info = whisper.transcribe( |
|
268 |
- file.file, |
|
269 |
- task=Task.TRANSLATE, |
|
270 |
- initial_prompt=prompt, |
|
271 |
- temperature=temperature, |
|
272 |
- vad_filter=True, |
|
273 |
- ) |
|
274 |
- segments = Segment.from_faster_whisper_segments(segments) |
|
275 |
- |
|
276 |
- if stream: |
|
277 |
- return segments_to_streaming_response(segments, transcription_info, response_format) |
|
278 |
- else: |
|
279 |
- return segments_to_response(segments, transcription_info, response_format) |
|
280 |
- |
|
281 |
- |
|
282 |
-# https://platform.openai.com/docs/api-reference/audio/createTranscription |
|
283 |
-# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 |
|
284 |
-@app.post( |
|
285 |
- "/v1/audio/transcriptions", |
|
286 |
- response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse, |
|
287 |
-) |
|
288 |
-def transcribe_file( |
|
289 |
- file: Annotated[UploadFile, Form()], |
|
290 |
- model: Annotated[ModelName, Form()] = config.whisper.model, |
|
291 |
- language: Annotated[Language | None, Form()] = config.default_language, |
|
292 |
- prompt: Annotated[str | None, Form()] = None, |
|
293 |
- response_format: Annotated[ResponseFormat, Form()] = config.default_response_format, |
|
294 |
- temperature: Annotated[float, Form()] = 0.0, |
|
295 |
- timestamp_granularities: Annotated[ |
|
296 |
- list[Literal["segment", "word"]], |
|
297 |
- Form(alias="timestamp_granularities[]"), |
|
298 |
- ] = ["segment"], |
|
299 |
- stream: Annotated[bool, Form()] = False, |
|
300 |
- hotwords: Annotated[str | None, Form()] = None, |
|
301 |
-) -> Response | StreamingResponse: |
|
302 |
- whisper = model_manager.load_model(model) |
|
303 |
- segments, transcription_info = whisper.transcribe( |
|
304 |
- file.file, |
|
305 |
- task=Task.TRANSCRIBE, |
|
306 |
- language=language, |
|
307 |
- initial_prompt=prompt, |
|
308 |
- word_timestamps="word" in timestamp_granularities, |
|
309 |
- temperature=temperature, |
|
310 |
- vad_filter=True, |
|
311 |
- hotwords=hotwords, |
|
312 |
- ) |
|
313 |
- segments = Segment.from_faster_whisper_segments(segments) |
|
314 |
- |
|
315 |
- if stream: |
|
316 |
- return segments_to_streaming_response(segments, transcription_info, response_format) |
|
317 |
- else: |
|
318 |
- return segments_to_response(segments, transcription_info, response_format) |
|
319 |
- |
|
320 |
- |
|
321 |
-async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: |
|
322 |
- try: |
|
323 |
- while True: |
|
324 |
- bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds) |
|
325 |
- logger.debug(f"Received {len(bytes_)} bytes of audio data") |
|
326 |
- audio_samples = audio_samples_from_file(BytesIO(bytes_)) |
|
327 |
- audio_stream.extend(audio_samples) |
|
328 |
- if audio_stream.duration - config.inactivity_window_seconds >= 0: |
|
329 |
- audio = audio_stream.after(audio_stream.duration - config.inactivity_window_seconds) |
|
330 |
- vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0) |
|
331 |
- # NOTE: This is a synchronous operation that runs every time new data is received. |
|
332 |
- # This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. # noqa: E501 |
|
333 |
- timestamps = get_speech_timestamps(audio.data, vad_opts) |
|
334 |
- if len(timestamps) == 0: |
|
335 |
- logger.info(f"No speech detected in the last {config.inactivity_window_seconds} seconds.") |
|
336 |
- break |
|
337 |
- elif ( |
|
338 |
- # last speech end time |
|
339 |
- config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND |
|
340 |
- >= config.max_inactivity_seconds |
|
341 |
- ): |
|
342 |
- logger.info(f"Not enough speech in the last {config.inactivity_window_seconds} seconds.") |
|
343 |
- break |
|
344 |
- except TimeoutError: |
|
345 |
- logger.info(f"No data received in {config.max_no_data_seconds} seconds. Closing the connection.") |
|
346 |
- except WebSocketDisconnect as e: |
|
347 |
- logger.info(f"Client disconnected: {e}") |
|
348 |
- audio_stream.close() |
|
349 |
- |
|
350 |
- |
|
351 |
-@app.websocket("/v1/audio/transcriptions") |
|
352 |
-async def transcribe_stream( |
|
353 |
- ws: WebSocket, |
|
354 |
- model: Annotated[ModelName, Query()] = config.whisper.model, |
|
355 |
- language: Annotated[Language | None, Query()] = config.default_language, |
|
356 |
- response_format: Annotated[ResponseFormat, Query()] = config.default_response_format, |
|
357 |
- temperature: Annotated[float, Query()] = 0.0, |
|
358 |
-) -> None: |
|
359 |
- await ws.accept() |
|
360 |
- transcribe_opts = { |
|
361 |
- "language": language, |
|
362 |
- "temperature": temperature, |
|
363 |
- "vad_filter": True, |
|
364 |
- "condition_on_previous_text": False, |
|
365 |
- } |
|
366 |
- whisper = model_manager.load_model(model) |
|
367 |
- asr = FasterWhisperASR(whisper, **transcribe_opts) |
|
368 |
- audio_stream = AudioStream() |
|
369 |
- async with asyncio.TaskGroup() as tg: |
|
370 |
- tg.create_task(audio_receiver(ws, audio_stream)) |
|
371 |
- async for transcription in audio_transcriber(asr, audio_stream): |
|
372 |
- logger.debug(f"Sending transcription: {transcription.text}") |
|
373 |
- if ws.client_state == WebSocketState.DISCONNECTED: |
|
374 |
- break |
|
375 |
- |
|
376 |
- if response_format == ResponseFormat.TEXT: |
|
377 |
- await ws.send_text(transcription.text) |
|
378 |
- elif response_format == ResponseFormat.JSON: |
|
379 |
- await ws.send_json(TranscriptionJsonResponse.from_transcription(transcription).model_dump()) |
|
380 |
- elif response_format == ResponseFormat.VERBOSE_JSON: |
|
381 |
- await ws.send_json(TranscriptionVerboseJsonResponse.from_transcription(transcription).model_dump()) |
|
382 |
- |
|
383 |
- if ws.client_state != WebSocketState.DISCONNECTED: |
|
384 |
- logger.info("Closing the connection.") |
|
385 |
- await ws.close() |
|
386 |
- |
|
387 | 54 |
|
388 | 55 |
if config.enable_ui: |
389 | 56 |
import gradio as gr |
--- src/faster_whisper_server/model_manager.py
+++ src/faster_whisper_server/model_manager.py
... | ... | @@ -41,3 +41,6 @@ |
41 | 41 |
) |
42 | 42 |
self.loaded_models[model_name] = whisper |
43 | 43 |
return whisper |
44 |
+ |
|
45 |
+ |
|
46 |
+model_manager = ModelManager() |
+++ src/faster_whisper_server/routers/__init__.py
... | ... | @@ -0,0 +1,0 @@ |
+++ src/faster_whisper_server/routers/list_models.py
... | ... | @@ -0,0 +1,87 @@ |
1 | +from __future__ import annotations | |
2 | + | |
3 | +from typing import TYPE_CHECKING, Annotated | |
4 | + | |
5 | +from fastapi import ( | |
6 | + APIRouter, | |
7 | + HTTPException, | |
8 | + Path, | |
9 | +) | |
10 | +import huggingface_hub | |
11 | + | |
12 | +from faster_whisper_server.server_models import ( | |
13 | + ModelListResponse, | |
14 | + ModelObject, | |
15 | +) | |
16 | + | |
17 | +if TYPE_CHECKING: | |
18 | + from huggingface_hub.hf_api import ModelInfo | |
19 | + | |
20 | +router = APIRouter() | |
21 | + | |
22 | + | |
23 | +@router.get("/v1/models") | |
24 | +def get_models() -> ModelListResponse: | |
25 | + models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True) | |
26 | + models = list(models) | |
27 | + models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003 | |
28 | + transformed_models: list[ModelObject] = [] | |
29 | + for model in models: | |
30 | + assert model.created_at is not None | |
31 | + assert model.card_data is not None | |
32 | + assert model.card_data.language is None or isinstance(model.card_data.language, str | list) | |
33 | + if model.card_data.language is None: | |
34 | + language = [] | |
35 | + elif isinstance(model.card_data.language, str): | |
36 | + language = [model.card_data.language] | |
37 | + else: | |
38 | + language = model.card_data.language | |
39 | + transformed_model = ModelObject( | |
40 | + id=model.id, | |
41 | + created=int(model.created_at.timestamp()), | |
42 | + object_="model", | |
43 | + owned_by=model.id.split("/")[0], | |
44 | + language=language, | |
45 | + ) | |
46 | + transformed_models.append(transformed_model) | |
47 | + return ModelListResponse(data=transformed_models) | |
48 | + | |
49 | + | |
50 | +@router.get("/v1/models/{model_name:path}") | |
51 | +# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537 | |
52 | +def get_model( | |
53 | + model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")], | |
54 | +) -> ModelObject: | |
55 | + models = huggingface_hub.list_models( | |
56 | + model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True | |
57 | + ) | |
58 | + models = list(models) | |
59 | + models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003 | |
60 | + if len(models) == 0: | |
61 | + raise HTTPException(status_code=404, detail="Model doesn't exists") | |
62 | + exact_match: ModelInfo | None = None | |
63 | + for model in models: | |
64 | + if model.id == model_name: | |
65 | + exact_match = model | |
66 | + break | |
67 | + if exact_match is None: | |
68 | + raise HTTPException( | |
69 | + status_code=404, | |
70 | + detail=f"Model doesn't exists. Possible matches: {', '.join([model.id for model in models])}", | |
71 | + ) | |
72 | + assert exact_match.created_at is not None | |
73 | + assert exact_match.card_data is not None | |
74 | + assert exact_match.card_data.language is None or isinstance(exact_match.card_data.language, str | list) | |
75 | + if exact_match.card_data.language is None: | |
76 | + language = [] | |
77 | + elif isinstance(exact_match.card_data.language, str): | |
78 | + language = [exact_match.card_data.language] | |
79 | + else: | |
80 | + language = exact_match.card_data.language | |
81 | + return ModelObject( | |
82 | + id=exact_match.id, | |
83 | + created=int(exact_match.created_at.timestamp()), | |
84 | + object_="model", | |
85 | + owned_by=exact_match.id.split("/")[0], | |
86 | + language=language, | |
87 | + ) |
+++ src/faster_whisper_server/routers/misc.py
... | ... | @@ -0,0 +1,53 @@ |
1 | +from __future__ import annotations | |
2 | + | |
3 | +import gc | |
4 | + | |
5 | +from fastapi import ( | |
6 | + APIRouter, | |
7 | + Response, | |
8 | +) | |
9 | +from faster_whisper_server import hf_utils | |
10 | +from faster_whisper_server.model_manager import model_manager | |
11 | +import huggingface_hub | |
12 | +from huggingface_hub.hf_api import RepositoryNotFoundError | |
13 | + | |
14 | +router = APIRouter() | |
15 | + | |
16 | + | |
17 | +@router.get("/health") | |
18 | +def health() -> Response: | |
19 | + return Response(status_code=200, content="OK") | |
20 | + | |
21 | + | |
22 | +@router.post("/api/pull/{model_name:path}", tags=["experimental"], summary="Download a model from Hugging Face.") | |
23 | +def pull_model(model_name: str) -> Response: | |
24 | + if hf_utils.does_local_model_exist(model_name): | |
25 | + return Response(status_code=200, content="Model already exists") | |
26 | + try: | |
27 | + huggingface_hub.snapshot_download(model_name, repo_type="model") | |
28 | + except RepositoryNotFoundError as e: | |
29 | + return Response(status_code=404, content=str(e)) | |
30 | + return Response(status_code=201, content="Model downloaded") | |
31 | + | |
32 | + | |
33 | +@router.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.") | |
34 | +def get_running_models() -> dict[str, list[str]]: | |
35 | + return {"models": list(model_manager.loaded_models.keys())} | |
36 | + | |
37 | + | |
38 | +@router.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.") | |
39 | +def load_model_route(model_name: str) -> Response: | |
40 | + if model_name in model_manager.loaded_models: | |
41 | + return Response(status_code=409, content="Model already loaded") | |
42 | + model_manager.load_model(model_name) | |
43 | + return Response(status_code=201) | |
44 | + | |
45 | + | |
46 | +@router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.") | |
47 | +def stop_running_model(model_name: str) -> Response: | |
48 | + model = model_manager.loaded_models.get(model_name) | |
49 | + if model is not None: | |
50 | + del model_manager.loaded_models[model_name] | |
51 | + gc.collect() | |
52 | + return Response(status_code=204) | |
53 | + return Response(status_code=404) |
+++ src/faster_whisper_server/routers/stt.py
... | ... | @@ -0,0 +1,246 @@ |
1 | +from __future__ import annotations | |
2 | + | |
3 | +import asyncio | |
4 | +from io import BytesIO | |
5 | +from typing import TYPE_CHECKING, Annotated, Literal | |
6 | + | |
7 | +from fastapi import ( | |
8 | + APIRouter, | |
9 | + Form, | |
10 | + Query, | |
11 | + Response, | |
12 | + UploadFile, | |
13 | + WebSocket, | |
14 | + WebSocketDisconnect, | |
15 | +) | |
16 | +from fastapi.responses import StreamingResponse | |
17 | +from fastapi.websockets import WebSocketState | |
18 | +from faster_whisper.vad import VadOptions, get_speech_timestamps | |
19 | +from faster_whisper_server.asr import FasterWhisperASR | |
20 | +from faster_whisper_server.audio import AudioStream, audio_samples_from_file | |
21 | +from faster_whisper_server.config import ( | |
22 | + SAMPLES_PER_SECOND, | |
23 | + Language, | |
24 | + ResponseFormat, | |
25 | + Task, | |
26 | + config, | |
27 | +) | |
28 | +from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt | |
29 | +from faster_whisper_server.logger import logger | |
30 | +from faster_whisper_server.model_manager import model_manager | |
31 | +from faster_whisper_server.server_models import ( | |
32 | + TranscriptionJsonResponse, | |
33 | + TranscriptionVerboseJsonResponse, | |
34 | +) | |
35 | +from faster_whisper_server.transcriber import audio_transcriber | |
36 | +from pydantic import AfterValidator | |
37 | + | |
38 | +if TYPE_CHECKING: | |
39 | + from collections.abc import Generator, Iterable | |
40 | + | |
41 | + from faster_whisper.transcribe import TranscriptionInfo | |
42 | + | |
43 | + | |
44 | +router = APIRouter() | |
45 | + | |
46 | + | |
47 | +def segments_to_response( | |
48 | + segments: Iterable[Segment], | |
49 | + transcription_info: TranscriptionInfo, | |
50 | + response_format: ResponseFormat, | |
51 | +) -> Response: | |
52 | + segments = list(segments) | |
53 | + if response_format == ResponseFormat.TEXT: # noqa: RET503 | |
54 | + return Response(segments_to_text(segments), media_type="text/plain") | |
55 | + elif response_format == ResponseFormat.JSON: | |
56 | + return Response( | |
57 | + TranscriptionJsonResponse.from_segments(segments).model_dump_json(), | |
58 | + media_type="application/json", | |
59 | + ) | |
60 | + elif response_format == ResponseFormat.VERBOSE_JSON: | |
61 | + return Response( | |
62 | + TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info).model_dump_json(), | |
63 | + media_type="application/json", | |
64 | + ) | |
65 | + elif response_format == ResponseFormat.VTT: | |
66 | + return Response( | |
67 | + "".join(segments_to_vtt(segment, i) for i, segment in enumerate(segments)), media_type="text/vtt" | |
68 | + ) | |
69 | + elif response_format == ResponseFormat.SRT: | |
70 | + return Response( | |
71 | + "".join(segments_to_srt(segment, i) for i, segment in enumerate(segments)), media_type="text/plain" | |
72 | + ) | |
73 | + | |
74 | + | |
75 | +def format_as_sse(data: str) -> str: | |
76 | + return f"data: {data}\n\n" | |
77 | + | |
78 | + | |
79 | +def segments_to_streaming_response( | |
80 | + segments: Iterable[Segment], | |
81 | + transcription_info: TranscriptionInfo, | |
82 | + response_format: ResponseFormat, | |
83 | +) -> StreamingResponse: | |
84 | + def segment_responses() -> Generator[str, None, None]: | |
85 | + for i, segment in enumerate(segments): | |
86 | + if response_format == ResponseFormat.TEXT: | |
87 | + data = segment.text | |
88 | + elif response_format == ResponseFormat.JSON: | |
89 | + data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json() | |
90 | + elif response_format == ResponseFormat.VERBOSE_JSON: | |
91 | + data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json() | |
92 | + elif response_format == ResponseFormat.VTT: | |
93 | + data = segments_to_vtt(segment, i) | |
94 | + elif response_format == ResponseFormat.SRT: | |
95 | + data = segments_to_srt(segment, i) | |
96 | + yield format_as_sse(data) | |
97 | + | |
98 | + return StreamingResponse(segment_responses(), media_type="text/event-stream") | |
99 | + | |
100 | + | |
101 | +def handle_default_openai_model(model_name: str) -> str: | |
102 | + """Exists because some callers may not be able override the default("whisper-1") model name. | |
103 | + | |
104 | + For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623. | |
105 | + """ | |
106 | + if model_name == "whisper-1": | |
107 | + logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.") | |
108 | + return config.whisper.model | |
109 | + return model_name | |
110 | + | |
111 | + | |
112 | +ModelName = Annotated[str, AfterValidator(handle_default_openai_model)] | |
113 | + | |
114 | + | |
115 | +@router.post( | |
116 | + "/v1/audio/translations", | |
117 | + response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse, | |
118 | +) | |
119 | +def translate_file( | |
120 | + file: Annotated[UploadFile, Form()], | |
121 | + model: Annotated[ModelName, Form()] = config.whisper.model, | |
122 | + prompt: Annotated[str | None, Form()] = None, | |
123 | + response_format: Annotated[ResponseFormat, Form()] = config.default_response_format, | |
124 | + temperature: Annotated[float, Form()] = 0.0, | |
125 | + stream: Annotated[bool, Form()] = False, | |
126 | +) -> Response | StreamingResponse: | |
127 | + whisper = model_manager.load_model(model) | |
128 | + segments, transcription_info = whisper.transcribe( | |
129 | + file.file, | |
130 | + task=Task.TRANSLATE, | |
131 | + initial_prompt=prompt, | |
132 | + temperature=temperature, | |
133 | + vad_filter=True, | |
134 | + ) | |
135 | + segments = Segment.from_faster_whisper_segments(segments) | |
136 | + | |
137 | + if stream: | |
138 | + return segments_to_streaming_response(segments, transcription_info, response_format) | |
139 | + else: | |
140 | + return segments_to_response(segments, transcription_info, response_format) | |
141 | + | |
142 | + | |
143 | +# https://platform.openai.com/docs/api-reference/audio/createTranscription | |
144 | +# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 | |
145 | +@router.post( | |
146 | + "/v1/audio/transcriptions", | |
147 | + response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse, | |
148 | +) | |
149 | +def transcribe_file( | |
150 | + file: Annotated[UploadFile, Form()], | |
151 | + model: Annotated[ModelName, Form()] = config.whisper.model, | |
152 | + language: Annotated[Language | None, Form()] = config.default_language, | |
153 | + prompt: Annotated[str | None, Form()] = None, | |
154 | + response_format: Annotated[ResponseFormat, Form()] = config.default_response_format, | |
155 | + temperature: Annotated[float, Form()] = 0.0, | |
156 | + timestamp_granularities: Annotated[ | |
157 | + list[Literal["segment", "word"]], | |
158 | + Form(alias="timestamp_granularities[]"), | |
159 | + ] = ["segment"], | |
160 | + stream: Annotated[bool, Form()] = False, | |
161 | + hotwords: Annotated[str | None, Form()] = None, | |
162 | +) -> Response | StreamingResponse: | |
163 | + whisper = model_manager.load_model(model) | |
164 | + segments, transcription_info = whisper.transcribe( | |
165 | + file.file, | |
166 | + task=Task.TRANSCRIBE, | |
167 | + language=language, | |
168 | + initial_prompt=prompt, | |
169 | + word_timestamps="word" in timestamp_granularities, | |
170 | + temperature=temperature, | |
171 | + vad_filter=True, | |
172 | + hotwords=hotwords, | |
173 | + ) | |
174 | + segments = Segment.from_faster_whisper_segments(segments) | |
175 | + | |
176 | + if stream: | |
177 | + return segments_to_streaming_response(segments, transcription_info, response_format) | |
178 | + else: | |
179 | + return segments_to_response(segments, transcription_info, response_format) | |
180 | + | |
181 | + | |
182 | +async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: | |
183 | + try: | |
184 | + while True: | |
185 | + bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds) | |
186 | + logger.debug(f"Received {len(bytes_)} bytes of audio data") | |
187 | + audio_samples = audio_samples_from_file(BytesIO(bytes_)) | |
188 | + audio_stream.extend(audio_samples) | |
189 | + if audio_stream.duration - config.inactivity_window_seconds >= 0: | |
190 | + audio = audio_stream.after(audio_stream.duration - config.inactivity_window_seconds) | |
191 | + vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0) | |
192 | + # NOTE: This is a synchronous operation that runs every time new data is received. | |
193 | + # This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. # noqa: E501 | |
194 | + timestamps = get_speech_timestamps(audio.data, vad_opts) | |
195 | + if len(timestamps) == 0: | |
196 | + logger.info(f"No speech detected in the last {config.inactivity_window_seconds} seconds.") | |
197 | + break | |
198 | + elif ( | |
199 | + # last speech end time | |
200 | + config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND | |
201 | + >= config.max_inactivity_seconds | |
202 | + ): | |
203 | + logger.info(f"Not enough speech in the last {config.inactivity_window_seconds} seconds.") | |
204 | + break | |
205 | + except TimeoutError: | |
206 | + logger.info(f"No data received in {config.max_no_data_seconds} seconds. Closing the connection.") | |
207 | + except WebSocketDisconnect as e: | |
208 | + logger.info(f"Client disconnected: {e}") | |
209 | + audio_stream.close() | |
210 | + | |
211 | + | |
212 | +@router.websocket("/v1/audio/transcriptions") | |
213 | +async def transcribe_stream( | |
214 | + ws: WebSocket, | |
215 | + model: Annotated[ModelName, Query()] = config.whisper.model, | |
216 | + language: Annotated[Language | None, Query()] = config.default_language, | |
217 | + response_format: Annotated[ResponseFormat, Query()] = config.default_response_format, | |
218 | + temperature: Annotated[float, Query()] = 0.0, | |
219 | +) -> None: | |
220 | + await ws.accept() | |
221 | + transcribe_opts = { | |
222 | + "language": language, | |
223 | + "temperature": temperature, | |
224 | + "vad_filter": True, | |
225 | + "condition_on_previous_text": False, | |
226 | + } | |
227 | + whisper = model_manager.load_model(model) | |
228 | + asr = FasterWhisperASR(whisper, **transcribe_opts) | |
229 | + audio_stream = AudioStream() | |
230 | + async with asyncio.TaskGroup() as tg: | |
231 | + tg.create_task(audio_receiver(ws, audio_stream)) | |
232 | + async for transcription in audio_transcriber(asr, audio_stream): | |
233 | + logger.debug(f"Sending transcription: {transcription.text}") | |
234 | + if ws.client_state == WebSocketState.DISCONNECTED: | |
235 | + break | |
236 | + | |
237 | + if response_format == ResponseFormat.TEXT: | |
238 | + await ws.send_text(transcription.text) | |
239 | + elif response_format == ResponseFormat.JSON: | |
240 | + await ws.send_json(TranscriptionJsonResponse.from_transcription(transcription).model_dump()) | |
241 | + elif response_format == ResponseFormat.VERBOSE_JSON: | |
242 | + await ws.send_json(TranscriptionVerboseJsonResponse.from_transcription(transcription).model_dump()) | |
243 | + | |
244 | + if ws.client_state != WebSocketState.DISCONNECTED: | |
245 | + logger.info("Closing the connection.") | |
246 | + await ws.close() |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?