Fedir Zadniprovskyi 2024-07-03
fix: models route not returning openai compatible response
@eb226c55c23254df90c0f6fd542b41b8c3fe1499
faster_whisper_server/gradio_app.py
--- faster_whisper_server/gradio_app.py
+++ faster_whisper_server/gradio_app.py
@@ -67,7 +67,7 @@
     def update_model_dropdown() -> gr.Dropdown:
         res = http_client.get("/v1/models")
         res_data = res.json()
-        models: list[str] = [model["id"] for model in res_data]
+        models: list[str] = [model["id"] for model in res_data["data"]]
         assert config.whisper.model in models
         recommended_models = {model for model in models if model.startswith("Systran")}
         other_models = [model for model in models if model not in recommended_models]
faster_whisper_server/main.py
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
@@ -38,6 +38,7 @@
 from faster_whisper_server.gradio_app import create_gradio_demo
 from faster_whisper_server.logger import logger
 from faster_whisper_server.server_models import (
+    ModelListResponse,
     ModelObject,
     TranscriptionJsonResponse,
     TranscriptionVerboseJsonResponse,
@@ -85,7 +86,7 @@
 
 
 @app.get("/v1/models")
-def get_models() -> list[ModelObject]:
+def get_models() -> ModelListResponse:
     models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition")
     models = [
         ModelObject(
@@ -97,7 +98,7 @@
         for model in models
         if model.created_at is not None
     ]
-    return models
+    return ModelListResponse(data=models)
 
 
 @app.get("/v1/models/{model_name:path}")
faster_whisper_server/server_models.py
--- faster_whisper_server/server_models.py
+++ faster_whisper_server/server_models.py
@@ -119,6 +119,11 @@
         )
 
 
+class ModelListResponse(BaseModel):
+    data: list[ModelObject]
+    object: Literal["list"] = "list"
+
+
 class ModelObject(BaseModel):
     id: str
     """The model identifier, which can be referenced in the API endpoints."""
pyproject.toml
--- pyproject.toml
+++ pyproject.toml
@@ -17,8 +17,8 @@
 ]
 
 [project.optional-dependencies]
-dev = ["ruff", "pytest", "pytest-xdist"]
-other = ["youtube-dl@git+https://github.com/ytdl-org/youtube-dl.git", "openai", "aider-chat"]
+dev = ["ruff", "pytest", "pytest-xdist", "openai"]
+other = ["youtube-dl@git+https://github.com/ytdl-org/youtube-dl.git", "aider-chat"]
 
 # https://docs.astral.sh/ruff/configuration/
 [tool.ruff]
requirements-dev.txt
--- requirements-dev.txt
+++ requirements-dev.txt
@@ -9,6 +9,7 @@
 anyio==4.4.0
     # via
     #   httpx
+    #   openai
     #   starlette
     #   watchfiles
 attrs==23.2.0
@@ -38,6 +39,8 @@
     # via faster-whisper
 cycler==0.12.1
     # via matplotlib
+distro==1.9.0
+    # via openai
 dnspython==2.6.1
     # via email-validator
 email-validator==2.2.0
@@ -82,6 +85,7 @@
     #   fastapi
     #   gradio
     #   gradio-client
+    #   openai
 httpx-sse==0.4.0
     # via faster-whisper-server (pyproject.toml)
 huggingface-hub==0.23.4
@@ -138,6 +142,8 @@
     #   pandas
 onnxruntime==1.18.0
     # via faster-whisper
+openai==1.35.9
+    # via faster-whisper-server (pyproject.toml)
 orjson==3.10.5
     # via
     #   fastapi
@@ -170,6 +176,7 @@
     #   faster-whisper-server (pyproject.toml)
     #   fastapi
     #   gradio
+    #   openai
     #   pydantic-settings
 pydantic-core==2.20.0
     # via pydantic
@@ -236,6 +243,7 @@
     # via
     #   anyio
     #   httpx
+    #   openai
 soundfile==0.12.1
     # via faster-whisper-server (pyproject.toml)
 starlette==0.37.2
@@ -249,7 +257,9 @@
 toolz==0.12.1
     # via altair
 tqdm==4.66.4
-    # via huggingface-hub
+    # via
+    #   huggingface-hub
+    #   openai
 typer==0.12.3
     # via
     #   fastapi-cli
@@ -260,6 +270,7 @@
     #   gradio
     #   gradio-client
     #   huggingface-hub
+    #   openai
     #   pydantic
     #   pydantic-core
     #   typer
tests/api_model_test.py
--- tests/api_model_test.py
+++ tests/api_model_test.py
@@ -1,4 +1,5 @@
 from fastapi.testclient import TestClient
+from openai import OpenAI
 
 from faster_whisper_server.server_models import ModelObject
 
@@ -17,10 +18,8 @@
     )
 
 
-def test_list_models(client: TestClient) -> None:
-    response = client.get("/v1/models")
-    data = response.json()
-    models = [model_dict_to_object(model_dict) for model_dict in data]
+def test_list_models(openai_client: OpenAI) -> None:
+    models = openai_client.models.list().data
     assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS
 
 
tests/conftest.py
--- tests/conftest.py
+++ tests/conftest.py
@@ -2,6 +2,7 @@
 import logging
 
 from fastapi.testclient import TestClient
+from openai import OpenAI
 import pytest
 
 from faster_whisper_server.main import app
@@ -19,3 +20,8 @@
 def client() -> Generator[TestClient, None, None]:
     with TestClient(app) as client:
         yield client
+
+
+@pytest.fixture()
+def openai_client(client: TestClient) -> OpenAI:
+    return OpenAI(api_key="cant-be-empty", http_client=client)
Add a comment
List