Fedir Zadniprovskyi 2024-06-03
fix: streaming doesn't use sse #15
@762b5abb7e730aff697b8179f8d1982c07a6be5a
examples/youtube/script.sh
--- examples/youtube/script.sh
+++ examples/youtube/script.sh
@@ -14,7 +14,7 @@
 youtube-dl --extract-audio --audio-format mp3 -o the-evolution-of-the-operating-system.mp3 'https://www.youtube.com/watch?v=1lG7lFLXBIs'
 
 # Make a request to the API to transcribe the audio. The response will be streamed to the terminal and saved to a file. The video is 30 minutes long, so it might take a while to transcribe, especially if you are running this on a CPU. `Systran/faster-distil-whisper-large-v3` takes ~30 seconds on Nvidia L4. `Systran/faster-whisper-tiny.en` takes ~1 minute on Ryzen 7 7700X. The .txt file in the example was transcribed using `Systran/faster-distil-whisper-large-v3`.
-curl -s http://localhost:8000/v1/audio/transcriptions -F "file=@the-evolution-of-the-operating-system.mp3" -F "stream=true" -F "language=en" -F "response_format=text" | tee the-evolution-of-the-operating-system.txt
+curl -s http://localhost:8000/v1/audio/transcriptions -F "file=@the-evolution-of-the-operating-system.mp3" -F "language=en" -F "response_format=text" | tee the-evolution-of-the-operating-system.txt
 
 # Here I'm using `aichat` which is a CLI LLM client. You could use any other client that supports attaching/uploading files. https://github.com/sigoden/aichat
 aichat -m openai:gpt-4o -f the-evolution-of-the-operating-system.txt 'What companies are mentioned in the following Youtube video transcription? Responed with just a list of names'
faster_whisper_server/main.py
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
@@ -4,7 +4,7 @@
 import time
 from contextlib import asynccontextmanager
 from io import BytesIO
-from typing import Annotated, Literal, OrderedDict
+from typing import Annotated, Generator, Literal, OrderedDict
 
 import huggingface_hub
 from fastapi import (
@@ -127,6 +127,10 @@
     )
 
 
+def format_as_sse(data: str) -> str:
+    return f"data: {data}\n\n"
+
+
 @app.post("/v1/audio/translations")
 def translate_file(
     file: Annotated[UploadFile, Form()],
@@ -146,19 +150,6 @@
         vad_filter=True,
     )
 
-    def segment_responses():
-        for segment in segments:
-            if response_format == ResponseFormat.TEXT:
-                yield segment.text
-            elif response_format == ResponseFormat.JSON:
-                yield TranscriptionJsonResponse.from_segments(
-                    [segment]
-                ).model_dump_json()
-            elif response_format == ResponseFormat.VERBOSE_JSON:
-                yield TranscriptionVerboseJsonResponse.from_segment(
-                    segment, transcription_info
-                ).model_dump_json()
-
     if not stream:
         segments = list(segments)
         logger.info(
@@ -173,6 +164,21 @@
                 segments, transcription_info
             )
     else:
+
+        def segment_responses() -> Generator[str, None, None]:
+            for segment in segments:
+                if response_format == ResponseFormat.TEXT:
+                    data = segment.text
+                elif response_format == ResponseFormat.JSON:
+                    data = TranscriptionJsonResponse.from_segments(
+                        [segment]
+                    ).model_dump_json()
+                elif response_format == ResponseFormat.VERBOSE_JSON:
+                    data = TranscriptionVerboseJsonResponse.from_segment(
+                        segment, transcription_info
+                    ).model_dump_json()
+                yield format_as_sse(data)
+
         return StreamingResponse(segment_responses(), media_type="text/event-stream")
 
 
@@ -204,22 +210,6 @@
         vad_filter=True,
     )
 
-    def segment_responses():
-        for segment in segments:
-            logger.info(
-                f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds"
-            )
-            if response_format == ResponseFormat.TEXT:
-                yield segment.text
-            elif response_format == ResponseFormat.JSON:
-                yield TranscriptionJsonResponse.from_segments(
-                    [segment]
-                ).model_dump_json()
-            elif response_format == ResponseFormat.VERBOSE_JSON:
-                yield TranscriptionVerboseJsonResponse.from_segment(
-                    segment, transcription_info
-                ).model_dump_json()
-
     if not stream:
         segments = list(segments)
         logger.info(
@@ -234,6 +224,24 @@
                 segments, transcription_info
             )
     else:
+
+        def segment_responses() -> Generator[str, None, None]:
+            for segment in segments:
+                logger.info(
+                    f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds"
+                )
+                if response_format == ResponseFormat.TEXT:
+                    data = segment.text
+                elif response_format == ResponseFormat.JSON:
+                    data = TranscriptionJsonResponse.from_segments(
+                        [segment]
+                    ).model_dump_json()
+                elif response_format == ResponseFormat.VERBOSE_JSON:
+                    data = TranscriptionVerboseJsonResponse.from_segment(
+                        segment, transcription_info
+                    ).model_dump_json()
+                yield format_as_sse(data)
+
         return StreamingResponse(segment_responses(), media_type="text/event-stream")
 
 
 
tests/sse_test.py (added)
+++ tests/sse_test.py
@@ -0,0 +1,82 @@
+import json
+import os
+from typing import Generator
+
+import pytest
+from fastapi.testclient import TestClient
+from httpx_sse import connect_sse
+
+from faster_whisper_server.main import app
+from faster_whisper_server.server_models import (
+    TranscriptionJsonResponse,
+    TranscriptionVerboseJsonResponse,
+)
+
+
+@pytest.fixture()
+def client() -> Generator[TestClient, None, None]:
+    with TestClient(app) as client:
+        yield client
+
+
+FILE_PATHS = ["audio.wav"]  # HACK
+ENDPOINTS = [
+    "/v1/audio/transcriptions",
+    "/v1/audio/translations",
+]
+
+
+parameters = [
+    (file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS
+]
+
+
+@pytest.mark.parametrize("file_path,endpoint", parameters)
+def test_streaming_transcription_text(
+    client: TestClient, file_path: str, endpoint: str
+):
+    extension = os.path.splitext(file_path)[1]
+    with open(file_path, "rb") as f:
+        data = f.read()
+    kwargs = {
+        "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
+        "data": {"response_format": "text", "stream": True},
+    }
+    with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
+        for event in event_source.iter_sse():
+            print(event)
+            assert (
+                len(event.data) > 1
+            )  # HACK: 1 because of the space character that's always prepended
+
+
+@pytest.mark.parametrize("file_path,endpoint", parameters)
+def test_streaming_transcription_json(
+    client: TestClient, file_path: str, endpoint: str
+):
+    extension = os.path.splitext(file_path)[1]
+    with open(file_path, "rb") as f:
+        data = f.read()
+    kwargs = {
+        "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
+        "data": {"response_format": "json", "stream": True},
+    }
+    with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
+        for event in event_source.iter_sse():
+            TranscriptionJsonResponse(**json.loads(event.data))
+
+
+@pytest.mark.parametrize("file_path,endpoint", parameters)
+def test_streaming_transcription_verbose_json(
+    client: TestClient, file_path: str, endpoint: str
+):
+    extension = os.path.splitext(file_path)[1]
+    with open(file_path, "rb") as f:
+        data = f.read()
+    kwargs = {
+        "files": {"file": (f"audio.{extension}", data, f"audio/{extension}")},
+        "data": {"response_format": "verbose_json", "stream": True},
+    }
+    with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
+        for event in event_source.iter_sse():
+            TranscriptionVerboseJsonResponse(**json.loads(event.data))
Add a comment
List