Fedir Zadniprovskyi 2024-05-25
feat: further improve openai compatabilit + refactor
@4bdd7f27bd66f87bdd342df2ba4ad8f3c7206ae3
speaches/asr.py
--- speaches/asr.py
+++ speaches/asr.py
@@ -3,28 +3,20 @@
 from typing import Iterable
 
 from faster_whisper import transcribe
-from pydantic import BaseModel
 
 from speaches.audio import Audio
-from speaches.config import Language
 from speaches.core import Transcription, Word
 from speaches.logger import logger
-
-
-class TranscribeOpts(BaseModel):
-    language: Language | None
-    vad_filter: bool
-    condition_on_previous_text: bool
 
 
 class FasterWhisperASR:
     def __init__(
         self,
         whisper: transcribe.WhisperModel,
-        transcribe_opts: TranscribeOpts,
+        **kwargs,
     ) -> None:
         self.whisper = whisper
-        self.transcribe_opts = transcribe_opts
+        self.transcribe_opts = kwargs
 
     def _transcribe(
         self,
@@ -36,7 +28,7 @@
             audio.data,
             initial_prompt=prompt,
             word_timestamps=True,
-            **self.transcribe_opts.model_dump(),
+            **self.transcribe_opts,
         )
         words = words_from_whisper_segments(segments)
         for word in words:
speaches/main.py
--- speaches/main.py
+++ speaches/main.py
@@ -5,17 +5,18 @@
 import time
 from contextlib import asynccontextmanager
 from io import BytesIO
-from typing import Annotated
+from typing import Annotated, Literal
 
-from fastapi import (Depends, FastAPI, Response, UploadFile, WebSocket,
+from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket,
                      WebSocketDisconnect)
 from fastapi.websockets import WebSocketState
 from faster_whisper import WhisperModel
 from faster_whisper.vad import VadOptions, get_speech_timestamps
 
-from speaches.asr import FasterWhisperASR, TranscribeOpts
+from speaches import utils
+from speaches.asr import FasterWhisperASR
 from speaches.audio import AudioStream, audio_samples_from_file
-from speaches.config import SAMPLES_PER_SECOND, Language, config
+from speaches.config import SAMPLES_PER_SECOND, Language, Model, config
 from speaches.core import Transcription
 from speaches.logger import logger
 from speaches.server_models import (ResponseFormat, TranscriptionJsonResponse,
@@ -48,32 +49,40 @@
     return Response(status_code=200, content="Everything is peachy!")
 
 
-async def transcription_parameters(
-    language: Language = Language.EN,
-    vad_filter: bool = True,
-    condition_on_previous_text: bool = False,
-) -> TranscribeOpts:
-    return TranscribeOpts(
-        language=language,
-        vad_filter=vad_filter,
-        condition_on_previous_text=condition_on_previous_text,
-    )
-
-
-TranscribeParams = Annotated[TranscribeOpts, Depends(transcription_parameters)]
-
-
+# https://platform.openai.com/docs/api-reference/audio/createTranscription
+# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
 @app.post("/v1/audio/transcriptions")
 async def transcribe_file(
-    file: UploadFile,
-    transcription_opts: TranscribeParams,
-    response_format: ResponseFormat = ResponseFormat.JSON,
-) -> str:
-    asr = FasterWhisperASR(whisper, transcription_opts)
-    audio_samples = audio_samples_from_file(file.file)
-    audio = AudioStream(audio_samples)
-    transcription, _ = await asr.transcribe(audio)
-    return format_transcription(transcription, response_format)
+    file: Annotated[UploadFile, Form()],
+    model: Annotated[Model, Form()] = config.whisper.model,
+    language: Annotated[Language | None, Form()] = None,
+    prompt: Annotated[str | None, Form()] = None,
+    response_format: Annotated[ResponseFormat, Form()] = ResponseFormat.JSON,
+    temperature: Annotated[float, Form()] = 0.0,
+    timestamp_granularities: Annotated[
+        list[Literal["segments"] | Literal["words"]],
+        Form(alias="timestamp_granularities[]"),
+    ] = ["segments"],
+):
+    assert (
+        model == config.whisper.model
+    ), "Specifying a model that is different from the default is not supported yet."
+    segments, transcription_info = whisper.transcribe(
+        file.file,
+        language=language,
+        initial_prompt=prompt,
+        word_timestamps="words" in timestamp_granularities,
+        temperature=temperature,
+    )
+    segments = list(segments)
+    if response_format == ResponseFormat.TEXT:
+        return utils.segments_text(segments)
+    elif response_format == ResponseFormat.JSON:
+        return TranscriptionJsonResponse.from_segments(segments)
+    elif response_format == ResponseFormat.VERBOSE_JSON:
+        return TranscriptionVerboseJsonResponse.from_segments(
+            segments, transcription_info
+        )
 
 
 async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
