

feat: model unloading
@cb7375c9ca40bf121ee40e294acadf25466c7d1f
--- pyproject.toml
+++ pyproject.toml
... | ... | @@ -19,10 +19,10 @@ |
19 | 19 |
client = [ |
20 | 20 |
"keyboard>=0.13.5", |
21 | 21 |
] |
22 |
-# NOTE: when installing `dev` group, all other groups should also be installed |
|
23 | 22 |
dev = [ |
24 | 23 |
"anyio>=4.4.0", |
25 | 24 |
"basedpyright>=1.18.0", |
25 |
+ "pytest-antilru>=2.0.0", |
|
26 | 26 |
"pytest-asyncio>=0.24.0", |
27 | 27 |
"pytest-xdist>=3.6.1", |
28 | 28 |
"pytest>=8.3.3", |
--- src/faster_whisper_server/config.py
+++ src/faster_whisper_server/config.py
... | ... | @@ -1,7 +1,6 @@ |
1 | 1 |
import enum |
2 |
-from typing import Self |
|
3 | 2 |
|
4 |
-from pydantic import BaseModel, Field, model_validator |
|
3 |
+from pydantic import BaseModel, Field |
|
5 | 4 |
from pydantic_settings import BaseSettings, SettingsConfigDict |
6 | 5 |
|
7 | 6 |
SAMPLES_PER_SECOND = 16000 |
... | ... | @@ -163,6 +162,12 @@ |
163 | 162 |
compute_type: Quantization = Field(default=Quantization.DEFAULT) |
164 | 163 |
cpu_threads: int = 0 |
165 | 164 |
num_workers: int = 1 |
165 |
+ ttl: int = Field(default=300, ge=-1) |
|
166 |
+ """ |
|
167 |
+ Time in seconds until the model is unloaded if it is not being used. |
|
168 |
+ -1: Never unload the model. |
|
169 |
+ 0: Unload the model immediately after usage. |
|
170 |
+ """ |
|
166 | 171 |
|
167 | 172 |
|
168 | 173 |
class Config(BaseSettings): |
... | ... | @@ -198,10 +203,6 @@ |
198 | 203 |
""" |
199 | 204 |
default_response_format: ResponseFormat = ResponseFormat.JSON |
200 | 205 |
whisper: WhisperConfig = WhisperConfig() |
201 |
- max_models: int = 1 |
|
202 |
- """ |
|
203 |
- Maximum number of models that can be loaded at a time. |
|
204 |
- """ |
|
205 | 206 |
preload_models: list[str] = Field( |
206 | 207 |
default_factory=list, |
207 | 208 |
examples=[ |
... | ... | @@ -210,8 +211,8 @@ |
210 | 211 |
], |
211 | 212 |
) |
212 | 213 |
""" |
213 |
- List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request. |
|
214 |
- """ # noqa: E501 |
|
214 |
+ List of models to preload on startup. By default, the model is first loaded on first request. |
|
215 |
+ """ |
|
215 | 216 |
max_no_data_seconds: float = 1.0 |
216 | 217 |
""" |
217 | 218 |
Max duration to wait for the next audio chunk before transcription is finilized and connection is closed. |
... | ... | @@ -230,11 +231,3 @@ |
230 | 231 |
Controls how many latest seconds of audio are being passed through VAD. |
231 | 232 |
Should be greater than `max_inactivity_seconds` |
232 | 233 |
""" |
233 |
- |
|
234 |
- @model_validator(mode="after") |
|
235 |
- def ensure_preloaded_models_is_lte_max_models(self) -> Self: |
|
236 |
- if len(self.preload_models) > self.max_models: |
|
237 |
- raise ValueError( |
|
238 |
- f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501 |
|
239 |
- ) |
|
240 |
- return self |
--- src/faster_whisper_server/dependencies.py
+++ src/faster_whisper_server/dependencies.py
... | ... | @@ -18,7 +18,7 @@ |
18 | 18 |
@lru_cache |
19 | 19 |
def get_model_manager() -> ModelManager: |
20 | 20 |
config = get_config() # HACK |
21 |
- return ModelManager(config) |
|
21 |
+ return ModelManager(config.whisper) |
|
22 | 22 |
|
23 | 23 |
|
24 | 24 |
ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)] |
--- src/faster_whisper_server/model_manager.py
+++ src/faster_whisper_server/model_manager.py
... | ... | @@ -3,48 +3,132 @@ |
3 | 3 |
from collections import OrderedDict |
4 | 4 |
import gc |
5 | 5 |
import logging |
6 |
+import threading |
|
6 | 7 |
import time |
7 | 8 |
from typing import TYPE_CHECKING |
8 | 9 |
|
9 | 10 |
from faster_whisper import WhisperModel |
10 | 11 |
|
11 | 12 |
if TYPE_CHECKING: |
13 |
+ from collections.abc import Callable |
|
14 |
+ |
|
12 | 15 |
from faster_whisper_server.config import ( |
13 |
- Config, |
|
16 |
+ WhisperConfig, |
|
14 | 17 |
) |
15 | 18 |
|
16 | 19 |
logger = logging.getLogger(__name__) |
17 | 20 |
|
21 |
+# TODO: enable concurrent model downloads |
|
22 |
+ |
|
23 |
+ |
|
24 |
+class SelfDisposingWhisperModel: |
|
25 |
+ def __init__( |
|
26 |
+ self, |
|
27 |
+ model_id: str, |
|
28 |
+ whisper_config: WhisperConfig, |
|
29 |
+ *, |
|
30 |
+ on_unload: Callable[[str], None] | None = None, |
|
31 |
+ ) -> None: |
|
32 |
+ self.model_id = model_id |
|
33 |
+ self.whisper_config = whisper_config |
|
34 |
+ self.on_unload = on_unload |
|
35 |
+ |
|
36 |
+ self.ref_count: int = 0 |
|
37 |
+ self.rlock = threading.RLock() |
|
38 |
+ self.expire_timer: threading.Timer | None = None |
|
39 |
+ self.whisper: WhisperModel | None = None |
|
40 |
+ |
|
41 |
+ def unload(self) -> None: |
|
42 |
+ with self.rlock: |
|
43 |
+ if self.whisper is None: |
|
44 |
+ raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}") |
|
45 |
+ if self.ref_count > 0: |
|
46 |
+ raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}") |
|
47 |
+ if self.expire_timer: |
|
48 |
+ self.expire_timer.cancel() |
|
49 |
+ self.whisper = None |
|
50 |
+ # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992 |
|
51 |
+ gc.collect() |
|
52 |
+ logger.info(f"Model {self.model_id} unloaded") |
|
53 |
+ if self.on_unload is not None: |
|
54 |
+ self.on_unload(self.model_id) |
|
55 |
+ |
|
56 |
+ def _load(self) -> None: |
|
57 |
+ with self.rlock: |
|
58 |
+ assert self.whisper is None |
|
59 |
+ logger.debug(f"Loading model {self.model_id}") |
|
60 |
+ start = time.perf_counter() |
|
61 |
+ self.whisper = WhisperModel( |
|
62 |
+ self.model_id, |
|
63 |
+ device=self.whisper_config.inference_device, |
|
64 |
+ device_index=self.whisper_config.device_index, |
|
65 |
+ compute_type=self.whisper_config.compute_type, |
|
66 |
+ cpu_threads=self.whisper_config.cpu_threads, |
|
67 |
+ num_workers=self.whisper_config.num_workers, |
|
68 |
+ ) |
|
69 |
+ logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s") |
|
70 |
+ |
|
71 |
+ def _increment_ref(self) -> None: |
|
72 |
+ with self.rlock: |
|
73 |
+ self.ref_count += 1 |
|
74 |
+ if self.expire_timer: |
|
75 |
+ logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling") |
|
76 |
+ self.expire_timer.cancel() |
|
77 |
+ logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}") |
|
78 |
+ |
|
79 |
+ def _decrement_ref(self) -> None: |
|
80 |
+ with self.rlock: |
|
81 |
+ self.ref_count -= 1 |
|
82 |
+ logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}") |
|
83 |
+ if self.ref_count <= 0: |
|
84 |
+ if self.whisper_config.ttl > 0: |
|
85 |
+ logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s") |
|
86 |
+ self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload) |
|
87 |
+ self.expire_timer.start() |
|
88 |
+ elif self.whisper_config.ttl == 0: |
|
89 |
+ logger.info(f"Model {self.model_id} is idle, unloading immediately") |
|
90 |
+ self.unload() |
|
91 |
+ else: |
|
92 |
+ logger.info(f"Model {self.model_id} is idle, not unloading") |
|
93 |
+ |
|
94 |
+ def __enter__(self) -> WhisperModel: |
|
95 |
+ with self.rlock: |
|
96 |
+ if self.whisper is None: |
|
97 |
+ self._load() |
|
98 |
+ self._increment_ref() |
|
99 |
+ assert self.whisper is not None |
|
100 |
+ return self.whisper |
|
101 |
+ |
|
102 |
+ def __exit__(self, *_args) -> None: # noqa: ANN002 |
|
103 |
+ self._decrement_ref() |
|
104 |
+ |
|
18 | 105 |
|
19 | 106 |
class ModelManager: |
20 |
- def __init__(self, config: Config) -> None: |
|
21 |
- self.config = config |
|
22 |
- self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict() |
|
107 |
+ def __init__(self, whisper_config: WhisperConfig) -> None: |
|
108 |
+ self.whisper_config = whisper_config |
|
109 |
+ self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict() |
|
110 |
+ self._lock = threading.Lock() |
|
23 | 111 |
|
24 |
- def load_model(self, model_name: str) -> WhisperModel: |
|
25 |
- if model_name in self.loaded_models: |
|
26 |
- logger.debug(f"{model_name} model already loaded") |
|
27 |
- return self.loaded_models[model_name] |
|
28 |
- if len(self.loaded_models) >= self.config.max_models: |
|
29 |
- oldest_model_name = next(iter(self.loaded_models)) |
|
30 |
- logger.info( |
|
31 |
- f"Max models ({self.config.max_models}) reached. Unloading the oldest model: {oldest_model_name}" |
|
112 |
+ def _handle_model_unload(self, model_name: str) -> None: |
|
113 |
+ with self._lock: |
|
114 |
+ if model_name in self.loaded_models: |
|
115 |
+ del self.loaded_models[model_name] |
|
116 |
+ |
|
117 |
+ def unload_model(self, model_name: str) -> None: |
|
118 |
+ with self._lock: |
|
119 |
+ model = self.loaded_models.get(model_name) |
|
120 |
+ if model is None: |
|
121 |
+ raise KeyError(f"Model {model_name} not found") |
|
122 |
+ self.loaded_models[model_name].unload() |
|
123 |
+ |
|
124 |
+ def load_model(self, model_name: str) -> SelfDisposingWhisperModel: |
|
125 |
+ with self._lock: |
|
126 |
+ if model_name in self.loaded_models: |
|
127 |
+ logger.debug(f"{model_name} model already loaded") |
|
128 |
+ return self.loaded_models[model_name] |
|
129 |
+ self.loaded_models[model_name] = SelfDisposingWhisperModel( |
|
130 |
+ model_name, |
|
131 |
+ self.whisper_config, |
|
132 |
+ on_unload=self._handle_model_unload, |
|
32 | 133 |
) |
33 |
- del self.loaded_models[oldest_model_name] |
|
34 |
- gc.collect() |
|
35 |
- logger.debug(f"Loading {model_name}...") |
|
36 |
- start = time.perf_counter() |
|
37 |
- # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check? |
|
38 |
- whisper = WhisperModel( |
|
39 |
- model_name, |
|
40 |
- device=self.config.whisper.inference_device, |
|
41 |
- device_index=self.config.whisper.device_index, |
|
42 |
- compute_type=self.config.whisper.compute_type, |
|
43 |
- cpu_threads=self.config.whisper.cpu_threads, |
|
44 |
- num_workers=self.config.whisper.num_workers, |
|
45 |
- ) |
|
46 |
- logger.info( |
|
47 |
- f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference." # noqa: E501 |
|
48 |
- ) |
|
49 |
- self.loaded_models[model_name] = whisper |
|
50 |
- return whisper |
|
134 |
+ return self.loaded_models[model_name] |
--- src/faster_whisper_server/routers/misc.py
+++ src/faster_whisper_server/routers/misc.py
... | ... | @@ -1,7 +1,5 @@ |
1 | 1 |
from __future__ import annotations |
2 | 2 |
|
3 |
-import gc |
|
4 |
- |
|
5 | 3 |
from fastapi import ( |
6 | 4 |
APIRouter, |
7 | 5 |
Response, |
... | ... | @@ -42,15 +40,19 @@ |
42 | 40 |
def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response: |
43 | 41 |
if model_name in model_manager.loaded_models: |
44 | 42 |
return Response(status_code=409, content="Model already loaded") |
45 |
- model_manager.load_model(model_name) |
|
43 |
+ with model_manager.load_model(model_name): |
|
44 |
+ pass |
|
46 | 45 |
return Response(status_code=201) |
47 | 46 |
|
48 | 47 |
|
49 | 48 |
@router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.") |
50 | 49 |
def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response: |
51 |
- model = model_manager.loaded_models.get(model_name) |
|
52 |
- if model is not None: |
|
53 |
- del model_manager.loaded_models[model_name] |
|
54 |
- gc.collect() |
|
50 |
+ try: |
|
51 |
+ model_manager.unload_model(model_name) |
|
55 | 52 |
return Response(status_code=204) |
56 |
- return Response(status_code=404) |
|
53 |
+ except (KeyError, ValueError) as e: |
|
54 |
+ match e: |
|
55 |
+ case KeyError(): |
|
56 |
+ return Response(status_code=404, content="Model not found") |
|
57 |
+ case ValueError(): |
|
58 |
+ return Response(status_code=409, content=str(e)) |
--- src/faster_whisper_server/routers/stt.py
+++ src/faster_whisper_server/routers/stt.py
... | ... | @@ -142,20 +142,20 @@ |
142 | 142 |
model = config.whisper.model |
143 | 143 |
if response_format is None: |
144 | 144 |
response_format = config.default_response_format |
145 |
- whisper = model_manager.load_model(model) |
|
146 |
- segments, transcription_info = whisper.transcribe( |
|
147 |
- file.file, |
|
148 |
- task=Task.TRANSLATE, |
|
149 |
- initial_prompt=prompt, |
|
150 |
- temperature=temperature, |
|
151 |
- vad_filter=vad_filter, |
|
152 |
- ) |
|
153 |
- segments = TranscriptionSegment.from_faster_whisper_segments(segments) |
|
145 |
+ with model_manager.load_model(model) as whisper: |
|
146 |
+ segments, transcription_info = whisper.transcribe( |
|
147 |
+ file.file, |
|
148 |
+ task=Task.TRANSLATE, |
|
149 |
+ initial_prompt=prompt, |
|
150 |
+ temperature=temperature, |
|
151 |
+ vad_filter=vad_filter, |
|
152 |
+ ) |
|
153 |
+ segments = TranscriptionSegment.from_faster_whisper_segments(segments) |
|
154 | 154 |
|
155 |
- if stream: |
|
156 |
- return segments_to_streaming_response(segments, transcription_info, response_format) |
|
157 |
- else: |
|
158 |
- return segments_to_response(segments, transcription_info, response_format) |
|
155 |
+ if stream: |
|
156 |
+ return segments_to_streaming_response(segments, transcription_info, response_format) |
|
157 |
+ else: |
|
158 |
+ return segments_to_response(segments, transcription_info, response_format) |
|
159 | 159 |
|
160 | 160 |
|
161 | 161 |
# HACK: Since Form() doesn't support `alias`, we need to use a workaround. |
... | ... | @@ -206,23 +206,23 @@ |
206 | 206 |
logger.warning( |
207 | 207 |
"It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501 |
208 | 208 |
) |
209 |
- whisper = model_manager.load_model(model) |
|
210 |
- segments, transcription_info = whisper.transcribe( |
|
211 |
- file.file, |
|
212 |
- task=Task.TRANSCRIBE, |
|
213 |
- language=language, |
|
214 |
- initial_prompt=prompt, |
|
215 |
- word_timestamps="word" in timestamp_granularities, |
|
216 |
- temperature=temperature, |
|
217 |
- vad_filter=vad_filter, |
|
218 |
- hotwords=hotwords, |
|
219 |
- ) |
|
220 |
- segments = TranscriptionSegment.from_faster_whisper_segments(segments) |
|
209 |
+ with model_manager.load_model(model) as whisper: |
|
210 |
+ segments, transcription_info = whisper.transcribe( |
|
211 |
+ file.file, |
|
212 |
+ task=Task.TRANSCRIBE, |
|
213 |
+ language=language, |
|
214 |
+ initial_prompt=prompt, |
|
215 |
+ word_timestamps="word" in timestamp_granularities, |
|
216 |
+ temperature=temperature, |
|
217 |
+ vad_filter=vad_filter, |
|
218 |
+ hotwords=hotwords, |
|
219 |
+ ) |
|
220 |
+ segments = TranscriptionSegment.from_faster_whisper_segments(segments) |
|
221 | 221 |
|
222 |
- if stream: |
|
223 |
- return segments_to_streaming_response(segments, transcription_info, response_format) |
|
224 |
- else: |
|
225 |
- return segments_to_response(segments, transcription_info, response_format) |
|
222 |
+ if stream: |
|
223 |
+ return segments_to_streaming_response(segments, transcription_info, response_format) |
|
224 |
+ else: |
|
225 |
+ return segments_to_response(segments, transcription_info, response_format) |
|
226 | 226 |
|
227 | 227 |
|
228 | 228 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: |
... | ... | @@ -280,24 +280,24 @@ |
280 | 280 |
"vad_filter": vad_filter, |
281 | 281 |
"condition_on_previous_text": False, |
282 | 282 |
} |
283 |
- whisper = model_manager.load_model(model) |
|
284 |
- asr = FasterWhisperASR(whisper, **transcribe_opts) |
|
285 |
- audio_stream = AudioStream() |
|
286 |
- async with asyncio.TaskGroup() as tg: |
|
287 |
- tg.create_task(audio_receiver(ws, audio_stream)) |
|
288 |
- async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration): |
|
289 |
- logger.debug(f"Sending transcription: {transcription.text}") |
|
290 |
- if ws.client_state == WebSocketState.DISCONNECTED: |
|
291 |
- break |
|
283 |
+ with model_manager.load_model(model) as whisper: |
|
284 |
+ asr = FasterWhisperASR(whisper, **transcribe_opts) |
|
285 |
+ audio_stream = AudioStream() |
|
286 |
+ async with asyncio.TaskGroup() as tg: |
|
287 |
+ tg.create_task(audio_receiver(ws, audio_stream)) |
|
288 |
+ async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration): |
|
289 |
+ logger.debug(f"Sending transcription: {transcription.text}") |
|
290 |
+ if ws.client_state == WebSocketState.DISCONNECTED: |
|
291 |
+ break |
|
292 | 292 |
|
293 |
- if response_format == ResponseFormat.TEXT: |
|
294 |
- await ws.send_text(transcription.text) |
|
295 |
- elif response_format == ResponseFormat.JSON: |
|
296 |
- await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump()) |
|
297 |
- elif response_format == ResponseFormat.VERBOSE_JSON: |
|
298 |
- await ws.send_json( |
|
299 |
- CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump() |
|
300 |
- ) |
|
293 |
+ if response_format == ResponseFormat.TEXT: |
|
294 |
+ await ws.send_text(transcription.text) |
|
295 |
+ elif response_format == ResponseFormat.JSON: |
|
296 |
+ await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump()) |
|
297 |
+ elif response_format == ResponseFormat.VERBOSE_JSON: |
|
298 |
+ await ws.send_json( |
|
299 |
+ CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump() |
|
300 |
+ ) |
|
301 | 301 |
|
302 | 302 |
if ws.client_state != WebSocketState.DISCONNECTED: |
303 | 303 |
logger.info("Closing the connection.") |
+++ tests/model_manager_test.py
... | ... | @@ -0,0 +1,122 @@ |
1 | +import asyncio | |
2 | +import os | |
3 | + | |
4 | +import anyio | |
5 | +from httpx import ASGITransport, AsyncClient | |
6 | +import pytest | |
7 | + | |
8 | +from faster_whisper_server.main import create_app | |
9 | + | |
10 | + | |
11 | +@pytest.mark.asyncio | |
12 | +async def test_model_unloaded_after_ttl() -> None: | |
13 | + ttl = 5 | |
14 | + model = "Systran/faster-whisper-tiny.en" | |
15 | + os.environ["WHISPER__TTL"] = str(ttl) | |
16 | + os.environ["ENABLE_UI"] = "false" | |
17 | + async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: | |
18 | + res = (await aclient.get("/api/ps")).json() | |
19 | + assert len(res["models"]) == 0 | |
20 | + await aclient.post(f"/api/ps/{model}") | |
21 | + res = (await aclient.get("/api/ps")).json() | |
22 | + assert len(res["models"]) == 1 | |
23 | + await asyncio.sleep(ttl + 1) | |
24 | + res = (await aclient.get("/api/ps")).json() | |
25 | + assert len(res["models"]) == 0 | |
26 | + | |
27 | + | |
28 | +@pytest.mark.asyncio | |
29 | +async def test_ttl_resets_after_usage() -> None: | |
30 | + ttl = 5 | |
31 | + model = "Systran/faster-whisper-tiny.en" | |
32 | + os.environ["WHISPER__TTL"] = str(ttl) | |
33 | + os.environ["ENABLE_UI"] = "false" | |
34 | + async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: | |
35 | + await aclient.post(f"/api/ps/{model}") | |
36 | + res = (await aclient.get("/api/ps")).json() | |
37 | + assert len(res["models"]) == 1 | |
38 | + await asyncio.sleep(ttl - 2) | |
39 | + res = (await aclient.get("/api/ps")).json() | |
40 | + assert len(res["models"]) == 1 | |
41 | + | |
42 | + async with await anyio.open_file("audio.wav", "rb") as f: | |
43 | + data = await f.read() | |
44 | + res = ( | |
45 | + await aclient.post( | |
46 | + "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} | |
47 | + ) | |
48 | + ).json() | |
49 | + res = (await aclient.get("/api/ps")).json() | |
50 | + assert len(res["models"]) == 1 | |
51 | + await asyncio.sleep(ttl - 2) | |
52 | + res = (await aclient.get("/api/ps")).json() | |
53 | + assert len(res["models"]) == 1 | |
54 | + | |
55 | + await asyncio.sleep(3) | |
56 | + res = (await aclient.get("/api/ps")).json() | |
57 | + assert len(res["models"]) == 0 | |
58 | + | |
59 | + # test the model can be used again after being unloaded | |
60 | + # this just ensures the model can be loaded again after being unloaded | |
61 | + res = ( | |
62 | + await aclient.post( | |
63 | + "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} | |
64 | + ) | |
65 | + ).json() | |
66 | + | |
67 | + | |
68 | +@pytest.mark.asyncio | |
69 | +async def test_model_cant_be_unloaded_when_used() -> None: | |
70 | + ttl = 0 | |
71 | + model = "Systran/faster-whisper-tiny.en" | |
72 | + os.environ["WHISPER__TTL"] = str(ttl) | |
73 | + os.environ["ENABLE_UI"] = "false" | |
74 | + async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: | |
75 | + async with await anyio.open_file("audio.wav", "rb") as f: | |
76 | + data = await f.read() | |
77 | + | |
78 | + task = asyncio.create_task( | |
79 | + aclient.post( | |
80 | + "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model} | |
81 | + ) | |
82 | + ) | |
83 | + await asyncio.sleep(0.01) | |
84 | + res = await aclient.delete(f"/api/ps/{model}") | |
85 | + assert res.status_code == 409 | |
86 | + | |
87 | + await task | |
88 | + res = (await aclient.get("/api/ps")).json() | |
89 | + assert len(res["models"]) == 0 | |
90 | + | |
91 | + | |
92 | +@pytest.mark.asyncio | |
93 | +async def test_model_cant_be_loaded_twice() -> None: | |
94 | + ttl = -1 | |
95 | + model = "Systran/faster-whisper-tiny.en" | |
96 | + os.environ["ENABLE_UI"] = "false" | |
97 | + os.environ["WHISPER__TTL"] = str(ttl) | |
98 | + async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: | |
99 | + res = await aclient.post(f"/api/ps/{model}") | |
100 | + assert res.status_code == 201 | |
101 | + res = await aclient.post(f"/api/ps/{model}") | |
102 | + assert res.status_code == 409 | |
103 | + res = (await aclient.get("/api/ps")).json() | |
104 | + assert len(res["models"]) == 1 | |
105 | + | |
106 | + | |
107 | +@pytest.mark.asyncio | |
108 | +async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None: | |
109 | + ttl = 0 | |
110 | + os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en" | |
111 | + os.environ["WHISPER__TTL"] = str(ttl) | |
112 | + os.environ["ENABLE_UI"] = "false" | |
113 | + async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient: | |
114 | + async with await anyio.open_file("audio.wav", "rb") as f: | |
115 | + data = await f.read() | |
116 | + res = await aclient.post( | |
117 | + "/v1/audio/transcriptions", | |
118 | + files={"file": ("audio.wav", data, "audio/wav")}, | |
119 | + data={"model": "Systran/faster-whisper-tiny.en"}, | |
120 | + ) | |
121 | + res = (await aclient.get("/api/ps")).json() | |
122 | + assert len(res["models"]) == 0 |
--- uv.lock
+++ uv.lock
... | ... | @@ -293,6 +293,7 @@ |
293 | 293 |
{ name = "anyio" }, |
294 | 294 |
{ name = "basedpyright" }, |
295 | 295 |
{ name = "pytest" }, |
296 |
+ { name = "pytest-antilru" }, |
|
296 | 297 |
{ name = "pytest-asyncio" }, |
297 | 298 |
{ name = "pytest-xdist" }, |
298 | 299 |
{ name = "ruff" }, |
... | ... | @@ -322,6 +323,7 @@ |
322 | 323 |
{ name = "pydantic", specifier = ">=2.9.0" }, |
323 | 324 |
{ name = "pydantic-settings", specifier = ">=2.5.2" }, |
324 | 325 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" }, |
326 |
+ { name = "pytest-antilru", marker = "extra == 'dev'", specifier = ">=2.0.0" }, |
|
325 | 327 |
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" }, |
326 | 328 |
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.6.1" }, |
327 | 329 |
{ name = "python-multipart", specifier = ">=0.0.10" }, |
... | ... | @@ -3483,6 +3485,18 @@ |
3483 | 3485 |
] |
3484 | 3486 |
|
3485 | 3487 |
[[package]] |
3488 |
+name = "pytest-antilru" |
|
3489 |
+version = "2.0.0" |
|
3490 |
+source = { registry = "https://pypi.org/simple" } |
|
3491 |
+dependencies = [ |
|
3492 |
+ { name = "pytest" }, |
|
3493 |
+] |
|
3494 |
+sdist = { url = "https://files.pythonhosted.org/packages/c6/01/0b5ef3f143f335b5cb1c1e8e6497769dfb48aed5a791b5dfd119151e2b15/pytest_antilru-2.0.0.tar.gz", hash = "sha256:48cff342648b6a1ce4e5398cf203966905d546b3f2bee7bb55d7cb3ec87a85fb", size = 5569 } |
|
3495 |
+wheels = [ |
|
3496 |
+ { url = "https://files.pythonhosted.org/packages/23/f0/fc9f5aaaf2818a7d7f795e99fcf59719dd6ec5f98005e642e1efd63ad2a4/pytest_antilru-2.0.0-py3-none-any.whl", hash = "sha256:cf1d97db0e7b17ef568c1f0bf4c89b8748053fe07546f4eb2558bebf64c1ad33", size = 6301 }, |
|
3497 |
+] |
|
3498 |
+ |
|
3499 |
+[[package]] |
|
3486 | 3500 |
name = "pytest-asyncio" |
3487 | 3501 |
version = "0.24.0" |
3488 | 3502 |
source = { registry = "https://pypi.org/simple" } |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?