

test: capture openai's param handling
@db3aa0e9b3d7108bbff1c81322ca0251af245292
--- src/faster_whisper_server/routers/stt.py
+++ src/faster_whisper_server/routers/stt.py
... | ... | @@ -3,7 +3,7 @@ |
3 | 3 |
import asyncio |
4 | 4 |
from io import BytesIO |
5 | 5 |
import logging |
6 |
-from typing import TYPE_CHECKING, Annotated, Literal |
|
6 |
+from typing import TYPE_CHECKING, Annotated |
|
7 | 7 |
|
8 | 8 |
from fastapi import ( |
9 | 9 |
APIRouter, |
... | ... | @@ -30,6 +30,7 @@ |
30 | 30 |
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt |
31 | 31 |
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config |
32 | 32 |
from faster_whisper_server.server_models import ( |
33 |
+ TimestampGranularities, |
|
33 | 34 |
TranscriptionJsonResponse, |
34 | 35 |
TranscriptionVerboseJsonResponse, |
35 | 36 |
) |
... | ... | @@ -165,7 +166,7 @@ |
165 | 166 |
response_format: Annotated[ResponseFormat | None, Form()] = None, |
166 | 167 |
temperature: Annotated[float, Form()] = 0.0, |
167 | 168 |
timestamp_granularities: Annotated[ |
168 |
- list[Literal["segment", "word"]], |
|
169 |
+ TimestampGranularities, |
|
169 | 170 |
Form(alias="timestamp_granularities[]"), |
170 | 171 |
] = ["segment"], |
171 | 172 |
stream: Annotated[bool, Form()] = False, |
--- src/faster_whisper_server/server_models.py
+++ src/faster_whisper_server/server_models.py
... | ... | @@ -107,3 +107,15 @@ |
107 | 107 |
] |
108 | 108 |
}, |
109 | 109 |
) |
110 |
+ |
|
111 |
+ |
|
112 |
+TimestampGranularities = list[Literal["segment", "word"]] |
|
113 |
+ |
|
114 |
+ |
|
115 |
+TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [ |
|
116 |
+ [], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities |
|
117 |
+ ["segment"], |
|
118 |
+ ["word"], |
|
119 |
+ ["word", "segment"], |
|
120 |
+ ["segment", "word"], # same as ["word", "segment"] but order is different |
|
121 |
+] |
--- tests/conftest.py
+++ tests/conftest.py
... | ... | @@ -5,7 +5,7 @@ |
5 | 5 |
from fastapi.testclient import TestClient |
6 | 6 |
from faster_whisper_server.main import create_app |
7 | 7 |
from httpx import ASGITransport, AsyncClient |
8 |
-from openai import OpenAI |
|
8 |
+from openai import AsyncOpenAI, OpenAI |
|
9 | 9 |
import pytest |
10 | 10 |
import pytest_asyncio |
11 | 11 |
|
... | ... | @@ -35,3 +35,10 @@ |
35 | 35 |
@pytest.fixture() |
36 | 36 |
def openai_client(client: TestClient) -> OpenAI: |
37 | 37 |
return OpenAI(api_key="cant-be-empty", http_client=client) |
38 |
+ |
|
39 |
+ |
|
40 |
+@pytest.fixture() |
|
41 |
+def actual_openai_client() -> AsyncOpenAI: |
|
42 |
+ return AsyncOpenAI( |
|
43 |
+ base_url="https://api.openai.com/v1" |
|
44 |
+ ) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value |
+++ tests/openai_timestamp_granularities_test.py
... | ... | @@ -0,0 +1,56 @@ |
1 | +"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501 | |
2 | + | |
3 | +from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities | |
4 | +from openai import AsyncOpenAI, BadRequestError | |
5 | +import pytest | |
6 | + | |
7 | + | |
8 | +@pytest.mark.asyncio() | |
9 | +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) | |
10 | +async def test_openai_json_response_format_and_timestamp_granularities_combinations( | |
11 | + actual_openai_client: AsyncOpenAI, | |
12 | + timestamp_granularities: TimestampGranularities, | |
13 | +) -> None: | |
14 | + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 | |
15 | + | |
16 | + if "word" in timestamp_granularities: | |
17 | + with pytest.raises(BadRequestError): | |
18 | + await actual_openai_client.audio.transcriptions.create( | |
19 | + file=audio_file, | |
20 | + model="whisper-1", | |
21 | + response_format="json", | |
22 | + timestamp_granularities=timestamp_granularities, | |
23 | + ) | |
24 | + else: | |
25 | + await actual_openai_client.audio.transcriptions.create( | |
26 | + file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities | |
27 | + ) | |
28 | + | |
29 | + | |
30 | +@pytest.mark.asyncio() | |
31 | +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) | |
32 | +async def test_openai_verbose_json_response_format_and_timestamp_granularities_combinations( | |
33 | + actual_openai_client: AsyncOpenAI, | |
34 | + timestamp_granularities: TimestampGranularities, | |
35 | +) -> None: | |
36 | + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 | |
37 | + | |
38 | + transcription = await actual_openai_client.audio.transcriptions.create( | |
39 | + file=audio_file, | |
40 | + model="whisper-1", | |
41 | + response_format="verbose_json", | |
42 | + timestamp_granularities=timestamp_granularities, | |
43 | + ) | |
44 | + | |
45 | + assert transcription.__pydantic_extra__ | |
46 | + if timestamp_granularities == ["word"]: | |
47 | + # This is an exception where segments are not present | |
48 | + assert transcription.__pydantic_extra__.get("segments") is None | |
49 | + assert transcription.__pydantic_extra__.get("words") is not None | |
50 | + elif "word" in timestamp_granularities: | |
51 | + assert transcription.__pydantic_extra__.get("segments") is not None | |
52 | + assert transcription.__pydantic_extra__.get("words") is not None | |
53 | + else: | |
54 | + # Unless explicitly requested, words are not present | |
55 | + assert transcription.__pydantic_extra__.get("segments") is not None | |
56 | + assert transcription.__pydantic_extra__.get("words") is None |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?