Fedir Zadniprovskyi 2024-07-20
refactor
@9f134f9b35fb81ad7363b521be543f1957acdfe1
faster_whisper_server/asr.py
--- faster_whisper_server/asr.py
+++ faster_whisper_server/asr.py
@@ -1,11 +1,10 @@
 import asyncio
-from collections.abc import Iterable
 import time
 
 from faster_whisper import transcribe
 
 from faster_whisper_server.audio import Audio
-from faster_whisper_server.core import Transcription, Word
+from faster_whisper_server.core import Segment, Transcription, Word
 from faster_whisper_server.logger import logger
 
 
@@ -30,7 +29,8 @@
             word_timestamps=True,
             **self.transcribe_opts,
         )
-        words = words_from_whisper_segments(segments)
+        segments = Segment.from_faster_whisper_segments(segments)
+        words = Word.from_segments(segments)
         for word in words:
             word.offset(audio.start)
         transcription = Transcription(words)
@@ -54,19 +54,3 @@
             audio,
             prompt,
         )
-
-
-def words_from_whisper_segments(segments: Iterable[transcribe.Segment]) -> list[Word]:
-    words: list[Word] = []
-    for segment in segments:
-        assert segment.words is not None
-        words.extend(
-            Word(
-                start=word.start,
-                end=word.end,
-                text=word.word,
-                probability=word.probability,
-            )
-            for word in segment.words
-        )
-    return words
faster_whisper_server/core.py
--- faster_whisper_server/core.py
+++ faster_whisper_server/core.py
@@ -1,41 +1,83 @@
-# TODO: rename module
 from __future__ import annotations
 
-from dataclasses import dataclass
 import re
+from typing import TYPE_CHECKING
+
+from pydantic import BaseModel
 
 from faster_whisper_server.config import config
 
+if TYPE_CHECKING:
+    from collections.abc import Iterable
 
-# TODO: use the `Segment` from `faster-whisper.transcribe` instead
-@dataclass
-class Segment:
-    text: str
-    start: float = 0.0
-    end: float = 0.0
+    import faster_whisper.transcribe
 
-    @property
-    def is_eos(self) -> bool:
-        if self.text.endswith("..."):
-            return False
-        return any(self.text.endswith(punctuation_symbol) for punctuation_symbol in ".?!")
+
+class Word(BaseModel):
+    start: float
+    end: float
+    word: str
+    probability: float
+
+    @classmethod
+    def from_segments(cls, segments: Iterable[Segment]) -> list[Word]:
+        words: list[Word] = []
+        for segment in segments:
+            assert segment.words is not None
+            words.extend(segment.words)
+        return words
 
     def offset(self, seconds: float) -> None:
         self.start += seconds
         self.end += seconds
 
-
-# TODO: use the `Word` from `faster-whisper.transcribe` instead
-@dataclass
-class Word(Segment):
-    probability: float = 0.0
-
     @classmethod
     def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]:
         i = 0
-        while i < len(a) and i < len(b) and canonicalize_word(a[i].text) == canonicalize_word(b[i].text):
+        while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word):
             i += 1
         return a[:i]
+
+
+class Segment(BaseModel):
+    id: int
+    seek: int
+    start: float
+    end: float
+    text: str
+    tokens: list[int]
+    temperature: float
+    avg_logprob: float
+    compression_ratio: float
+    no_speech_prob: float
+    words: list[Word] | None
+
+    @classmethod
+    def from_faster_whisper_segments(cls, segments: Iterable[faster_whisper.transcribe.Segment]) -> Iterable[Segment]:
+        for segment in segments:
+            yield cls(
+                id=segment.id,
+                seek=segment.seek,
+                start=segment.start,
+                end=segment.end,
+                text=segment.text,
+                tokens=segment.tokens,
+                temperature=segment.temperature,
+                avg_logprob=segment.avg_logprob,
+                compression_ratio=segment.compression_ratio,
+                no_speech_prob=segment.no_speech_prob,
+                words=[
+                    Word(
+                        start=word.start,
+                        end=word.end,
+                        word=word.word,
+                        probability=word.probability,
+                    )
+                    for word in segment.words
+                ]
+                if segment.words is not None
+                else None,
+            )
 
 
 class Transcription:
@@ -45,7 +87,7 @@
 
     @property
     def text(self) -> str:
