Fedir Zadniprovskyi 01-10
feat: gradio speech generation tab
@9cf1e387e26ffd749313673f0aa31470aa090ef3
src/faster_whisper_server/gradio_app.py
--- src/faster_whisper_server/gradio_app.py
+++ src/faster_whisper_server/gradio_app.py
@@ -7,6 +7,15 @@
 from openai import OpenAI
 
 from faster_whisper_server.config import Config, Task
+from faster_whisper_server.hf_utils import PiperModel
+
+# FIX: this won't work on ARM
+from faster_whisper_server.routers.speech import (
+    DEFAULT_VOICE,
+    MAX_SAMPLE_RATE,
+    MIN_SAMPLE_RATE,
+    SUPPORTED_RESPONSE_FORMATS,
+)
 
 TRANSCRIPTION_ENDPOINT = "/v1/audio/transcriptions"
 TRANSLATION_ENDPOINT = "/v1/audio/translations"
@@ -14,12 +23,15 @@
 TIMEOUT = httpx.Timeout(timeout=TIMEOUT_SECONDS)
 
 
-def create_gradio_demo(config: Config) -> gr.Blocks:
+def create_gradio_demo(config: Config) -> gr.Blocks:  # noqa: C901, PLR0915
     base_url = f"http://{config.host}:{config.port}"
     http_client = httpx.Client(base_url=base_url, timeout=TIMEOUT)
     openai_client = OpenAI(base_url=f"{base_url}/v1", api_key="cant-be-empty")
 
-    def handler(file_path: str, model: str, task: Task, temperature: float, stream: bool) -> Generator[str, None, None]:
+    # TODO: make async
+    def whisper_handler(
+        file_path: str, model: str, task: Task, temperature: float, stream: bool
+    ) -> Generator[str, None, None]:
         if task == Task.TRANSCRIBE:
             endpoint = TRANSCRIPTION_ENDPOINT
         elif task == Task.TRANSLATE:
@@ -65,7 +77,7 @@
                 for event in event_source.iter_sse():
                     yield event.data
 
-    def update_model_dropdown() -> gr.Dropdown:
+    def update_whisper_model_dropdown() -> gr.Dropdown:
         models = openai_client.models.list().data
         model_names: list[str] = [model.id for model in models]
         assert config.whisper.model in model_names
@@ -73,37 +85,100 @@
         other_models = [model for model in model_names if model not in recommended_models]
         model_names = list(recommended_models) + other_models
         return gr.Dropdown(
-            # no idea why it's complaining
-            choices=model_names,  # pyright: ignore[reportArgumentType]
+            choices=model_names,
             label="Model",
             value=config.whisper.model,
         )
 
