Fedir Zadniprovskyi 2024-09-08
feat: api route to download model
@9ccd70484d6d0840d8a0d6dc417189aafb82260d
src/faster_whisper_server/main.py
--- src/faster_whisper_server/main.py
+++ src/faster_whisper_server/main.py
@@ -25,8 +25,10 @@
 from faster_whisper import WhisperModel
 from faster_whisper.vad import VadOptions, get_speech_timestamps
 import huggingface_hub
+from huggingface_hub.hf_api import RepositoryNotFoundError
 from pydantic import AfterValidator
 
+from faster_whisper_server import hf_utils
 from faster_whisper_server.asr import FasterWhisperASR
 from faster_whisper_server.audio import AudioStream, audio_samples_from_file
 from faster_whisper_server.config import (
@@ -108,6 +110,17 @@
     return Response(status_code=200, content="OK")
 
 
+@app.post("/api/pull/{model_name:path}", tags=["experimental"], summary="Download a model from Hugging Face.")
+def pull_model(model_name: str) -> Response:
+    if hf_utils.does_local_model_exist(model_name):
+        return Response(status_code=200, content="Model already exists")
+    try:
+        huggingface_hub.snapshot_download(model_name, repo_type="model")
+    except RepositoryNotFoundError as e:
+        return Response(status_code=404, content=str(e))
+    return Response(status_code=201, content="Model downloaded")
+
+
 @app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
 def get_running_models() -> dict[str, list[str]]:
     return {"models": list(loaded_models.keys())}
Add a comment
List