-        return " ".join(word.text for word in self.words).strip()
+        return " ".join(word.word for word in self.words).strip()
 
     @property
     def start(self) -> float:
@@ -77,48 +119,57 @@
                 raise ValueError(f"Words overlap: {words[i - 1]} and {words[i]}. All words: {words}")
 
 
-def test_segment_is_eos() -> None:
-    assert not Segment("Hello").is_eos
-    assert not Segment("Hello...").is_eos
-    assert Segment("Hello.").is_eos
-    assert Segment("Hello!").is_eos
-    assert Segment("Hello?").is_eos
-    assert not Segment("Hello. Yo").is_eos
-    assert not Segment("Hello. Yo...").is_eos
-    assert Segment("Hello. Yo.").is_eos
+def is_eos(text: str) -> bool:
+    if text.endswith("..."):
+        return False
+    return any(text.endswith(punctuation_symbol) for punctuation_symbol in ".?!")
 
 
-def to_full_sentences(words: list[Word]) -> list[Segment]:
-    sentences: list[Segment] = [Segment("")]
+def test_is_eos() -> None:
+    assert not is_eos("Hello")
+    assert not is_eos("Hello...")
+    assert is_eos("Hello.")
+    assert is_eos("Hello!")
+    assert is_eos("Hello?")
+    assert not is_eos("Hello. Yo")
+    assert not is_eos("Hello. Yo...")
+    assert is_eos("Hello. Yo.")
+
+
+def to_full_sentences(words: list[Word]) -> list[list[Word]]:
+    sentences: list[list[Word]] = [[]]
     for word in words:
-        sentences[-1] = Segment(
-            start=sentences[-1].start,
-            end=word.end,
-            text=sentences[-1].text + word.text,
-        )
-        if word.is_eos:
-            sentences.append(Segment(""))
-    if len(sentences) > 0 and not sentences[-1].is_eos:
+        sentences[-1].append(word)
+        if is_eos(word.word):
+            sentences.append([])
+    if len(sentences[-1]) == 0 or not is_eos(sentences[-1][-1].word):
         sentences.pop()
     return sentences
 
 
 def tests_to_full_sentences() -> None:
+    def word(text: str) -> Word:
+        return Word(word=text, start=0.0, end=0.0, probability=0.0)
+
     assert to_full_sentences([]) == []
