Fedir Zadniprovskyi 2024-06-13
chore: handle "whisper-1" model name
@3dda14e4a25775cdba0188cdc4f656e2248c86d0
faster_whisper_server/main.py
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
@@ -11,6 +11,7 @@
     FastAPI,
     Form,
     HTTPException,
+    Path,
     Query,
     Response,
     UploadFile,
@@ -22,6 +23,7 @@
 from faster_whisper import WhisperModel
 from faster_whisper.vad import VadOptions, get_speech_timestamps
 from huggingface_hub.hf_api import ModelInfo
+from pydantic import AfterValidator
 
 from faster_whisper_server import utils
 from faster_whisper_server.asr import FasterWhisperASR
@@ -85,7 +87,7 @@
     return Response(status_code=200, content="OK")
 
 
-@app.get("/v1/models", response_model=list[ModelObject])
+@app.get("/v1/models")
 def get_models() -> list[ModelObject]:
     models = huggingface_hub.list_models(library="ctranslate2")
     models = [
@@ -101,8 +103,8 @@
     return models
 
 
-@app.get("/v1/models/{model_name:path}", response_model=ModelObject)
-def get_model(model_name: str) -> ModelObject:
+@app.get("/v1/models/{model_name:path}")
+def get_model(model_name: Annotated[str, Path()]) -> ModelObject:
     models = list(
         huggingface_hub.list_models(model_name=model_name, library="ctranslate2")
     )
@@ -131,10 +133,25 @@
     return f"data: {data}\n\n"
 
 
+def handle_default_openai_model(model_name: str) -> str:
+    """This exists because some callers may not be able override the default("whisper-1") model name.
+    For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
+    """
+    if model_name == "whisper-1":
+        logger.info(
+            f"{model_name} is not a valid model name. Using {config.whisper.model} instead."
+        )
+        return config.whisper.model
+    return model_name
+
+
+ModelName = Annotated[str, AfterValidator(handle_default_openai_model)]
+
+
 @app.post("/v1/audio/translations")
 def translate_file(
     file: Annotated[UploadFile, Form()],
-    model: Annotated[str, Form()] = config.whisper.model,
+    model: Annotated[ModelName, Form()] = config.whisper.model,
     prompt: Annotated[str | None, Form()] = None,
     response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
     temperature: Annotated[float, Form()] = 0.0,
@@ -187,7 +204,7 @@
 @app.post("/v1/audio/transcriptions")
 def transcribe_file(
     file: Annotated[UploadFile, Form()],
-    model: Annotated[str, Form()] = config.whisper.model,
+    model: Annotated[ModelName, Form()] = config.whisper.model,
     language: Annotated[Language | None, Form()] = config.default_language,
     prompt: Annotated[str | None, Form()] = None,
     response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
@@ -289,7 +306,7 @@
 @app.websocket("/v1/audio/transcriptions")
 async def transcribe_stream(
     ws: WebSocket,
-    model: Annotated[str, Query()] = config.whisper.model,
+    model: Annotated[ModelName, Query()] = config.whisper.model,
     language: Annotated[Language | None, Query()] = config.default_language,
     response_format: Annotated[
         ResponseFormat, Query()
Add a comment
List