Fedir Zadniprovskyi 2024-10-23
feat: tts
@1d2b5185ac43a404809126462384c83d9222b882
src/faster_whisper_server/dependencies.py
--- src/faster_whisper_server/dependencies.py
+++ src/faster_whisper_server/dependencies.py
@@ -4,7 +4,7 @@
 from fastapi import Depends
 
 from faster_whisper_server.config import Config
-from faster_whisper_server.model_manager import ModelManager
+from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
 
 
 @lru_cache
@@ -16,9 +16,18 @@
 
 
 @lru_cache
-def get_model_manager() -> ModelManager:
+def get_model_manager() -> WhisperModelManager:
     config = get_config()  # HACK
-    return ModelManager(config.whisper)
+    return WhisperModelManager(config.whisper)
 
 
-ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
+ModelManagerDependency = Annotated[WhisperModelManager, Depends(get_model_manager)]
+
+
+@lru_cache
+def get_piper_model_manager() -> PiperModelManager:
+    config = get_config()  # HACK
+    return PiperModelManager(config.whisper.ttl)  # HACK
+
+
+PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
src/faster_whisper_server/hf_utils.py
--- src/faster_whisper_server/hf_utils.py
+++ src/faster_whisper_server/hf_utils.py
@@ -1,9 +1,16 @@
 from collections.abc import Generator
+from functools import lru_cache
+import json
 import logging
 from pathlib import Path
 import typing
+from typing import Any, Literal
 
 import huggingface_hub
+from huggingface_hub.constants import HF_HUB_CACHE
+from pydantic import BaseModel
+
+from faster_whisper_server.api_models import Model
 
 logger = logging.getLogger(__name__)
 
@@ -12,10 +19,36 @@
 
 
 def does_local_model_exist(model_id: str) -> bool:
-    return any(model_id == model.repo_id for model, _ in list_local_models())
+    return any(model_id == model.repo_id for model, _ in list_local_whisper_models())
 
 
-def list_local_models() -> Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None]:
+def list_whisper_models() -> Generator[Model, None, None]:
+    models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
+    models = list(models)
+    models.sort(key=lambda model: model.downloads or -1, reverse=True)
+    for model in models:
+        assert model.created_at is not None
+        assert model.card_data is not None
+        assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
+        if model.card_data.language is None:
+            language = []
+        elif isinstance(model.card_data.language, str):
+            language = [model.card_data.language]
+        else:
+            language = model.card_data.language
+        transformed_model = Model(
+            id=model.id,
+            created=int(model.created_at.timestamp()),
+            object_="model",
+            owned_by=model.id.split("/")[0],
+            language=language,
+        )
+        yield transformed_model
+
+
+def list_local_whisper_models() -> (
+    Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None]
+):
     hf_cache = huggingface_hub.scan_cache_dir()
     hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"]
     for model in hf_models:
@@ -36,3 +69,129 @@
             and TASK_NAME in model_card_data.tags
         ):
             yield model, model_card_data
