

feat: add /v1/models and /v1/model routes #14
Now users will also have to specify a full model name. Before: tiny.en Now: Systran/faster-whisper-tiny.en
@b5858eea50cb6602a0afcf523d3b374d7256a153
--- faster_whisper_server/main.py
+++ faster_whisper_server/main.py
... | ... | @@ -6,9 +6,11 @@ |
6 | 6 |
from io import BytesIO |
7 | 7 |
from typing import Annotated, Literal, OrderedDict |
8 | 8 |
|
9 |
+import huggingface_hub |
|
9 | 10 |
from fastapi import ( |
10 | 11 |
FastAPI, |
11 | 12 |
Form, |
13 |
+ HTTPException, |
|
12 | 14 |
Query, |
13 | 15 |
Response, |
14 | 16 |
UploadFile, |
... | ... | @@ -19,6 +21,7 @@ |
19 | 21 |
from fastapi.websockets import WebSocketState |
20 | 22 |
from faster_whisper import WhisperModel |
21 | 23 |
from faster_whisper.vad import VadOptions, get_speech_timestamps |
24 |
+from huggingface_hub.hf_api import ModelInfo |
|
22 | 25 |
|
23 | 26 |
from faster_whisper_server import utils |
24 | 27 |
from faster_whisper_server.asr import FasterWhisperASR |
... | ... | @@ -31,24 +34,25 @@ |
31 | 34 |
) |
32 | 35 |
from faster_whisper_server.logger import logger |
33 | 36 |
from faster_whisper_server.server_models import ( |
37 |
+ ModelObject, |
|
34 | 38 |
TranscriptionJsonResponse, |
35 | 39 |
TranscriptionVerboseJsonResponse, |
36 | 40 |
) |
37 | 41 |
from faster_whisper_server.transcriber import audio_transcriber |
38 | 42 |
|
39 |
-models: OrderedDict[str, WhisperModel] = OrderedDict() |
|
43 |
+loaded_models: OrderedDict[str, WhisperModel] = OrderedDict() |
|
40 | 44 |
|
41 | 45 |
|
42 | 46 |
def load_model(model_name: str) -> WhisperModel: |
43 |
- if model_name in models: |
|
47 |
+ if model_name in loaded_models: |
|
44 | 48 |
logger.debug(f"{model_name} model already loaded") |
45 |
- return models[model_name] |
|
46 |
- if len(models) >= config.max_models: |
|
47 |
- oldest_model_name = next(iter(models)) |
|
49 |
+ return loaded_models[model_name] |
|
50 |
+ if len(loaded_models) >= config.max_models: |
|
51 |
+ oldest_model_name = next(iter(loaded_models)) |
|
48 | 52 |
logger.info( |
49 | 53 |
f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}" |
50 | 54 |
) |
51 |
- del models[oldest_model_name] |
|
55 |
+ del loaded_models[oldest_model_name] |
|
52 | 56 |
logger.debug(f"Loading {model_name}...") |
53 | 57 |
start = time.perf_counter() |
54 | 58 |
# NOTE: will raise an exception if the model name isn't valid |
... | ... | @@ -60,7 +64,7 @@ |
60 | 64 |
logger.info( |
61 | 65 |
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." |
62 | 66 |
) |
63 |
- models[model_name] = whisper |
|
67 |
+ loaded_models[model_name] = whisper |
|
64 | 68 |
return whisper |
65 | 69 |
|
66 | 70 |
|
... | ... | @@ -68,9 +72,9 @@ |
68 | 72 |
async def lifespan(_: FastAPI): |
69 | 73 |
load_model(config.whisper.model) |
70 | 74 |
yield |
71 |
- for model in models.keys(): |
|
75 |
+ for model in loaded_models.keys(): |
|
72 | 76 |
logger.info(f"Unloading {model}") |
73 |
- del models[model] |
|
77 |
+ del loaded_models[model] |
|
74 | 78 |
|
75 | 79 |
|
76 | 80 |
app = FastAPI(lifespan=lifespan) |
... | ... | @@ -81,6 +85,48 @@ |
81 | 85 |
return Response(status_code=200, content="OK") |
82 | 86 |
|
83 | 87 |
|
88 |
+@app.get("/v1/models", response_model=list[ModelObject]) |
|
89 |
+def get_models() -> list[ModelObject]: |
|
90 |
+ models = huggingface_hub.list_models(library="ctranslate2") |
|
91 |
+ models = [ |
|
92 |
+ ModelObject( |
|
93 |
+ id=model.id, |
|
94 |
+ created=int(model.created_at.timestamp()), |
|
95 |
+ object_="model", |
|
96 |
+ owned_by=model.id.split("/")[0], |
|
97 |
+ ) |
|
98 |
+ for model in models |
|
99 |
+ if model.created_at is not None |
|
100 |
+ ] |
|
101 |
+ return models |
|
102 |
+ |
|
103 |
+ |
|
104 |
+@app.get("/v1/models/{model_name:path}", response_model=ModelObject) |
|
105 |
+def get_model(model_name: str) -> ModelObject: |
|
106 |
+ models = list( |
|
107 |
+ huggingface_hub.list_models(model_name=model_name, library="ctranslate2") |
|
108 |
+ ) |
|
109 |
+ if len(models) == 0: |
|
110 |
+ raise HTTPException(status_code=404, detail="Model doesn't exists") |
|
111 |
+ exact_match: ModelInfo | None = None |
|
112 |
+ for model in models: |
|
113 |
+ if model.id == model_name: |
|
114 |
+ exact_match = model |
|
115 |
+ break |
|
116 |
+ if exact_match is None: |
|
117 |
+ raise HTTPException( |
|
118 |
+ status_code=404, |
|
119 |
+ detail=f"Model doesn't exists. Possible matches: {", ".join([model.id for model in models])}", |
|
120 |
+ ) |
|
121 |
+ assert exact_match.created_at is not None |
|
122 |
+ return ModelObject( |
|
123 |
+ id=exact_match.id, |
|
124 |
+ created=int(exact_match.created_at.timestamp()), |
|
125 |
+ object_="model", |
|
126 |
+ owned_by=exact_match.id.split("/")[0], |
|
127 |
+ ) |
|
128 |
+ |
|
129 |
+ |
|
84 | 130 |
@app.post("/v1/audio/translations") |
85 | 131 |
def translate_file( |
86 | 132 |
file: Annotated[UploadFile, Form()], |
--- faster_whisper_server/server_models.py
+++ faster_whisper_server/server_models.py
... | ... | @@ -1,7 +1,9 @@ |
1 | 1 |
from __future__ import annotations |
2 | 2 |
|
3 |
+from typing import Literal |
|
4 |
+ |
|
3 | 5 |
from faster_whisper.transcribe import Segment, TranscriptionInfo, Word |
4 |
-from pydantic import BaseModel |
|
6 |
+from pydantic import BaseModel, ConfigDict, Field |
|
5 | 7 |
|
6 | 8 |
from faster_whisper_server import utils |
7 | 9 |
from faster_whisper_server.core import Transcription |
... | ... | @@ -125,3 +127,16 @@ |
125 | 127 |
], |
126 | 128 |
segments=[], # FIX: hardcoded |
127 | 129 |
) |
130 |
+ |
|
131 |
+ |
|
132 |
+class ModelObject(BaseModel): |
|
133 |
+ model_config = ConfigDict(populate_by_name=True) |
|
134 |
+ |
|
135 |
+ id: str |
|
136 |
+ """The model identifier, which can be referenced in the API endpoints.""" |
|
137 |
+ created: int |
|
138 |
+ """The Unix timestamp (in seconds) when the model was created.""" |
|
139 |
+ object_: Literal["model"] = Field(serialization_alias="object") |
|
140 |
+ """The object type, which is always "model".""" |
|
141 |
+ owned_by: str |
|
142 |
+ """The organization that owns the model.""" |
+++ tests/__init__.py
... | ... | @@ -0,0 +1,0 @@ |
+++ tests/api_model_test.py
... | ... | @@ -0,0 +1,48 @@ |
1 | +from typing import Generator | |
2 | + | |
3 | +import pytest | |
4 | +from fastapi.testclient import TestClient | |
5 | + | |
6 | +from faster_whisper_server.main import app | |
7 | +from faster_whisper_server.server_models import ModelObject | |
8 | + | |
9 | +MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en" | |
10 | +MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist" | |
11 | +MIN_EXPECTED_NUMBER_OF_MODELS = ( | |
12 | + 200 # At the time of the test creation there are 228 models | |
13 | +) | |
14 | + | |
15 | + | |
16 | +@pytest.fixture() | |
17 | +def client() -> Generator[TestClient, None, None]: | |
18 | + with TestClient(app) as client: | |
19 | + yield client | |
20 | + | |
21 | + | |
22 | +# HACK: because ModelObject(**data) doesn't work | |
23 | +def model_dict_to_object(model_dict: dict) -> ModelObject: | |
24 | + return ModelObject( | |
25 | + id=model_dict["id"], | |
26 | + created=model_dict["created"], | |
27 | + object_=model_dict["object"], | |
28 | + owned_by=model_dict["owned_by"], | |
29 | + ) | |
30 | + | |
31 | + | |
32 | +def test_list_models(client: TestClient): | |
33 | + response = client.get("/v1/models") | |
34 | + data = response.json() | |
35 | + models = [model_dict_to_object(model_dict) for model_dict in data] | |
36 | + assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS | |
37 | + | |
38 | + | |
39 | +def test_model_exists(client: TestClient): | |
40 | + response = client.get(f"/v1/model/{MODEL_THAT_EXISTS}") | |
41 | + data = response.json() | |
42 | + model = model_dict_to_object(data) | |
43 | + assert model.id == MODEL_THAT_EXISTS | |
44 | + | |
45 | + | |
46 | +def test_model_does_not_exist(client: TestClient): | |
47 | + response = client.get(f"/v1/model/{MODEL_THAT_DOES_NOT_EXIST}") | |
48 | + assert response.status_code == 404 |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?