Fedir Zadniprovskyi 2024-09-22
test: capture openai's param handling
@db3aa0e9b3d7108bbff1c81322ca0251af245292
src/faster_whisper_server/routers/stt.py
--- src/faster_whisper_server/routers/stt.py
+++ src/faster_whisper_server/routers/stt.py
@@ -3,7 +3,7 @@
 import asyncio
 from io import BytesIO
 import logging
-from typing import TYPE_CHECKING, Annotated, Literal
+from typing import TYPE_CHECKING, Annotated
 
 from fastapi import (
     APIRouter,
@@ -30,6 +30,7 @@
 from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
 from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
 from faster_whisper_server.server_models import (
+    TimestampGranularities,
     TranscriptionJsonResponse,
     TranscriptionVerboseJsonResponse,
 )
@@ -165,7 +166,7 @@
     response_format: Annotated[ResponseFormat | None, Form()] = None,
     temperature: Annotated[float, Form()] = 0.0,
     timestamp_granularities: Annotated[
-        list[Literal["segment", "word"]],
+        TimestampGranularities,
         Form(alias="timestamp_granularities[]"),
     ] = ["segment"],
     stream: Annotated[bool, Form()] = False,
src/faster_whisper_server/server_models.py
--- src/faster_whisper_server/server_models.py
+++ src/faster_whisper_server/server_models.py
@@ -107,3 +107,15 @@
             ]
         },
     )
+
+
+TimestampGranularities = list[Literal["segment", "word"]]
+
+
+TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [
+    [],  # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities
+    ["segment"],
+    ["word"],
+    ["word", "segment"],
+    ["segment", "word"],  # same as ["word", "segment"] but order is different
+]
tests/conftest.py
--- tests/conftest.py
+++ tests/conftest.py
@@ -5,7 +5,7 @@
 from fastapi.testclient import TestClient
 from faster_whisper_server.main import create_app
 from httpx import ASGITransport, AsyncClient
-from openai import OpenAI
+from openai import AsyncOpenAI, OpenAI
 import pytest
 import pytest_asyncio
 
@@ -35,3 +35,10 @@
 @pytest.fixture()
 def openai_client(client: TestClient) -> OpenAI:
     return OpenAI(api_key="cant-be-empty", http_client=client)
+
+
+@pytest.fixture()
+def actual_openai_client() -> AsyncOpenAI:
+    return AsyncOpenAI(
+        base_url="https://api.openai.com/v1"
+    )  # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
 
tests/openai_timestamp_granularities_test.py (added)
+++ tests/openai_timestamp_granularities_test.py
@@ -0,0 +1,56 @@
+"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters."""  # noqa: E501
+
+from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
+from openai import AsyncOpenAI, BadRequestError
+import pytest
+
+
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
+async def test_openai_json_response_format_and_timestamp_granularities_combinations(
+    actual_openai_client: AsyncOpenAI,
+    timestamp_granularities: TimestampGranularities,
+) -> None:
+    audio_file = open("audio.wav", "rb")  # noqa: SIM115, ASYNC230
+
+    if "word" in timestamp_granularities:
+        with pytest.raises(BadRequestError):
+            await actual_openai_client.audio.transcriptions.create(
+                file=audio_file,
+                model="whisper-1",
+                response_format="json",
+                timestamp_granularities=timestamp_granularities,
+            )
+    else:
+        await actual_openai_client.audio.transcriptions.create(
+            file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
+        )
+
+
+@pytest.mark.asyncio()
+@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
+async def test_openai_verbose_json_response_format_and_timestamp_granularities_combinations(
+    actual_openai_client: AsyncOpenAI,
+    timestamp_granularities: TimestampGranularities,
+) -> None:
+    audio_file = open("audio.wav", "rb")  # noqa: SIM115, ASYNC230
+
+    transcription = await actual_openai_client.audio.transcriptions.create(
+        file=audio_file,
+        model="whisper-1",
+        response_format="verbose_json",
+        timestamp_granularities=timestamp_granularities,
+    )
+
+    assert transcription.__pydantic_extra__
+    if timestamp_granularities == ["word"]:
+        # This is an exception where segments are not present
+        assert transcription.__pydantic_extra__.get("segments") is None
+        assert transcription.__pydantic_extra__.get("words") is not None
+    elif "word" in timestamp_granularities:
+        assert transcription.__pydantic_extra__.get("segments") is not None
+        assert transcription.__pydantic_extra__.get("words") is not None
+    else:
+        # Unless explicitly requested, words are not present
+        assert transcription.__pydantic_extra__.get("segments") is not None
+        assert transcription.__pydantic_extra__.get("words") is None
Add a comment
List