+
+
+def get_whisper_models() -> Generator[Model, None, None]:
+    models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
+    models = list(models)
+    models.sort(key=lambda model: model.downloads or -1, reverse=True)
+    for model in models:
+        assert model.created_at is not None
+        assert model.card_data is not None
+        assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
+        if model.card_data.language is None:
+            language = []
+        elif isinstance(model.card_data.language, str):
+            language = [model.card_data.language]
+        else:
+            language = model.card_data.language
+        transformed_model = Model(
+            id=model.id,
+            created=int(model.created_at.timestamp()),
+            object_="model",
+            owned_by=model.id.split("/")[0],
+            language=language,
+        )
+        yield transformed_model
+
+
+class PiperModel(BaseModel):
+    id: str
+    object: Literal["model"] = "model"
+    created: int
+    owned_by: Literal["rhasspy"] = "rhasspy"
+    path: Path
+    config_path: Path
+
+
+def get_model_path(model_id: str, *, cache_dir: str | Path | None = None) -> Path | None:
+    if cache_dir is None:
+        cache_dir = HF_HUB_CACHE
+
+    cache_dir = Path(cache_dir).expanduser().resolve()
+    if not cache_dir.exists():
+        raise huggingface_hub.CacheNotFound(
+            f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.",  # noqa: E501
+            cache_dir=cache_dir,
+        )
+
+    if cache_dir.is_file():
+        raise ValueError(
+            f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable."  # noqa: E501
+        )
+
+    for repo_path in cache_dir.iterdir():
+        if not repo_path.is_dir():
+            continue
+        if repo_path.name == ".locks":  # skip './.locks/' folder
+            continue
+        repo_type, repo_id = repo_path.name.split("--", maxsplit=1)
+        repo_type = repo_type[:-1]  # "models" -> "model"
+        repo_id = repo_id.replace("--", "/")  # google--fleurs -> "google/fleurs"
+        if repo_type != "model":
+            continue
+        if model_id == repo_id:
+            return repo_path
+
+    return None
+
+
+def list_model_files(
+    model_id: str, glob_pattern: str = "**/*", *, cache_dir: str | Path | None = None
+) -> Generator[Path, None, None]:
+    repo_path = get_model_path(model_id, cache_dir=cache_dir)
+    if repo_path is None:
+        return None
+    snapshots_path = repo_path / "snapshots"
+    if not snapshots_path.exists():
+        return None
+    yield from list(snapshots_path.glob(glob_pattern))
+
+
+def list_piper_models() -> Generator[PiperModel, None, None]:
+    model_weights_files = list_model_files("rhasspy/piper-voices", glob_pattern="**/*.onnx")
+    for model_weights_file in model_weights_files:
+        model_config_file = model_weights_file.with_suffix(".json")
+        yield PiperModel(
+            id=model_weights_file.name,
+            created=int(model_weights_file.stat().st_mtime),
+            path=model_weights_file,
+            config_path=model_config_file,
+        )
+
+
+# NOTE: It's debatable whether caching should be done here or by the caller. Should be revisited.
+
+
+@lru_cache
+def read_piper_voices_config() -> dict[str, Any]:
+    voices_file = next(list_model_files("rhasspy/piper-voices", glob_pattern="**/voices.json"), None)
+    if voices_file is None:
+        raise FileNotFoundError("Could not find voices.json file")  # noqa: EM101
+    return json.loads(voices_file.read_text())
+
+
+@lru_cache
+def get_piper_voice_model_file(voice: str) -> Path:
+    model_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx"), None)
+    if model_file is None:
+        raise FileNotFoundError(f"Could not find model file for '{voice}' voice")
+    return model_file
+
+
+class PiperVoiceConfigAudio(BaseModel):
+    sample_rate: int
+    quality: int
+
+
+class PiperVoiceConfig(BaseModel):
+    audio: PiperVoiceConfigAudio
+    # NOTE: there are more fields in the config, but we don't care about them
+
+
+@lru_cache
+def read_piper_voice_config(voice: str) -> PiperVoiceConfig:
+    model_config_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx.json"), None)
+    if model_config_file is None:
+        raise FileNotFoundError(f"Could not find config file for '{voice}' voice")
+    return PiperVoiceConfig.model_validate_json(model_config_file.read_text())
src/faster_whisper_server/main.py
--- src/faster_whisper_server/main.py
+++ src/faster_whisper_server/main.py
@@ -2,6 +2,7 @@
 
 from contextlib import asynccontextmanager
 import logging
