Fedir Zadniprovskyi 2024-12-17
feat: return 4xx on invalid files (#164)
@c3b4c8039a1c16eadd11f78afcae8a291819982c
src/faster_whisper_server/routers/stt.py
--- src/faster_whisper_server/routers/stt.py
+++ src/faster_whisper_server/routers/stt.py
@@ -5,8 +5,10 @@
 import logging
 from typing import TYPE_CHECKING, Annotated
 
+import av.error
 from fastapi import (
     APIRouter,
+    Depends,
     Form,
     Query,
     Request,
@@ -15,9 +17,13 @@
     WebSocket,
     WebSocketDisconnect,
 )
+from fastapi.exceptions import HTTPException
 from fastapi.responses import StreamingResponse
 from fastapi.websockets import WebSocketState
+from faster_whisper.audio import decode_audio
 from faster_whisper.vad import VadOptions, get_speech_timestamps
+from numpy import float32
+from numpy.typing import NDArray
 from pydantic import AfterValidator, Field
 
 from faster_whisper_server.api_models import (
@@ -49,6 +55,35 @@
 logger = logging.getLogger(__name__)
 
 router = APIRouter()
+
+
+# TODO: test async vs sync performance
+def audio_file_dependency(
+    file: Annotated[UploadFile, Form()],
+) -> NDArray[float32]:
+    try:
+        audio = decode_audio(file.file)
+    except av.error.InvalidDataError as e:
+        raise HTTPException(
+            status_code=415,
+            detail="Failed to decode audio. The provided file type is not supported.",
+        ) from e
+    except av.error.ValueError as e:
+        raise HTTPException(
+            status_code=400,
+            # TODO: list supported file types
+            detail="Failed to decode audio. The provided file is likely empty.",
+        ) from e
+    except Exception as e:
+        logger.exception(
+            "Failed to decode audio. This is likely a bug. Please create an issue at https://github.com/fedirz/faster-whisper-server/issues/new."
+        )
+        raise HTTPException(status_code=500, detail="Failed to decode audio.") from e
+    else:
+        return audio  # pyright: ignore reportReturnType
+
+
+AudioFileDependency = Annotated[NDArray[float32], Depends(audio_file_dependency)]
 
 
 def segments_to_response(
@@ -140,7 +175,7 @@
 def translate_file(
     config: ConfigDependency,
     model_manager: ModelManagerDependency,
-    file: Annotated[UploadFile, Form()],
+    audio: AudioFileDependency,
     model: Annotated[ModelName | None, Form()] = None,
     prompt: Annotated[str | None, Form()] = None,
     response_format: Annotated[ResponseFormat | None, Form()] = None,
@@ -154,7 +189,7 @@
         response_format = config.default_response_format
     with model_manager.load_model(model) as whisper:
         segments, transcription_info = whisper.transcribe(
-            file.file,
+            audio,
             task=Task.TRANSLATE,
             initial_prompt=prompt,
             temperature=temperature,
@@ -190,7 +225,7 @@
     config: ConfigDependency,
     model_manager: ModelManagerDependency,
     request: Request,
-    file: Annotated[UploadFile, Form()],
+    audio: AudioFileDependency,
     model: Annotated[ModelName | None, Form()] = None,
     language: Annotated[Language | None, Form()] = None,
     prompt: Annotated[str | None, Form()] = None,
@@ -218,7 +253,7 @@
         )
     with model_manager.load_model(model) as whisper:
         segments, transcription_info = whisper.transcribe(
-            file.file,
+            audio,
             task=Task.TRANSCRIBE,
             language=language,
             initial_prompt=prompt,
Add a comment
List