@@ -135,11 +144,31 @@
 @app.websocket("/v1/audio/transcriptions")
 async def transcribe_stream(
     ws: WebSocket,
-    transcription_opts: TranscribeParams,
-    response_format: ResponseFormat = ResponseFormat.JSON,
+    model: Annotated[Model, Query()] = config.whisper.model,
+    language: Annotated[Language | None, Query()] = None,
+    prompt: Annotated[str | None, Query()] = None,
+    response_format: Annotated[ResponseFormat, Query()] = ResponseFormat.JSON,
+    temperature: Annotated[float, Query()] = 0.0,
+    timestamp_granularities: Annotated[
+        list[Literal["segments"] | Literal["words"]],
+        Query(
+            alias="timestamp_granularities[]",
+            description="No-op. Ignored. Only for compatibility.",
+        ),
+    ] = ["segments", "words"],
 ) -> None:
+    assert (
+        model == config.whisper.model
+    ), "Specifying a model that is different from the default is not supported yet."
     await ws.accept()
-    asr = FasterWhisperASR(whisper, transcription_opts)
+    transcribe_opts = {
+        "language": language,
+        "initial_prompt": prompt,
+        "temperature": temperature,
+        "vad_filter": True,
+        "condition_on_previous_text": False,
+    }
+    asr = FasterWhisperASR(whisper, **transcribe_opts)
     audio_stream = AudioStream()
     async with asyncio.TaskGroup() as tg:
         tg.create_task(audio_receiver(ws, audio_stream))
speaches/server_models.py
--- speaches/server_models.py
+++ speaches/server_models.py
@@ -2,9 +2,10 @@
 
 import enum
 
-from faster_whisper.transcribe import Segment, Word
+from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
 from pydantic import BaseModel
 
+from speaches import utils
 from speaches.core import Transcription
 
 
@@ -22,10 +23,58 @@
     text: str
 
     @classmethod
+    def from_segments(cls, segments: list[Segment]) -> TranscriptionJsonResponse:
+        return cls(text=utils.segments_text(segments))
+
+    @classmethod
     def from_transcription(
         cls, transcription: Transcription
     ) -> TranscriptionJsonResponse:
         return cls(text=transcription.text)
+
+
+class WordObject(BaseModel):
+    start: float
+    end: float
+    word: str
+    probability: float
+
+    @classmethod
+    def from_word(cls, word: Word) -> WordObject:
+        return cls(
+            start=word.start,
+            end=word.end,
+            word=word.word,
+            probability=word.probability,
+        )
+
+
+class SegmentObject(BaseModel):
+    id: int
+    seek: int
+    start: float
+    end: float
+    text: str
+    tokens: list[int]
+    temperature: float
+    avg_logprob: float
+    compression_ratio: float
+    no_speech_prob: float
+
+    @classmethod
+    def from_segment(cls, segment: Segment) -> SegmentObject:
+        return cls(
+            id=segment.id,
+            seek=segment.seek,
+            start=segment.start,
+            end=segment.end,
+            text=segment.text,
+            tokens=segment.tokens,
+            temperature=segment.temperature,
+            avg_logprob=segment.avg_logprob,
+            compression_ratio=segment.compression_ratio,
+            no_speech_prob=segment.no_speech_prob,
+        )
 
 
 # https://platform.openai.com/docs/api-reference/audio/verbose-json-object
@@ -34,8 +83,23 @@
     language: str
     duration: float
     text: str
-    words: list[Word]
-    segments: list[Segment]
+    words: list[WordObject]
+    segments: list[SegmentObject]
+
+    @classmethod
+    def from_segments(
+        cls, segments: list[Segment], transcription_info: TranscriptionInfo
+    ) -> TranscriptionVerboseJsonResponse:
+        return cls(
+            language=transcription_info.language,
+            duration=transcription_info.duration,
+            text=utils.segments_text(segments),
+            segments=[SegmentObject.from_segment(segment) for segment in segments],
+            words=[
+                WordObject.from_word(word)
+                for word in utils.words_from_segments(segments)
+            ],
+        )
 
     @classmethod
     def from_transcription(
@@ -46,7 +110,7 @@
             duration=transcription.duration,
             text=transcription.text,
             words=[
-                Word(
+                WordObject(
                     start=word.start,
                     end=word.end,
                     word=word.text,
 
speaches/utils.py (added)
+++ speaches/utils.py
@@ -0,0 +1,14 @@
+from faster_whisper.transcribe import Segment, Word
+
+
+def segments_text(segments: list[Segment]) -> str:
+    return "".join(segment.text for segment in segments).strip()
+
+
+def words_from_segments(segments: list[Segment]) -> list[Word]:
+    words = []
+    for segment in segments:
+        if segment.words is None:
+            continue
+        words.extend(segment.words)
+    return words
Add a comment
List