-    assert to_full_sentences([Word(text="Hello")]) == []
-    assert to_full_sentences([Word(text="Hello..."), Word(" world")]) == []
-    assert to_full_sentences([Word(text="Hello..."), Word(" world.")]) == [Segment(text="Hello... world.")]
-    assert to_full_sentences([Word(text="Hello..."), Word(" world."), Word(" How")]) == [
-        Segment(text="Hello... world.")
+    assert to_full_sentences([word(text="Hello")]) == []
+    assert to_full_sentences([word(text="Hello..."), word(" world")]) == []
+    assert to_full_sentences([word(text="Hello..."), word(" world.")]) == [[word("Hello..."), word(" world.")]]
+    assert to_full_sentences([word(text="Hello..."), word(" world."), word(" How")]) == [
+        [word("Hello..."), word(" world.")],
     ]
 
 
-def to_text(words: list[Word]) -> str:
-    return "".join(word.text for word in words)
+def word_to_text(words: list[Word]) -> str:
+    return "".join(word.word for word in words)
 
 
-def to_text_w_ts(words: list[Word]) -> str:
-    return "".join(f"{word.text}({word.start:.2f}-{word.end:.2f})" for word in words)
+def words_to_text_w_ts(words: list[Word]) -> str:
+    return "".join(f"{word.word}({word.start:.2f}-{word.end:.2f})" for word in words)
+
+
+def segments_to_text(segments: Iterable[Segment]) -> str:
+    return "".join(segment.text for segment in segments).strip()
 
 
 def canonicalize_word(text: str) -> str:
@@ -136,14 +187,14 @@
 
 def common_prefix(a: list[Word], b: list[Word]) -> list[Word]:
     i = 0
-    while i < len(a) and i < len(b) and canonicalize_word(a[i].text) == canonicalize_word(b[i].text):
+    while i < len(a) and i < len(b) and canonicalize_word(a[i].word) == canonicalize_word(b[i].word):
         i += 1
     return a[:i]
 
 
 def test_common_prefix() -> None:
     def word(text: str) -> Word:
-        return Word(text=text, start=0.0, end=0.0, probability=0.0)
+        return Word(word=text, start=0.0, end=0.0, probability=0.0)
 
     a = [word("a"), word("b"), word("c")]
     b = [word("a"), word("b"), word("c")]
@@ -176,7 +227,7 @@
 
 def test_common_prefix_and_canonicalization() -> None:
     def word(text: str) -> Word:
-        return Word(text=text, start=0.0, end=0.0, probability=0.0)
+        return Word(word=text, start=0.0, end=0.0, probability=0.0)
 
     a = [word("A...")]
     b = [word("a?"), word("b"), word("c")]
faster_whisper_server/main.py
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
@@ -24,7 +24,6 @@
 import huggingface_hub
 from pydantic import AfterValidator
 
-from faster_whisper_server import utils
 from faster_whisper_server.asr import FasterWhisperASR
 from faster_whisper_server.audio import AudioStream, audio_samples_from_file
 from faster_whisper_server.config import (
@@ -34,6 +33,7 @@
     Task,
     config,
 )
+from faster_whisper_server.core import Segment, segments_to_text
 from faster_whisper_server.logger import logger
 from faster_whisper_server.server_models import (
     ModelListResponse,
@@ -46,7 +46,7 @@
 if TYPE_CHECKING:
     from collections.abc import Generator, Iterable
 
-    from faster_whisper.transcribe import Segment, TranscriptionInfo
+    from faster_whisper.transcribe import TranscriptionInfo
     from huggingface_hub.hf_api import ModelInfo
 
 loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
@@ -157,7 +157,7 @@
 ) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
     segments = list(segments)
     if response_format == ResponseFormat.TEXT:  # noqa: RET503
-        return utils.segments_text(segments)
+        return segments_to_text(segments)
     elif response_format == ResponseFormat.JSON:
         return TranscriptionJsonResponse.from_segments(segments)
     elif response_format == ResponseFormat.VERBOSE_JSON:
@@ -220,6 +220,7 @@
         temperature=temperature,
         vad_filter=True,
     )
+    segments = Segment.from_faster_whisper_segments(segments)
 
     if stream:
         return segments_to_streaming_response(segments, transcription_info, response_format)
@@ -258,6 +259,7 @@
         vad_filter=True,
         hotwords=hotwords,
     )
+    segments = Segment.from_faster_whisper_segments(segments)
 
     if stream:
         return segments_to_streaming_response(segments, transcription_info, response_format)
faster_whisper_server/server_models.py
--- faster_whisper_server/server_models.py
+++ faster_whisper_server/server_models.py
@@ -4,12 +4,10 @@
 
 from pydantic import BaseModel, ConfigDict, Field
 
-from faster_whisper_server import utils
+from faster_whisper_server.core import Segment, Transcription, Word, segments_to_text
 
 if TYPE_CHECKING:
-    from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
-
-    from faster_whisper_server.core import Transcription
+    from faster_whisper.transcribe import TranscriptionInfo
 
 
 # https://platform.openai.com/docs/api-reference/audio/json-object
@@ -18,55 +16,11 @@
 
     @classmethod
     def from_segments(cls, segments: list[Segment]) -> TranscriptionJsonResponse:
-        return cls(text=utils.segments_text(segments))
+        return cls(text=segments_to_text(segments))
 
     @classmethod
     def from_transcription(cls, transcription: Transcription) -> TranscriptionJsonResponse:
         return cls(text=transcription.text)
-
-
-class WordObject(BaseModel):
-    start: float
-    end: float
-    word: str
-    probability: float
-
-    @classmethod
-    def from_word(cls, word: Word) -> WordObject:
-        return cls(
-            start=word.start,
-            end=word.end,
-            word=word.word,
-            probability=word.probability,
-        )
-
-
-class SegmentObject(BaseModel):
-    id: int
-    seek: int
-    start: float
-    end: float
-    text: str
-    tokens: list[int]
-    temperature: float
-    avg_logprob: float
-    compression_ratio: float
-    no_speech_prob: float
-
-    @classmethod
-    def from_segment(cls, segment: Segment) -> SegmentObject:
-        return cls(
-            id=segment.id,
-            seek=segment.seek,
-            start=segment.start,
-            end=segment.end,
-            text=segment.text,
-            tokens=segment.tokens,
-            temperature=segment.temperature,
-            avg_logprob=segment.avg_logprob,
-            compression_ratio=segment.compression_ratio,
-            no_speech_prob=segment.no_speech_prob,
-        )
 
 
 # https://platform.openai.com/docs/api-reference/audio/verbose-json-object
