Fedir Zadniprovskyi 2024-12-17
feat: support BatchedInferencePipeline (#169)
@0e60c81b0c2085d8c242959f1d1a374237d3dcfd
pyproject.toml
--- pyproject.toml
+++ pyproject.toml
@@ -6,7 +6,7 @@
 dependencies = [
     "ctranslate2>=4.5.0",
     "fastapi>=0.115.0",
-    "faster-whisper>=1.0.3",
+    "faster-whisper>=1.1.0",
     "huggingface-hub>=0.25.1",
     "numpy>=2.1.1",
     "piper-phonemize ; platform_machine == 'x86_64'",
src/faster_whisper_server/asr.py
--- src/faster_whisper_server/asr.py
+++ src/faster_whisper_server/asr.py
@@ -31,6 +31,7 @@
         prompt: str | None = None,
     ) -> tuple[Transcription, transcribe.TranscriptionInfo]:
         start = time.perf_counter()
+        # NOTE: should `BatchedInferencePipeline` be used here?
         segments, transcription_info = self.whisper.transcribe(
             audio.data,
             initial_prompt=prompt,
src/faster_whisper_server/config.py
--- src/faster_whisper_server/config.py
+++ src/faster_whisper_server/config.py
@@ -168,6 +168,10 @@
     -1: Never unload the model.
     0: Unload the model immediately after usage.
     """
+    use_batched_mode: bool = False
+    """
+    Whether to use batch mode(introduced in 1.1.0 `faster-whisper` release) for inference. This will likely become the default in the future and the configuration option will be removed.
+    """  # noqa: E501
 
 
 class Config(BaseSettings):
src/faster_whisper_server/routers/stt.py
--- src/faster_whisper_server/routers/stt.py
+++ src/faster_whisper_server/routers/stt.py
@@ -21,6 +21,7 @@
 from fastapi.responses import StreamingResponse
 from fastapi.websockets import WebSocketState
 from faster_whisper.audio import decode_audio
+from faster_whisper.transcribe import BatchedInferencePipeline
 from faster_whisper.vad import VadOptions, get_speech_timestamps
 from numpy import float32
 from numpy.typing import NDArray
@@ -188,7 +189,8 @@
     if response_format is None:
         response_format = config.default_response_format
     with model_manager.load_model(model) as whisper:
-        segments, transcription_info = whisper.transcribe(
+        whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper
+        segments, transcription_info = whisper_model.transcribe(
             audio,
             task=Task.TRANSLATE,
             initial_prompt=prompt,
@@ -252,7 +254,8 @@
             "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities."  # noqa: E501
         )
     with model_manager.load_model(model) as whisper:
-        segments, transcription_info = whisper.transcribe(
+        whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper
+        segments, transcription_info = whisper_model.transcribe(
             audio,
             task=Task.TRANSCRIBE,
             language=language,
uv.lock
--- uv.lock
+++ uv.lock
@@ -230,7 +230,7 @@
 
 [[package]]
 name = "faster-whisper"
-version = "1.0.3"
+version = "1.1.0"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
     { name = "av" },
@@ -238,10 +238,11 @@
     { name = "huggingface-hub" },
     { name = "onnxruntime" },
     { name = "tokenizers" },
+    { name = "tqdm" },
 ]
-sdist = { url = "https://files.pythonhosted.org/packages/1e/f2/77437ee937233d6e8259e3df511a4662cd7e833dabeaaddbfc929d2a3ed5/faster-whisper-1.0.3.tar.gz", hash = "sha256:1a145db86450b56aaa623c8df7d4ef86e8a1159900f60533e2890e98e8453a17", size = 1980019 }
+sdist = { url = "https://files.pythonhosted.org/packages/31/b1/124f6d5a547756170e11eea405ae6c08afa2b96e8ccd10947a1244b50cdb/faster-whisper-1.1.0.tar.gz", hash = "sha256:cea4bba5d4527173fdbacafa56f2ffb17dd322688f6c3fdf5fd7b6b6c193ce17", size = 1124950 }
 wheels = [
-    { url = "https://files.pythonhosted.org/packages/7f/00/4742b1cd3afd23d0ff9b7e72ec40b2c398988332a5578115728fd83415d1/faster_whisper-1.0.3-py3-none-any.whl", hash = "sha256:364d0e378ab232ed26f39656e5c98548b38045224e206b20f7d8c90e2745b9d3", size = 1974982 },
+    { url = "https://files.pythonhosted.org/packages/7b/03/ab118cb743dcf671da01ad0cfd7564465dda115db32976fdc95e21ce8feb/faster_whisper-1.1.0-py3-none-any.whl", hash = "sha256:0f2d025676bbff1e46c4108b6f9a82578d6e33826c174af2990e45b33fab6182", size = 1118168 },
 ]
 
 [[package]]
@@ -295,7 +296,7 @@
     { name = "basedpyright", marker = "extra == 'dev'", specifier = ">=1.18.0" },
     { name = "ctranslate2", specifier = ">=4.5.0" },
     { name = "fastapi", specifier = ">=0.115.0" },
-    { name = "faster-whisper", specifier = ">=1.0.3" },
+    { name = "faster-whisper", specifier = ">=1.1.0" },
     { name = "gradio", marker = "extra == 'ui'", specifier = ">=5.0.2" },
     { name = "httpx", marker = "extra == 'ui'", specifier = ">=0.27.2" },
     { name = "httpx-sse", marker = "extra == 'ui'", specifier = ">=0.4.0" },
Add a comment
List