Fedir Zadniprovskyi 2024-07-20
feat: handle srt and vtt response formats
@47d26173bf1f002536c51d26ca56ba53e47c1dac
faster_whisper_server/config.py
--- faster_whisper_server/config.py
+++ faster_whisper_server/config.py
@@ -15,35 +15,8 @@
     TEXT = "text"
     JSON = "json"
     VERBOSE_JSON = "verbose_json"
-    # NOTE: While inspecting outputs of these formats with `curl`, I noticed there's one or two "\n" inserted at the end of the response. # noqa: E501
-
-    # VTT = "vtt" # TODO
-    # 1
-    # 00:00:00,000 --> 00:00:09,220
-    # In his video on Large Language Models or LLMs, OpenAI co-founder and YouTuber Andrej Karpathy
-    #
-    # 2
-    # 00:00:09,220 --> 00:00:12,280
-    # likened LLMs to operating systems.
-    #
-    # 3
-    # 00:00:12,280 --> 00:00:13,280
-    # Karpathy said,
-    #
-    # SRT = "srt" # TODO
-    # WEBVTT
-    #
-    # 00:00:00.000 --> 00:00:09.220
-    # In his video on Large Language Models or LLMs, OpenAI co-founder and YouTuber Andrej Karpathy
-    #
-    # 00:00:09.220 --> 00:00:12.280
-    # likened LLMs to operating systems.
-    #
-    # 00:00:12.280 --> 00:00:13.280
-    # Karpathy said,
-    #
-    # 00:00:13.280 --> 00:00:19.799
-    # I see a lot of equivalence between this new LLM OS and operating systems of today.
+    SRT = "srt"
+    VTT = "vtt"
 
 
 class Device(enum.StrEnum):
faster_whisper_server/core.py
--- faster_whisper_server/core.py
+++ faster_whisper_server/core.py
@@ -172,6 +172,62 @@
     return "".join(segment.text for segment in segments).strip()
 
 
+def srt_format_timestamp(ts: float) -> str:
+    hours = ts // 3600
+    minutes = (ts % 3600) // 60
+    seconds = ts % 60
+    milliseconds = (ts * 1000) % 1000
+    return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
+
+
+def test_srt_format_timestamp() -> None:
+    assert srt_format_timestamp(0.0) == "00:00:00,000"
+    assert srt_format_timestamp(1.0) == "00:00:01,000"
+    assert srt_format_timestamp(1.234) == "00:00:01,234"
+    assert srt_format_timestamp(60.0) == "00:01:00,000"
+    assert srt_format_timestamp(61.0) == "00:01:01,000"
+    assert srt_format_timestamp(61.234) == "00:01:01,234"
+    assert srt_format_timestamp(3600.0) == "01:00:00,000"
+    assert srt_format_timestamp(3601.0) == "01:00:01,000"
+    assert srt_format_timestamp(3601.234) == "01:00:01,234"
+    assert srt_format_timestamp(23423.4234) == "06:30:23,423"
+
+
+def vtt_format_timestamp(ts: float) -> str:
+    hours = ts // 3600
+    minutes = (ts % 3600) // 60
+    seconds = ts % 60
+    milliseconds = (ts * 1000) % 1000
+    return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
+
+
+def test_vtt_format_timestamp() -> None:
+    assert vtt_format_timestamp(0.0) == "00:00:00.000"
+    assert vtt_format_timestamp(1.0) == "00:00:01.000"
+    assert vtt_format_timestamp(1.234) == "00:00:01.234"
+    assert vtt_format_timestamp(60.0) == "00:01:00.000"
+    assert vtt_format_timestamp(61.0) == "00:01:01.000"
+    assert vtt_format_timestamp(61.234) == "00:01:01.234"
+    assert vtt_format_timestamp(3600.0) == "01:00:00.000"
+    assert vtt_format_timestamp(3601.0) == "01:00:01.000"
+    assert vtt_format_timestamp(3601.234) == "01:00:01.234"
+    assert vtt_format_timestamp(23423.4234) == "06:30:23.423"
+
+
+def segments_to_vtt(segment: Segment, i: int) -> str:
+    start = segment.start if i > 0 else 0.0
+    result = f"{vtt_format_timestamp(start)} --> {vtt_format_timestamp(segment.end)}\n{segment.text}\n\n"
+
+    if i == 0:
+        return f"WEBVTT\n\n{result}"
+    else:
+        return result
+
+
+def segments_to_srt(segment: Segment, i: int) -> str:
+    return f"{i + 1}\n{srt_format_timestamp(segment.start)} --> {srt_format_timestamp(segment.end)}\n{segment.text}\n\n"
+
+
 def canonicalize_word(text: str) -> str:
     text = text.lower()
     # Remove non-alphabetic characters using regular expression