@@ -75,8 +29,8 @@
     language: str
     duration: float
     text: str
-    words: list[WordObject]
-    segments: list[SegmentObject]
+    words: list[Word]
+    segments: list[Segment]
 
     @classmethod
     def from_segment(cls, segment: Segment, transcription_info: TranscriptionInfo) -> TranscriptionVerboseJsonResponse:
@@ -84,8 +38,8 @@
             language=transcription_info.language,
             duration=segment.end - segment.start,
             text=segment.text,
-            words=([WordObject.from_word(word) for word in segment.words] if isinstance(segment.words, list) else []),
-            segments=[SegmentObject.from_segment(segment)],
+            words=(segment.words if isinstance(segment.words, list) else []),
+            segments=[segment],
         )
 
     @classmethod
@@ -95,9 +49,9 @@
         return cls(
             language=transcription_info.language,
             duration=transcription_info.duration,
-            text=utils.segments_text(segments),
-            segments=[SegmentObject.from_segment(segment) for segment in segments],
-            words=[WordObject.from_word(word) for word in utils.words_from_segments(segments)],
+            text=segments_to_text(segments),
+            segments=segments,
+            words=Word.from_segments(segments),
         )
 
     @classmethod
@@ -106,15 +60,7 @@
             language="english",  # FIX: hardcoded
             duration=transcription.duration,
             text=transcription.text,
-            words=[
-                WordObject(
-                    start=word.start,
-                    end=word.end,
-                    word=word.text,
-                    probability=word.probability,
-                )
-                for word in transcription.words
-            ],
+            words=transcription.words,
             segments=[],  # FIX: hardcoded
         )
 
faster_whisper_server/transcriber.py
--- faster_whisper_server/transcriber.py
+++ faster_whisper_server/transcriber.py
@@ -4,12 +4,7 @@
 
 from faster_whisper_server.audio import Audio, AudioStream
 from faster_whisper_server.config import config
-from faster_whisper_server.core import (
-    Transcription,
-    Word,
-    common_prefix,
-    to_full_sentences,
-)
+from faster_whisper_server.core import Transcription, Word, common_prefix, to_full_sentences, word_to_text
 from faster_whisper_server.logger import logger
 
 if TYPE_CHECKING:
@@ -37,30 +32,16 @@
 
         return prefix
 
-    @classmethod
-    def prompt(cls, confirmed: Transcription) -> str | None:
-        sentences = to_full_sentences(confirmed.words)
-        if len(sentences) == 0:
-            return None
-        return sentences[-1].text
 
-    # TODO: better name
-    @classmethod
-    def needs_audio_after(cls, confirmed: Transcription) -> float:
-        full_sentences = to_full_sentences(confirmed.words)
-        return full_sentences[-1].end if len(full_sentences) > 0 else 0.0
-
-
+# TODO: needs a better name
 def needs_audio_after(confirmed: Transcription) -> float:
     full_sentences = to_full_sentences(confirmed.words)
-    return full_sentences[-1].end if len(full_sentences) > 0 else 0.0
+    return full_sentences[-1][-1].end if len(full_sentences) > 0 else 0.0
 
 
 def prompt(confirmed: Transcription) -> str | None:
     sentences = to_full_sentences(confirmed.words)
-    if len(sentences) == 0:
-        return None
-    return sentences[-1].text
+    return word_to_text(sentences[-1]) if len(sentences) > 0 else None
 
 
 async def audio_transcriber(
 
faster_whisper_server/utils.py (deleted)
--- faster_whisper_server/utils.py
@@ -1,14 +0,0 @@
-from faster_whisper.transcribe import Segment, Word
-
-
-def segments_text(segments: list[Segment]) -> str:
-    return "".join(segment.text for segment in segments).strip()
-
-
-def words_from_segments(segments: list[Segment]) -> list[Word]:
-    words = []
-    for segment in segments:
-        if segment.words is None:
-            continue
-        words.extend(segment.words)
-    return words
Add a comment
List