

feat: add a playground
Notable change is that whisper model won't be loaded on startup anymore
@6fad0746f1afa2b2be957d62d7a3cf6e92cdfe5e
--- faster_whisper_server/config.py
+++ faster_whisper_server/config.py
... | ... | @@ -168,6 +168,11 @@ |
168 | 168 |
ZH = "zh" |
169 | 169 |
|
170 | 170 |
|
171 |
+class Task(enum.StrEnum): |
|
172 |
+ TRANSCRIPTION = "transcription" |
|
173 |
+ TRANSLATION = "translation" |
|
174 |
+ |
|
175 |
+ |
|
171 | 176 |
class WhisperConfig(BaseModel): |
172 | 177 |
model: str = Field(default="Systran/faster-whisper-medium.en") |
173 | 178 |
""" |
+++ faster_whisper_server/gradio_app.py
... | ... | @@ -0,0 +1,102 @@ |
1 | +import os | |
2 | +from typing import Generator | |
3 | + | |
4 | +import gradio as gr | |
5 | +import httpx | |
6 | +from httpx_sse import connect_sse | |
7 | + | |
8 | +from faster_whisper_server.config import Config, Task | |
9 | + | |
10 | +TRANSCRIPTION_ENDPOINT = "/v1/audio/transcriptions" | |
11 | +TRANSLATION_ENDPOINT = "/v1/audio/translations" | |
12 | + | |
13 | + | |
14 | +def create_gradio_demo(config: Config) -> gr.Blocks: | |
15 | + host = os.getenv("UVICORN_HOST", "0.0.0.0") | |
16 | + port = os.getenv("UVICORN_PORT", 8000) | |
17 | + # NOTE: worth looking into generated clients | |
18 | + http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None) | |
19 | + | |
20 | + def handler( | |
21 | + file_path: str | None, model: str, task: Task, temperature: float, stream: bool | |
22 | + ) -> Generator[str, None, None]: | |
23 | + if file_path is None: | |
24 | + yield "" | |
25 | + return | |
26 | + if stream: | |
27 | + yield from transcribe_audio_streaming(file_path, task, temperature, model) | |
28 | + yield transcribe_audio(file_path, task, temperature, model) | |
29 | + | |
30 | + def transcribe_audio( | |
31 | + file_path: str, task: Task, temperature: float, model: str | |
32 | + ) -> str: | |
33 | + if task == Task.TRANSCRIPTION: | |
34 | + endpoint = TRANSCRIPTION_ENDPOINT | |
35 | + elif task == Task.TRANSLATION: | |
36 | + endpoint = TRANSLATION_ENDPOINT | |
37 | + | |
38 | + with open(file_path, "rb") as file: | |
39 | + response = http_client.post( | |
40 | + endpoint, | |
41 | + files={"file": file}, | |
42 | + data={ | |
43 | + "model": model, | |
44 | + "response_format": "text", | |
45 | + "temperature": temperature, | |
46 | + }, | |
47 | + ) | |
48 | + | |
49 | + response.raise_for_status() | |
50 | + return response.text | |
51 | + | |
52 | + def transcribe_audio_streaming( | |
53 | + file_path: str, task: Task, temperature: float, model: str | |
54 | + ) -> Generator[str, None, None]: | |
55 | + with open(file_path, "rb") as file: | |
56 | + kwargs = { | |
57 | + "files": {"file": file}, | |
58 | + "data": { | |
59 | + "response_format": "text", | |
60 | + "temperature": temperature, | |
61 | + "model": model, | |
62 | + "stream": True, | |
63 | + }, | |
64 | + } | |
65 | + endpoint = ( | |
66 | + TRANSCRIPTION_ENDPOINT | |
67 | + if task == Task.TRANSCRIPTION | |
68 | + else TRANSLATION_ENDPOINT | |
69 | + ) | |
70 | + with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source: | |
71 | + for event in event_source.iter_sse(): | |
72 | + yield event.data | |
73 | + | |
74 | + model_dropdown = gr.Dropdown( | |
75 | + # TODO: use output from /v1/models | |
76 | + choices=[config.whisper.model], | |
77 | + label="Model", | |
78 | + value=config.whisper.model, | |
79 | + ) | |
80 | + task_dropdown = gr.Dropdown( | |
81 | + choices=[task.value for task in Task], | |
82 | + label="Task", | |
83 | + value=Task.TRANSCRIPTION, | |
84 | + ) | |
85 | + temperature_slider = gr.Slider( | |
86 | + minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0 | |
87 | + ) | |
88 | + stream_checkbox = gr.Checkbox(label="Stream", value=True) | |
89 | + demo = gr.Interface( | |
90 | + title="Whisper Playground", | |
91 | + description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""", | |
92 | + inputs=[ | |
93 | + gr.Audio(type="filepath"), | |
94 | + model_dropdown, | |
95 | + task_dropdown, | |
96 | + temperature_slider, | |
97 | + stream_checkbox, | |
98 | + ], | |
99 | + fn=handler, | |
100 | + outputs="text", | |
101 | + ) | |
102 | + return demo |
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
... | ... | @@ -2,10 +2,10 @@ |
2 | 2 |
|
3 | 3 |
import asyncio |
4 | 4 |
import time |
5 |
-from contextlib import asynccontextmanager |
|
6 | 5 |
from io import BytesIO |
7 | 6 |
from typing import Annotated, Generator, Iterable, Literal, OrderedDict |
8 | 7 |
|
8 |
+import gradio as gr |
|
9 | 9 |
import huggingface_hub |
10 | 10 |
from fastapi import ( |
11 | 11 |
FastAPI, |
... | ... | @@ -33,8 +33,10 @@ |
33 | 33 |
SAMPLES_PER_SECOND, |
34 | 34 |
Language, |
35 | 35 |
ResponseFormat, |
36 |
+ Task, |
|
36 | 37 |
config, |
37 | 38 |
) |
39 |
+from faster_whisper_server.gradio_app import create_gradio_demo |
|
38 | 40 |
from faster_whisper_server.logger import logger |
39 | 41 |
from faster_whisper_server.server_models import ( |
40 | 42 |
ModelObject, |
... | ... | @@ -71,16 +73,7 @@ |
71 | 73 |
return whisper |
72 | 74 |
|
73 | 75 |
|
74 |
-@asynccontextmanager |
|
75 |
-async def lifespan(_: FastAPI): |
|
76 |
- load_model(config.whisper.model) |
|
77 |
- yield |
|
78 |
- for model in loaded_models.keys(): |
|
79 |
- logger.info(f"Unloading {model}") |
|
80 |
- del loaded_models[model] |
|
81 |
- |
|
82 |
- |
|
83 |
-app = FastAPI(lifespan=lifespan) |
|
76 |
+app = FastAPI() |
|
84 | 77 |
|
85 | 78 |
|
86 | 79 |
@app.get("/health") |
... | ... | @@ -210,7 +203,7 @@ |
210 | 203 |
whisper = load_model(model) |
211 | 204 |
segments, transcription_info = whisper.transcribe( |
212 | 205 |
file.file, |
213 |
- task="translate", |
|
206 |
+ task=Task.TRANSLATION, |
|
214 | 207 |
initial_prompt=prompt, |
215 | 208 |
temperature=temperature, |
216 | 209 |
vad_filter=True, |
... | ... | @@ -251,7 +244,7 @@ |
251 | 244 |
whisper = load_model(model) |
252 | 245 |
segments, transcription_info = whisper.transcribe( |
253 | 246 |
file.file, |
254 |
- task="transcribe", |
|
247 |
+ task=Task.TRANSCRIPTION, |
|
255 | 248 |
language=language, |
256 | 249 |
initial_prompt=prompt, |
257 | 250 |
word_timestamps="word" in timestamp_granularities, |
... | ... | @@ -353,3 +346,6 @@ |
353 | 346 |
if not ws.client_state == WebSocketState.DISCONNECTED: |
354 | 347 |
logger.info("Closing the connection.") |
355 | 348 |
await ws.close() |
349 |
+ |
|
350 |
+ |
|
351 |
+app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/") |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?