+import platform
 from typing import TYPE_CHECKING
 
 from fastapi import (
@@ -30,6 +31,14 @@
 
     logger = logging.getLogger(__name__)
 
+    if platform.machine() == "x86_64":
+        from faster_whisper_server.routers.speech import (
+            router as speech_router,
+        )
+    else:
+        logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
+        speech_router = None
+
     config = get_config()  # HACK
     logger.debug(f"Config: {config}")
 
@@ -46,6 +55,8 @@
     app.include_router(stt_router)
     app.include_router(list_models_router)
     app.include_router(misc_router)
+    if speech_router is not None:
+        app.include_router(speech_router)
 
     if config.allow_origins is not None:
         app.add_middleware(
src/faster_whisper_server/model_manager.py
--- src/faster_whisper_server/model_manager.py
+++ src/faster_whisper_server/model_manager.py
@@ -9,8 +9,12 @@
 
 from faster_whisper import WhisperModel
 
+from faster_whisper_server.hf_utils import get_piper_voice_model_file
+
 if TYPE_CHECKING:
     from collections.abc import Callable
+
+    from piper.voice import PiperVoice
 
     from faster_whisper_server.config import (
         WhisperConfig,
@@ -18,54 +22,45 @@
 
 logger = logging.getLogger(__name__)
 
+
 # TODO: enable concurrent model downloads
 
 
-class SelfDisposingWhisperModel:
+class SelfDisposingModel[T]:
     def __init__(
-        self,
-        model_id: str,
-        whisper_config: WhisperConfig,
-        *,
-        on_unload: Callable[[str], None] | None = None,
+        self, model_id: str, load_fn: Callable[[], T], ttl: int, unload_fn: Callable[[str], None] | None = None
     ) -> None:
         self.model_id = model_id
-        self.whisper_config = whisper_config
-        self.on_unload = on_unload
+        self.load_fn = load_fn
+        self.ttl = ttl
+        self.unload_fn = unload_fn
 
         self.ref_count: int = 0
         self.rlock = threading.RLock()
         self.expire_timer: threading.Timer | None = None
-        self.whisper: WhisperModel | None = None
+        self.model: T | None = None
 
     def unload(self) -> None:
         with self.rlock:
-            if self.whisper is None:
+            if self.model is None:
                 raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
             if self.ref_count > 0:
                 raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
             if self.expire_timer:
                 self.expire_timer.cancel()
-            self.whisper = None
+            self.model = None
             # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
             gc.collect()
             logger.info(f"Model {self.model_id} unloaded")
-            if self.on_unload is not None:
-                self.on_unload(self.model_id)
+            if self.unload_fn is not None:
+                self.unload_fn(self.model_id)
 
     def _load(self) -> None:
         with self.rlock:
-            assert self.whisper is None
+            assert self.model is None
             logger.debug(f"Loading model {self.model_id}")
             start = time.perf_counter()
-            self.whisper = WhisperModel(
-                self.model_id,
-                device=self.whisper_config.inference_device,
-                device_index=self.whisper_config.device_index,
-                compute_type=self.whisper_config.compute_type,
-                cpu_threads=self.whisper_config.cpu_threads,
-                num_workers=self.whisper_config.num_workers,
-            )
+            self.model = self.load_fn()
             logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
 
     def _increment_ref(self) -> None:
@@ -81,33 +76,43 @@
             self.ref_count -= 1
             logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
             if self.ref_count <= 0:
-                if self.whisper_config.ttl > 0:
-                    logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
-                    self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
+                if self.ttl > 0:
+                    logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.ttl}s")
+                    self.expire_timer = threading.Timer(self.ttl, self.unload)
                     self.expire_timer.start()
-                elif self.whisper_config.ttl == 0:
+                elif self.ttl == 0:
                     logger.info(f"Model {self.model_id} is idle, unloading immediately")
                     self.unload()
                 else:
                     logger.info(f"Model {self.model_id} is idle, not unloading")
 
-    def __enter__(self) -> WhisperModel:
+    def __enter__(self) -> T:
         with self.rlock:
-            if self.whisper is None:
+            if self.model is None:
                 self._load()
             self._increment_ref()
-            assert self.whisper is not None
-            return self.whisper
+            assert self.model is not None
+            return self.model
 
     def __exit__(self, *_args) -> None:  # noqa: ANN002
         self._decrement_ref()
 
 
-class ModelManager:
+class WhisperModelManager:
     def __init__(self, whisper_config: WhisperConfig) -> None:
         self.whisper_config = whisper_config
-        self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
+        self.loaded_models: OrderedDict[str, SelfDisposingModel[WhisperModel]] = OrderedDict()
         self._lock = threading.Lock()
+
+    def _load_fn(self, model_id: str) -> WhisperModel:
+        return WhisperModel(
+            model_id,
+            device=self.whisper_config.inference_device,
+            device_index=self.whisper_config.device_index,
+            compute_type=self.whisper_config.compute_type,
+            cpu_threads=self.whisper_config.cpu_threads,
+            num_workers=self.whisper_config.num_workers,
+        )
 
     def _handle_model_unload(self, model_name: str) -> None:
         with self._lock:
@@ -121,14 +126,57 @@
                 raise KeyError(f"Model {model_name} not found")
             self.loaded_models[model_name].unload()
 
-    def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
+    def load_model(self, model_name: str) -> SelfDisposingModel[WhisperModel]:
+        logger.debug(f"Loading model {model_name}")
+        with self._lock:
+            logger.debug("Acquired lock")
+            if model_name in self.loaded_models:
+                logger.debug(f"{model_name} model already loaded")
+                return self.loaded_models[model_name]
+            self.loaded_models[model_name] = SelfDisposingModel[WhisperModel](
+                model_name,
+                load_fn=lambda: self._load_fn(model_name),
+                ttl=self.whisper_config.ttl,
+                unload_fn=self._handle_model_unload,
+            )
+            return self.loaded_models[model_name]
+
+
+class PiperModelManager:
+    def __init__(self, ttl: int) -> None:
+        self.ttl = ttl
+        self.loaded_models: OrderedDict[str, SelfDisposingModel[PiperVoice]] = OrderedDict()
+        self._lock = threading.Lock()
+
+    def _load_fn(self, model_id: str) -> PiperVoice:
+        from piper.voice import PiperVoice
+
+        model_path = get_piper_voice_model_file(model_id)
+        return PiperVoice.load(model_path)
+
+    def _handle_model_unload(self, model_name: str) -> None:
+        with self._lock:
+            if model_name in self.loaded_models:
+                del self.loaded_models[model_name]
+
+    def unload_model(self, model_name: str) -> None:
+        with self._lock:
+            model = self.loaded_models.get(model_name)
+            if model is None:
+                raise KeyError(f"Model {model_name} not found")
+            self.loaded_models[model_name].unload()
+
+    def load_model(self, model_name: str) -> SelfDisposingModel[PiperVoice]:
+        from piper.voice import PiperVoice
+
         with self._lock:
             if model_name in self.loaded_models:
                 logger.debug(f"{model_name} model already loaded")
                 return self.loaded_models[model_name]
-            self.loaded_models[model_name] = SelfDisposingWhisperModel(
+            self.loaded_models[model_name] = SelfDisposingModel[PiperVoice](
                 model_name,
-                self.whisper_config,
-                on_unload=self._handle_model_unload,
+                load_fn=lambda: self._load_fn(model_name),
+                ttl=self.ttl,
+                unload_fn=self._handle_model_unload,
             )
             return self.loaded_models[model_name]
src/faster_whisper_server/routers/list_models.py
--- src/faster_whisper_server/routers/list_models.py
+++ src/faster_whisper_server/routers/list_models.py
@@ -13,6 +13,7 @@
     ListModelsResponse,
     Model,
 )
+from faster_whisper_server.hf_utils import list_whisper_models
 
 if TYPE_CHECKING:
     from huggingface_hub.hf_api import ModelInfo
@@ -22,34 +23,13 @@
 
 @router.get("/v1/models")
 def get_models() -> ListModelsResponse:
-    models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
-    models = list(models)
-    models.sort(key=lambda model: model.downloads or -1, reverse=True)
-    transformed_models: list[Model] = []
-    for model in models:
-        assert model.created_at is not None
-        assert model.card_data is not None
-        assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
-        if model.card_data.language is None:
-            language = []
-        elif isinstance(model.card_data.language, str):
-            language = [model.card_data.language]
-        else:
-            language = model.card_data.language
-        transformed_model = Model(
-            id=model.id,
-            created=int(model.created_at.timestamp()),
-            object_="model",
-            owned_by=model.id.split("/")[0],
-            language=language,
-        )
-        transformed_models.append(transformed_model)
-    return ListModelsResponse(data=transformed_models)
+    whisper_models = list(list_whisper_models())
+    return ListModelsResponse(data=whisper_models)
 
 
 @router.get("/v1/models/{model_name:path}")
-# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
 def get_model(
+    # NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
     model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
 ) -> Model:
     models = huggingface_hub.list_models(
 
src/faster_whisper_server/routers/speech.py (added)
+++ src/faster_whisper_server/routers/speech.py
@@ -0,0 +1,164 @@
+from collections.abc import Generator
+import io
+import logging
+import time
+from typing import Annotated, Literal, Self
+
+from fastapi import APIRouter
+from fastapi.responses import StreamingResponse
+import numpy as np
+from piper.voice import PiperVoice
+from pydantic import BaseModel, BeforeValidator, Field, ValidationError, model_validator
+import soundfile as sf
+
+from faster_whisper_server.dependencies import PiperModelManagerDependency
+from faster_whisper_server.hf_utils import read_piper_voices_config
+
+DEFAULT_MODEL = "piper"
+# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-response_format
+DEFAULT_RESPONSE_FORMAT = "mp3"
+DEFAULT_VOICE = "en_US-amy-medium"  # TODO: make configurable
+DEFAULT_VOICE_SAMPLE_RATE = 22050  # NOTE: Dependant on the voice
+
+# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-model
+# https://platform.openai.com/docs/models/tts
+OPENAI_SUPPORTED_SPEECH_MODEL = ("tts-1", "tts-1-hd")
+
+# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-voice
+# https://platform.openai.com/docs/guides/text-to-speech/voice-options
+OPENAI_SUPPORTED_SPEECH_VOICE_NAMES = ("alloy", "echo", "fable", "onyx", "nova", "shimmer")
+
+# https://platform.openai.com/docs/guides/text-to-speech/supported-output-formats
+type ResponseFormat = Literal["mp3", "flac", "wav", "pcm"]
+SUPPORTED_RESPONSE_FORMATS = ("mp3", "flac", "wav", "pcm")
+UNSUPORTED_RESPONSE_FORMATS = ("opus", "aac")
+
+MIN_SAMPLE_RATE = 8000
+MAX_SAMPLE_RATE = 48000
+
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+
+# aip 'Write a function `resample_audio` which would take in RAW PCM 16-bit signed, little-endian audio data represented as bytes (`audio_bytes`) and resample it (either downsample or upsample) from `sample_rate` to `target_sample_rate` using numpy'  # noqa: E501
+def resample_audio(audio_bytes: bytes, sample_rate: int, target_sample_rate: int) -> bytes:
+    audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
+    duration = len(audio_data) / sample_rate
+    target_length = int(duration * target_sample_rate)
+    resampled_data = np.interp(
+        np.linspace(0, len(audio_data), target_length, endpoint=False), np.arange(len(audio_data)), audio_data
+    )
+    return resampled_data.astype(np.int16).tobytes()
+
+
+def generate_audio(
+    piper_tts: PiperVoice, text: str, *, speed: float = 1.0, sample_rate: int | None = None
+) -> Generator[bytes, None, None]:
+    if sample_rate is None:
+        sample_rate = piper_tts.config.sample_rate
+    start = time.perf_counter()
+    for audio_bytes in piper_tts.synthesize_stream_raw(text, length_scale=1.0 / speed):
+        if sample_rate != piper_tts.config.sample_rate:
+            audio_bytes = resample_audio(audio_bytes, piper_tts.config.sample_rate, sample_rate)  # noqa: PLW2901
+        yield audio_bytes
+    logger.info(f"Generated audio for {len(text)} characters in {time.perf_counter() - start}s")
+
+
+def convert_audio_format(
+    audio_bytes: bytes,
+    sample_rate: int,
+    audio_format: ResponseFormat,
+    format: str = "RAW",  # noqa: A002
+    channels: int = 1,
+    subtype: str = "PCM_16",
+    endian: str = "LITTLE",
+) -> bytes:
+    # NOTE: the default dtype is float64. Should something else be used? Would that improve performance?
+    data, _ = sf.read(
+        io.BytesIO(audio_bytes),
+        samplerate=sample_rate,
+        format=format,
+        channels=channels,
+        subtype=subtype,
+        endian=endian,
+    )
+    converted_audio_bytes_buffer = io.BytesIO()
+    sf.write(converted_audio_bytes_buffer, data, samplerate=sample_rate, format=audio_format)
+    return converted_audio_bytes_buffer.getvalue()
+
+
+def handle_openai_supported_model_ids(model_id: str) -> str:
+    if model_id in OPENAI_SUPPORTED_SPEECH_MODEL:
+        logger.warning(f"{model_id} is not a valid model name. Using '{DEFAULT_MODEL}' instead.")
+        return DEFAULT_MODEL
+    return model_id
+
+
+ModelId = Annotated[
+    Literal["piper"],
+    BeforeValidator(handle_openai_supported_model_ids),
+    Field(
+        description=f"The ID of the model. The only supported model is '{DEFAULT_MODEL}'.",
+        examples=[DEFAULT_MODEL],
+    ),
+]
+
+
+def handle_openai_supported_voices(voice: str) -> str:
+    if voice in OPENAI_SUPPORTED_SPEECH_VOICE_NAMES:
+        logger.warning(f"{voice} is not a valid voice name. Using '{DEFAULT_VOICE}' instead.")
+        return DEFAULT_VOICE
+    return voice
+
+
+Voice = Annotated[str, BeforeValidator(handle_openai_supported_voices)]  # TODO: description and examples
+
+
+class CreateSpeechRequestBody(BaseModel):
+    model: ModelId = DEFAULT_MODEL
+    input: str = Field(
+        ...,
+        description="The text to generate audio for. ",
+        examples=[
+            "A rainbow is an optical phenomenon caused by refraction, internal reflection and dispersion of light in water droplets resulting in a continuous spectrum of light appearing in the sky. The rainbow takes the form of a multicoloured circular arc. Rainbows caused by sunlight always appear in the section of sky directly opposite the Sun. Rainbows can be caused by many forms of airborne water. These include not only rain, but also mist, spray, and airborne dew."  # noqa: E501
+        ],
+    )
+    voice: Voice = DEFAULT_VOICE
+    response_format: ResponseFormat = Field(
+        DEFAULT_RESPONSE_FORMAT,
+        description=f"The format to audio in. Supported formats are {", ".join(SUPPORTED_RESPONSE_FORMATS)}. {", ".join(UNSUPORTED_RESPONSE_FORMATS)} are not supported",  # noqa: E501
+        examples=list(SUPPORTED_RESPONSE_FORMATS),
+    )
+    # https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-voice
+    speed: float = Field(1.0, ge=0.25, le=4.0)
+    """The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default."""
+    sample_rate: int | None = Field(None, ge=MIN_SAMPLE_RATE, le=MAX_SAMPLE_RATE)  # TODO: document
+
+    # TODO: move into `Voice`
+    @model_validator(mode="after")
+    def verify_voice_is_valid(self) -> Self:
+        valid_voices = read_piper_voices_config()
+        if self.voice not in valid_voices:
+            raise ValidationError(f"Voice '{self.voice}' is not supported. Supported voices: {valid_voices.keys()}")
+        return self
+
+
+# https://platform.openai.com/docs/api-reference/audio/createSpeech
+@router.post("/v1/audio/speech")
+def synthesize(
+    piper_model_manager: PiperModelManagerDependency,
+    body: CreateSpeechRequestBody,
+) -> StreamingResponse:
+    with piper_model_manager.load_model(body.voice) as piper_tts:
+        audio_generator = generate_audio(piper_tts, body.input, speed=body.speed, sample_rate=body.sample_rate)
+        if body.response_format != "pcm":
+            audio_generator = (
+                convert_audio_format(
+                    audio_bytes, body.sample_rate or piper_tts.config.sample_rate, body.response_format
+                )
+                for audio_bytes in audio_generator
+            )
+
+        return StreamingResponse(audio_generator, media_type=f"audio/{body.response_format}")
src/faster_whisper_server/text_utils.py
--- src/faster_whisper_server/text_utils.py
+++ src/faster_whisper_server/text_utils.py
@@ -3,8 +3,6 @@
 import re
 from typing import TYPE_CHECKING
 
-from faster_whisper_server.dependencies import get_config
-
 if TYPE_CHECKING:
     from collections.abc import Iterable
 
@@ -40,6 +38,8 @@
         self.words.extend(words)
 
     def _ensure_no_word_overlap(self, words: list[TranscriptionWord]) -> None:
+        from faster_whisper_server.dependencies import get_config  # HACK: avoid circular import
+
         config = get_config()  # HACK
         if len(self.words) > 0 and len(words) > 0:
             if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
tests/conftest.py
--- tests/conftest.py
+++ tests/conftest.py
@@ -4,6 +4,7 @@
 
 from fastapi.testclient import TestClient
 from httpx import ASGITransport, AsyncClient
+from huggingface_hub import snapshot_download
 from openai import AsyncOpenAI
 import pytest
 import pytest_asyncio
@@ -44,3 +45,10 @@
     return AsyncOpenAI(
         base_url="https://api.openai.com/v1"
     )  # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
+
+
+# TODO: remove the download after running the tests
+@pytest.fixture(scope="session", autouse=True)
+def download_piper_voices() -> None:
+    # Only download `voices.json` and the default voice
+    snapshot_download("rhasspy/piper-voices", allow_patterns=["voices.json", "en/en_US/amy/**"])
 
tests/speech_test.py (added)
+++ tests/speech_test.py
@@ -0,0 +1,158 @@
+import io
+import platform
+
+from openai import APIConnectionError, AsyncOpenAI, UnprocessableEntityError
+import pytest
+import soundfile as sf
+
+from faster_whisper_server.routers.speech import (
+    DEFAULT_MODEL,
+    DEFAULT_RESPONSE_FORMAT,
+    DEFAULT_VOICE,
+    SUPPORTED_RESPONSE_FORMATS,
+    ResponseFormat,
+)
+
+DEFAULT_INPUT = "Hello, world!"
+
+platform_machine = platform.machine()
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
+@pytest.mark.parametrize("response_format", SUPPORTED_RESPONSE_FORMATS)
+async def test_create_speech_formats(openai_client: AsyncOpenAI, response_format: ResponseFormat) -> None:
+    await openai_client.audio.speech.create(
+        model=DEFAULT_MODEL,
+        voice=DEFAULT_VOICE,  # type: ignore  # noqa: PGH003
+        input=DEFAULT_INPUT,
+        response_format=response_format,
+    )
+
+
+GOOD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [
+    ("tts-1", "alloy"),  # OpenAI and OpenAI
+    ("tts-1-hd", "echo"),  # OpenAI and OpenAI
+    ("tts-1", DEFAULT_VOICE),  # OpenAI and Piper
+    (DEFAULT_MODEL, "echo"),  # Piper and OpenAI
+    (DEFAULT_MODEL, DEFAULT_VOICE),  # Piper and Piper
+]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
+@pytest.mark.parametrize(("model", "voice"), GOOD_MODEL_VOICE_PAIRS)
+async def test_create_speech_good_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None:
+    await openai_client.audio.speech.create(
+        model=model,
+        voice=voice,  # type: ignore  # noqa: PGH003
+        input=DEFAULT_INPUT,
+        response_format=DEFAULT_RESPONSE_FORMAT,
+    )
+
+
+BAD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [
+    ("tts-1", "invalid"),  # OpenAI and invalid
+    ("invalid", "echo"),  # Invalid and OpenAI
+    (DEFAULT_MODEL, "invalid"),  # Piper and invalid
+    ("invalid", DEFAULT_VOICE),  # Invalid and Piper
+    ("invalid", "invalid"),  # Invalid and invalid
+]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
+@pytest.mark.parametrize(("model", "voice"), BAD_MODEL_VOICE_PAIRS)
+async def test_create_speech_bad_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None:
+    # NOTE: not sure why `APIConnectionError` is sometimes raised
+    with pytest.raises((UnprocessableEntityError, APIConnectionError)):
+        await openai_client.audio.speech.create(
+            model=model,
+            voice=voice,  # type: ignore  # noqa: PGH003
+            input=DEFAULT_INPUT,
+            response_format=DEFAULT_RESPONSE_FORMAT,
+        )
+
+
+SUPPORTED_SPEEDS = [0.25, 0.5, 1.0, 2.0, 4.0]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
+async def test_create_speech_with_varying_speed(openai_client: AsyncOpenAI) -> None:
+    previous_size: int | None = None
+    for speed in SUPPORTED_SPEEDS:
+        res = await openai_client.audio.speech.create(
+            model=DEFAULT_MODEL,
+            voice=DEFAULT_VOICE,  # type: ignore  # noqa: PGH003
+            input=DEFAULT_INPUT,
+            response_format="pcm",
+            speed=speed,
+        )
+        audio_bytes = res.read()
+        if previous_size is not None:
+            assert len(audio_bytes) * 1.5 < previous_size  # TODO: document magic number
+        previous_size = len(audio_bytes)
+
+
+UNSUPPORTED_SPEEDS = [0.1, 4.1]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
+@pytest.mark.parametrize("speed", UNSUPPORTED_SPEEDS)
+async def test_create_speech_with_unsupported_speed(openai_client: AsyncOpenAI, speed: float) -> None:
+    with pytest.raises(UnprocessableEntityError):
+        await openai_client.audio.speech.create(
+            model=DEFAULT_MODEL,
+            voice=DEFAULT_VOICE,  # type: ignore  # noqa: PGH003
+            input=DEFAULT_INPUT,
+            response_format="pcm",
+            speed=speed,
+        )
+
+
+VALID_SAMPLE_RATES = [16000, 22050, 24000, 48000]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
+@pytest.mark.parametrize("sample_rate", VALID_SAMPLE_RATES)
+async def test_speech_valid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None:
+    res = await openai_client.audio.speech.create(
+        model=DEFAULT_MODEL,
+        voice=DEFAULT_VOICE,  # type: ignore  # noqa: PGH003
+        input=DEFAULT_INPUT,
+        response_format="wav",
+        extra_body={"sample_rate": sample_rate},
+    )
+    _, actual_sample_rate = sf.read(io.BytesIO(res.content))
+    assert actual_sample_rate == sample_rate
+
+
+INVALID_SAMPLE_RATES = [7999, 48001]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
+@pytest.mark.parametrize("sample_rate", INVALID_SAMPLE_RATES)
+async def test_speech_invalid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None:
+    with pytest.raises(UnprocessableEntityError):
+        await openai_client.audio.speech.create(
+            model=DEFAULT_MODEL,
+            voice=DEFAULT_VOICE,  # type: ignore  # noqa: PGH003
+            input=DEFAULT_INPUT,
+            response_format="wav",
+            extra_body={"sample_rate": sample_rate},
+        )
+
+
+# TODO: implement the following test
+
+# NUMBER_OF_MODELS = 1
+# NUMBER_OF_VOICES = 124
+#
+#
+# @pytest.mark.asyncio
+# async def test_list_tts_models(openai_client: AsyncOpenAI) -> None:
+#     raise NotImplementedError
Add a comment
List