

refactor: update response model names and module name
@b68af2b6a548ecd9a6a95165f0a63e74e9b3f47a
+++ src/faster_whisper_server/api_models.py
... | ... | @@ -0,0 +1,208 @@ |
1 | +from __future__ import annotations | |
2 | + | |
3 | +from typing import TYPE_CHECKING, Literal | |
4 | + | |
5 | +from pydantic import BaseModel, ConfigDict, Field | |
6 | + | |
7 | +from faster_whisper_server.text_utils import Transcription, canonicalize_word, segments_to_text | |
8 | + | |
9 | +if TYPE_CHECKING: | |
10 | + from collections.abc import Iterable | |
11 | + | |
12 | + import faster_whisper.transcribe | |
13 | + | |
14 | + | |
15 | +# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10909 | |
16 | +class TranscriptionWord(BaseModel): | |
17 | + start: float | |
18 | + end: float | |
19 | + word: str | |
20 | + probability: float | |
21 | + | |
22 | + @classmethod | |
23 | + def from_segments(cls, segments: Iterable[TranscriptionSegment]) -> list[TranscriptionWord]: | |
24 | + words: list[TranscriptionWord] = [] | |
25 | + for segment in segments: | |
26 | + # NOTE: a temporary "fix" for https://github.com/fedirz/faster-whisper-server/issues/58. | |
27 | + # TODO: properly address the issue | |
28 | + assert ( | |
29 | + segment.words is not None | |
30 | + ), "Segment must have words. If you are using an API ensure `timestamp_granularities[]=word` is set" | |
31 | + words.extend(segment.words) | |
32 | + return words | |
33 | + | |
34 | + def offset(self, seconds: float) -> None: | |
35 | + self.start += seconds | |
36 | + self.end += seconds | |
37 | + | |
38 | + @classmethod | |
39 | + def common_prefix(cls, a: list[TranscriptionWord], b: list[TranscriptionWord]) -> list[TranscriptionWord]: | |
40 | + i = 0 | |
41 | + while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word): | |
42 | + i += 1 | |
43 | + return a[:i] | |
44 | + | |
45 | + | |
46 | +# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10938 | |
47 | +class TranscriptionSegment(BaseModel): | |
48 | + id: int | |
49 | + seek: int | |
50 | + start: float | |
51 | + end: float | |
52 | + text: str | |
53 | + tokens: list[int] | |
54 | + temperature: float | |
55 | + avg_logprob: float | |
56 | + compression_ratio: float | |
57 | + no_speech_prob: float | |
58 | + words: list[TranscriptionWord] | None | |
59 | + | |
60 | + @classmethod | |
61 | + def from_faster_whisper_segments( | |
62 | + cls, segments: Iterable[faster_whisper.transcribe.Segment] | |
63 | + ) -> Iterable[TranscriptionSegment]: | |
64 | + for segment in segments: | |
65 | + yield cls( | |
66 | + id=segment.id, | |
67 | + seek=segment.seek, | |
68 | + start=segment.start, | |
69 | + end=segment.end, | |
70 | + text=segment.text, | |
71 | + tokens=segment.tokens, | |
72 | + temperature=segment.temperature, | |
73 | + avg_logprob=segment.avg_logprob, | |
74 | + compression_ratio=segment.compression_ratio, | |
75 | + no_speech_prob=segment.no_speech_prob, | |
76 | + words=[ | |
77 | + TranscriptionWord( | |
78 | + start=word.start, | |
79 | + end=word.end, | |
80 | + word=word.word, | |
81 | + probability=word.probability, | |
82 | + ) | |
83 | + for word in segment.words | |
84 | + ] | |
85 | + if segment.words is not None | |
86 | + else None, | |
87 | + ) | |
88 | + | |
89 | + | |
90 | +# https://platform.openai.com/docs/api-reference/audio/json-object | |
91 | +# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10924 | |
92 | +class CreateTranscriptionResponseJson(BaseModel): | |
93 | + text: str | |
94 | + | |
95 | + @classmethod | |
96 | + def from_segments(cls, segments: list[TranscriptionSegment]) -> CreateTranscriptionResponseJson: | |
97 | + return cls(text=segments_to_text(segments)) | |
98 | + | |
99 | + @classmethod | |
100 | + def from_transcription(cls, transcription: Transcription) -> CreateTranscriptionResponseJson: | |
101 | + return cls(text=transcription.text) | |
102 | + | |
103 | + | |
104 | +# https://platform.openai.com/docs/api-reference/audio/verbose-json-object | |
105 | +# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L11007 | |
106 | +class CreateTranscriptionResponseVerboseJson(BaseModel): | |
107 | + task: str = "transcribe" | |
108 | + language: str | |
109 | + duration: float | |
110 | + text: str | |
111 | + words: list[TranscriptionWord] | None | |
112 | + segments: list[TranscriptionSegment] | |
113 | + | |
114 | + @classmethod | |
115 | + def from_segment( | |
116 | + cls, segment: TranscriptionSegment, transcription_info: faster_whisper.transcribe.TranscriptionInfo | |
117 | + ) -> CreateTranscriptionResponseVerboseJson: | |
118 | + return cls( | |
119 | + language=transcription_info.language, | |
120 | + duration=segment.end - segment.start, | |
121 | + text=segment.text, | |
122 | + words=segment.words if transcription_info.transcription_options.word_timestamps else None, | |
123 | + segments=[segment], | |
124 | + ) | |
125 | + | |
126 | + @classmethod | |
127 | + def from_segments( | |
128 | + cls, segments: list[TranscriptionSegment], transcription_info: faster_whisper.transcribe.TranscriptionInfo | |
129 | + ) -> CreateTranscriptionResponseVerboseJson: | |
130 | + return cls( | |
131 | + language=transcription_info.language, | |
132 | + duration=transcription_info.duration, | |
133 | + text=segments_to_text(segments), | |
134 | + segments=segments, | |
135 | + words=TranscriptionWord.from_segments(segments) | |
136 | + if transcription_info.transcription_options.word_timestamps | |
137 | + else None, | |
138 | + ) | |
139 | + | |
140 | + @classmethod | |
141 | + def from_transcription(cls, transcription: Transcription) -> CreateTranscriptionResponseVerboseJson: | |
142 | + return cls( | |
143 | + language="english", # FIX: hardcoded | |
144 | + duration=transcription.duration, | |
145 | + text=transcription.text, | |
146 | + words=transcription.words, | |
147 | + segments=[], # FIX: hardcoded | |
148 | + ) | |
149 | + | |
150 | + | |
151 | +# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8730 | |
152 | +class ListModelsResponse(BaseModel): | |
153 | + data: list[Model] | |
154 | + object: Literal["list"] = "list" | |
155 | + | |
156 | + | |
157 | +# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L11146 | |
158 | +class Model(BaseModel): | |
159 | + id: str | |
160 | + """The model identifier, which can be referenced in the API endpoints.""" | |
161 | + created: int | |
162 | + """The Unix timestamp (in seconds) when the model was created.""" | |
163 | + object_: Literal["model"] = Field(serialization_alias="object") | |
164 | + """The object type, which is always "model".""" | |
165 | + owned_by: str | |
166 | + """The organization that owns the model.""" | |
167 | + language: list[str] = Field(default_factory=list) | |
168 | + """List of ISO 639-3 supported by the model. It's possible that the list will be empty. This field is not a part of the OpenAI API spec and is added for convenience.""" # noqa: E501 | |
169 | + | |
170 | + model_config = ConfigDict( | |
171 | + populate_by_name=True, | |
172 | + json_schema_extra={ | |
173 | + "examples": [ | |
174 | + { | |
175 | + "id": "Systran/faster-whisper-large-v3", | |
176 | + "created": 1700732060, | |
177 | + "object": "model", | |
178 | + "owned_by": "Systran", | |
179 | + }, | |
180 | + { | |
181 | + "id": "Systran/faster-distil-whisper-large-v3", | |
182 | + "created": 1711378296, | |
183 | + "object": "model", | |
184 | + "owned_by": "Systran", | |
185 | + }, | |
186 | + { | |
187 | + "id": "bofenghuang/whisper-large-v2-cv11-french-ct2", | |
188 | + "created": 1687968011, | |
189 | + "object": "model", | |
190 | + "owned_by": "bofenghuang", | |
191 | + }, | |
192 | + ] | |
193 | + }, | |
194 | + ) | |
195 | + | |
196 | + | |
197 | +# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10909 | |
198 | +TimestampGranularities = list[Literal["segment", "word"]] | |
199 | + | |
200 | + | |
201 | +DEFAULT_TIMESTAMP_GRANULARITIES: TimestampGranularities = ["segment"] | |
202 | +TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [ | |
203 | + [], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities | |
204 | + ["segment"], | |
205 | + ["word"], | |
206 | + ["word", "segment"], | |
207 | + ["segment", "word"], # same as ["word", "segment"] but order is different | |
208 | +] |
--- src/faster_whisper_server/asr.py
+++ src/faster_whisper_server/asr.py
... | ... | @@ -1,11 +1,17 @@ |
1 |
+from __future__ import annotations |
|
2 |
+ |
|
1 | 3 |
import asyncio |
2 | 4 |
import logging |
3 | 5 |
import time |
6 |
+from typing import TYPE_CHECKING |
|
4 | 7 |
|
5 |
-from faster_whisper import transcribe |
|
8 |
+from faster_whisper_server.api_models import TranscriptionSegment, TranscriptionWord |
|
9 |
+from faster_whisper_server.text_utils import Transcription |
|
6 | 10 |
|
7 |
-from faster_whisper_server.audio import Audio |
|
8 |
-from faster_whisper_server.core import Segment, Transcription, Word |
|
11 |
+if TYPE_CHECKING: |
|
12 |
+ from faster_whisper import transcribe |
|
13 |
+ |
|
14 |
+ from faster_whisper_server.audio import Audio |
|
9 | 15 |
|
10 | 16 |
logger = logging.getLogger(__name__) |
11 | 17 |
|
... | ... | @@ -31,8 +37,8 @@ |
31 | 37 |
word_timestamps=True, |
32 | 38 |
**self.transcribe_opts, |
33 | 39 |
) |
34 |
- segments = Segment.from_faster_whisper_segments(segments) |
|
35 |
- words = Word.from_segments(segments) |
|
40 |
+ segments = TranscriptionSegment.from_faster_whisper_segments(segments) |
|
41 |
+ words = TranscriptionWord.from_segments(segments) |
|
36 | 42 |
for word in words: |
37 | 43 |
word.offset(audio.start) |
38 | 44 |
transcription = Transcription(words) |
--- src/faster_whisper_server/core.py
... | ... | @@ -1,299 +0,0 @@ |
1 | -from __future__ import annotations | |
2 | - | |
3 | -import re | |
4 | -from typing import TYPE_CHECKING | |
5 | - | |
6 | -from pydantic import BaseModel | |
7 | - | |
8 | -from faster_whisper_server.dependencies import get_config | |
9 | - | |
10 | -if TYPE_CHECKING: | |
11 | - from collections.abc import Iterable | |
12 | - | |
13 | - import faster_whisper.transcribe | |
14 | - | |
15 | - | |
16 | -class Word(BaseModel): | |
17 | - start: float | |
18 | - end: float | |
19 | - word: str | |
20 | - probability: float | |
21 | - | |
22 | - @classmethod | |
23 | - def from_segments(cls, segments: Iterable[Segment]) -> list[Word]: | |
24 | - words: list[Word] = [] | |
25 | - for segment in segments: | |
26 | - # NOTE: a temporary "fix" for https://github.com/fedirz/faster-whisper-server/issues/58. | |
27 | - # TODO: properly address the issue | |
28 | - assert ( | |
29 | - segment.words is not None | |
30 | - ), "Segment must have words. If you are using an API ensure `timestamp_granularities[]=word` is set" | |
31 | - words.extend(segment.words) | |
32 | - return words | |
33 | - | |
34 | - def offset(self, seconds: float) -> None: | |
35 | - self.start += seconds | |
36 | - self.end += seconds | |
37 | - | |
38 | - @classmethod | |
39 | - def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]: | |
40 | - i = 0 | |
41 | - while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word): | |
42 | - i += 1 | |
43 | - return a[:i] | |
44 | - | |
45 | - | |
46 | -class Segment(BaseModel): | |
47 | - id: int | |
48 | - seek: int | |
49 | - start: float | |
50 | - end: float | |
51 | - text: str | |
52 | - tokens: list[int] | |
53 | - temperature: float | |
54 | - avg_logprob: float | |
55 | - compression_ratio: float | |
56 | - no_speech_prob: float | |
57 | - words: list[Word] | None | |
58 | - | |
59 | - @classmethod | |
60 | - def from_faster_whisper_segments(cls, segments: Iterable[faster_whisper.transcribe.Segment]) -> Iterable[Segment]: | |
61 | - for segment in segments: | |
62 | - yield cls( | |
63 | - id=segment.id, | |
64 | - seek=segment.seek, | |
65 | - start=segment.start, | |
66 | - end=segment.end, | |
67 | - text=segment.text, | |
68 | - tokens=segment.tokens, | |
69 | - temperature=segment.temperature, | |
70 | - avg_logprob=segment.avg_logprob, | |
71 | - compression_ratio=segment.compression_ratio, | |
72 | - no_speech_prob=segment.no_speech_prob, | |
73 | - words=[ | |
74 | - Word( | |
75 | - start=word.start, | |
76 | - end=word.end, | |
77 | - word=word.word, | |
78 | - probability=word.probability, | |
79 | - ) | |
80 | - for word in segment.words | |
81 | - ] | |
82 | - if segment.words is not None | |
83 | - else None, | |
84 | - ) | |
85 | - | |
86 | - | |
87 | -class Transcription: | |
88 | - def __init__(self, words: list[Word] = []) -> None: | |
89 | - self.words: list[Word] = [] | |
90 | - self.extend(words) | |
91 | - | |
92 | - @property | |
93 | - def text(self) -> str: | |
94 | - return " ".join(word.word for word in self.words).strip() | |
95 | - | |
96 | - @property | |
97 | - def start(self) -> float: | |
98 | - return self.words[0].start if len(self.words) > 0 else 0.0 | |
99 | - | |
100 | - @property | |
101 | - def end(self) -> float: | |
102 | - return self.words[-1].end if len(self.words) > 0 else 0.0 | |
103 | - | |
104 | - @property | |
105 | - def duration(self) -> float: | |
106 | - return self.end - self.start | |
107 | - | |
108 | - def after(self, seconds: float) -> Transcription: | |
109 | - return Transcription(words=[word for word in self.words if word.start > seconds]) | |
110 | - | |
111 | - def extend(self, words: list[Word]) -> None: | |
112 | - self._ensure_no_word_overlap(words) | |
113 | - self.words.extend(words) | |
114 | - | |
115 | - def _ensure_no_word_overlap(self, words: list[Word]) -> None: | |
116 | - config = get_config() # HACK | |
117 | - if len(self.words) > 0 and len(words) > 0: | |
118 | - if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end: | |
119 | - raise ValueError( | |
120 | - f"Words overlap: {self.words[-1]} and {words[0]}. Error margin: {config.word_timestamp_error_margin}" # noqa: E501 | |
121 | - ) | |
122 | - for i in range(1, len(words)): | |
123 | - if words[i].start + config.word_timestamp_error_margin <= words[i - 1].end: | |
124 | - raise ValueError(f"Words overlap: {words[i - 1]} and {words[i]}. All words: {words}") | |
125 | - | |
126 | - | |
127 | -def is_eos(text: str) -> bool: | |
128 | - if text.endswith("..."): | |
129 | - return False | |
130 | - return any(text.endswith(punctuation_symbol) for punctuation_symbol in ".?!") | |
131 | - | |
132 | - | |
133 | -def test_is_eos() -> None: | |
134 | - assert not is_eos("Hello") | |
135 | - assert not is_eos("Hello...") | |
136 | - assert is_eos("Hello.") | |
137 | - assert is_eos("Hello!") | |
138 | - assert is_eos("Hello?") | |
139 | - assert not is_eos("Hello. Yo") | |
140 | - assert not is_eos("Hello. Yo...") | |
141 | - assert is_eos("Hello. Yo.") | |
142 | - | |
143 | - | |
144 | -def to_full_sentences(words: list[Word]) -> list[list[Word]]: | |
145 | - sentences: list[list[Word]] = [[]] | |
146 | - for word in words: | |
147 | - sentences[-1].append(word) | |
148 | - if is_eos(word.word): | |
149 | - sentences.append([]) | |
150 | - if len(sentences[-1]) == 0 or not is_eos(sentences[-1][-1].word): | |
151 | - sentences.pop() | |
152 | - return sentences | |
153 | - | |
154 | - | |
155 | -def tests_to_full_sentences() -> None: | |
156 | - def word(text: str) -> Word: | |
157 | - return Word(word=text, start=0.0, end=0.0, probability=0.0) | |
158 | - | |
159 | - assert to_full_sentences([]) == [] | |
160 | - assert to_full_sentences([word(text="Hello")]) == [] | |
161 | - assert to_full_sentences([word(text="Hello..."), word(" world")]) == [] | |
162 | - assert to_full_sentences([word(text="Hello..."), word(" world.")]) == [[word("Hello..."), word(" world.")]] | |
163 | - assert to_full_sentences([word(text="Hello..."), word(" world."), word(" How")]) == [ | |
164 | - [word("Hello..."), word(" world.")], | |
165 | - ] | |
166 | - | |
167 | - | |
168 | -def word_to_text(words: list[Word]) -> str: | |
169 | - return "".join(word.word for word in words) | |
170 | - | |
171 | - | |
172 | -def words_to_text_w_ts(words: list[Word]) -> str: | |
173 | - return "".join(f"{word.word}({word.start:.2f}-{word.end:.2f})" for word in words) | |
174 | - | |
175 | - | |
176 | -def segments_to_text(segments: Iterable[Segment]) -> str: | |
177 | - return "".join(segment.text for segment in segments).strip() | |
178 | - | |
179 | - | |
180 | -def srt_format_timestamp(ts: float) -> str: | |
181 | - hours = ts // 3600 | |
182 | - minutes = (ts % 3600) // 60 | |
183 | - seconds = ts % 60 | |
184 | - milliseconds = (ts * 1000) % 1000 | |
185 | - return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}" | |
186 | - | |
187 | - | |
188 | -def test_srt_format_timestamp() -> None: | |
189 | - assert srt_format_timestamp(0.0) == "00:00:00,000" | |
190 | - assert srt_format_timestamp(1.0) == "00:00:01,000" | |
191 | - assert srt_format_timestamp(1.234) == "00:00:01,234" | |
192 | - assert srt_format_timestamp(60.0) == "00:01:00,000" | |
193 | - assert srt_format_timestamp(61.0) == "00:01:01,000" | |
194 | - assert srt_format_timestamp(61.234) == "00:01:01,234" | |
195 | - assert srt_format_timestamp(3600.0) == "01:00:00,000" | |
196 | - assert srt_format_timestamp(3601.0) == "01:00:01,000" | |
197 | - assert srt_format_timestamp(3601.234) == "01:00:01,234" | |
198 | - assert srt_format_timestamp(23423.4234) == "06:30:23,423" | |
199 | - | |
200 | - | |
201 | -def vtt_format_timestamp(ts: float) -> str: | |
202 | - hours = ts // 3600 | |
203 | - minutes = (ts % 3600) // 60 | |
204 | - seconds = ts % 60 | |
205 | - milliseconds = (ts * 1000) % 1000 | |
206 | - return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}" | |
207 | - | |
208 | - | |
209 | -def test_vtt_format_timestamp() -> None: | |
210 | - assert vtt_format_timestamp(0.0) == "00:00:00.000" | |
211 | - assert vtt_format_timestamp(1.0) == "00:00:01.000" | |
212 | - assert vtt_format_timestamp(1.234) == "00:00:01.234" | |
213 | - assert vtt_format_timestamp(60.0) == "00:01:00.000" | |
214 | - assert vtt_format_timestamp(61.0) == "00:01:01.000" | |
215 | - assert vtt_format_timestamp(61.234) == "00:01:01.234" | |
216 | - assert vtt_format_timestamp(3600.0) == "01:00:00.000" | |
217 | - assert vtt_format_timestamp(3601.0) == "01:00:01.000" | |
218 | - assert vtt_format_timestamp(3601.234) == "01:00:01.234" | |
219 | - assert vtt_format_timestamp(23423.4234) == "06:30:23.423" | |
220 | - | |
221 | - | |
222 | -def segments_to_vtt(segment: Segment, i: int) -> str: | |
223 | - start = segment.start if i > 0 else 0.0 | |
224 | - result = f"{vtt_format_timestamp(start)} --> {vtt_format_timestamp(segment.end)}\n{segment.text}\n\n" | |
225 | - | |
226 | - if i == 0: | |
227 | - return f"WEBVTT\n\n{result}" | |
228 | - else: | |
229 | - return result | |
230 | - | |
231 | - | |
232 | -def segments_to_srt(segment: Segment, i: int) -> str: | |
233 | - return f"{i + 1}\n{srt_format_timestamp(segment.start)} --> {srt_format_timestamp(segment.end)}\n{segment.text}\n\n" | |
234 | - | |
235 | - | |
236 | -def canonicalize_word(text: str) -> str: | |
237 | - text = text.lower() | |
238 | - # Remove non-alphabetic characters using regular expression | |
239 | - text = re.sub(r"[^a-z]", "", text) | |
240 | - return text.lower().strip().strip(".,?!") | |
241 | - | |
242 | - | |
243 | -def test_canonicalize_word() -> None: | |
244 | - assert canonicalize_word("ABC") == "abc" | |
245 | - assert canonicalize_word("...ABC?") == "abc" | |
246 | - assert canonicalize_word("... AbC ...") == "abc" | |
247 | - | |
248 | - | |
249 | -def common_prefix(a: list[Word], b: list[Word]) -> list[Word]: | |
250 | - i = 0 | |
251 | - while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word): | |
252 | - i += 1 | |
253 | - return a[:i] | |
254 | - | |
255 | - | |
256 | -def test_common_prefix() -> None: | |
257 | - def word(text: str) -> Word: | |
258 | - return Word(word=text, start=0.0, end=0.0, probability=0.0) | |
259 | - | |
260 | - a = [word("a"), word("b"), word("c")] | |
261 | - b = [word("a"), word("b"), word("c")] | |
262 | - assert common_prefix(a, b) == [word("a"), word("b"), word("c")] | |
263 | - | |
264 | - a = [word("a"), word("b"), word("c")] | |
265 | - b = [word("a"), word("b"), word("d")] | |
266 | - assert common_prefix(a, b) == [word("a"), word("b")] | |
267 | - | |
268 | - a = [word("a"), word("b"), word("c")] | |
269 | - b = [word("a")] | |
270 | - assert common_prefix(a, b) == [word("a")] | |
271 | - | |
272 | - a = [word("a")] | |
273 | - b = [word("a"), word("b"), word("c")] | |
274 | - assert common_prefix(a, b) == [word("a")] | |
275 | - | |
276 | - a = [word("a")] | |
277 | - b = [] | |
278 | - assert common_prefix(a, b) == [] | |
279 | - | |
280 | - a = [] | |
281 | - b = [word("a")] | |
282 | - assert common_prefix(a, b) == [] | |
283 | - | |
284 | - a = [word("a"), word("b"), word("c")] | |
285 | - b = [word("b"), word("c")] | |
286 | - assert common_prefix(a, b) == [] | |
287 | - | |
288 | - | |
289 | -def test_common_prefix_and_canonicalization() -> None: | |
290 | - def word(text: str) -> Word: | |
291 | - return Word(word=text, start=0.0, end=0.0, probability=0.0) | |
292 | - | |
293 | - a = [word("A...")] | |
294 | - b = [word("a?"), word("b"), word("c")] | |
295 | - assert common_prefix(a, b) == [word("A...")] | |
296 | - | |
297 | - a = [word("A..."), word("B?"), word("C,")] | |
298 | - b = [word("a??"), word(" b"), word(" ,c")] | |
299 | - assert common_prefix(a, b) == [word("A..."), word("B?"), word("C,")] |
--- src/faster_whisper_server/routers/list_models.py
+++ src/faster_whisper_server/routers/list_models.py
... | ... | @@ -9,9 +9,9 @@ |
9 | 9 |
) |
10 | 10 |
import huggingface_hub |
11 | 11 |
|
12 |
-from faster_whisper_server.server_models import ( |
|
13 |
- ModelListResponse, |
|
14 |
- ModelObject, |
|
12 |
+from faster_whisper_server.api_models import ( |
|
13 |
+ ListModelsResponse, |
|
14 |
+ Model, |
|
15 | 15 |
) |
16 | 16 |
|
17 | 17 |
if TYPE_CHECKING: |
... | ... | @@ -21,11 +21,11 @@ |
21 | 21 |
|
22 | 22 |
|
23 | 23 |
@router.get("/v1/models") |
24 |
-def get_models() -> ModelListResponse: |
|
24 |
+def get_models() -> ListModelsResponse: |
|
25 | 25 |
models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True) |
26 | 26 |
models = list(models) |
27 | 27 |
models.sort(key=lambda model: model.downloads, reverse=True) # type: ignore # noqa: PGH003 |
28 |
- transformed_models: list[ModelObject] = [] |
|
28 |
+ transformed_models: list[Model] = [] |
|
29 | 29 |
for model in models: |
30 | 30 |
assert model.created_at is not None |
31 | 31 |
assert model.card_data is not None |
... | ... | @@ -36,7 +36,7 @@ |
36 | 36 |
language = [model.card_data.language] |
37 | 37 |
else: |
38 | 38 |
language = model.card_data.language |
39 |
- transformed_model = ModelObject( |
|
39 |
+ transformed_model = Model( |
|
40 | 40 |
id=model.id, |
41 | 41 |
created=int(model.created_at.timestamp()), |
42 | 42 |
object_="model", |
... | ... | @@ -44,14 +44,14 @@ |
44 | 44 |
language=language, |
45 | 45 |
) |
46 | 46 |
transformed_models.append(transformed_model) |
47 |
- return ModelListResponse(data=transformed_models) |
|
47 |
+ return ListModelsResponse(data=transformed_models) |
|
48 | 48 |
|
49 | 49 |
|
50 | 50 |
@router.get("/v1/models/{model_name:path}") |
51 | 51 |
# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537 |
52 | 52 |
def get_model( |
53 | 53 |
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")], |
54 |
-) -> ModelObject: |
|
54 |
+) -> Model: |
|
55 | 55 |
models = huggingface_hub.list_models( |
56 | 56 |
model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True |
57 | 57 |
) |
... | ... | @@ -78,7 +78,7 @@ |
78 | 78 |
language = [exact_match.card_data.language] |
79 | 79 |
else: |
80 | 80 |
language = exact_match.card_data.language |
81 |
- return ModelObject( |
|
81 |
+ return Model( |
|
82 | 82 |
id=exact_match.id, |
83 | 83 |
created=int(exact_match.created_at.timestamp()), |
84 | 84 |
object_="model", |
--- src/faster_whisper_server/routers/stt.py
+++ src/faster_whisper_server/routers/stt.py
... | ... | @@ -20,6 +20,14 @@ |
20 | 20 |
from faster_whisper.vad import VadOptions, get_speech_timestamps |
21 | 21 |
from pydantic import AfterValidator |
22 | 22 |
|
23 |
+from faster_whisper_server.api_models import ( |
|
24 |
+ DEFAULT_TIMESTAMP_GRANULARITIES, |
|
25 |
+ TIMESTAMP_GRANULARITIES_COMBINATIONS, |
|
26 |
+ CreateTranscriptionResponseJson, |
|
27 |
+ CreateTranscriptionResponseVerboseJson, |
|
28 |
+ TimestampGranularities, |
|
29 |
+ TranscriptionSegment, |
|
30 |
+) |
|
23 | 31 |
from faster_whisper_server.asr import FasterWhisperASR |
24 | 32 |
from faster_whisper_server.audio import AudioStream, audio_samples_from_file |
25 | 33 |
from faster_whisper_server.config import ( |
... | ... | @@ -28,15 +36,8 @@ |
28 | 36 |
ResponseFormat, |
29 | 37 |
Task, |
30 | 38 |
) |
31 |
-from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt |
|
32 | 39 |
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config |
33 |
-from faster_whisper_server.server_models import ( |
|
34 |
- DEFAULT_TIMESTAMP_GRANULARITIES, |
|
35 |
- TIMESTAMP_GRANULARITIES_COMBINATIONS, |
|
36 |
- TimestampGranularities, |
|
37 |
- TranscriptionJsonResponse, |
|
38 |
- TranscriptionVerboseJsonResponse, |
|
39 |
-) |
|
40 |
+from faster_whisper_server.text_utils import segments_to_srt, segments_to_text, segments_to_vtt |
|
40 | 41 |
from faster_whisper_server.transcriber import audio_transcriber |
41 | 42 |
|
42 | 43 |
if TYPE_CHECKING: |
... | ... | @@ -51,7 +52,7 @@ |
51 | 52 |
|
52 | 53 |
|
53 | 54 |
def segments_to_response( |
54 |
- segments: Iterable[Segment], |
|
55 |
+ segments: Iterable[TranscriptionSegment], |
|
55 | 56 |
transcription_info: TranscriptionInfo, |
56 | 57 |
response_format: ResponseFormat, |
57 | 58 |
) -> Response: |
... | ... | @@ -60,12 +61,12 @@ |
60 | 61 |
return Response(segments_to_text(segments), media_type="text/plain") |
61 | 62 |
elif response_format == ResponseFormat.JSON: |
62 | 63 |
return Response( |
63 |
- TranscriptionJsonResponse.from_segments(segments).model_dump_json(), |
|
64 |
+ CreateTranscriptionResponseJson.from_segments(segments).model_dump_json(), |
|
64 | 65 |
media_type="application/json", |
65 | 66 |
) |
66 | 67 |
elif response_format == ResponseFormat.VERBOSE_JSON: |
67 | 68 |
return Response( |
68 |
- TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info).model_dump_json(), |
|
69 |
+ CreateTranscriptionResponseVerboseJson.from_segments(segments, transcription_info).model_dump_json(), |
|
69 | 70 |
media_type="application/json", |
70 | 71 |
) |
71 | 72 |
elif response_format == ResponseFormat.VTT: |
... | ... | @@ -83,7 +84,7 @@ |
83 | 84 |
|
84 | 85 |
|
85 | 86 |
def segments_to_streaming_response( |
86 |
- segments: Iterable[Segment], |
|
87 |
+ segments: Iterable[TranscriptionSegment], |
|
87 | 88 |
transcription_info: TranscriptionInfo, |
88 | 89 |
response_format: ResponseFormat, |
89 | 90 |
) -> StreamingResponse: |
... | ... | @@ -92,9 +93,11 @@ |
92 | 93 |
if response_format == ResponseFormat.TEXT: |
93 | 94 |
data = segment.text |
94 | 95 |
elif response_format == ResponseFormat.JSON: |
95 |
- data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json() |
|
96 |
+ data = CreateTranscriptionResponseJson.from_segments([segment]).model_dump_json() |
|
96 | 97 |
elif response_format == ResponseFormat.VERBOSE_JSON: |
97 |
- data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json() |
|
98 |
+ data = CreateTranscriptionResponseVerboseJson.from_segment( |
|
99 |
+ segment, transcription_info |
|
100 |
+ ).model_dump_json() |
|
98 | 101 |
elif response_format == ResponseFormat.VTT: |
99 | 102 |
data = segments_to_vtt(segment, i) |
100 | 103 |
elif response_format == ResponseFormat.SRT: |
... | ... | @@ -121,7 +124,7 @@ |
121 | 124 |
|
122 | 125 |
@router.post( |
123 | 126 |
"/v1/audio/translations", |
124 |
- response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse, |
|
127 |
+ response_model=str | CreateTranscriptionResponseJson | CreateTranscriptionResponseVerboseJson, |
|
125 | 128 |
) |
126 | 129 |
def translate_file( |
127 | 130 |
config: ConfigDependency, |
... | ... | @@ -145,7 +148,7 @@ |
145 | 148 |
temperature=temperature, |
146 | 149 |
vad_filter=True, |
147 | 150 |
) |
148 |
- segments = Segment.from_faster_whisper_segments(segments) |
|
151 |
+ segments = TranscriptionSegment.from_faster_whisper_segments(segments) |
|
149 | 152 |
|
150 | 153 |
if stream: |
151 | 154 |
return segments_to_streaming_response(segments, transcription_info, response_format) |
... | ... | @@ -169,7 +172,7 @@ |
169 | 172 |
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 |
170 | 173 |
@router.post( |
171 | 174 |
"/v1/audio/transcriptions", |
172 |
- response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse, |
|
175 |
+ response_model=str | CreateTranscriptionResponseJson | CreateTranscriptionResponseVerboseJson, |
|
173 | 176 |
) |
174 | 177 |
def transcribe_file( |
175 | 178 |
config: ConfigDependency, |
... | ... | @@ -211,7 +214,7 @@ |
211 | 214 |
vad_filter=True, |
212 | 215 |
hotwords=hotwords, |
213 | 216 |
) |
214 |
- segments = Segment.from_faster_whisper_segments(segments) |
|
217 |
+ segments = TranscriptionSegment.from_faster_whisper_segments(segments) |
|
215 | 218 |
|
216 | 219 |
if stream: |
217 | 220 |
return segments_to_streaming_response(segments, transcription_info, response_format) |
... | ... | @@ -286,9 +289,11 @@ |
286 | 289 |
if response_format == ResponseFormat.TEXT: |
287 | 290 |
await ws.send_text(transcription.text) |
288 | 291 |
elif response_format == ResponseFormat.JSON: |
289 |
- await ws.send_json(TranscriptionJsonResponse.from_transcription(transcription).model_dump()) |
|
292 |
+ await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump()) |
|
290 | 293 |
elif response_format == ResponseFormat.VERBOSE_JSON: |
291 |
- await ws.send_json(TranscriptionVerboseJsonResponse.from_transcription(transcription).model_dump()) |
|
294 |
+ await ws.send_json( |
|
295 |
+ CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump() |
|
296 |
+ ) |
|
292 | 297 |
|
293 | 298 |
if ws.client_state != WebSocketState.DISCONNECTED: |
294 | 299 |
logger.info("Closing the connection.") |
--- src/faster_whisper_server/server_models.py
... | ... | @@ -1,122 +0,0 @@ |
1 | -from __future__ import annotations | |
2 | - | |
3 | -from typing import TYPE_CHECKING, Literal | |
4 | - | |
5 | -from pydantic import BaseModel, ConfigDict, Field | |
6 | - | |
7 | -from faster_whisper_server.core import Segment, Transcription, Word, segments_to_text | |
8 | - | |
9 | -if TYPE_CHECKING: | |
10 | - from faster_whisper.transcribe import TranscriptionInfo | |
11 | - | |
12 | - | |
13 | -# https://platform.openai.com/docs/api-reference/audio/json-object | |
14 | -class TranscriptionJsonResponse(BaseModel): | |
15 | - text: str | |
16 | - | |
17 | - @classmethod | |
18 | - def from_segments(cls, segments: list[Segment]) -> TranscriptionJsonResponse: | |
19 | - return cls(text=segments_to_text(segments)) | |
20 | - | |
21 | - @classmethod | |
22 | - def from_transcription(cls, transcription: Transcription) -> TranscriptionJsonResponse: | |
23 | - return cls(text=transcription.text) | |
24 | - | |
25 | - | |
26 | -# https://platform.openai.com/docs/api-reference/audio/verbose-json-object | |
27 | -class TranscriptionVerboseJsonResponse(BaseModel): | |
28 | - task: str = "transcribe" | |
29 | - language: str | |
30 | - duration: float | |
31 | - text: str | |
32 | - words: list[Word] | None | |
33 | - segments: list[Segment] | |
34 | - | |
35 | - @classmethod | |
36 | - def from_segment(cls, segment: Segment, transcription_info: TranscriptionInfo) -> TranscriptionVerboseJsonResponse: | |
37 | - return cls( | |
38 | - language=transcription_info.language, | |
39 | - duration=segment.end - segment.start, | |
40 | - text=segment.text, | |
41 | - words=segment.words if transcription_info.transcription_options.word_timestamps else None, | |
42 | - segments=[segment], | |
43 | - ) | |
44 | - | |
45 | - @classmethod | |
46 | - def from_segments( | |
47 | - cls, segments: list[Segment], transcription_info: TranscriptionInfo | |
48 | - ) -> TranscriptionVerboseJsonResponse: | |
49 | - return cls( | |
50 | - language=transcription_info.language, | |
51 | - duration=transcription_info.duration, | |
52 | - text=segments_to_text(segments), | |
53 | - segments=segments, | |
54 | - words=Word.from_segments(segments) if transcription_info.transcription_options.word_timestamps else None, | |
55 | - ) | |
56 | - | |
57 | - @classmethod | |
58 | - def from_transcription(cls, transcription: Transcription) -> TranscriptionVerboseJsonResponse: | |
59 | - return cls( | |
60 | - language="english", # FIX: hardcoded | |
61 | - duration=transcription.duration, | |
62 | - text=transcription.text, | |
63 | - words=transcription.words, | |
64 | - segments=[], # FIX: hardcoded | |
65 | - ) | |
66 | - | |
67 | - | |
68 | -class ModelListResponse(BaseModel): | |
69 | - data: list[ModelObject] | |
70 | - object: Literal["list"] = "list" | |
71 | - | |
72 | - | |
73 | -class ModelObject(BaseModel): | |
74 | - id: str | |
75 | - """The model identifier, which can be referenced in the API endpoints.""" | |
76 | - created: int | |
77 | - """The Unix timestamp (in seconds) when the model was created.""" | |
78 | - object_: Literal["model"] = Field(serialization_alias="object") | |
79 | - """The object type, which is always "model".""" | |
80 | - owned_by: str | |
81 | - """The organization that owns the model.""" | |
82 | - language: list[str] = Field(default_factory=list) | |
83 | - """List of ISO 639-3 supported by the model. It's possible that the list will be empty. This field is not a part of the OpenAI API spec and is added for convenience.""" # noqa: E501 | |
84 | - | |
85 | - model_config = ConfigDict( | |
86 | - populate_by_name=True, | |
87 | - json_schema_extra={ | |
88 | - "examples": [ | |
89 | - { | |
90 | - "id": "Systran/faster-whisper-large-v3", | |
91 | - "created": 1700732060, | |
92 | - "object": "model", | |
93 | - "owned_by": "Systran", | |
94 | - }, | |
95 | - { | |
96 | - "id": "Systran/faster-distil-whisper-large-v3", | |
97 | - "created": 1711378296, | |
98 | - "object": "model", | |
99 | - "owned_by": "Systran", | |
100 | - }, | |
101 | - { | |
102 | - "id": "bofenghuang/whisper-large-v2-cv11-french-ct2", | |
103 | - "created": 1687968011, | |
104 | - "object": "model", | |
105 | - "owned_by": "bofenghuang", | |
106 | - }, | |
107 | - ] | |
108 | - }, | |
109 | - ) | |
110 | - | |
111 | - | |
112 | -TimestampGranularities = list[Literal["segment", "word"]] | |
113 | - | |
114 | - | |
115 | -DEFAULT_TIMESTAMP_GRANULARITIES: TimestampGranularities = ["segment"] | |
116 | -TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [ | |
117 | - [], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities | |
118 | - ["segment"], | |
119 | - ["word"], | |
120 | - ["word", "segment"], | |
121 | - ["segment", "word"], # same as ["word", "segment"] but order is different | |
122 | -] |
+++ src/faster_whisper_server/text_utils.py
... | ... | @@ -0,0 +1,124 @@ |
1 | +from __future__ import annotations | |
2 | + | |
3 | +import re | |
4 | +from typing import TYPE_CHECKING | |
5 | + | |
6 | +from faster_whisper_server.dependencies import get_config | |
7 | + | |
8 | +if TYPE_CHECKING: | |
9 | + from collections.abc import Iterable | |
10 | + | |
11 | + from faster_whisper_server.api_models import TranscriptionSegment, TranscriptionWord | |
12 | + | |
13 | + | |
14 | +class Transcription: | |
15 | + def __init__(self, words: list[TranscriptionWord] = []) -> None: | |
16 | + self.words: list[TranscriptionWord] = [] | |
17 | + self.extend(words) | |
18 | + | |
19 | + @property | |
20 | + def text(self) -> str: | |
21 | + return " ".join(word.word for word in self.words).strip() | |
22 | + | |
23 | + @property | |
24 | + def start(self) -> float: | |
25 | + return self.words[0].start if len(self.words) > 0 else 0.0 | |
26 | + | |
27 | + @property | |
28 | + def end(self) -> float: | |
29 | + return self.words[-1].end if len(self.words) > 0 else 0.0 | |
30 | + | |
31 | + @property | |
32 | + def duration(self) -> float: | |
33 | + return self.end - self.start | |
34 | + | |
35 | + def after(self, seconds: float) -> Transcription: | |
36 | + return Transcription(words=[word for word in self.words if word.start > seconds]) | |
37 | + | |
38 | + def extend(self, words: list[TranscriptionWord]) -> None: | |
39 | + self._ensure_no_word_overlap(words) | |
40 | + self.words.extend(words) | |
41 | + | |
42 | + def _ensure_no_word_overlap(self, words: list[TranscriptionWord]) -> None: | |
43 | + config = get_config() # HACK | |
44 | + if len(self.words) > 0 and len(words) > 0: | |
45 | + if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end: | |
46 | + raise ValueError( | |
47 | + f"Words overlap: {self.words[-1]} and {words[0]}. Error margin: {config.word_timestamp_error_margin}" # noqa: E501 | |
48 | + ) | |
49 | + for i in range(1, len(words)): | |
50 | + if words[i].start + config.word_timestamp_error_margin <= words[i - 1].end: | |
51 | + raise ValueError(f"Words overlap: {words[i - 1]} and {words[i]}. All words: {words}") | |
52 | + | |
53 | + | |
54 | +def is_eos(text: str) -> bool: | |
55 | + if text.endswith("..."): | |
56 | + return False | |
57 | + return any(text.endswith(punctuation_symbol) for punctuation_symbol in ".?!") | |
58 | + | |
59 | + | |
60 | +def to_full_sentences(words: list[TranscriptionWord]) -> list[list[TranscriptionWord]]: | |
61 | + sentences: list[list[TranscriptionWord]] = [[]] | |
62 | + for word in words: | |
63 | + sentences[-1].append(word) | |
64 | + if is_eos(word.word): | |
65 | + sentences.append([]) | |
66 | + if len(sentences[-1]) == 0 or not is_eos(sentences[-1][-1].word): | |
67 | + sentences.pop() | |
68 | + return sentences | |
69 | + | |
70 | + | |
71 | +def word_to_text(words: list[TranscriptionWord]) -> str: | |
72 | + return "".join(word.word for word in words) | |
73 | + | |
74 | + | |
75 | +def words_to_text_w_ts(words: list[TranscriptionWord]) -> str: | |
76 | + return "".join(f"{word.word}({word.start:.2f}-{word.end:.2f})" for word in words) | |
77 | + | |
78 | + | |
79 | +def segments_to_text(segments: Iterable[TranscriptionSegment]) -> str: | |
80 | + return "".join(segment.text for segment in segments).strip() | |
81 | + | |
82 | + | |
83 | +def srt_format_timestamp(ts: float) -> str: | |
84 | + hours = ts // 3600 | |
85 | + minutes = (ts % 3600) // 60 | |
86 | + seconds = ts % 60 | |
87 | + milliseconds = (ts * 1000) % 1000 | |
88 | + return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}" | |
89 | + | |
90 | + | |
91 | +def vtt_format_timestamp(ts: float) -> str: | |
92 | + hours = ts // 3600 | |
93 | + minutes = (ts % 3600) // 60 | |
94 | + seconds = ts % 60 | |
95 | + milliseconds = (ts * 1000) % 1000 | |
96 | + return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}" | |
97 | + | |
98 | + | |
99 | +def segments_to_vtt(segment: TranscriptionSegment, i: int) -> str: | |
100 | + start = segment.start if i > 0 else 0.0 | |
101 | + result = f"{vtt_format_timestamp(start)} --> {vtt_format_timestamp(segment.end)}\n{segment.text}\n\n" | |
102 | + | |
103 | + if i == 0: | |
104 | + return f"WEBVTT\n\n{result}" | |
105 | + else: | |
106 | + return result | |
107 | + | |
108 | + | |
109 | +def segments_to_srt(segment: TranscriptionSegment, i: int) -> str: | |
110 | + return f"{i + 1}\n{srt_format_timestamp(segment.start)} --> {srt_format_timestamp(segment.end)}\n{segment.text}\n\n" | |
111 | + | |
112 | + | |
113 | +def canonicalize_word(text: str) -> str: | |
114 | + text = text.lower() | |
115 | + # Remove non-alphabetic characters using regular expression | |
116 | + text = re.sub(r"[^a-z]", "", text) | |
117 | + return text.lower().strip().strip(".,?!") | |
118 | + | |
119 | + | |
120 | +def common_prefix(a: list[TranscriptionWord], b: list[TranscriptionWord]) -> list[TranscriptionWord]: | |
121 | + i = 0 | |
122 | + while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word): | |
123 | + i += 1 | |
124 | + return a[:i] |
+++ src/faster_whisper_server/text_utils_test.py
... | ... | @@ -0,0 +1,111 @@ |
1 | +from faster_whisper_server.api_models import TranscriptionWord | |
2 | +from faster_whisper_server.text_utils import ( | |
3 | + canonicalize_word, | |
4 | + common_prefix, | |
5 | + is_eos, | |
6 | + srt_format_timestamp, | |
7 | + to_full_sentences, | |
8 | + vtt_format_timestamp, | |
9 | +) | |
10 | + | |
11 | + | |
12 | +def test_is_eos() -> None: | |
13 | + assert not is_eos("Hello") | |
14 | + assert not is_eos("Hello...") | |
15 | + assert is_eos("Hello.") | |
16 | + assert is_eos("Hello!") | |
17 | + assert is_eos("Hello?") | |
18 | + assert not is_eos("Hello. Yo") | |
19 | + assert not is_eos("Hello. Yo...") | |
20 | + assert is_eos("Hello. Yo.") | |
21 | + | |
22 | + | |
23 | +def tests_to_full_sentences() -> None: | |
24 | + def word(text: str) -> TranscriptionWord: | |
25 | + return TranscriptionWord(word=text, start=0.0, end=0.0, probability=0.0) | |
26 | + | |
27 | + assert to_full_sentences([]) == [] | |
28 | + assert to_full_sentences([word(text="Hello")]) == [] | |
29 | + assert to_full_sentences([word(text="Hello..."), word(" world")]) == [] | |
30 | + assert to_full_sentences([word(text="Hello..."), word(" world.")]) == [[word("Hello..."), word(" world.")]] | |
31 | + assert to_full_sentences([word(text="Hello..."), word(" world."), word(" How")]) == [ | |
32 | + [word("Hello..."), word(" world.")], | |
33 | + ] | |
34 | + | |
35 | + | |
36 | +def test_srt_format_timestamp() -> None: | |
37 | + assert srt_format_timestamp(0.0) == "00:00:00,000" | |
38 | + assert srt_format_timestamp(1.0) == "00:00:01,000" | |
39 | + assert srt_format_timestamp(1.234) == "00:00:01,234" | |
40 | + assert srt_format_timestamp(60.0) == "00:01:00,000" | |
41 | + assert srt_format_timestamp(61.0) == "00:01:01,000" | |
42 | + assert srt_format_timestamp(61.234) == "00:01:01,234" | |
43 | + assert srt_format_timestamp(3600.0) == "01:00:00,000" | |
44 | + assert srt_format_timestamp(3601.0) == "01:00:01,000" | |
45 | + assert srt_format_timestamp(3601.234) == "01:00:01,234" | |
46 | + assert srt_format_timestamp(23423.4234) == "06:30:23,423" | |
47 | + | |
48 | + | |
49 | +def test_vtt_format_timestamp() -> None: | |
50 | + assert vtt_format_timestamp(0.0) == "00:00:00.000" | |
51 | + assert vtt_format_timestamp(1.0) == "00:00:01.000" | |
52 | + assert vtt_format_timestamp(1.234) == "00:00:01.234" | |
53 | + assert vtt_format_timestamp(60.0) == "00:01:00.000" | |
54 | + assert vtt_format_timestamp(61.0) == "00:01:01.000" | |
55 | + assert vtt_format_timestamp(61.234) == "00:01:01.234" | |
56 | + assert vtt_format_timestamp(3600.0) == "01:00:00.000" | |
57 | + assert vtt_format_timestamp(3601.0) == "01:00:01.000" | |
58 | + assert vtt_format_timestamp(3601.234) == "01:00:01.234" | |
59 | + assert vtt_format_timestamp(23423.4234) == "06:30:23.423" | |
60 | + | |
61 | + | |
62 | +def test_canonicalize_word() -> None: | |
63 | + assert canonicalize_word("ABC") == "abc" | |
64 | + assert canonicalize_word("...ABC?") == "abc" | |
65 | + assert canonicalize_word("... AbC ...") == "abc" | |
66 | + | |
67 | + | |
68 | +def test_common_prefix() -> None: | |
69 | + def word(text: str) -> TranscriptionWord: | |
70 | + return TranscriptionWord(word=text, start=0.0, end=0.0, probability=0.0) | |
71 | + | |
72 | + a = [word("a"), word("b"), word("c")] | |
73 | + b = [word("a"), word("b"), word("c")] | |
74 | + assert common_prefix(a, b) == [word("a"), word("b"), word("c")] | |
75 | + | |
76 | + a = [word("a"), word("b"), word("c")] | |
77 | + b = [word("a"), word("b"), word("d")] | |
78 | + assert common_prefix(a, b) == [word("a"), word("b")] | |
79 | + | |
80 | + a = [word("a"), word("b"), word("c")] | |
81 | + b = [word("a")] | |
82 | + assert common_prefix(a, b) == [word("a")] | |
83 | + | |
84 | + a = [word("a")] | |
85 | + b = [word("a"), word("b"), word("c")] | |
86 | + assert common_prefix(a, b) == [word("a")] | |
87 | + | |
88 | + a = [word("a")] | |
89 | + b = [] | |
90 | + assert common_prefix(a, b) == [] | |
91 | + | |
92 | + a = [] | |
93 | + b = [word("a")] | |
94 | + assert common_prefix(a, b) == [] | |
95 | + | |
96 | + a = [word("a"), word("b"), word("c")] | |
97 | + b = [word("b"), word("c")] | |
98 | + assert common_prefix(a, b) == [] | |
99 | + | |
100 | + | |
101 | +def test_common_prefix_and_canonicalization() -> None: | |
102 | + def word(text: str) -> TranscriptionWord: | |
103 | + return TranscriptionWord(word=text, start=0.0, end=0.0, probability=0.0) | |
104 | + | |
105 | + a = [word("A...")] | |
106 | + b = [word("a?"), word("b"), word("c")] | |
107 | + assert common_prefix(a, b) == [word("A...")] | |
108 | + | |
109 | + a = [word("A..."), word("B?"), word("C,")] | |
110 | + b = [word("a??"), word(" b"), word(" ,c")] | |
111 | + assert common_prefix(a, b) == [word("A..."), word("B?"), word("C,")] |
--- src/faster_whisper_server/transcriber.py
+++ src/faster_whisper_server/transcriber.py
... | ... | @@ -4,11 +4,12 @@ |
4 | 4 |
from typing import TYPE_CHECKING |
5 | 5 |
|
6 | 6 |
from faster_whisper_server.audio import Audio, AudioStream |
7 |
-from faster_whisper_server.core import Transcription, Word, common_prefix, to_full_sentences, word_to_text |
|
7 |
+from faster_whisper_server.text_utils import Transcription, common_prefix, to_full_sentences, word_to_text |
|
8 | 8 |
|
9 | 9 |
if TYPE_CHECKING: |
10 | 10 |
from collections.abc import AsyncGenerator |
11 | 11 |
|
12 |
+ from faster_whisper_server.api_models import TranscriptionWord |
|
12 | 13 |
from faster_whisper_server.asr import FasterWhisperASR |
13 | 14 |
|
14 | 15 |
logger = logging.getLogger(__name__) |
... | ... | @@ -18,7 +19,7 @@ |
18 | 19 |
def __init__(self) -> None: |
19 | 20 |
self.unconfirmed = Transcription() |
20 | 21 |
|
21 |
- def merge(self, confirmed: Transcription, incoming: Transcription) -> list[Word]: |
|
22 |
+ def merge(self, confirmed: Transcription, incoming: Transcription) -> list[TranscriptionWord]: |
|
22 | 23 |
# https://github.com/ufal/whisper_streaming/blob/main/whisper_online.py#L264 |
23 | 24 |
incoming = incoming.after(confirmed.end - 0.1) |
24 | 25 |
prefix = common_prefix(incoming.words, self.unconfirmed.words) |
--- tests/api_timestamp_granularities_test.py
+++ tests/api_timestamp_granularities_test.py
... | ... | @@ -1,6 +1,6 @@ |
1 | 1 |
"""See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501 |
2 | 2 |
|
3 |
-from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities |
|
3 |
+from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities |
|
4 | 4 |
from openai import AsyncOpenAI |
5 | 5 |
import pytest |
6 | 6 |
|
--- tests/openai_timestamp_granularities_test.py
+++ tests/openai_timestamp_granularities_test.py
... | ... | @@ -1,6 +1,6 @@ |
1 | 1 |
"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501 |
2 | 2 |
|
3 |
-from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities |
|
3 |
+from faster_whisper_server.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities |
|
4 | 4 |
from openai import AsyncOpenAI, BadRequestError |
5 | 5 |
import pytest |
6 | 6 |
|
--- tests/sse_test.py
+++ tests/sse_test.py
... | ... | @@ -2,9 +2,9 @@ |
2 | 2 |
import os |
3 | 3 |
|
4 | 4 |
from fastapi.testclient import TestClient |
5 |
-from faster_whisper_server.server_models import ( |
|
6 |
- TranscriptionJsonResponse, |
|
7 |
- TranscriptionVerboseJsonResponse, |
|
5 |
+from faster_whisper_server.api_models import ( |
|
6 |
+ CreateTranscriptionResponseJson, |
|
7 |
+ CreateTranscriptionResponseVerboseJson, |
|
8 | 8 |
) |
9 | 9 |
from httpx_sse import connect_sse |
10 | 10 |
import pytest |
... | ... | @@ -48,7 +48,7 @@ |
48 | 48 |
} |
49 | 49 |
with connect_sse(client, "POST", endpoint, **kwargs) as event_source: |
50 | 50 |
for event in event_source.iter_sse(): |
51 |
- TranscriptionJsonResponse(**json.loads(event.data)) |
|
51 |
+ CreateTranscriptionResponseJson(**json.loads(event.data)) |
|
52 | 52 |
|
53 | 53 |
|
54 | 54 |
@pytest.mark.parametrize(("file_path", "endpoint"), parameters) |
... | ... | @@ -62,7 +62,7 @@ |
62 | 62 |
} |
63 | 63 |
with connect_sse(client, "POST", endpoint, **kwargs) as event_source: |
64 | 64 |
for event in event_source.iter_sse(): |
65 |
- TranscriptionVerboseJsonResponse(**json.loads(event.data)) |
|
65 |
+ CreateTranscriptionResponseVerboseJson(**json.loads(event.data)) |
|
66 | 66 |
|
67 | 67 |
|
68 | 68 |
def test_transcription_vtt(client: TestClient) -> None: |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?