

feat: further improve openai compatabilit + refactor
@4bdd7f27bd66f87bdd342df2ba4ad8f3c7206ae3
--- speaches/asr.py
+++ speaches/asr.py
... | ... | @@ -3,28 +3,20 @@ |
3 | 3 |
from typing import Iterable |
4 | 4 |
|
5 | 5 |
from faster_whisper import transcribe |
6 |
-from pydantic import BaseModel |
|
7 | 6 |
|
8 | 7 |
from speaches.audio import Audio |
9 |
-from speaches.config import Language |
|
10 | 8 |
from speaches.core import Transcription, Word |
11 | 9 |
from speaches.logger import logger |
12 |
- |
|
13 |
- |
|
14 |
-class TranscribeOpts(BaseModel): |
|
15 |
- language: Language | None |
|
16 |
- vad_filter: bool |
|
17 |
- condition_on_previous_text: bool |
|
18 | 10 |
|
19 | 11 |
|
20 | 12 |
class FasterWhisperASR: |
21 | 13 |
def __init__( |
22 | 14 |
self, |
23 | 15 |
whisper: transcribe.WhisperModel, |
24 |
- transcribe_opts: TranscribeOpts, |
|
16 |
+ **kwargs, |
|
25 | 17 |
) -> None: |
26 | 18 |
self.whisper = whisper |
27 |
- self.transcribe_opts = transcribe_opts |
|
19 |
+ self.transcribe_opts = kwargs |
|
28 | 20 |
|
29 | 21 |
def _transcribe( |
30 | 22 |
self, |
... | ... | @@ -36,7 +28,7 @@ |
36 | 28 |
audio.data, |
37 | 29 |
initial_prompt=prompt, |
38 | 30 |
word_timestamps=True, |
39 |
- **self.transcribe_opts.model_dump(), |
|
31 |
+ **self.transcribe_opts, |
|
40 | 32 |
) |
41 | 33 |
words = words_from_whisper_segments(segments) |
42 | 34 |
for word in words: |
--- speaches/main.py
+++ speaches/main.py
... | ... | @@ -5,17 +5,18 @@ |
5 | 5 |
import time |
6 | 6 |
from contextlib import asynccontextmanager |
7 | 7 |
from io import BytesIO |
8 |
-from typing import Annotated |
|
8 |
+from typing import Annotated, Literal |
|
9 | 9 |
|
10 |
-from fastapi import (Depends, FastAPI, Response, UploadFile, WebSocket, |
|
10 |
+from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket, |
|
11 | 11 |
WebSocketDisconnect) |
12 | 12 |
from fastapi.websockets import WebSocketState |
13 | 13 |
from faster_whisper import WhisperModel |
14 | 14 |
from faster_whisper.vad import VadOptions, get_speech_timestamps |
15 | 15 |
|
16 |
-from speaches.asr import FasterWhisperASR, TranscribeOpts |
|
16 |
+from speaches import utils |
|
17 |
+from speaches.asr import FasterWhisperASR |
|
17 | 18 |
from speaches.audio import AudioStream, audio_samples_from_file |
18 |
-from speaches.config import SAMPLES_PER_SECOND, Language, config |
|
19 |
+from speaches.config import SAMPLES_PER_SECOND, Language, Model, config |
|
19 | 20 |
from speaches.core import Transcription |
20 | 21 |
from speaches.logger import logger |
21 | 22 |
from speaches.server_models import (ResponseFormat, TranscriptionJsonResponse, |
... | ... | @@ -48,32 +49,40 @@ |
48 | 49 |
return Response(status_code=200, content="Everything is peachy!") |
49 | 50 |
|
50 | 51 |
|
51 |
-async def transcription_parameters( |
|
52 |
- language: Language = Language.EN, |
|
53 |
- vad_filter: bool = True, |
|
54 |
- condition_on_previous_text: bool = False, |
|
55 |
-) -> TranscribeOpts: |
|
56 |
- return TranscribeOpts( |
|
57 |
- language=language, |
|
58 |
- vad_filter=vad_filter, |
|
59 |
- condition_on_previous_text=condition_on_previous_text, |
|
60 |
- ) |
|
61 |
- |
|
62 |
- |
|
63 |
-TranscribeParams = Annotated[TranscribeOpts, Depends(transcription_parameters)] |
|
64 |
- |
|
65 |
- |
|
52 |
+# https://platform.openai.com/docs/api-reference/audio/createTranscription |
|
53 |
+# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 |
|
66 | 54 |
@app.post("/v1/audio/transcriptions") |
67 | 55 |
async def transcribe_file( |
68 |
- file: UploadFile, |
|
69 |
- transcription_opts: TranscribeParams, |
|
70 |
- response_format: ResponseFormat = ResponseFormat.JSON, |
|
71 |
-) -> str: |
|
72 |
- asr = FasterWhisperASR(whisper, transcription_opts) |
|
73 |
- audio_samples = audio_samples_from_file(file.file) |
|
74 |
- audio = AudioStream(audio_samples) |
|
75 |
- transcription, _ = await asr.transcribe(audio) |
|
76 |
- return format_transcription(transcription, response_format) |
|
56 |
+ file: Annotated[UploadFile, Form()], |
|
57 |
+ model: Annotated[Model, Form()] = config.whisper.model, |
|
58 |
+ language: Annotated[Language | None, Form()] = None, |
|
59 |
+ prompt: Annotated[str | None, Form()] = None, |
|
60 |
+ response_format: Annotated[ResponseFormat, Form()] = ResponseFormat.JSON, |
|
61 |
+ temperature: Annotated[float, Form()] = 0.0, |
|
62 |
+ timestamp_granularities: Annotated[ |
|
63 |
+ list[Literal["segments"] | Literal["words"]], |
|
64 |
+ Form(alias="timestamp_granularities[]"), |
|
65 |
+ ] = ["segments"], |
|
66 |
+): |
|
67 |
+ assert ( |
|
68 |
+ model == config.whisper.model |
|
69 |
+ ), "Specifying a model that is different from the default is not supported yet." |
|
70 |
+ segments, transcription_info = whisper.transcribe( |
|
71 |
+ file.file, |
|
72 |
+ language=language, |
|
73 |
+ initial_prompt=prompt, |
|
74 |
+ word_timestamps="words" in timestamp_granularities, |
|
75 |
+ temperature=temperature, |
|
76 |
+ ) |
|
77 |
+ segments = list(segments) |
|
78 |
+ if response_format == ResponseFormat.TEXT: |
|
79 |
+ return utils.segments_text(segments) |
|
80 |
+ elif response_format == ResponseFormat.JSON: |
|
81 |
+ return TranscriptionJsonResponse.from_segments(segments) |
|
82 |
+ elif response_format == ResponseFormat.VERBOSE_JSON: |
|
83 |
+ return TranscriptionVerboseJsonResponse.from_segments( |
|
84 |
+ segments, transcription_info |
|
85 |
+ ) |
|
77 | 86 |
|
78 | 87 |
|
79 | 88 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: |
... | ... | @@ -135,11 +144,31 @@ |
135 | 144 |
@app.websocket("/v1/audio/transcriptions") |
136 | 145 |
async def transcribe_stream( |
137 | 146 |
ws: WebSocket, |
138 |
- transcription_opts: TranscribeParams, |
|
139 |
- response_format: ResponseFormat = ResponseFormat.JSON, |
|
147 |
+ model: Annotated[Model, Query()] = config.whisper.model, |
|
148 |
+ language: Annotated[Language | None, Query()] = None, |
|
149 |
+ prompt: Annotated[str | None, Query()] = None, |
|
150 |
+ response_format: Annotated[ResponseFormat, Query()] = ResponseFormat.JSON, |
|
151 |
+ temperature: Annotated[float, Query()] = 0.0, |
|
152 |
+ timestamp_granularities: Annotated[ |
|
153 |
+ list[Literal["segments"] | Literal["words"]], |
|
154 |
+ Query( |
|
155 |
+ alias="timestamp_granularities[]", |
|
156 |
+ description="No-op. Ignored. Only for compatibility.", |
|
157 |
+ ), |
|
158 |
+ ] = ["segments", "words"], |
|
140 | 159 |
) -> None: |
160 |
+ assert ( |
|
161 |
+ model == config.whisper.model |
|
162 |
+ ), "Specifying a model that is different from the default is not supported yet." |
|
141 | 163 |
await ws.accept() |
142 |
- asr = FasterWhisperASR(whisper, transcription_opts) |
|
164 |
+ transcribe_opts = { |
|
165 |
+ "language": language, |
|
166 |
+ "initial_prompt": prompt, |
|
167 |
+ "temperature": temperature, |
|
168 |
+ "vad_filter": True, |
|
169 |
+ "condition_on_previous_text": False, |
|
170 |
+ } |
|
171 |
+ asr = FasterWhisperASR(whisper, **transcribe_opts) |
|
143 | 172 |
audio_stream = AudioStream() |
144 | 173 |
async with asyncio.TaskGroup() as tg: |
145 | 174 |
tg.create_task(audio_receiver(ws, audio_stream)) |
--- speaches/server_models.py
+++ speaches/server_models.py
... | ... | @@ -2,9 +2,10 @@ |
2 | 2 |
|
3 | 3 |
import enum |
4 | 4 |
|
5 |
-from faster_whisper.transcribe import Segment, Word |
|
5 |
+from faster_whisper.transcribe import Segment, TranscriptionInfo, Word |
|
6 | 6 |
from pydantic import BaseModel |
7 | 7 |
|
8 |
+from speaches import utils |
|
8 | 9 |
from speaches.core import Transcription |
9 | 10 |
|
10 | 11 |
|
... | ... | @@ -22,10 +23,58 @@ |
22 | 23 |
text: str |
23 | 24 |
|
24 | 25 |
@classmethod |
26 |
+ def from_segments(cls, segments: list[Segment]) -> TranscriptionJsonResponse: |
|
27 |
+ return cls(text=utils.segments_text(segments)) |
|
28 |
+ |
|
29 |
+ @classmethod |
|
25 | 30 |
def from_transcription( |
26 | 31 |
cls, transcription: Transcription |
27 | 32 |
) -> TranscriptionJsonResponse: |
28 | 33 |
return cls(text=transcription.text) |
34 |
+ |
|
35 |
+ |
|
36 |
+class WordObject(BaseModel): |
|
37 |
+ start: float |
|
38 |
+ end: float |
|
39 |
+ word: str |
|
40 |
+ probability: float |
|
41 |
+ |
|
42 |
+ @classmethod |
|
43 |
+ def from_word(cls, word: Word) -> WordObject: |
|
44 |
+ return cls( |
|
45 |
+ start=word.start, |
|
46 |
+ end=word.end, |
|
47 |
+ word=word.word, |
|
48 |
+ probability=word.probability, |
|
49 |
+ ) |
|
50 |
+ |
|
51 |
+ |
|
52 |
+class SegmentObject(BaseModel): |
|
53 |
+ id: int |
|
54 |
+ seek: int |
|
55 |
+ start: float |
|
56 |
+ end: float |
|
57 |
+ text: str |
|
58 |
+ tokens: list[int] |
|
59 |
+ temperature: float |
|
60 |
+ avg_logprob: float |
|
61 |
+ compression_ratio: float |
|
62 |
+ no_speech_prob: float |
|
63 |
+ |
|
64 |
+ @classmethod |
|
65 |
+ def from_segment(cls, segment: Segment) -> SegmentObject: |
|
66 |
+ return cls( |
|
67 |
+ id=segment.id, |
|
68 |
+ seek=segment.seek, |
|
69 |
+ start=segment.start, |
|
70 |
+ end=segment.end, |
|
71 |
+ text=segment.text, |
|
72 |
+ tokens=segment.tokens, |
|
73 |
+ temperature=segment.temperature, |
|
74 |
+ avg_logprob=segment.avg_logprob, |
|
75 |
+ compression_ratio=segment.compression_ratio, |
|
76 |
+ no_speech_prob=segment.no_speech_prob, |
|
77 |
+ ) |
|
29 | 78 |
|
30 | 79 |
|
31 | 80 |
# https://platform.openai.com/docs/api-reference/audio/verbose-json-object |
... | ... | @@ -34,8 +83,23 @@ |
34 | 83 |
language: str |
35 | 84 |
duration: float |
36 | 85 |
text: str |
37 |
- words: list[Word] |
|
38 |
- segments: list[Segment] |
|
86 |
+ words: list[WordObject] |
|
87 |
+ segments: list[SegmentObject] |
|
88 |
+ |
|
89 |
+ @classmethod |
|
90 |
+ def from_segments( |
|
91 |
+ cls, segments: list[Segment], transcription_info: TranscriptionInfo |
|
92 |
+ ) -> TranscriptionVerboseJsonResponse: |
|
93 |
+ return cls( |
|
94 |
+ language=transcription_info.language, |
|
95 |
+ duration=transcription_info.duration, |
|
96 |
+ text=utils.segments_text(segments), |
|
97 |
+ segments=[SegmentObject.from_segment(segment) for segment in segments], |
|
98 |
+ words=[ |
|
99 |
+ WordObject.from_word(word) |
|
100 |
+ for word in utils.words_from_segments(segments) |
|
101 |
+ ], |
|
102 |
+ ) |
|
39 | 103 |
|
40 | 104 |
@classmethod |
41 | 105 |
def from_transcription( |
... | ... | @@ -46,7 +110,7 @@ |
46 | 110 |
duration=transcription.duration, |
47 | 111 |
text=transcription.text, |
48 | 112 |
words=[ |
49 |
- Word( |
|
113 |
+ WordObject( |
|
50 | 114 |
start=word.start, |
51 | 115 |
end=word.end, |
52 | 116 |
word=word.text, |
+++ speaches/utils.py
... | ... | @@ -0,0 +1,14 @@ |
1 | +from faster_whisper.transcribe import Segment, Word | |
2 | + | |
3 | + | |
4 | +def segments_text(segments: list[Segment]) -> str: | |
5 | + return "".join(segment.text for segment in segments).strip() | |
6 | + | |
7 | + | |
8 | +def words_from_segments(segments: list[Segment]) -> list[Word]: | |
9 | + words = [] | |
10 | + for segment in segments: | |
11 | + if segment.words is None: | |
12 | + continue | |
13 | + words.extend(segment.words) | |
14 | + return words |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?