Fedir Zadniprovskyi 2024-06-23
feat: add a playground
Notable change is that whisper model won't be loaded on startup anymore
@6fad0746f1afa2b2be957d62d7a3cf6e92cdfe5e
faster_whisper_server/config.py
--- faster_whisper_server/config.py
+++ faster_whisper_server/config.py
@@ -168,6 +168,11 @@
     ZH = "zh"
 
 
+class Task(enum.StrEnum):
+    TRANSCRIPTION = "transcription"
+    TRANSLATION = "translation"
+
+
 class WhisperConfig(BaseModel):
     model: str = Field(default="Systran/faster-whisper-medium.en")
     """
 
faster_whisper_server/gradio_app.py (added)
+++ faster_whisper_server/gradio_app.py
@@ -0,0 +1,102 @@
+import os
+from typing import Generator
+
+import gradio as gr
+import httpx
+from httpx_sse import connect_sse
+
+from faster_whisper_server.config import Config, Task
+
+TRANSCRIPTION_ENDPOINT = "/v1/audio/transcriptions"
+TRANSLATION_ENDPOINT = "/v1/audio/translations"
+
+
+def create_gradio_demo(config: Config) -> gr.Blocks:
+    host = os.getenv("UVICORN_HOST", "0.0.0.0")
+    port = os.getenv("UVICORN_PORT", 8000)
+    # NOTE: worth looking into generated clients
+    http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None)
+
+    def handler(
+        file_path: str | None, model: str, task: Task, temperature: float, stream: bool
+    ) -> Generator[str, None, None]:
+        if file_path is None:
+            yield ""
+            return
+        if stream:
+            yield from transcribe_audio_streaming(file_path, task, temperature, model)
+        yield transcribe_audio(file_path, task, temperature, model)
+
+    def transcribe_audio(
+        file_path: str, task: Task, temperature: float, model: str
+    ) -> str:
+        if task == Task.TRANSCRIPTION:
+            endpoint = TRANSCRIPTION_ENDPOINT
+        elif task == Task.TRANSLATION:
+            endpoint = TRANSLATION_ENDPOINT
+
+        with open(file_path, "rb") as file:
+            response = http_client.post(
+                endpoint,
+                files={"file": file},
+                data={
+                    "model": model,
+                    "response_format": "text",
+                    "temperature": temperature,
+                },
+            )
+
+        response.raise_for_status()
+        return response.text
+
+    def transcribe_audio_streaming(
+        file_path: str, task: Task, temperature: float, model: str
+    ) -> Generator[str, None, None]:
+        with open(file_path, "rb") as file:
+            kwargs = {
+                "files": {"file": file},
+                "data": {
+                    "response_format": "text",
+                    "temperature": temperature,
+                    "model": model,
+                    "stream": True,
+                },
+            }
+            endpoint = (
+                TRANSCRIPTION_ENDPOINT
+                if task == Task.TRANSCRIPTION
+                else TRANSLATION_ENDPOINT
+            )
+            with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
+                for event in event_source.iter_sse():
+                    yield event.data
+
+    model_dropdown = gr.Dropdown(
+        # TODO: use output from /v1/models
+        choices=[config.whisper.model],
+        label="Model",
+        value=config.whisper.model,
+    )
+    task_dropdown = gr.Dropdown(
+        choices=[task.value for task in Task],
+        label="Task",
+        value=Task.TRANSCRIPTION,
+    )
+    temperature_slider = gr.Slider(
+        minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0
+    )
+    stream_checkbox = gr.Checkbox(label="Stream", value=True)
+    demo = gr.Interface(
+        title="Whisper Playground",
+        description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""",
+        inputs=[
+            gr.Audio(type="filepath"),
+            model_dropdown,
+            task_dropdown,
+            temperature_slider,
+            stream_checkbox,
+        ],
+        fn=handler,
+        outputs="text",
+    )
+    return demo
faster_whisper_server/main.py
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
@@ -2,10 +2,10 @@
 
 import asyncio
 import time
-from contextlib import asynccontextmanager
 from io import BytesIO
 from typing import Annotated, Generator, Iterable, Literal, OrderedDict
 
+import gradio as gr
 import huggingface_hub
 from fastapi import (
     FastAPI,
@@ -33,8 +33,10 @@
     SAMPLES_PER_SECOND,
     Language,
     ResponseFormat,
+    Task,
     config,
 )
+from faster_whisper_server.gradio_app import create_gradio_demo
 from faster_whisper_server.logger import logger
 from faster_whisper_server.server_models import (
     ModelObject,
@@ -71,16 +73,7 @@
     return whisper
 
 
-@asynccontextmanager
-async def lifespan(_: FastAPI):
-    load_model(config.whisper.model)
-    yield
-    for model in loaded_models.keys():
-        logger.info(f"Unloading {model}")
-        del loaded_models[model]
-
-
-app = FastAPI(lifespan=lifespan)
+app = FastAPI()
 
 
 @app.get("/health")
@@ -210,7 +203,7 @@
     whisper = load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
-        task="translate",
+        task=Task.TRANSLATION,
         initial_prompt=prompt,
         temperature=temperature,
         vad_filter=True,
@@ -251,7 +244,7 @@
     whisper = load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
-        task="transcribe",
+        task=Task.TRANSCRIPTION,
         language=language,
         initial_prompt=prompt,
         word_timestamps="word" in timestamp_granularities,
@@ -353,3 +346,6 @@
     if not ws.client_state == WebSocketState.DISCONNECTED:
         logger.info("Closing the connection.")
         await ws.close()
+
+
+app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/")
Add a comment
List