

fix: `timestamp_granularities[]` handling (#28, #58, #81)
@58273d963c5fb8597b3e99a26a9b0ae0293b50b5
--- src/faster_whisper_server/routers/stt.py
+++ src/faster_whisper_server/routers/stt.py
... | ... | @@ -9,6 +9,7 @@ |
9 | 9 |
APIRouter, |
10 | 10 |
Form, |
11 | 11 |
Query, |
12 |
+ Request, |
|
12 | 13 |
Response, |
13 | 14 |
UploadFile, |
14 | 15 |
WebSocket, |
... | ... | @@ -30,6 +31,8 @@ |
30 | 31 |
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt |
31 | 32 |
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config |
32 | 33 |
from faster_whisper_server.server_models import ( |
34 |
+ DEFAULT_TIMESTAMP_GRANULARITIES, |
|
35 |
+ TIMESTAMP_GRANULARITIES_COMBINATIONS, |
|
33 | 36 |
TimestampGranularities, |
34 | 37 |
TranscriptionJsonResponse, |
35 | 38 |
TranscriptionVerboseJsonResponse, |
... | ... | @@ -150,6 +153,18 @@ |
150 | 153 |
return segments_to_response(segments, transcription_info, response_format) |
151 | 154 |
|
152 | 155 |
|
156 |
+# HACK: Since Form() doesn't support `alias`, we need to use a workaround. |
|
157 |
+async def get_timestamp_granularities(request: Request) -> TimestampGranularities: |
|
158 |
+ form = await request.form() |
|
159 |
+ if form.get("timestamp_granularities[]") is None: |
|
160 |
+ return DEFAULT_TIMESTAMP_GRANULARITIES |
|
161 |
+ timestamp_granularities = form.getlist("timestamp_granularities[]") |
|
162 |
+ assert ( |
|
163 |
+ timestamp_granularities in TIMESTAMP_GRANULARITIES_COMBINATIONS |
|
164 |
+ ), f"{timestamp_granularities} is not a valid value for `timestamp_granularities[]`." |
|
165 |
+ return timestamp_granularities |
|
166 |
+ |
|
167 |
+ |
|
153 | 168 |
# https://platform.openai.com/docs/api-reference/audio/createTranscription |
154 | 169 |
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 |
155 | 170 |
@router.post( |
... | ... | @@ -159,6 +174,7 @@ |
159 | 174 |
def transcribe_file( |
160 | 175 |
config: ConfigDependency, |
161 | 176 |
model_manager: ModelManagerDependency, |
177 |
+ request: Request, |
|
162 | 178 |
file: Annotated[UploadFile, Form()], |
163 | 179 |
model: Annotated[ModelName | None, Form()] = None, |
164 | 180 |
language: Annotated[Language | None, Form()] = None, |
... | ... | @@ -167,6 +183,7 @@ |
167 | 183 |
temperature: Annotated[float, Form()] = 0.0, |
168 | 184 |
timestamp_granularities: Annotated[ |
169 | 185 |
TimestampGranularities, |
186 |
+ # WARN: `alias` doesn't actually work. |
|
170 | 187 |
Form(alias="timestamp_granularities[]"), |
171 | 188 |
] = ["segment"], |
172 | 189 |
stream: Annotated[bool, Form()] = False, |
... | ... | @@ -178,6 +195,11 @@ |
178 | 195 |
language = config.default_language |
179 | 196 |
if response_format is None: |
180 | 197 |
response_format = config.default_response_format |
198 |
+ timestamp_granularities = asyncio.run(get_timestamp_granularities(request)) |
|
199 |
+ if timestamp_granularities != DEFAULT_TIMESTAMP_GRANULARITIES and response_format != ResponseFormat.VERBOSE_JSON: |
|
200 |
+ logger.warning( |
|
201 |
+ "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501 |
|
202 |
+ ) |
|
181 | 203 |
whisper = model_manager.load_model(model) |
182 | 204 |
segments, transcription_info = whisper.transcribe( |
183 | 205 |
file.file, |
--- src/faster_whisper_server/server_models.py
+++ src/faster_whisper_server/server_models.py
... | ... | @@ -29,7 +29,7 @@ |
29 | 29 |
language: str |
30 | 30 |
duration: float |
31 | 31 |
text: str |
32 |
- words: list[Word] |
|
32 |
+ words: list[Word] | None |
|
33 | 33 |
segments: list[Segment] |
34 | 34 |
|
35 | 35 |
@classmethod |
... | ... | @@ -38,7 +38,7 @@ |
38 | 38 |
language=transcription_info.language, |
39 | 39 |
duration=segment.end - segment.start, |
40 | 40 |
text=segment.text, |
41 |
- words=(segment.words if isinstance(segment.words, list) else []), |
|
41 |
+ words=segment.words if transcription_info.transcription_options.word_timestamps else None, |
|
42 | 42 |
segments=[segment], |
43 | 43 |
) |
44 | 44 |
|
... | ... | @@ -51,7 +51,7 @@ |
51 | 51 |
duration=transcription_info.duration, |
52 | 52 |
text=segments_to_text(segments), |
53 | 53 |
segments=segments, |
54 |
- words=Word.from_segments(segments), |
|
54 |
+ words=Word.from_segments(segments) if transcription_info.transcription_options.word_timestamps else None, |
|
55 | 55 |
) |
56 | 56 |
|
57 | 57 |
@classmethod |
... | ... | @@ -112,6 +112,7 @@ |
112 | 112 |
TimestampGranularities = list[Literal["segment", "word"]] |
113 | 113 |
|
114 | 114 |
|
115 |
+DEFAULT_TIMESTAMP_GRANULARITIES: TimestampGranularities = ["segment"] |
|
115 | 116 |
TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [ |
116 | 117 |
[], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities |
117 | 118 |
["segment"], |
+++ tests/api_timestamp_granularities_test.py
... | ... | @@ -0,0 +1,43 @@ |
1 | +"""See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501 | |
2 | + | |
3 | +from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities | |
4 | +from openai import AsyncOpenAI | |
5 | +import pytest | |
6 | + | |
7 | + | |
8 | +@pytest.mark.asyncio() | |
9 | +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) | |
10 | +async def test_api_json_response_format_and_timestamp_granularities_combinations( | |
11 | + openai_client: AsyncOpenAI, | |
12 | + timestamp_granularities: TimestampGranularities, | |
13 | +) -> None: | |
14 | + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 | |
15 | + | |
16 | + await openai_client.audio.transcriptions.create( | |
17 | + file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities | |
18 | + ) | |
19 | + | |
20 | + | |
21 | +@pytest.mark.asyncio() | |
22 | +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) | |
23 | +async def test_api_verbose_json_response_format_and_timestamp_granularities_combinations( | |
24 | + openai_client: AsyncOpenAI, | |
25 | + timestamp_granularities: TimestampGranularities, | |
26 | +) -> None: | |
27 | + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 | |
28 | + | |
29 | + transcription = await openai_client.audio.transcriptions.create( | |
30 | + file=audio_file, | |
31 | + model="whisper-1", | |
32 | + response_format="verbose_json", | |
33 | + timestamp_granularities=timestamp_granularities, | |
34 | + ) | |
35 | + | |
36 | + assert transcription.__pydantic_extra__ | |
37 | + if "word" in timestamp_granularities: | |
38 | + assert transcription.__pydantic_extra__.get("segments") is not None | |
39 | + assert transcription.__pydantic_extra__.get("words") is not None | |
40 | + else: | |
41 | + # Unless explicitly requested, words are not present | |
42 | + assert transcription.__pydantic_extra__.get("segments") is not None | |
43 | + assert transcription.__pydantic_extra__.get("words") is None |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?