Fedir Zadniprovskyi 2024-06-03
feat: add /v1/models and /v1/model routes #14
Now users will also have to specify a full model name.
Before: tiny.en
Now: Systran/faster-whisper-tiny.en
@b5858eea50cb6602a0afcf523d3b374d7256a153
faster_whisper_server/main.py
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
@@ -6,9 +6,11 @@
 from io import BytesIO
 from typing import Annotated, Literal, OrderedDict
 
+import huggingface_hub
 from fastapi import (
     FastAPI,
     Form,
+    HTTPException,
     Query,
     Response,
     UploadFile,
@@ -19,6 +21,7 @@
 from fastapi.websockets import WebSocketState
 from faster_whisper import WhisperModel
 from faster_whisper.vad import VadOptions, get_speech_timestamps
+from huggingface_hub.hf_api import ModelInfo
 
 from faster_whisper_server import utils
 from faster_whisper_server.asr import FasterWhisperASR
@@ -31,24 +34,25 @@
 )
 from faster_whisper_server.logger import logger
 from faster_whisper_server.server_models import (
+    ModelObject,
     TranscriptionJsonResponse,
     TranscriptionVerboseJsonResponse,
 )
 from faster_whisper_server.transcriber import audio_transcriber
 
-models: OrderedDict[str, WhisperModel] = OrderedDict()
+loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
 
 
 def load_model(model_name: str) -> WhisperModel:
-    if model_name in models:
+    if model_name in loaded_models:
         logger.debug(f"{model_name} model already loaded")
-        return models[model_name]
-    if len(models) >= config.max_models:
-        oldest_model_name = next(iter(models))
+        return loaded_models[model_name]
+    if len(loaded_models) >= config.max_models:
+        oldest_model_name = next(iter(loaded_models))
         logger.info(
             f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
         )
-        del models[oldest_model_name]
+        del loaded_models[oldest_model_name]
     logger.debug(f"Loading {model_name}...")
     start = time.perf_counter()
     # NOTE: will raise an exception if the model name isn't valid
@@ -60,7 +64,7 @@
     logger.info(
         f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."
     )
-    models[model_name] = whisper
+    loaded_models[model_name] = whisper
     return whisper
 
 
@@ -68,9 +72,9 @@
 async def lifespan(_: FastAPI):
     load_model(config.whisper.model)
     yield
-    for model in models.keys():
+    for model in loaded_models.keys():
         logger.info(f"Unloading {model}")
-        del models[model]
+        del loaded_models[model]
 
 
 app = FastAPI(lifespan=lifespan)
@@ -81,6 +85,48 @@
     return Response(status_code=200, content="OK")
 
 
+@app.get("/v1/models", response_model=list[ModelObject])
+def get_models() -> list[ModelObject]:
+    models = huggingface_hub.list_models(library="ctranslate2")
+    models = [
+        ModelObject(
+            id=model.id,
+            created=int(model.created_at.timestamp()),
+            object_="model",
+            owned_by=model.id.split("/")[0],
+        )
+        for model in models
+        if model.created_at is not None
+    ]
+    return models
+
+
+@app.get("/v1/models/{model_name:path}", response_model=ModelObject)
+def get_model(model_name: str) -> ModelObject:
+    models = list(
+        huggingface_hub.list_models(model_name=model_name, library="ctranslate2")
+    )
+    if len(models) == 0:
+        raise HTTPException(status_code=404, detail="Model doesn't exists")
+    exact_match: ModelInfo | None = None
+    for model in models:
+        if model.id == model_name:
+            exact_match = model
+            break
+    if exact_match is None:
+        raise HTTPException(
+            status_code=404,
+            detail=f"Model doesn't exists. Possible matches: {", ".join([model.id for model in models])}",
+        )
+    assert exact_match.created_at is not None
+    return ModelObject(
+        id=exact_match.id,
+        created=int(exact_match.created_at.timestamp()),
+        object_="model",
+        owned_by=exact_match.id.split("/")[0],
+    )
+
+
 @app.post("/v1/audio/translations")
 def translate_file(
     file: Annotated[UploadFile, Form()],
faster_whisper_server/server_models.py
--- faster_whisper_server/server_models.py
+++ faster_whisper_server/server_models.py
@@ -1,7 +1,9 @@
 from __future__ import annotations
 
+from typing import Literal
+
 from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict, Field
 
 from faster_whisper_server import utils
 from faster_whisper_server.core import Transcription
@@ -125,3 +127,16 @@
             ],
             segments=[],  # FIX: hardcoded
         )
+
+
+class ModelObject(BaseModel):
+    model_config = ConfigDict(populate_by_name=True)
+
+    id: str
+    """The model identifier, which can be referenced in the API endpoints."""
+    created: int
+    """The Unix timestamp (in seconds) when the model was created."""
+    object_: Literal["model"] = Field(serialization_alias="object")
+    """The object type, which is always "model"."""
+    owned_by: str
+    """The organization that owns the model."""
 
tests/__init__.py (added)
+++ tests/__init__.py
@@ -0,0 +1,0 @@
 
tests/api_model_test.py (added)
+++ tests/api_model_test.py
@@ -0,0 +1,48 @@
+from typing import Generator
+
+import pytest
+from fastapi.testclient import TestClient
+
+from faster_whisper_server.main import app
+from faster_whisper_server.server_models import ModelObject
+
+MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en"
+MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist"
+MIN_EXPECTED_NUMBER_OF_MODELS = (
+    200  # At the time of the test creation there are 228 models
+)
+
+
+@pytest.fixture()
+def client() -> Generator[TestClient, None, None]:
+    with TestClient(app) as client:
+        yield client
+
+
+# HACK: because ModelObject(**data) doesn't work
+def model_dict_to_object(model_dict: dict) -> ModelObject:
+    return ModelObject(
+        id=model_dict["id"],
+        created=model_dict["created"],
+        object_=model_dict["object"],
+        owned_by=model_dict["owned_by"],
+    )
+
+
+def test_list_models(client: TestClient):
+    response = client.get("/v1/models")
+    data = response.json()
+    models = [model_dict_to_object(model_dict) for model_dict in data]
+    assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS
+
+
+def test_model_exists(client: TestClient):
+    response = client.get(f"/v1/model/{MODEL_THAT_EXISTS}")
+    data = response.json()
+    model = model_dict_to_object(data)
+    assert model.id == MODEL_THAT_EXISTS
+
+
+def test_model_does_not_exist(client: TestClient):
+    response = client.get(f"/v1/model/{MODEL_THAT_DOES_NOT_EXIST}")
+    assert response.status_code == 404
Add a comment
List