-    model_dropdown = gr.Dropdown(
-        choices=[config.whisper.model],
-        label="Model",
-        value=config.whisper.model,
-    )
-    task_dropdown = gr.Dropdown(
-        choices=[task.value for task in Task],
-        label="Task",
-        value=Task.TRANSCRIBE,
-    )
-    temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0)
-    stream_checkbox = gr.Checkbox(label="Stream", value=True)
-    with gr.Interface(
-        title="Whisper Playground",
-        description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""",  # noqa: E501
-        inputs=[
-            gr.Audio(type="filepath"),
-            model_dropdown,
-            task_dropdown,
-            temperature_slider,
-            stream_checkbox,
-        ],
-        fn=handler,
-        outputs="text",
-        analytics_enabled=False,  # disable telemetry
-    ) as demo:
-        demo.load(update_model_dropdown, inputs=None, outputs=model_dropdown)
+    def update_piper_voices_dropdown() -> gr.Dropdown:
+        res = http_client.get("/v1/audio/speech/voices").raise_for_status()
+        piper_models = [PiperModel.model_validate(x) for x in res.json()]
+        return gr.Dropdown(choices=[model.voice for model in piper_models], label="Voice", value=DEFAULT_VOICE)
+
+    # TODO: make async
+    def handle_audio_speech(text: str, voice: str, response_format: str, speed: float, sample_rate: int | None) -> Path:
+        res = openai_client.audio.speech.create(
+            input=text,
+            model="piper",
+            voice=voice,  # pyright: ignore[reportArgumentType]
+            response_format=response_format,  # pyright: ignore[reportArgumentType]
+            speed=speed,
+            extra_body={"sample_rate": sample_rate},
+        )
+        audio_bytes = res.response.read()
+        file_path = Path(f"audio.{response_format}")
+        with file_path.open("wb") as file:
+            file.write(audio_bytes)
+        return file_path
+
+    with gr.Blocks(title="faster-whisper-server Playground") as demo:
+        gr.Markdown(
+            "### Consider supporting the project by starring the [repository on GitHub](https://github.com/fedirz/faster-whisper-server)."
+        )
+        with gr.Tab(label="Transcribe/Translate"):
+            audio = gr.Audio(type="filepath")
+            model_dropdown = gr.Dropdown(
+                choices=[config.whisper.model],
+                label="Model",
+                value=config.whisper.model,
+            )
+            task_dropdown = gr.Dropdown(
+                choices=[task.value for task in Task],
+                label="Task",
+                value=Task.TRANSCRIBE,
+            )
+            temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0)
+            stream_checkbox = gr.Checkbox(label="Stream", value=True)
+            button = gr.Button("Generate")
+
+            output = gr.Textbox()
+
+            # NOTE: the inputs order must match the `whisper_handler` signature
+            button.click(
+                whisper_handler, [audio, model_dropdown, task_dropdown, temperature_slider, stream_checkbox], output
+            )
+
+        with gr.Tab(label="Speech Generation"):
+            # TODO: add warning about ARM
+            text = gr.Textbox(label="Input Text")
+            voice_dropdown = gr.Dropdown(
+                choices=["en_US-amy-medium"],
+                label="Voice",
+                value="en_US-amy-medium",
+                info="""
+The last part of the voice name is the quality (x_low, low, medium, high).
+Each quality has a different default sample rate:
+- x_low: 16000 Hz
+- low: 16000 Hz
+- medium: 22050 Hz
+- high: 22050 Hz
+""",
+            )
+            response_fromat_dropdown = gr.Dropdown(
+                choices=SUPPORTED_RESPONSE_FORMATS,
+                label="Response Format",
+                value="wav",
+            )
+            speed_slider = gr.Slider(minimum=0.25, maximum=4.0, step=0.05, label="Speed", value=1.0)
+            sample_rate_slider = gr.Number(
+                minimum=MIN_SAMPLE_RATE,
+                maximum=MAX_SAMPLE_RATE,
+                label="Desired Sample Rate",
+                info="""
+Setting this will resample the generated audio to the desired sample rate.
+You may want to set this if you are going to use voices of different qualities but want to keep the same sample rate.
+Default: None (No resampling)
+""",
+                value=lambda: None,
+            )
+            button = gr.Button("Generate Speech")
+            output = gr.Audio(type="filepath")
+            button.click(
+                handle_audio_speech,
+                [text, voice_dropdown, response_fromat_dropdown, speed_slider, sample_rate_slider],
+                output,
+            )
+
+        demo.load(update_whisper_model_dropdown, inputs=None, outputs=model_dropdown)
+        demo.load(update_piper_voices_dropdown, inputs=None, outputs=voice_dropdown)
     return demo
src/faster_whisper_server/hf_utils.py
--- src/faster_whisper_server/hf_utils.py
+++ src/faster_whisper_server/hf_utils.py
@@ -1,5 +1,5 @@
 from collections.abc import Generator
-from functools import lru_cache
+from functools import cached_property, lru_cache
 import json
 import logging
 from pathlib import Path
@@ -8,7 +8,7 @@
 
 import huggingface_hub
 from huggingface_hub.constants import HF_HUB_CACHE
-from pydantic import BaseModel
+from pydantic import BaseModel, Field, computed_field
 
 from faster_whisper_server.api_models import Model
 
@@ -95,13 +95,51 @@
         yield transformed_model
 
 
+PiperVoiceQuality = Literal["x_low", "low", "medium", "high"]
+PIPER_VOICE_QUALITY_SAMPLE_RATE_MAP: dict[PiperVoiceQuality, int] = {
+    "x_low": 16000,
+    "low": 22050,
+    "medium": 22050,
+    "high": 22050,
+}
+
+
 class PiperModel(BaseModel):
-    id: str
+    """Similar structure to the GET /v1/models response but with extra fields."""
+
     object: Literal["model"] = "model"
     created: int
     owned_by: Literal["rhasspy"] = "rhasspy"
-    path: Path
-    config_path: Path
+    model_path: Path = Field(
+        examples=[
+            "/home/nixos/.cache/huggingface/hub/models--rhasspy--piper-voices/snapshots/3d796cc2f2c884b3517c527507e084f7bb245aea/en/en_US/amy/medium/en_US-amy-medium.onnx"
+        ]
+    )
+
+    @computed_field(examples=["rhasspy/piper-voices/en_US-amy-medium"])
+    @cached_property
+    def id(self) -> str:
+        return f"rhasspy/piper-voices/{self.model_path.name.removesuffix(".onnx")}"
+
+    @computed_field(examples=["rhasspy/piper-voices/en_US-amy-medium"])
+    @cached_property
+    def voice(self) -> str:
+        return self.model_path.name.removesuffix(".onnx")
+
+    @computed_field
+    @cached_property
+    def config_path(self) -> Path:
+        return Path(str(self.model_path) + ".json")
+
+    @computed_field
+    @cached_property
+    def quality(self) -> PiperVoiceQuality:
+        return self.id.split("-")[-1]  # pyright: ignore[reportReturnType]
+
+    @computed_field
+    @cached_property
+    def sample_rate(self) -> int:
+        return PIPER_VOICE_QUALITY_SAMPLE_RATE_MAP[self.quality]
 
 
 def get_model_path(model_id: str, *, cache_dir: str | Path | None = None) -> Path | None:
@@ -151,12 +189,9 @@
 def list_piper_models() -> Generator[PiperModel, None, None]:
     model_weights_files = list_model_files("rhasspy/piper-voices", glob_pattern="**/*.onnx")
     for model_weights_file in model_weights_files:
-        model_config_file = model_weights_file.with_suffix(".json")
         yield PiperModel(
-            id=model_weights_file.name,
             created=int(model_weights_file.stat().st_mtime),
-            path=model_weights_file,
-            config_path=model_config_file,
+            model_path=model_weights_file,
         )
 
 
src/faster_whisper_server/routers/speech.py
--- src/faster_whisper_server/routers/speech.py
+++ src/faster_whisper_server/routers/speech.py
@@ -12,7 +12,11 @@
 import soundfile as sf
 
 from faster_whisper_server.dependencies import PiperModelManagerDependency
-from faster_whisper_server.hf_utils import read_piper_voices_config
+from faster_whisper_server.hf_utils import (
+    PiperModel,
+    list_piper_models,
+    read_piper_voices_config,
+)
 
 DEFAULT_MODEL = "piper"
 # https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-response_format
@@ -126,6 +130,14 @@
         ],
     )
     voice: Voice = DEFAULT_VOICE
+    """
+The last part of the voice name is the quality (x_low, low, medium, high).
+Each quality has a different default sample rate:
+- x_low: 16000 Hz
+- low: 16000 Hz
+- medium: 22050 Hz
+- high: 22050 Hz
+    """
     response_format: ResponseFormat = Field(
         DEFAULT_RESPONSE_FORMAT,
         description=f"The format to audio in. Supported formats are {", ".join(SUPPORTED_RESPONSE_FORMATS)}. {", ".join(UNSUPORTED_RESPONSE_FORMATS)} are not supported",  # noqa: E501
@@ -136,6 +148,7 @@
     """The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default."""
     sample_rate: int | None = Field(None, ge=MIN_SAMPLE_RATE, le=MAX_SAMPLE_RATE)
     """Desired sample rate to convert the generated audio to. If not provided, the model's default sample rate will be used."""  # noqa: E501
+    # TODO: document default sample rate for each voice quality
 
     # TODO: move into `Voice`
     @model_validator(mode="after")
@@ -163,3 +176,8 @@
             )
 
         return StreamingResponse(audio_generator, media_type=f"audio/{body.response_format}")
+
+
+@router.get("/v1/audio/speech/voices")
+def list_voices() -> list[PiperModel]:
+    return list(list_piper_models())
Add a comment
List