Fedir Zadniprovskyi 01-10
chore: use async handlers in gradio
@b9c9ba8472852eea8e690d51caee53e274f40c25
src/faster_whisper_server/gradio_app.py
--- src/faster_whisper_server/gradio_app.py
+++ src/faster_whisper_server/gradio_app.py
@@ -1,10 +1,10 @@
-from collections.abc import Generator
+from collections.abc import AsyncGenerator
 from pathlib import Path
 
 import gradio as gr
 import httpx
-from httpx_sse import connect_sse
-from openai import OpenAI
+from httpx_sse import aconnect_sse
+from openai import AsyncOpenAI
 
 from faster_whisper_server.config import Config, Task
 from faster_whisper_server.hf_utils import PiperModel
@@ -25,13 +25,19 @@
 
 def create_gradio_demo(config: Config) -> gr.Blocks:  # noqa: C901, PLR0915
     base_url = f"http://{config.host}:{config.port}"
-    http_client = httpx.Client(base_url=base_url, timeout=TIMEOUT)
-    openai_client = OpenAI(base_url=f"{base_url}/v1", api_key="cant-be-empty")
+    # TODO: test that auth works
+    http_client = httpx.AsyncClient(
+        base_url=base_url,
+        timeout=TIMEOUT,
+        headers={"Authorization": f"Bearer {config.api_key}"} if config.api_key else {},
+    )
+    openai_client = AsyncOpenAI(
+        base_url=f"{base_url}/v1", api_key=config.api_key if config.api_key else "cant-be-empty"
+    )
 
-    # TODO: make async
-    def whisper_handler(
+    async def whisper_handler(
         file_path: str, model: str, task: Task, temperature: float, stream: bool
-    ) -> Generator[str, None, None]:
+    ) -> AsyncGenerator[str, None]:
         if task == Task.TRANSCRIBE:
             endpoint = TRANSCRIPTION_ENDPOINT
         elif task == Task.TRANSLATE:
@@ -39,15 +45,15 @@
 
         if stream:
             previous_transcription = ""
-            for transcription in streaming_audio_task(file_path, endpoint, temperature, model):
+            async for transcription in streaming_audio_task(file_path, endpoint, temperature, model):
                 previous_transcription += transcription
                 yield previous_transcription
         else:
-            yield audio_task(file_path, endpoint, temperature, model)
+            yield await audio_task(file_path, endpoint, temperature, model)
 
-    def audio_task(file_path: str, endpoint: str, temperature: float, model: str) -> str:
-        with Path(file_path).open("rb") as file:
-            response = http_client.post(
+    async def audio_task(file_path: str, endpoint: str, temperature: float, model: str) -> str:
+        with Path(file_path).open("rb") as file:  # noqa: ASYNC230
+            response = await http_client.post(
                 endpoint,
                 files={"file": file},
                 data={
@@ -60,10 +66,10 @@
         response.raise_for_status()
         return response.text
 
-    def streaming_audio_task(
+    async def streaming_audio_task(
         file_path: str, endpoint: str, temperature: float, model: str
-    ) -> Generator[str, None, None]:
-        with Path(file_path).open("rb") as file:
+    ) -> AsyncGenerator[str, None]:
+        with Path(file_path).open("rb") as file:  # noqa: ASYNC230
             kwargs = {
                 "files": {"file": file},
                 "data": {
@@ -73,12 +79,12 @@
                     "stream": True,
                 },
             }
-            with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
-                for event in event_source.iter_sse():
+            async with aconnect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
+                async for event in event_source.aiter_sse():
                     yield event.data
 
-    def update_whisper_model_dropdown() -> gr.Dropdown:
-        models = openai_client.models.list().data
+    async def update_whisper_model_dropdown() -> gr.Dropdown:
+        models = (await openai_client.models.list()).data
         model_names: list[str] = [model.id for model in models]
         assert config.whisper.model in model_names
         recommended_models = {model for model in model_names if model.startswith("Systran")}
@@ -90,14 +96,15 @@
             value=config.whisper.model,
         )
 
-    def update_piper_voices_dropdown() -> gr.Dropdown:
-        res = http_client.get("/v1/audio/speech/voices").raise_for_status()
+    async def update_piper_voices_dropdown() -> gr.Dropdown:
+        res = (await http_client.get("/v1/audio/speech/voices")).raise_for_status()
         piper_models = [PiperModel.model_validate(x) for x in res.json()]
         return gr.Dropdown(choices=[model.voice for model in piper_models], label="Voice", value=DEFAULT_VOICE)
 
-    # TODO: make async
-    def handle_audio_speech(text: str, voice: str, response_format: str, speed: float, sample_rate: int | None) -> Path:
-        res = openai_client.audio.speech.create(
+    async def handle_audio_speech(
+        text: str, voice: str, response_format: str, speed: float, sample_rate: int | None
+    ) -> Path:
+        res = await openai_client.audio.speech.create(
             input=text,
             model="piper",
             voice=voice,  # pyright: ignore[reportArgumentType]
@@ -107,7 +114,7 @@
         )
         audio_bytes = res.response.read()
         file_path = Path(f"audio.{response_format}")
-        with file_path.open("wb") as file:
+        with file_path.open("wb") as file:  # noqa: ASYNC230
             file.write(audio_bytes)
         return file_path
 
Add a comment
List