faster_whisper_server/main.py
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
@@ -33,7 +33,7 @@
     Task,
     config,
 )
-from faster_whisper_server.core import Segment, segments_to_text
+from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
 from faster_whisper_server.logger import logger
 from faster_whisper_server.server_models import (
     ModelListResponse,
@@ -154,14 +154,28 @@
     segments: Iterable[Segment],
     transcription_info: TranscriptionInfo,
     response_format: ResponseFormat,
-) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
+) -> Response:
     segments = list(segments)
     if response_format == ResponseFormat.TEXT:  # noqa: RET503
-        return segments_to_text(segments)
+        return Response(segments_to_text(segments), media_type="text/plain")
     elif response_format == ResponseFormat.JSON:
-        return TranscriptionJsonResponse.from_segments(segments)
+        return Response(
+            TranscriptionJsonResponse.from_segments(segments).model_dump_json(),
+            media_type="application/json",
+        )
     elif response_format == ResponseFormat.VERBOSE_JSON:
-        return TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info)
+        return Response(
+            TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info).model_dump_json(),
+            media_type="application/json",
+        )
+    elif response_format == ResponseFormat.VTT:
+        return Response(
+            "".join(segments_to_vtt(segment, i) for i, segment in enumerate(segments)), media_type="text/vtt"
+        )
+    elif response_format == ResponseFormat.SRT:
+        return Response(
+            "".join(segments_to_srt(segment, i) for i, segment in enumerate(segments)), media_type="text/plain"
+        )
 
 
 def format_as_sse(data: str) -> str:
@@ -174,13 +188,17 @@
     response_format: ResponseFormat,
 ) -> StreamingResponse:
     def segment_responses() -> Generator[str, None, None]:
-        for segment in segments:
+        for i, segment in enumerate(segments):
             if response_format == ResponseFormat.TEXT:
                 data = segment.text
             elif response_format == ResponseFormat.JSON:
                 data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
             elif response_format == ResponseFormat.VERBOSE_JSON:
                 data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
+            elif response_format == ResponseFormat.VTT:
+                data = segments_to_vtt(segment, i)
+            elif response_format == ResponseFormat.SRT:
+                data = segments_to_srt(segment, i)
             yield format_as_sse(data)
 
     return StreamingResponse(segment_responses(), media_type="text/event-stream")
@@ -211,7 +229,7 @@
     response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
     temperature: Annotated[float, Form()] = 0.0,
     stream: Annotated[bool, Form()] = False,
-) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
+) -> Response | StreamingResponse:
     whisper = load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
@@ -247,7 +265,7 @@
     ] = ["segment"],
     stream: Annotated[bool, Form()] = False,
     hotwords: Annotated[str | None, Form()] = None,
-) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
+) -> Response | StreamingResponse:
     whisper = load_model(model)
     segments, transcription_info = whisper.transcribe(
         file.file,
pyproject.toml
--- pyproject.toml
+++ pyproject.toml
@@ -18,7 +18,7 @@
 ]
 
 [project.optional-dependencies]
-dev = ["ruff==0.5.3", "pytest", "basedpyright==1.13.0", "pytest-xdist"]
+dev = ["ruff==0.5.3", "pytest", "webvtt-py", "srt", "basedpyright==1.13.0", "pytest-xdist"]
 
 other = ["youtube-dl @ git+https://github.com/ytdl-org/youtube-dl.git@37cea84f775129ad715b9bcd617251c831fcc980", "aider-chat==0.39.0"]
 
requirements-all.txt
--- requirements-all.txt
+++ requirements-all.txt
@@ -496,7 +496,7 @@
     # via aider-chat
 semantic-version==2.10.0
     # via gradio
-setuptools==71.0.3
+setuptools==71.0.4
     # via ctranslate2
 shellingham==1.5.4
     # via typer
@@ -524,11 +524,13 @@
     # via
     #   aider-chat
     #   beautifulsoup4
+srt==3.5.3
+    # via faster-whisper-server (pyproject.toml)
 starlette==0.37.2
     # via fastapi
 streamlit==1.35.0
     # via aider-chat
-sympy==1.13.0
+sympy==1.13.1
     # via onnxruntime
 tenacity==8.3.0
     # via
