Fedir Zadniprovskyi 2024-06-23
extract segments to response logic
@3673dfa4e158fa02fa4fea5be0f6e17d12cae09f
faster_whisper_server/main.py
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
@@ -4,7 +4,7 @@
 import time
 from contextlib import asynccontextmanager
 from io import BytesIO
-from typing import Annotated, Generator, Literal, OrderedDict
+from typing import Annotated, Generator, Iterable, Literal, OrderedDict
 
 import huggingface_hub
 from fastapi import (
@@ -21,6 +21,7 @@
 from fastapi.responses import StreamingResponse
 from fastapi.websockets import WebSocketState
 from faster_whisper import WhisperModel
+from faster_whisper.transcribe import Segment, TranscriptionInfo
 from faster_whisper.vad import VadOptions, get_speech_timestamps
 from huggingface_hub.hf_api import ModelInfo
 from pydantic import AfterValidator
@@ -132,8 +133,46 @@
     )
 
 
+def segments_to_response(
+    segments: Iterable[Segment],
+    transcription_info: TranscriptionInfo,
+    response_format: ResponseFormat,
+) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
+    segments = list(segments)
+    if response_format == ResponseFormat.TEXT:
+        return utils.segments_text(segments)
+    elif response_format == ResponseFormat.JSON:
+        return TranscriptionJsonResponse.from_segments(segments)
+    elif response_format == ResponseFormat.VERBOSE_JSON:
+        return TranscriptionVerboseJsonResponse.from_segments(
+            segments, transcription_info
+        )
+
+
 def format_as_sse(data: str) -> str:
     return f"data: {data}\n\n"
+
+
+def segments_to_streaming_response(
+    segments: Iterable[Segment],
+    transcription_info: TranscriptionInfo,
+    response_format: ResponseFormat,
+) -> StreamingResponse:
+    def segment_responses() -> Generator[str, None, None]:
+        for segment in segments:
+            if response_format == ResponseFormat.TEXT:
+                data = segment.text
+            elif response_format == ResponseFormat.JSON:
+                data = TranscriptionJsonResponse.from_segments(
+                    [segment]
+                ).model_dump_json()
+            elif response_format == ResponseFormat.VERBOSE_JSON:
+                data = TranscriptionVerboseJsonResponse.from_segment(
+                    segment, transcription_info
+                ).model_dump_json()
+            yield format_as_sse(data)
+
+    return StreamingResponse(segment_responses(), media_type="text/event-stream")
 
 
 def handle_default_openai_model(model_name: str) -> str:
@@ -168,7 +207,6 @@
     | TranscriptionVerboseJsonResponse
     | StreamingResponse
 ):
-    start = time.perf_counter()
     whisper = load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
@@ -178,36 +216,12 @@
         vad_filter=True,
     )
 
-    if not stream:
-        segments = list(segments)
-        logger.info(
-            f"Translated {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds"
+    if stream:
+        return segments_to_streaming_response(
+            segments, transcription_info, response_format
         )
-        if response_format == ResponseFormat.TEXT:
-            return utils.segments_text(segments)
-        elif response_format == ResponseFormat.JSON:
-            return TranscriptionJsonResponse.from_segments(segments)
-        elif response_format == ResponseFormat.VERBOSE_JSON:
-            return TranscriptionVerboseJsonResponse.from_segments(
-                segments, transcription_info
-            )
     else:
-
-        def segment_responses() -> Generator[str, None, None]:
-            for segment in segments:
-                if response_format == ResponseFormat.TEXT:
-                    data = segment.text
-                elif response_format == ResponseFormat.JSON:
-                    data = TranscriptionJsonResponse.from_segments(
-                        [segment]
-                    ).model_dump_json()
-                elif response_format == ResponseFormat.VERBOSE_JSON:
-                    data = TranscriptionVerboseJsonResponse.from_segment(
-                        segment, transcription_info
-                    ).model_dump_json()
-                yield format_as_sse(data)
-
-        return StreamingResponse(segment_responses(), media_type="text/event-stream")
+        return segments_to_response(segments, transcription_info, response_format)
 
 
 # https://platform.openai.com/docs/api-reference/audio/createTranscription
@@ -234,7 +248,6 @@
     | TranscriptionVerboseJsonResponse
     | StreamingResponse
 ):
-    start = time.perf_counter()
     whisper = load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
@@ -246,39 +259,12 @@
         vad_filter=True,
     )
 
-    if not stream:
-        segments = list(segments)
-        logger.info(
-            f"Transcribed {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds"
+    if stream:
+        return segments_to_streaming_response(
+            segments, transcription_info, response_format
         )
-        if response_format == ResponseFormat.TEXT:
-            return utils.segments_text(segments)
-        elif response_format == ResponseFormat.JSON:
-            return TranscriptionJsonResponse.from_segments(segments)
-        elif response_format == ResponseFormat.VERBOSE_JSON:
-            return TranscriptionVerboseJsonResponse.from_segments(
-                segments, transcription_info
-            )
     else:
-
-        def segment_responses() -> Generator[str, None, None]:
-            for segment in segments:
-                logger.info(
-                    f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds"
-                )
-                if response_format == ResponseFormat.TEXT:
-                    data = segment.text
-                elif response_format == ResponseFormat.JSON:
-                    data = TranscriptionJsonResponse.from_segments(
-                        [segment]
-                    ).model_dump_json()
-                elif response_format == ResponseFormat.VERBOSE_JSON:
-                    data = TranscriptionVerboseJsonResponse.from_segment(
-                        segment, transcription_info
-                    ).model_dump_json()
-                yield format_as_sse(data)
-
-        return StreamingResponse(segment_responses(), media_type="text/event-stream")
+        return segments_to_response(segments, transcription_info, response_format)
 
 
 async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
Add a comment
List