Fedir Zadniprovskyi 2024-12-17
feat: add chat, speech, and transcription dependencies
@dc7bb9d51ff1e10178144b51762f168491886dec
src/faster_whisper_server/config.py
--- src/faster_whisper_server/config.py
+++ src/faster_whisper_server/config.py
@@ -236,3 +236,14 @@
     Controls how many latest seconds of audio are being passed through VAD.
     Should be greater than `max_inactivity_seconds`
     """
+
+    chat_completion_base_url: str = "https://api.openai.com/v1"
+    chat_completion_api_key: str | None = None
+
+    speech_base_url: str | None = None
+    speech_api_key: str | None = None
+    speech_model: str = "piper"
+    speech_extra_body: dict = {"sample_rate": 24000}
+
+    transcription_base_url: str | None = None
+    transcription_api_key: str | None = None
src/faster_whisper_server/dependencies.py
--- src/faster_whisper_server/dependencies.py
+++ src/faster_whisper_server/dependencies.py
@@ -3,6 +3,10 @@
 
 from fastapi import Depends, HTTPException, status
 from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from httpx import ASGITransport, AsyncClient
+from openai import AsyncOpenAI
+from openai.resources.audio import AsyncSpeech, AsyncTranscriptions
+from openai.resources.chat.completions import AsyncCompletions
 
 from faster_whisper_server.config import Config
 from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
@@ -45,3 +49,56 @@
 
 
 ApiKeyDependency = Depends(verify_api_key)
+
+
+@lru_cache
+def get_completion_client() -> AsyncCompletions:
+    config = get_config()  # HACK
+    oai_client = AsyncOpenAI(base_url=config.chat_completion_base_url, api_key=config.chat_completion_api_key)
+    return oai_client.chat.completions
+
+
+CompletionClientDependency = Annotated[AsyncCompletions, Depends(get_completion_client)]
+
+
+@lru_cache
+def get_speech_client() -> AsyncSpeech:
+    config = get_config()  # HACK
+    if config.speech_base_url is None:
+        # this might not work as expected if the `speech_router` won't have shared state with the main FastAPI `app`. TODO: verify  # noqa: E501
+        from faster_whisper_server.routers.speech import (
+            router as speech_router,
+        )
+
+        http_client = AsyncClient(
+            transport=ASGITransport(speech_router), base_url="http://test/v1"
+        )  # NOTE: "test" can be replaced with any other value
+        oai_client = AsyncOpenAI(http_client=http_client, api_key=config.speech_api_key)
+    else:
+        oai_client = AsyncOpenAI(base_url=config.speech_base_url, api_key=config.speech_api_key)
+    return oai_client.audio.speech
+
+
+SpeechClientDependency = Annotated[AsyncSpeech, Depends(get_speech_client)]
+
+
+@lru_cache
+def get_transcription_client() -> AsyncTranscriptions:
+    config = get_config()
+    if config.transcription_base_url is None:
+        # this might not work as expected if the `transcription_router` won't have shared state with the main FastAPI `app`. TODO: verify  # noqa: E501
+        from faster_whisper_server.routers.stt import (
+            router as stt_router,
+        )
+
+        http_client = AsyncClient(
+            transport=ASGITransport(stt_router), base_url="http://test/v1"
+        )  # NOTE: "test" can be replaced with any other value
+
+        oai_client = AsyncOpenAI(http_client=http_client, api_key=config.transcription_api_key)
+    else:
+        oai_client = AsyncOpenAI(base_url=config.transcription_base_url, api_key=config.transcription_api_key)
+    return oai_client.audio.transcriptions
+
+
+TranscriptionClientDependency = Annotated[AsyncTranscriptions, Depends(get_transcription_client)]
Add a comment
List