@@ -623,6 +625,8 @@
     # via
     #   gradio-client
     #   uvicorn
+webvtt-py==0.5.1
+    # via faster-whisper-server (pyproject.toml)
 yarl==1.9.4
     # via
     #   aider-chat
requirements-dev.txt
--- requirements-dev.txt
+++ requirements-dev.txt
@@ -146,7 +146,7 @@
     #   pandas
 onnxruntime==1.18.1
     # via faster-whisper
-openai==1.35.15
+openai==1.36.0
     # via faster-whisper-server (pyproject.toml)
 orjson==3.10.6
     # via gradio
@@ -235,7 +235,7 @@
     #   gradio
 semantic-version==2.10.0
     # via gradio
-setuptools==71.0.3
+setuptools==71.0.4
     # via ctranslate2
 shellingham==1.5.4
     # via typer
@@ -248,9 +248,11 @@
     #   openai
 soundfile==0.12.1
     # via faster-whisper-server (pyproject.toml)
+srt==3.5.3
+    # via faster-whisper-server (pyproject.toml)
 starlette==0.37.2
     # via fastapi
-sympy==1.13.0
+sympy==1.13.1
     # via onnxruntime
 tokenizers==0.19.1
     # via faster-whisper
@@ -295,3 +297,5 @@
     # via
     #   gradio-client
     #   uvicorn
+webvtt-py==0.5.1
+    # via faster-whisper-server (pyproject.toml)
requirements.txt
--- requirements.txt
+++ requirements.txt
@@ -138,7 +138,7 @@
     #   pandas
 onnxruntime==1.18.1
     # via faster-whisper
-openai==1.35.15
+openai==1.36.0
     # via faster-whisper-server (pyproject.toml)
 orjson==3.10.6
     # via gradio
@@ -216,7 +216,7 @@
     # via gradio
 semantic-version==2.10.0
     # via gradio
-setuptools==71.0.3
+setuptools==71.0.4
     # via ctranslate2
 shellingham==1.5.4
     # via typer
@@ -231,7 +231,7 @@
     # via faster-whisper-server (pyproject.toml)
 starlette==0.37.2
     # via fastapi
-sympy==1.13.0
+sympy==1.13.1
     # via onnxruntime
 tokenizers==0.19.1
     # via faster-whisper
tests/conftest.py
--- tests/conftest.py
+++ tests/conftest.py
@@ -1,10 +1,12 @@
 from collections.abc import Generator
 import logging
+import os
 
 from fastapi.testclient import TestClient
 from openai import OpenAI
 import pytest
 
+os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
 from faster_whisper_server.main import app
 
 disable_loggers = ["multipart.multipart", "faster_whisper"]
tests/sse_test.py
--- tests/sse_test.py
+++ tests/sse_test.py
@@ -4,6 +4,9 @@
 from fastapi.testclient import TestClient
 from httpx_sse import connect_sse
 import pytest
+import srt
+import webvtt
+import webvtt.vtt
 
 from faster_whisper_server.server_models import (
     TranscriptionJsonResponse,
@@ -61,3 +64,38 @@
     with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
         for event in event_source.iter_sse():
             TranscriptionVerboseJsonResponse(**json.loads(event.data))
+
+
+def test_transcription_vtt(client: TestClient) -> None:
+    with open("audio.wav", "rb") as f:
+        data = f.read()
+    kwargs = {
+        "files": {"file": ("audio.wav", data, "audio/wav")},
+        "data": {"response_format": "vtt", "stream": False},
+    }
+    response = client.post("/v1/audio/transcriptions", **kwargs)
+    assert response.status_code == 200
+    assert response.headers["content-type"] == "text/vtt; charset=utf-8"
+    text = response.text
+    webvtt.from_string(text)
+    text = text.replace("WEBVTT", "YO")
+    with pytest.raises(webvtt.vtt.MalformedFileError):
+        webvtt.from_string(text)
+
+
+def test_transcription_srt(client: TestClient) -> None:
+    with open("audio.wav", "rb") as f:
+        data = f.read()
+    kwargs = {
+        "files": {"file": ("audio.wav", data, "audio/wav")},
+        "data": {"response_format": "srt", "stream": False},
+    }
+    response = client.post("/v1/audio/transcriptions", **kwargs)
+    assert response.status_code == 200
+    assert "text/plain" in response.headers["content-type"]
+
+    text = response.text
+    list(srt.parse(text))
+    text = text.replace("1", "YO")
+    with pytest.raises(srt.SRTParseError):
+        list(srt.parse(text))
Add a comment
List