

feat: add kokoro tts support (#230)
@8f9d361746962707366583242327786db80b1270
--- docs/usage/text-to-speech.md
+++ docs/usage/text-to-speech.md
... | ... | @@ -1,12 +1,6 @@ |
1 | 1 |
!!! warning |
2 | 2 |
|
3 |
- This feature not supported on ARM devices only x86_64. I was unable to build [piper-phonemize](https://github.com/rhasspy/piper-phonemize)(my [fork](https://github.com/fedirz/piper-phonemize)) |
|
4 |
- |
|
5 |
-TODO: add a note about automatic downloads |
|
6 |
-TODO: add a demo |
|
7 |
-TODO: add a note about tts only running on cpu |
|
8 |
-TODO: add a note about exploring other models |
|
9 |
-TODO: add a note about performance |
|
3 |
+ `rhasspy/piper-voices` is only supported on x86_64. I was unable to build [piper-phonemize](https://github.com/rhasspy/piper-phonemize) for ARM. If you have experience building Python packages with third-party C++ dependencies, please consider contributing. See [#234](https://github.com/speaches-ai/speaches/issues/234) for more information. |
|
10 | 4 |
|
11 | 5 |
!!! note |
12 | 6 |
|
... | ... | @@ -14,10 +8,28 @@ |
14 | 8 |
|
15 | 9 |
## Prerequisite |
16 | 10 |
|
11 |
+!!! note |
|
12 |
+ |
|
13 |
+ `rhasspy/piper-voices` audio samples can be found [here](https://rhasspy.github.io/piper-samples/) |
|
14 |
+ |
|
15 |
+Download the Kokoro model and voices. |
|
16 |
+ |
|
17 |
+```bash |
|
18 |
+# Download the ONNX model (~346 MBs). You will find the path to the downloaded model in the output which you'll need for the next step. |
|
19 |
+docker exec -it speaches huggingface-cli download hexgrad/Kokoro-82M --include 'kokoro-v0_19.onnx' |
|
20 |
+# ... |
|
21 |
+# /home/ubuntu/.cache/huggingface/hub/models--hexgrad--Kokoro-82M/snapshots/c97b7bbc3e60f447383c79b2f94fee861ff156ac |
|
22 |
+ |
|
23 |
+# Download the voices.json (~54 MBs) file (we aren't using `docker exec` since the container doesn't have `curl` or `wget` installed) |
|
24 |
+curl --location -O https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/voices.json |
|
25 |
+# Replace the path with the one you got from the previous step |
|
26 |
+docker cp voices.json speaches:/home/ubuntu/.cache/huggingface/hub/models--hexgrad--Kokoro-82M/snapshots/c97b7bbc3e60f447383c79b2f94fee861ff156ac/voices.json |
|
27 |
+``` |
|
28 |
+ |
|
17 | 29 |
Download the piper voices from [HuggingFace model repository](https://huggingface.co/rhasspy/piper-voices) |
18 | 30 |
|
19 | 31 |
```bash |
20 |
-# Download all voices (~15 minutes / 7.7 Gbs) |
|
32 |
+# Download all voices (~15 minutes / 7.7 GBs) |
|
21 | 33 |
docker exec -it speaches huggingface-cli download rhasspy/piper-voices |
22 | 34 |
# Download all English voices (~4.5 minutes) |
23 | 35 |
docker exec -it speaches huggingface-cli download rhasspy/piper-voices --include 'en/**/*' 'voices.json' |
... | ... | @@ -27,14 +39,10 @@ |
27 | 39 |
docker exec -it speaches huggingface-cli download rhasspy/piper-voices --include 'en/en_US/amy/medium/*' 'voices.json' |
28 | 40 |
``` |
29 | 41 |
|
30 |
-!!! note |
|
31 |
- |
|
32 |
- You can find audio samples of all the available voices [here](https://rhasspy.github.io/piper-samples/) |
|
33 |
- |
|
34 | 42 |
## Curl |
35 | 43 |
|
36 | 44 |
```bash |
37 |
-# Generate speech from text using the default values (response_format="mp3", speed=1.0, voice="en_US-amy-medium", etc.) |
|
45 |
+# Generate speech from text using the default values (model="hexgrad/Kokoro-82M", voice="af", response_format="mp3", speed=1.0, etc.) |
|
38 | 46 |
curl http://localhost:8000/v1/audio/speech --header "Content-Type: application/json" --data '{"input": "Hello World!"}' --output audio.mp3 |
39 | 47 |
# Specifying the output format |
40 | 48 |
curl http://localhost:8000/v1/audio/speech --header "Content-Type: application/json" --data '{"input": "Hello World!", "response_format": "wav"}' --output audio.wav |
... | ... | @@ -42,13 +50,18 @@ |
42 | 50 |
curl http://localhost:8000/v1/audio/speech --header "Content-Type: application/json" --data '{"input": "Hello World!", "speed": 2.0}' --output audio.mp3 |
43 | 51 |
|
44 | 52 |
# List available (downloaded) voices |
45 |
-curl http://localhost:8000/v1/audio/speech/voices |
|
53 |
+curl --silent http://localhost:8000/v1/audio/speech/voices |
|
46 | 54 |
# List just the voice names |
47 |
-curl http://localhost:8000/v1/audio/speech/voices | jq --raw-output '.[] | .voice' |
|
48 |
-# List just the voices in your language |
|
49 |
-curl --silent http://localhost:8000/v1/audio/speech/voices | jq --raw-output '.[] | select(.voice | startswith("en")) | .voice' |
|
55 |
+curl --silent http://localhost:8000/v1/audio/speech/voices | jq --raw-output '.[] | .voice_id' |
|
56 |
+# List just the rhasspy/piper-voices voice names |
|
57 |
+curl --silent 'http://localhost:8000/v1/audio/speech/voices?model_id=rhasspy/piper-voices' | jq --raw-output '.[] | .voice_id' |
|
58 |
+# List just the hexgrad/Kokoro-82M voice names |
|
59 |
+curl --silent 'http://localhost:8000/v1/audio/speech/voices?model_id=hexgrad/Kokoro-82M' | jq --raw-output '.[] | .voice_id' |
|
50 | 60 |
|
51 |
-curl http://localhost:8000/v1/audio/speech --header "Content-Type: application/json" --data '{"input": "Hello World!", "voice": "en_US-ryan-high"}' --output audio.mp3 |
|
61 |
+# List just the voices in your language (piper) |
|
62 |
+curl --silent http://localhost:8000/v1/audio/speech/voices | jq --raw-output '.[] | select(.voice | startswith("en")) | .voice_id' |
|
63 |
+ |
|
64 |
+curl http://localhost:8000/v1/audio/speech --header "Content-Type: application/json" --data '{"input": "Hello World!", "voice": "af_sky"}' --output audio.mp3 |
|
52 | 65 |
``` |
53 | 66 |
|
54 | 67 |
## Python |
... | ... | @@ -64,8 +77,8 @@ |
64 | 77 |
res = client.post( |
65 | 78 |
"v1/audio/speech", |
66 | 79 |
json={ |
67 |
- "model": "piper", |
|
68 |
- "voice": "en_US-amy-medium", |
|
80 |
+ "model": "hexgrad/Kokoro-82M", |
|
81 |
+ "voice": "af", |
|
69 | 82 |
"input": "Hello, world!", |
70 | 83 |
"response_format": "mp3", |
71 | 84 |
"speed": 1, |
... | ... | @@ -92,8 +105,8 @@ |
92 | 105 |
|
93 | 106 |
openai = OpenAI(base_url="http://localhost:8000/v1", api_key="cant-be-empty") |
94 | 107 |
res = openai.audio.speech.create( |
95 |
- model="piper", |
|
96 |
- voice="en_US-amy-medium", # pyright: ignore[reportArgumentType] |
|
108 |
+ model="hexgrad/Kokoro-82M", |
|
109 |
+ voice="af", # pyright: ignore[reportArgumentType] |
|
97 | 110 |
input="Hello, world!", |
98 | 111 |
response_format="mp3", |
99 | 112 |
speed=1, |
--- src/speaches/api_models.py
+++ src/speaches/api_types.py
... | ... | @@ -1,8 +1,10 @@ |
1 | 1 |
from __future__ import annotations |
2 | 2 |
|
3 |
+from functools import cached_property |
|
4 |
+from pathlib import Path # noqa: TC003 |
|
3 | 5 |
from typing import TYPE_CHECKING, Literal |
4 | 6 |
|
5 |
-from pydantic import BaseModel, ConfigDict, Field |
|
7 |
+from pydantic import BaseModel, ConfigDict, Field, computed_field |
|
6 | 8 |
|
7 | 9 |
from speaches.text_utils import Transcription, canonicalize_word, segments_to_text |
8 | 10 |
|
... | ... | @@ -25,9 +27,9 @@ |
25 | 27 |
for segment in segments: |
26 | 28 |
# NOTE: a temporary "fix" for https://github.com/speaches-ai/speaches/issues/58. |
27 | 29 |
# TODO: properly address the issue |
28 |
- assert ( |
|
29 |
- segment.words is not None |
|
30 |
- ), "Segment must have words. If you are using an API ensure `timestamp_granularities[]=word` is set" |
|
30 |
+ assert segment.words is not None, ( |
|
31 |
+ "Segment must have words. If you are using an API ensure `timestamp_granularities[]=word` is set" |
|
32 |
+ ) |
|
31 | 33 |
words.extend(segment.words) |
32 | 34 |
return words |
33 | 35 |
|
... | ... | @@ -206,3 +208,30 @@ |
206 | 208 |
["word", "segment"], |
207 | 209 |
["segment", "word"], # same as ["word", "segment"] but order is different |
208 | 210 |
] |
211 |
+ |
|
212 |
+ |
|
213 |
+class Voice(BaseModel): |
|
214 |
+ """Similar structure to the GET /v1/models response but with extra fields.""" |
|
215 |
+ |
|
216 |
+ model_id: str |
|
217 |
+ voice_id: str |
|
218 |
+ created: int |
|
219 |
+ owned_by: str = Field( |
|
220 |
+ examples=[ |
|
221 |
+ "hexgrad", |
|
222 |
+ "rhaaspy", |
|
223 |
+ ] |
|
224 |
+ ) |
|
225 |
+ sample_rate: int |
|
226 |
+ model_path: Path = Field( |
|
227 |
+ examples=[ |
|
228 |
+ "/home/nixos/.cache/huggingface/hub/models--rhasspy--piper-voices/snapshots/3d796cc2f2c884b3517c527507e084f7bb245aea/en/en_US/amy/medium/en_US-amy-medium.onnx" |
|
229 |
+ ] |
|
230 |
+ ) |
|
231 |
+ object: Literal["voice"] = "voice" |
|
232 |
+ |
|
233 |
+ @computed_field(examples=["rhasspy/piper-voices/en_US-amy-medium"]) |
|
234 |
+ @cached_property |
|
235 |
+ def id(self) -> str: |
|
236 |
+ """Unique identifier for the model + voice.""" |
|
237 |
+ return self.model_id + "/" + self.voice_id |
--- src/speaches/asr.py
+++ src/speaches/asr.py
... | ... | @@ -5,7 +5,7 @@ |
5 | 5 |
import time |
6 | 6 |
from typing import TYPE_CHECKING |
7 | 7 |
|
8 |
-from speaches.api_models import TranscriptionSegment, TranscriptionWord |
|
8 |
+from speaches.api_types import TranscriptionSegment, TranscriptionWord |
|
9 | 9 |
from speaches.text_utils import Transcription |
10 | 10 |
|
11 | 11 |
if TYPE_CHECKING: |
--- src/speaches/audio.py
+++ src/speaches/audio.py
... | ... | @@ -1,6 +1,7 @@ |
1 | 1 |
from __future__ import annotations |
2 | 2 |
|
3 | 3 |
import asyncio |
4 |
+import io |
|
4 | 5 |
import logging |
5 | 6 |
from typing import TYPE_CHECKING, BinaryIO |
6 | 7 |
|
... | ... | @@ -14,10 +15,46 @@ |
14 | 15 |
|
15 | 16 |
from numpy.typing import NDArray |
16 | 17 |
|
18 |
+ from speaches.routers.speech import ResponseFormat |
|
19 |
+ |
|
17 | 20 |
|
18 | 21 |
logger = logging.getLogger(__name__) |
19 | 22 |
|
20 | 23 |
|
24 |
+# aip 'Write a function `resample_audio` which would take in RAW PCM 16-bit signed, little-endian audio data represented as bytes (`audio_bytes`) and resample it (either downsample or upsample) from `sample_rate` to `target_sample_rate` using numpy' # noqa: E501 |
|
25 |
+def resample_audio(audio_bytes: bytes, sample_rate: int, target_sample_rate: int) -> bytes: |
|
26 |
+ audio_data = np.frombuffer(audio_bytes, dtype=np.int16) |
|
27 |
+ duration = len(audio_data) / sample_rate |
|
28 |
+ target_length = int(duration * target_sample_rate) |
|
29 |
+ resampled_data = np.interp( |
|
30 |
+ np.linspace(0, len(audio_data), target_length, endpoint=False), np.arange(len(audio_data)), audio_data |
|
31 |
+ ) |
|
32 |
+ return resampled_data.astype(np.int16).tobytes() |
|
33 |
+ |
|
34 |
+ |
|
35 |
+def convert_audio_format( |
|
36 |
+ audio_bytes: bytes, |
|
37 |
+ sample_rate: int, |
|
38 |
+ audio_format: ResponseFormat, |
|
39 |
+ format: str = "RAW", # noqa: A002 |
|
40 |
+ channels: int = 1, |
|
41 |
+ subtype: str = "PCM_16", |
|
42 |
+ endian: str = "LITTLE", |
|
43 |
+) -> bytes: |
|
44 |
+ # NOTE: the default dtype is float64. Should something else be used? Would that improve performance? |
|
45 |
+ data, _ = sf.read( |
|
46 |
+ io.BytesIO(audio_bytes), |
|
47 |
+ samplerate=sample_rate, |
|
48 |
+ format=format, |
|
49 |
+ channels=channels, |
|
50 |
+ subtype=subtype, |
|
51 |
+ endian=endian, |
|
52 |
+ ) |
|
53 |
+ converted_audio_bytes_buffer = io.BytesIO() |
|
54 |
+ sf.write(converted_audio_bytes_buffer, data, samplerate=sample_rate, format=audio_format) |
|
55 |
+ return converted_audio_bytes_buffer.getvalue() |
|
56 |
+ |
|
57 |
+ |
|
21 | 58 |
def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]: |
22 | 59 |
audio_and_sample_rate = sf.read( |
23 | 60 |
file, |
--- src/speaches/dependencies.py
+++ src/speaches/dependencies.py
... | ... | @@ -10,7 +10,7 @@ |
10 | 10 |
from openai.resources.chat.completions import AsyncCompletions |
11 | 11 |
|
12 | 12 |
from speaches.config import Config |
13 |
-from speaches.model_manager import PiperModelManager, WhisperModelManager |
|
13 |
+from speaches.model_manager import KokoroModelManager, PiperModelManager, WhisperModelManager |
|
14 | 14 |
|
15 | 15 |
logger = logging.getLogger(__name__) |
16 | 16 |
|
... | ... | @@ -45,6 +45,14 @@ |
45 | 45 |
PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)] |
46 | 46 |
|
47 | 47 |
|
48 |
+@lru_cache |
|
49 |
+def get_kokoro_model_manager() -> KokoroModelManager: |
|
50 |
+ config = get_config() |
|
51 |
+ return KokoroModelManager(config.whisper.ttl) # HACK: should have its own config |
|
52 |
+ |
|
53 |
+ |
|
54 |
+KokoroModelManagerDependency = Annotated[KokoroModelManager, Depends(get_kokoro_model_manager)] |
|
55 |
+ |
|
48 | 56 |
security = HTTPBearer() |
49 | 57 |
|
50 | 58 |
|
--- src/speaches/gradio_app.py
+++ src/speaches/gradio_app.py
... | ... | @@ -7,13 +7,20 @@ |
7 | 7 |
from httpx_sse import aconnect_sse |
8 | 8 |
from openai import AsyncOpenAI |
9 | 9 |
|
10 |
+from speaches import kokoro_utils |
|
11 |
+from speaches.api_types import Voice |
|
10 | 12 |
from speaches.config import Config, Task |
11 |
-from speaches.hf_utils import PiperModel |
|
13 |
+from speaches.routers.speech import ( |
|
14 |
+ MAX_SAMPLE_RATE, |
|
15 |
+ MIN_SAMPLE_RATE, |
|
16 |
+ SUPPORTED_RESPONSE_FORMATS, |
|
17 |
+) |
|
12 | 18 |
|
13 | 19 |
TRANSCRIPTION_ENDPOINT = "/v1/audio/transcriptions" |
14 | 20 |
TRANSLATION_ENDPOINT = "/v1/audio/translations" |
15 | 21 |
TIMEOUT_SECONDS = 180 |
16 | 22 |
TIMEOUT = httpx.Timeout(timeout=TIMEOUT_SECONDS) |
23 |
+DEFAULT_TEXT = "A rainbow is an optical phenomenon caused by refraction, internal reflection and dispersion of light in water droplets resulting in a continuous spectrum of light appearing in the sky." # noqa: E501 |
|
17 | 24 |
|
18 | 25 |
# NOTE: `gr.Request` seems to be passed in as the last positional (not keyword) argument |
19 | 26 |
|
... | ... | @@ -104,23 +111,34 @@ |
104 | 111 |
value=config.whisper.model, |
105 | 112 |
) |
106 | 113 |
|
107 |
- async def update_piper_voices_dropdown(request: gr.Request) -> gr.Dropdown: |
|
114 |
+ async def update_voices_and_language_dropdown(model_id: str | None, request: gr.Request) -> dict: |
|
115 |
+ params = httpx.QueryParams({"model_id": model_id}) |
|
108 | 116 |
http_client = http_client_from_gradio_req(request, config) |
109 |
- res = (await http_client.get("/v1/audio/speech/voices")).raise_for_status() |
|
110 |
- piper_models = [PiperModel.model_validate(x) for x in res.json()] |
|
111 |
- return gr.Dropdown(choices=[model.voice for model in piper_models], label="Voice", value=DEFAULT_VOICE) |
|
117 |
+ res = (await http_client.get("/v1/audio/speech/voices", params=params)).raise_for_status() |
|
118 |
+ voice_ids = [Voice.model_validate(x).voice_id for x in res.json()] |
|
119 |
+ return { |
|
120 |
+ voice_dropdown: gr.update(choices=voice_ids, value=voice_ids[0]), |
|
121 |
+ language_dropdown: gr.update(visible=model_id == "hexgrad/Kokoro-82M"), |
|
122 |
+ } |
|
112 | 123 |
|
113 | 124 |
async def handle_audio_speech( |
114 |
- text: str, voice: str, response_format: str, speed: float, sample_rate: int | None, request: gr.Request |
|
125 |
+ text: str, |
|
126 |
+ model: str, |
|
127 |
+ voice: str, |
|
128 |
+ language: str | None, |
|
129 |
+ response_format: str, |
|
130 |
+ speed: float, |
|
131 |
+ sample_rate: int | None, |
|
132 |
+ request: gr.Request, |
|
115 | 133 |
) -> Path: |
116 | 134 |
openai_client = openai_client_from_gradio_req(request, config) |
117 | 135 |
res = await openai_client.audio.speech.create( |
118 | 136 |
input=text, |
119 |
- model="piper", |
|
137 |
+ model=model, |
|
120 | 138 |
voice=voice, # pyright: ignore[reportArgumentType] |
121 | 139 |
response_format=response_format, # pyright: ignore[reportArgumentType] |
122 | 140 |
speed=speed, |
123 |
- extra_body={"sample_rate": sample_rate}, |
|
141 |
+ extra_body={"language": language, "sample_rate": sample_rate}, |
|
124 | 142 |
) |
125 | 143 |
audio_bytes = res.response.read() |
126 | 144 |
file_path = Path(f"audio.{response_format}") |
... | ... | @@ -129,12 +147,18 @@ |
129 | 147 |
return file_path |
130 | 148 |
|
131 | 149 |
with gr.Blocks(title="Speaches Playground") as demo: |
150 |
+ gr.Markdown("# Speaches Playground") |
|
132 | 151 |
gr.Markdown( |
133 |
- "### Consider supporting the project by starring the [repository on GitHub](https://github.com/speaches-ai/speaches)." |
|
152 |
+ "### Consider supporting the project by starring the [speaches-ai/speaches repository on GitHub](https://github.com/speaches-ai/speaches)." |
|
134 | 153 |
) |
135 |
- with gr.Tab(label="Transcribe/Translate"): |
|
154 |
+ gr.Markdown("### Documentation Website: https://speaches-ai.github.io/speaches") |
|
155 |
+ gr.Markdown( |
|
156 |
+ "### For additional details regarding the parameters, see the [API Documentation](https://speaches-ai.github.io/speaches/api)" |
|
157 |
+ ) |
|
158 |
+ |
|
159 |
+ with gr.Tab(label="Speech-to-Text"): |
|
136 | 160 |
audio = gr.Audio(type="filepath") |
137 |
- model_dropdown = gr.Dropdown( |
|
161 |
+ whisper_model_dropdown = gr.Dropdown( |
|
138 | 162 |
choices=[config.whisper.model], |
139 | 163 |
label="Model", |
140 | 164 |
value=config.whisper.model, |
... | ... | @@ -152,59 +176,76 @@ |
152 | 176 |
|
153 | 177 |
# NOTE: the inputs order must match the `whisper_handler` signature |
154 | 178 |
button.click( |
155 |
- whisper_handler, [audio, model_dropdown, task_dropdown, temperature_slider, stream_checkbox], output |
|
179 |
+ whisper_handler, |
|
180 |
+ [audio, whisper_model_dropdown, task_dropdown, temperature_slider, stream_checkbox], |
|
181 |
+ output, |
|
156 | 182 |
) |
157 | 183 |
|
158 |
- with gr.Tab(label="Speech Generation"): |
|
159 |
- if platform.machine() == "x86_64": |
|
160 |
- from speaches.routers.speech import ( |
|
161 |
- DEFAULT_VOICE, |
|
162 |
- MAX_SAMPLE_RATE, |
|
163 |
- MIN_SAMPLE_RATE, |
|
164 |
- SUPPORTED_RESPONSE_FORMATS, |
|
165 |
- ) |
|
184 |
+ with gr.Tab(label="Text-to-Speech"): |
|
185 |
+ model_dropdown_choices = ["hexgrad/Kokoro-82M", "rhasspy/piper-voices"] |
|
186 |
+ if platform.machine() != "x86_64": |
|
187 |
+ model_dropdown_choices.remove("rhasspy/piper-voices") |
|
188 |
+ gr.Textbox("Speech generation using `rhasspy/piper-voices` model is only supported on x86_64 machines.") |
|
166 | 189 |
|
167 |
- text = gr.Textbox(label="Input Text") |
|
168 |
- voice_dropdown = gr.Dropdown( |
|
169 |
- choices=["en_US-amy-medium"], |
|
170 |
- label="Voice", |
|
171 |
- value="en_US-amy-medium", |
|
172 |
- info=""" |
|
173 |
-The last part of the voice name is the quality (x_low, low, medium, high). |
|
174 |
-Each quality has a different default sample rate: |
|
175 |
-- x_low: 16000 Hz |
|
176 |
-- low: 16000 Hz |
|
177 |
-- medium: 22050 Hz |
|
178 |
-- high: 22050 Hz |
|
179 |
-""", |
|
180 |
- ) |
|
181 |
- response_fromat_dropdown = gr.Dropdown( |
|
182 |
- choices=SUPPORTED_RESPONSE_FORMATS, |
|
183 |
- label="Response Format", |
|
184 |
- value="wav", |
|
185 |
- ) |
|
186 |
- speed_slider = gr.Slider(minimum=0.25, maximum=4.0, step=0.05, label="Speed", value=1.0) |
|
187 |
- sample_rate_slider = gr.Number( |
|
188 |
- minimum=MIN_SAMPLE_RATE, |
|
189 |
- maximum=MAX_SAMPLE_RATE, |
|
190 |
- label="Desired Sample Rate", |
|
191 |
- info=""" |
|
190 |
+ text = gr.Textbox( |
|
191 |
+ label="Input Text", |
|
192 |
+ value=DEFAULT_TEXT, |
|
193 |
+ ) |
|
194 |
+ stt_model_dropdown = gr.Dropdown( |
|
195 |
+ choices=model_dropdown_choices, |
|
196 |
+ label="Model", |
|
197 |
+ value="hexgrad/Kokoro-82M", |
|
198 |
+ ) |
|
199 |
+ voice_dropdown = gr.Dropdown( |
|
200 |
+ choices=["af"], |
|
201 |
+ label="Voice", |
|
202 |
+ value="af", |
|
203 |
+ ) |
|
204 |
+ language_dropdown = gr.Dropdown( |
|
205 |
+ choices=kokoro_utils.LANGUAGES, label="Language", value="en-us", visible=True |
|
206 |
+ ) |
|
207 |
+ stt_model_dropdown.change( |
|
208 |
+ update_voices_and_language_dropdown, |
|
209 |
+ inputs=[stt_model_dropdown], |
|
210 |
+ outputs=[voice_dropdown, language_dropdown], |
|
211 |
+ ) |
|
212 |
+ response_fromat_dropdown = gr.Dropdown( |
|
213 |
+ choices=SUPPORTED_RESPONSE_FORMATS, |
|
214 |
+ label="Response Format", |
|
215 |
+ value="wav", |
|
216 |
+ ) |
|
217 |
+ speed_slider = gr.Slider(minimum=0.25, maximum=4.0, step=0.05, label="Speed", value=1.0) |
|
218 |
+ sample_rate_slider = gr.Number( |
|
219 |
+ minimum=MIN_SAMPLE_RATE, |
|
220 |
+ maximum=MAX_SAMPLE_RATE, |
|
221 |
+ label="Desired Sample Rate", |
|
222 |
+ info=""" |
|
192 | 223 |
Setting this will resample the generated audio to the desired sample rate. |
193 |
-You may want to set this if you are going to use voices of different qualities but want to keep the same sample rate. |
|
224 |
+You may want to set this if you are going to use 'rhasspy/piper-voices' with voices of different qualities but want to keep the same sample rate. |
|
194 | 225 |
Default: None (No resampling) |
195 |
-""", |
|
196 |
- value=lambda: None, |
|
197 |
- ) |
|
198 |
- button = gr.Button("Generate Speech") |
|
199 |
- output = gr.Audio(type="filepath") |
|
200 |
- button.click( |
|
201 |
- handle_audio_speech, |
|
202 |
- [text, voice_dropdown, response_fromat_dropdown, speed_slider, sample_rate_slider], |
|
203 |
- output, |
|
204 |
- ) |
|
205 |
- demo.load(update_piper_voices_dropdown, inputs=None, outputs=voice_dropdown) |
|
206 |
- else: |
|
207 |
- gr.Textbox("Speech generation is only supported on x86_64 machines.") |
|
226 |
+""", # noqa: E501 |
|
227 |
+ value=lambda: None, |
|
228 |
+ ) |
|
229 |
+ button = gr.Button("Generate Speech") |
|
230 |
+ output = gr.Audio(type="filepath") |
|
231 |
+ button.click( |
|
232 |
+ handle_audio_speech, |
|
233 |
+ [ |
|
234 |
+ text, |
|
235 |
+ stt_model_dropdown, |
|
236 |
+ voice_dropdown, |
|
237 |
+ language_dropdown, |
|
238 |
+ response_fromat_dropdown, |
|
239 |
+ speed_slider, |
|
240 |
+ sample_rate_slider, |
|
241 |
+ ], |
|
242 |
+ output, |
|
243 |
+ ) |
|
208 | 244 |
|
209 |
- demo.load(update_whisper_model_dropdown, inputs=None, outputs=model_dropdown) |
|
245 |
+ demo.load(update_whisper_model_dropdown, inputs=None, outputs=whisper_model_dropdown) |
|
246 |
+ demo.load( |
|
247 |
+ update_voices_and_language_dropdown, |
|
248 |
+ inputs=[stt_model_dropdown], |
|
249 |
+ outputs=[voice_dropdown, language_dropdown], |
|
250 |
+ ) |
|
210 | 251 |
return demo |
--- src/speaches/hf_utils.py
+++ src/speaches/hf_utils.py
... | ... | @@ -1,16 +1,17 @@ |
1 | 1 |
from collections.abc import Generator |
2 |
-from functools import cached_property, lru_cache |
|
2 |
+from functools import lru_cache |
|
3 | 3 |
import json |
4 | 4 |
import logging |
5 | 5 |
from pathlib import Path |
6 | 6 |
import typing |
7 | 7 |
from typing import Any, Literal |
8 | 8 |
|
9 |
+import httpx |
|
9 | 10 |
import huggingface_hub |
10 | 11 |
from huggingface_hub.constants import HF_HUB_CACHE |
11 |
-from pydantic import BaseModel, Field, computed_field |
|
12 |
+from pydantic import BaseModel |
|
12 | 13 |
|
13 |
-from speaches.api_models import Model |
|
14 |
+from speaches.api_types import Model, Voice |
|
14 | 15 |
|
15 | 16 |
logger = logging.getLogger(__name__) |
16 | 17 |
|
... | ... | @@ -18,8 +19,13 @@ |
18 | 19 |
TASK_NAME = "automatic-speech-recognition" |
19 | 20 |
|
20 | 21 |
|
22 |
+def list_local_model_ids() -> list[str]: |
|
23 |
+ model_dirs = list(Path(HF_HUB_CACHE).glob("models--*")) |
|
24 |
+ return [model_id_from_path(model_dir) for model_dir in model_dirs] |
|
25 |
+ |
|
26 |
+ |
|
21 | 27 |
def does_local_model_exist(model_id: str) -> bool: |
22 |
- return any(model_id == model.repo_id for model, _ in list_local_whisper_models()) |
|
28 |
+ return model_id in list_local_model_ids() |
|
23 | 29 |
|
24 | 30 |
|
25 | 31 |
def list_whisper_models() -> Generator[Model, None, None]: |
... | ... | @@ -46,9 +52,9 @@ |
46 | 52 |
yield transformed_model |
47 | 53 |
|
48 | 54 |
|
49 |
-def list_local_whisper_models() -> ( |
|
50 |
- Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None] |
|
51 |
-): |
|
55 |
+def list_local_whisper_models() -> Generator[ |
|
56 |
+ tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None |
|
57 |
+]: |
|
52 | 58 |
hf_cache = huggingface_hub.scan_cache_dir() |
53 | 59 |
hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"] |
54 | 60 |
for model in hf_models: |
... | ... | @@ -69,6 +75,14 @@ |
69 | 75 |
and TASK_NAME in model_card_data.tags |
70 | 76 |
): |
71 | 77 |
yield model, model_card_data |
78 |
+ |
|
79 |
+ |
|
80 |
+def model_id_from_path(repo_path: Path) -> str: |
|
81 |
+ repo_type, repo_id = repo_path.name.split("--", maxsplit=1) |
|
82 |
+ repo_type = repo_type[:-1] # "models" -> "model" |
|
83 |
+ assert repo_type == "model" |
|
84 |
+ repo_id = repo_id.replace("--", "/") # google--fleurs -> "google/fleurs" |
|
85 |
+ return repo_id |
|
72 | 86 |
|
73 | 87 |
|
74 | 88 |
def get_whisper_models() -> Generator[Model, None, None]: |
... | ... | @@ -102,44 +116,6 @@ |
102 | 116 |
"medium": 22050, |
103 | 117 |
"high": 22050, |
104 | 118 |
} |
105 |
- |
|
106 |
- |
|
107 |
-class PiperModel(BaseModel): |
|
108 |
- """Similar structure to the GET /v1/models response but with extra fields.""" |
|
109 |
- |
|
110 |
- object: Literal["model"] = "model" |
|
111 |
- created: int |
|
112 |
- owned_by: Literal["rhasspy"] = "rhasspy" |
|
113 |
- model_path: Path = Field( |
|
114 |
- examples=[ |
|
115 |
- "/home/nixos/.cache/huggingface/hub/models--rhasspy--piper-voices/snapshots/3d796cc2f2c884b3517c527507e084f7bb245aea/en/en_US/amy/medium/en_US-amy-medium.onnx" |
|
116 |
- ] |
|
117 |
- ) |
|
118 |
- |
|
119 |
- @computed_field(examples=["rhasspy/piper-voices/en_US-amy-medium"]) |
|
120 |
- @cached_property |
|
121 |
- def id(self) -> str: |
|
122 |
- return f"rhasspy/piper-voices/{self.model_path.name.removesuffix(".onnx")}" |
|
123 |
- |
|
124 |
- @computed_field(examples=["rhasspy/piper-voices/en_US-amy-medium"]) |
|
125 |
- @cached_property |
|
126 |
- def voice(self) -> str: |
|
127 |
- return self.model_path.name.removesuffix(".onnx") |
|
128 |
- |
|
129 |
- @computed_field |
|
130 |
- @cached_property |
|
131 |
- def config_path(self) -> Path: |
|
132 |
- return Path(str(self.model_path) + ".json") |
|
133 |
- |
|
134 |
- @computed_field |
|
135 |
- @cached_property |
|
136 |
- def quality(self) -> PiperVoiceQuality: |
|
137 |
- return self.id.split("-")[-1] # pyright: ignore[reportReturnType] |
|
138 |
- |
|
139 |
- @computed_field |
|
140 |
- @cached_property |
|
141 |
- def sample_rate(self) -> int: |
|
142 |
- return PIPER_VOICE_QUALITY_SAMPLE_RATE_MAP[self.quality] |
|
143 | 119 |
|
144 | 120 |
|
145 | 121 |
def get_model_path(model_id: str, *, cache_dir: str | Path | None = None) -> Path | None: |
... | ... | @@ -186,12 +162,19 @@ |
186 | 162 |
yield from list(snapshots_path.glob(glob_pattern)) |
187 | 163 |
|
188 | 164 |
|
189 |
-def list_piper_models() -> Generator[PiperModel, None, None]: |
|
190 |
- model_weights_files = list_model_files("rhasspy/piper-voices", glob_pattern="**/*.onnx") |
|
165 |
+def list_piper_models() -> Generator[Voice, None, None]: |
|
166 |
+ model_id = "rhasspy/piper-voices" |
|
167 |
+ model_weights_files = list_model_files(model_id, glob_pattern="**/*.onnx") |
|
191 | 168 |
for model_weights_file in model_weights_files: |
192 |
- yield PiperModel( |
|
169 |
+ yield Voice( |
|
193 | 170 |
created=int(model_weights_file.stat().st_mtime), |
194 | 171 |
model_path=model_weights_file, |
172 |
+ voice_id=model_weights_file.name.removesuffix(".onnx"), |
|
173 |
+ model_id=model_id, |
|
174 |
+ owned_by=model_id.split("/")[0], |
|
175 |
+ sample_rate=PIPER_VOICE_QUALITY_SAMPLE_RATE_MAP[ |
|
176 |
+ model_weights_file.name.removesuffix(".onnx").split("-")[-1] |
|
177 |
+ ], # pyright: ignore[reportArgumentType] |
|
195 | 178 |
) |
196 | 179 |
|
197 | 180 |
|
... | ... | @@ -230,3 +213,33 @@ |
230 | 213 |
if model_config_file is None: |
231 | 214 |
raise FileNotFoundError(f"Could not find config file for '{voice}' voice") |
232 | 215 |
return PiperVoiceConfig.model_validate_json(model_config_file.read_text()) |
216 |
+ |
|
217 |
+ |
|
218 |
+def get_kokoro_model_path() -> Path: |
|
219 |
+ file_name = "kokoro-v0_19.onnx" |
|
220 |
+ onnx_files = list(list_model_files("hexgrad/Kokoro-82M", glob_pattern=f"**/{file_name}")) |
|
221 |
+ if len(onnx_files) == 0: |
|
222 |
+ raise ValueError(f"Could not find {file_name} file for 'hexgrad/Kokoro-82M' model") |
|
223 |
+ return onnx_files[0] |
|
224 |
+ |
|
225 |
+ |
|
226 |
+def download_kokoro_model() -> None: |
|
227 |
+ model_id = "hexgrad/Kokoro-82M" |
|
228 |
+ model_repo_path = Path( |
|
229 |
+ huggingface_hub.snapshot_download(model_id, repo_type="model", allow_patterns="**/kokoro-v0_19.onnx") |
|
230 |
+ ) |
|
231 |
+ # HACK |
|
232 |
+ res = httpx.get( |
|
233 |
+ "https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files/voices.json", follow_redirects=True |
|
234 |
+ ).raise_for_status() |
|
235 |
+ voices_path = model_repo_path / "voices.json" |
|
236 |
+ voices_path.write_bytes(res.content) |
|
237 |
+ |
|
238 |
+ |
|
239 |
+# alternative implementation that uses `huggingface_hub.scan_cache_dir`. Slightly cleaner but much slower |
|
240 |
+# def list_local_model_ids() -> list[str]: |
|
241 |
+# start = time.perf_counter() |
|
242 |
+# hf_cache = huggingface_hub.scan_cache_dir() |
|
243 |
+# logger.debug(f"Scanned HuggingFace cache in {time.perf_counter() - start:.2f} seconds") |
|
244 |
+# hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"] |
|
245 |
+# return [model.repo_id for model in hf_models] |
+++ src/speaches/kokoro_utils.py
... | ... | @@ -0,0 +1,51 @@ |
1 | +from collections.abc import AsyncGenerator | |
2 | +import logging | |
3 | +import time | |
4 | +from typing import Literal | |
5 | + | |
6 | +from kokoro_onnx import Kokoro | |
7 | +import numpy as np | |
8 | + | |
9 | +from speaches.audio import resample_audio | |
10 | + | |
11 | +logger = logging.getLogger(__name__) | |
12 | + | |
13 | +SAMPLE_RATE = 24000 # the default sample rate for Kokoro | |
14 | +Language = Literal["en-us", "en-gb", "fr-fr", "ja", "ko", "cmn"] | |
15 | +LANGUAGES: list[Language] = ["en-us", "en-gb", "fr-fr", "ja", "ko", "cmn"] | |
16 | + | |
17 | +VOICE_IDS = [ | |
18 | + "af", # Default voice is a 50-50 mix of Bella & Sarah | |
19 | + "af_bella", | |
20 | + "af_sarah", | |
21 | + "am_adam", | |
22 | + "am_michael", | |
23 | + "bf_emma", | |
24 | + "bf_isabella", | |
25 | + "bm_george", | |
26 | + "bm_lewis", | |
27 | + "af_nicole", | |
28 | + "af_sky", | |
29 | +] | |
30 | + | |
31 | + | |
32 | +async def generate_audio( | |
33 | + kokoro_tts: Kokoro, | |
34 | + text: str, | |
35 | + voice: str, | |
36 | + *, | |
37 | + language: Language = "en-us", | |
38 | + speed: float = 1.0, | |
39 | + sample_rate: int | None = None, | |
40 | +) -> AsyncGenerator[bytes, None]: | |
41 | + if sample_rate is None: | |
42 | + sample_rate = SAMPLE_RATE | |
43 | + start = time.perf_counter() | |
44 | + async for audio_data, _ in kokoro_tts.create_stream(text, voice, lang=language, speed=speed): | |
45 | + assert isinstance(audio_data, np.ndarray) and audio_data.dtype == np.float32 and isinstance(sample_rate, int) | |
46 | + normalized_audio_data = (audio_data * np.iinfo(np.int16).max).astype(np.int16) | |
47 | + audio_bytes = normalized_audio_data.tobytes() | |
48 | + if sample_rate != SAMPLE_RATE: | |
49 | + audio_bytes = resample_audio(audio_bytes, SAMPLE_RATE, sample_rate) | |
50 | + yield audio_bytes | |
51 | + logger.info(f"Generated audio for {len(text)} characters in {time.perf_counter() - start}s") |
--- src/speaches/main.py
+++ src/speaches/main.py
... | ... | @@ -18,6 +18,9 @@ |
18 | 18 |
from speaches.routers.models import ( |
19 | 19 |
router as models_router, |
20 | 20 |
) |
21 |
+from speaches.routers.speech import ( |
|
22 |
+ router as speech_router, |
|
23 |
+) |
|
21 | 24 |
from speaches.routers.stt import ( |
22 | 25 |
router as stt_router, |
23 | 26 |
) |
... | ... | @@ -47,12 +50,7 @@ |
47 | 50 |
logger.debug(f"Config: {config}") |
48 | 51 |
|
49 | 52 |
if platform.machine() == "x86_64": |
50 |
- from speaches.routers.speech import ( |
|
51 |
- router as speech_router, |
|
52 |
- ) |
|
53 |
- else: |
|
54 |
- logger.warning("`/v1/audio/speech` is only supported on x86_64 machines") |
|
55 |
- speech_router = None |
|
53 |
+ logger.warning("`POST /v1/audio/speech` with `model=rhasspy/piper-voices` is only supported on x86_64 machines") |
|
56 | 54 |
|
57 | 55 |
model_manager = get_model_manager() # HACK |
58 | 56 |
|
... | ... | @@ -71,8 +69,7 @@ |
71 | 69 |
app.include_router(stt_router) |
72 | 70 |
app.include_router(models_router) |
73 | 71 |
app.include_router(misc_router) |
74 |
- if speech_router is not None: |
|
75 |
- app.include_router(speech_router) |
|
72 |
+ app.include_router(speech_router) |
|
76 | 73 |
|
77 | 74 |
if config.allow_origins is not None: |
78 | 75 |
app.add_middleware( |
--- src/speaches/model_manager.py
+++ src/speaches/model_manager.py
... | ... | @@ -2,14 +2,18 @@ |
2 | 2 |
|
3 | 3 |
from collections import OrderedDict |
4 | 4 |
import gc |
5 |
+import json |
|
5 | 6 |
import logging |
7 |
+from pathlib import Path |
|
6 | 8 |
import threading |
7 | 9 |
import time |
8 | 10 |
from typing import TYPE_CHECKING |
9 | 11 |
|
10 | 12 |
from faster_whisper import WhisperModel |
13 |
+from kokoro_onnx import Kokoro |
|
14 |
+from onnxruntime import InferenceSession |
|
11 | 15 |
|
12 |
-from speaches.hf_utils import get_piper_voice_model_file |
|
16 |
+from speaches.hf_utils import get_kokoro_model_path, get_piper_voice_model_file |
|
13 | 17 |
|
14 | 18 |
if TYPE_CHECKING: |
15 | 19 |
from collections.abc import Callable |
... | ... | @@ -142,6 +146,9 @@ |
142 | 146 |
return self.loaded_models[model_name] |
143 | 147 |
|
144 | 148 |
|
149 |
+ONNX_PROVIDERS = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
|
150 |
+ |
|
151 |
+ |
|
145 | 152 |
class PiperModelManager: |
146 | 153 |
def __init__(self, ttl: int) -> None: |
147 | 154 |
self.ttl = ttl |
... | ... | @@ -149,10 +156,13 @@ |
149 | 156 |
self._lock = threading.Lock() |
150 | 157 |
|
151 | 158 |
def _load_fn(self, model_id: str) -> PiperVoice: |
152 |
- from piper.voice import PiperVoice |
|
159 |
+ from piper.voice import PiperConfig, PiperVoice |
|
153 | 160 |
|
154 | 161 |
model_path = get_piper_voice_model_file(model_id) |
155 |
- return PiperVoice.load(model_path) |
|
162 |
+ inf_sess = InferenceSession(model_path, providers=ONNX_PROVIDERS) |
|
163 |
+ config_path = Path(str(model_path) + ".json") |
|
164 |
+ conf = PiperConfig.from_dict(json.loads(config_path.read_text())) |
|
165 |
+ return PiperVoice(session=inf_sess, config=conf) |
|
156 | 166 |
|
157 | 167 |
def _handle_model_unload(self, model_name: str) -> None: |
158 | 168 |
with self._lock: |
... | ... | @@ -180,3 +190,42 @@ |
180 | 190 |
unload_fn=self._handle_model_unload, |
181 | 191 |
) |
182 | 192 |
return self.loaded_models[model_name] |
193 |
+ |
|
194 |
+ |
|
195 |
+class KokoroModelManager: |
|
196 |
+ def __init__(self, ttl: int) -> None: |
|
197 |
+ self.ttl = ttl |
|
198 |
+ self.loaded_models: OrderedDict[str, SelfDisposingModel[Kokoro]] = OrderedDict() |
|
199 |
+ self._lock = threading.Lock() |
|
200 |
+ |
|
201 |
+ # TODO |
|
202 |
+ def _load_fn(self, _model_id: str) -> Kokoro: |
|
203 |
+ model_path = get_kokoro_model_path() |
|
204 |
+ voices_path = model_path.parent / "voices.json" |
|
205 |
+ inf_sess = InferenceSession(model_path, providers=ONNX_PROVIDERS) |
|
206 |
+ return Kokoro.from_session(inf_sess, str(voices_path)) |
|
207 |
+ |
|
208 |
+ def _handle_model_unload(self, model_name: str) -> None: |
|
209 |
+ with self._lock: |
|
210 |
+ if model_name in self.loaded_models: |
|
211 |
+ del self.loaded_models[model_name] |
|
212 |
+ |
|
213 |
+ def unload_model(self, model_name: str) -> None: |
|
214 |
+ with self._lock: |
|
215 |
+ model = self.loaded_models.get(model_name) |
|
216 |
+ if model is None: |
|
217 |
+ raise KeyError(f"Model {model_name} not found") |
|
218 |
+ self.loaded_models[model_name].unload() |
|
219 |
+ |
|
220 |
+ def load_model(self, model_name: str) -> SelfDisposingModel[Kokoro]: |
|
221 |
+ with self._lock: |
|
222 |
+ if model_name in self.loaded_models: |
|
223 |
+ logger.debug(f"{model_name} model already loaded") |
|
224 |
+ return self.loaded_models[model_name] |
|
225 |
+ self.loaded_models[model_name] = SelfDisposingModel[Kokoro]( |
|
226 |
+ model_name, |
|
227 |
+ load_fn=lambda: self._load_fn(model_name), |
|
228 |
+ ttl=self.ttl, |
|
229 |
+ unload_fn=self._handle_model_unload, |
|
230 |
+ ) |
|
231 |
+ return self.loaded_models[model_name] |
+++ src/speaches/piper_utils.py
... | ... | @@ -0,0 +1,23 @@ |
1 | +from collections.abc import Generator | |
2 | +import logging | |
3 | +import time | |
4 | + | |
5 | +from piper.voice import PiperVoice | |
6 | + | |
7 | +from speaches.audio import resample_audio | |
8 | + | |
9 | +logger = logging.getLogger(__name__) | |
10 | + | |
11 | + | |
12 | +# TODO: async generator https://github.com/mikeshardmind/async-utils/blob/354b93a276572aa54c04212ceca5ac38fedf34ab/src/async_utils/gen_transform.py#L147 | |
13 | +def generate_audio( | |
14 | + piper_tts: PiperVoice, text: str, *, speed: float = 1.0, sample_rate: int | None = None | |
15 | +) -> Generator[bytes, None, None]: | |
16 | + if sample_rate is None: | |
17 | + sample_rate = piper_tts.config.sample_rate | |
18 | + start = time.perf_counter() | |
19 | + for audio_bytes in piper_tts.synthesize_stream_raw(text, length_scale=1.0 / speed): | |
20 | + if sample_rate != piper_tts.config.sample_rate: | |
21 | + audio_bytes = resample_audio(audio_bytes, piper_tts.config.sample_rate, sample_rate) # noqa: PLW2901 | |
22 | + yield audio_bytes | |
23 | + logger.info(f"Generated audio for {len(text)} characters in {time.perf_counter() - start}s") |
--- src/speaches/routers/misc.py
+++ src/speaches/routers/misc.py
... | ... | @@ -18,15 +18,19 @@ |
18 | 18 |
return Response(status_code=200, content="OK") |
19 | 19 |
|
20 | 20 |
|
21 |
-@router.post("/api/pull/{model_name:path}", tags=["experimental"], summary="Download a model from Hugging Face.") |
|
22 |
-def pull_model(model_name: str) -> Response: |
|
23 |
- if hf_utils.does_local_model_exist(model_name): |
|
24 |
- return Response(status_code=200, content="Model already exists") |
|
21 |
+@router.post( |
|
22 |
+ "/api/pull/{model_id:path}", |
|
23 |
+ tags=["experimental"], |
|
24 |
+ summary="Download a model from Hugging Face if it doesn't exist locally.", |
|
25 |
+) |
|
26 |
+def pull_model(model_id: str) -> Response: |
|
27 |
+ if hf_utils.does_local_model_exist(model_id): |
|
28 |
+ return Response(status_code=200, content=f"Model {model_id} already exists") |
|
25 | 29 |
try: |
26 |
- huggingface_hub.snapshot_download(model_name, repo_type="model") |
|
30 |
+ huggingface_hub.snapshot_download(model_id, repo_type="model") |
|
27 | 31 |
except RepositoryNotFoundError as e: |
28 | 32 |
return Response(status_code=404, content=str(e)) |
29 |
- return Response(status_code=201, content="Model downloaded") |
|
33 |
+ return Response(status_code=201, content=f"Model {model_id} downloaded") |
|
30 | 34 |
|
31 | 35 |
|
32 | 36 |
@router.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.") |
... | ... | @@ -36,19 +40,19 @@ |
36 | 40 |
return {"models": list(model_manager.loaded_models.keys())} |
37 | 41 |
|
38 | 42 |
|
39 |
-@router.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.") |
|
40 |
-def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response: |
|
41 |
- if model_name in model_manager.loaded_models: |
|
43 |
+@router.post("/api/ps/{model_id:path}", tags=["experimental"], summary="Load a model into memory.") |
|
44 |
+def load_model_route(model_manager: ModelManagerDependency, model_id: str) -> Response: |
|
45 |
+ if model_id in model_manager.loaded_models: |
|
42 | 46 |
return Response(status_code=409, content="Model already loaded") |
43 |
- with model_manager.load_model(model_name): |
|
47 |
+ with model_manager.load_model(model_id): |
|
44 | 48 |
pass |
45 | 49 |
return Response(status_code=201) |
46 | 50 |
|
47 | 51 |
|
48 |
-@router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.") |
|
49 |
-def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response: |
|
52 |
+@router.delete("/api/ps/{model_id:path}", tags=["experimental"], summary="Unload a model from memory.") |
|
53 |
+def stop_running_model(model_manager: ModelManagerDependency, model_id: str) -> Response: |
|
50 | 54 |
try: |
51 |
- model_manager.unload_model(model_name) |
|
55 |
+ model_manager.unload_model(model_id) |
|
52 | 56 |
return Response(status_code=204) |
53 | 57 |
except (KeyError, ValueError) as e: |
54 | 58 |
match e: |
--- src/speaches/routers/models.py
+++ src/speaches/routers/models.py
... | ... | @@ -9,7 +9,7 @@ |
9 | 9 |
) |
10 | 10 |
import huggingface_hub |
11 | 11 |
|
12 |
-from speaches.api_models import ( |
|
12 |
+from speaches.api_types import ( |
|
13 | 13 |
ListModelsResponse, |
14 | 14 |
Model, |
15 | 15 |
) |
--- src/speaches/routers/speech.py
+++ src/speaches/routers/speech.py
... | ... | @@ -1,28 +1,24 @@ |
1 |
-from collections.abc import Generator |
|
2 |
-import io |
|
3 | 1 |
import logging |
4 |
-import time |
|
5 | 2 |
from typing import Annotated, Literal, Self |
6 | 3 |
|
7 | 4 |
from fastapi import APIRouter |
8 | 5 |
from fastapi.responses import StreamingResponse |
9 |
-import numpy as np |
|
10 |
-from piper.voice import PiperVoice |
|
11 |
-from pydantic import BaseModel, BeforeValidator, Field, ValidationError, model_validator |
|
12 |
-import soundfile as sf |
|
6 |
+from pydantic import BaseModel, BeforeValidator, Field, model_validator |
|
13 | 7 |
|
14 |
-from speaches.dependencies import PiperModelManagerDependency |
|
8 |
+from speaches import kokoro_utils |
|
9 |
+from speaches.api_types import Voice |
|
10 |
+from speaches.audio import convert_audio_format |
|
11 |
+from speaches.dependencies import KokoroModelManagerDependency, PiperModelManagerDependency |
|
15 | 12 |
from speaches.hf_utils import ( |
16 |
- PiperModel, |
|
13 |
+ get_kokoro_model_path, |
|
17 | 14 |
list_piper_models, |
18 | 15 |
read_piper_voices_config, |
19 | 16 |
) |
20 | 17 |
|
21 |
-DEFAULT_MODEL = "piper" |
|
18 |
+DEFAULT_MODEL_ID = "hexgrad/Kokoro-82M" |
|
22 | 19 |
# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-response_format |
23 | 20 |
DEFAULT_RESPONSE_FORMAT = "mp3" |
24 |
-DEFAULT_VOICE = "en_US-amy-medium" # TODO: make configurable |
|
25 |
-DEFAULT_VOICE_SAMPLE_RATE = 22050 # NOTE: Dependant on the voice |
|
21 |
+DEFAULT_VOICE_ID = "af" # TODO: make configurable |
|
26 | 22 |
|
27 | 23 |
# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-model |
28 | 24 |
# https://platform.openai.com/docs/models/tts |
... | ... | @@ -46,82 +42,35 @@ |
46 | 42 |
router = APIRouter(tags=["speech-to-text"]) |
47 | 43 |
|
48 | 44 |
|
49 |
-# aip 'Write a function `resample_audio` which would take in RAW PCM 16-bit signed, little-endian audio data represented as bytes (`audio_bytes`) and resample it (either downsample or upsample) from `sample_rate` to `target_sample_rate` using numpy' # noqa: E501 |
|
50 |
-def resample_audio(audio_bytes: bytes, sample_rate: int, target_sample_rate: int) -> bytes: |
|
51 |
- audio_data = np.frombuffer(audio_bytes, dtype=np.int16) |
|
52 |
- duration = len(audio_data) / sample_rate |
|
53 |
- target_length = int(duration * target_sample_rate) |
|
54 |
- resampled_data = np.interp( |
|
55 |
- np.linspace(0, len(audio_data), target_length, endpoint=False), np.arange(len(audio_data)), audio_data |
|
56 |
- ) |
|
57 |
- return resampled_data.astype(np.int16).tobytes() |
|
58 |
- |
|
59 |
- |
|
60 |
-def generate_audio( |
|
61 |
- piper_tts: PiperVoice, text: str, *, speed: float = 1.0, sample_rate: int | None = None |
|
62 |
-) -> Generator[bytes, None, None]: |
|
63 |
- if sample_rate is None: |
|
64 |
- sample_rate = piper_tts.config.sample_rate |
|
65 |
- start = time.perf_counter() |
|
66 |
- for audio_bytes in piper_tts.synthesize_stream_raw(text, length_scale=1.0 / speed): |
|
67 |
- if sample_rate != piper_tts.config.sample_rate: |
|
68 |
- audio_bytes = resample_audio(audio_bytes, piper_tts.config.sample_rate, sample_rate) # noqa: PLW2901 |
|
69 |
- yield audio_bytes |
|
70 |
- logger.info(f"Generated audio for {len(text)} characters in {time.perf_counter() - start}s") |
|
71 |
- |
|
72 |
- |
|
73 |
-def convert_audio_format( |
|
74 |
- audio_bytes: bytes, |
|
75 |
- sample_rate: int, |
|
76 |
- audio_format: ResponseFormat, |
|
77 |
- format: str = "RAW", # noqa: A002 |
|
78 |
- channels: int = 1, |
|
79 |
- subtype: str = "PCM_16", |
|
80 |
- endian: str = "LITTLE", |
|
81 |
-) -> bytes: |
|
82 |
- # NOTE: the default dtype is float64. Should something else be used? Would that improve performance? |
|
83 |
- data, _ = sf.read( |
|
84 |
- io.BytesIO(audio_bytes), |
|
85 |
- samplerate=sample_rate, |
|
86 |
- format=format, |
|
87 |
- channels=channels, |
|
88 |
- subtype=subtype, |
|
89 |
- endian=endian, |
|
90 |
- ) |
|
91 |
- converted_audio_bytes_buffer = io.BytesIO() |
|
92 |
- sf.write(converted_audio_bytes_buffer, data, samplerate=sample_rate, format=audio_format) |
|
93 |
- return converted_audio_bytes_buffer.getvalue() |
|
94 |
- |
|
95 |
- |
|
96 | 45 |
def handle_openai_supported_model_ids(model_id: str) -> str: |
97 | 46 |
if model_id in OPENAI_SUPPORTED_SPEECH_MODEL: |
98 |
- logger.warning(f"{model_id} is not a valid model name. Using '{DEFAULT_MODEL}' instead.") |
|
99 |
- return DEFAULT_MODEL |
|
47 |
+ logger.warning(f"{model_id} is not a valid model name. Using '{DEFAULT_MODEL_ID}' instead.") |
|
48 |
+ return DEFAULT_MODEL_ID |
|
100 | 49 |
return model_id |
101 | 50 |
|
102 | 51 |
|
103 | 52 |
ModelId = Annotated[ |
104 |
- Literal["piper"], |
|
53 |
+ Literal["hexgrad/Kokoro-82M", "rhasspy/piper-voices"], |
|
105 | 54 |
BeforeValidator(handle_openai_supported_model_ids), |
106 | 55 |
Field( |
107 |
- description=f"The ID of the model. The only supported model is '{DEFAULT_MODEL}'.", |
|
108 |
- examples=[DEFAULT_MODEL], |
|
56 |
+ description="The ID of the model", |
|
57 |
+ examples=["hexgrad/Kokoro-82M", "rhasspy/piper-voices"], |
|
109 | 58 |
), |
110 | 59 |
] |
111 | 60 |
|
112 | 61 |
|
113 |
-def handle_openai_supported_voices(voice: str) -> str: |
|
114 |
- if voice in OPENAI_SUPPORTED_SPEECH_VOICE_NAMES: |
|
115 |
- logger.warning(f"{voice} is not a valid voice name. Using '{DEFAULT_VOICE}' instead.") |
|
116 |
- return DEFAULT_VOICE |
|
117 |
- return voice |
|
62 |
+def handle_openai_supported_voices(voice_id: str) -> str: |
|
63 |
+ if voice_id in OPENAI_SUPPORTED_SPEECH_VOICE_NAMES: |
|
64 |
+ logger.warning(f"{voice_id} is not a valid voice id. Using '{DEFAULT_VOICE_ID}' instead.") |
|
65 |
+ return DEFAULT_VOICE_ID |
|
66 |
+ return voice_id |
|
118 | 67 |
|
119 | 68 |
|
120 |
-Voice = Annotated[str, BeforeValidator(handle_openai_supported_voices)] # TODO: description and examples |
|
69 |
+VoiceId = Annotated[str, BeforeValidator(handle_openai_supported_voices)] # TODO: description and examples |
|
121 | 70 |
|
122 | 71 |
|
123 | 72 |
class CreateSpeechRequestBody(BaseModel): |
124 |
- model: ModelId = DEFAULT_MODEL |
|
73 |
+ model: ModelId = DEFAULT_MODEL_ID |
|
125 | 74 |
input: str = Field( |
126 | 75 |
..., |
127 | 76 |
description="The text to generate audio for. ", |
... | ... | @@ -129,55 +78,114 @@ |
129 | 78 |
"A rainbow is an optical phenomenon caused by refraction, internal reflection and dispersion of light in water droplets resulting in a continuous spectrum of light appearing in the sky. The rainbow takes the form of a multicoloured circular arc. Rainbows caused by sunlight always appear in the section of sky directly opposite the Sun. Rainbows can be caused by many forms of airborne water. These include not only rain, but also mist, spray, and airborne dew." # noqa: E501 |
130 | 79 |
], |
131 | 80 |
) |
132 |
- voice: Voice = DEFAULT_VOICE |
|
81 |
+ voice: VoiceId = DEFAULT_VOICE_ID |
|
133 | 82 |
""" |
134 |
-The last part of the voice name is the quality (x_low, low, medium, high). |
|
83 |
+For 'rhasspy/piper-voices' voices the last part of the voice name is the quality (x_low, low, medium, high). |
|
135 | 84 |
Each quality has a different default sample rate: |
136 | 85 |
- x_low: 16000 Hz |
137 | 86 |
- low: 16000 Hz |
138 | 87 |
- medium: 22050 Hz |
139 | 88 |
- high: 22050 Hz |
140 | 89 |
""" |
90 |
+ language: kokoro_utils.Language | None = None |
|
91 |
+ """ |
|
92 |
+ Only used for 'hexgrad/Kokoro-82M' models. The language of the text to generate audio for. |
|
93 |
+ """ |
|
141 | 94 |
response_format: ResponseFormat = Field( |
142 | 95 |
DEFAULT_RESPONSE_FORMAT, |
143 |
- description=f"The format to audio in. Supported formats are {", ".join(SUPPORTED_RESPONSE_FORMATS)}. {", ".join(UNSUPORTED_RESPONSE_FORMATS)} are not supported", # noqa: E501 |
|
96 |
+ description=f"The format to audio in. Supported formats are {', '.join(SUPPORTED_RESPONSE_FORMATS)}. {', '.join(UNSUPORTED_RESPONSE_FORMATS)} are not supported", # noqa: E501 |
|
144 | 97 |
examples=list(SUPPORTED_RESPONSE_FORMATS), |
145 | 98 |
) |
146 | 99 |
# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-voice |
147 |
- speed: float = Field(1.0, ge=0.25, le=4.0) |
|
148 |
- """The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default.""" |
|
100 |
+ speed: float = Field(1.0) |
|
101 |
+ """The speed of the generated audio. 1.0 is the default. |
|
102 |
+ For 'hexgrad/Kokoro-82M' models, the speed can be set to 0.5 to 2.0. |
|
103 |
+ For 'rhasspy/piper-voices' models, the speed can be set to 0.25 to 4.0. |
|
104 |
+ """ |
|
149 | 105 |
sample_rate: int | None = Field(None, ge=MIN_SAMPLE_RATE, le=MAX_SAMPLE_RATE) |
150 |
- """Desired sample rate to convert the generated audio to. If not provided, the model's default sample rate will be used.""" # noqa: E501 |
|
151 |
- # TODO: document default sample rate for each voice quality |
|
106 |
+ """Desired sample rate to convert the generated audio to. If not provided, the model's default sample rate will be used. |
|
107 |
+ For 'hexgrad/Kokoro-82M' models, the default sample rate is 24000 Hz. |
|
108 |
+ For 'rhasspy/piper-voices' models, the sample differs based on the voice quality (see `voice`). |
|
109 |
+ """ # noqa: E501 |
|
152 | 110 |
|
153 |
- # TODO: move into `Voice` |
|
154 | 111 |
@model_validator(mode="after") |
155 | 112 |
def verify_voice_is_valid(self) -> Self: |
156 |
- valid_voices = read_piper_voices_config() |
|
157 |
- if self.voice not in valid_voices: |
|
158 |
- raise ValidationError(f"Voice '{self.voice}' is not supported. Supported voices: {valid_voices.keys()}") |
|
113 |
+ if self.model == "hexgrad/Kokoro-82M": |
|
114 |
+ assert self.voice in kokoro_utils.VOICE_IDS |
|
115 |
+ elif self.model == "rhasspy/piper-voices": |
|
116 |
+ assert self.voice in read_piper_voices_config() |
|
117 |
+ return self |
|
118 |
+ |
|
119 |
+ @model_validator(mode="after") |
|
120 |
+ def validate_speed(self) -> Self: |
|
121 |
+ if self.model == "hexgrad/Kokoro-82M": |
|
122 |
+ assert 0.5 <= self.speed <= 2.0 |
|
123 |
+ if self.model == "rhasspy/piper-voices": |
|
124 |
+ assert 0.25 <= self.speed <= 4.0 |
|
159 | 125 |
return self |
160 | 126 |
|
161 | 127 |
|
162 | 128 |
# https://platform.openai.com/docs/api-reference/audio/createSpeech |
163 | 129 |
@router.post("/v1/audio/speech") |
164 |
-def synthesize( |
|
130 |
+async def synthesize( |
|
165 | 131 |
piper_model_manager: PiperModelManagerDependency, |
132 |
+ kokoro_model_manager: KokoroModelManagerDependency, |
|
166 | 133 |
body: CreateSpeechRequestBody, |
167 | 134 |
) -> StreamingResponse: |
168 |
- with piper_model_manager.load_model(body.voice) as piper_tts: |
|
169 |
- audio_generator = generate_audio(piper_tts, body.input, speed=body.speed, sample_rate=body.sample_rate) |
|
170 |
- if body.response_format != "pcm": |
|
171 |
- audio_generator = ( |
|
172 |
- convert_audio_format( |
|
173 |
- audio_bytes, body.sample_rate or piper_tts.config.sample_rate, body.response_format |
|
135 |
+ match body.model: |
|
136 |
+ case "hexgrad/Kokoro-82M": |
|
137 |
+ # TODO: download the `voices.json` file |
|
138 |
+ with kokoro_model_manager.load_model(body.voice) as tts: |
|
139 |
+ audio_generator = kokoro_utils.generate_audio( |
|
140 |
+ tts, |
|
141 |
+ body.input, |
|
142 |
+ body.voice, |
|
143 |
+ language=body.language or "en-us", |
|
144 |
+ speed=body.speed, |
|
145 |
+ sample_rate=body.sample_rate, |
|
174 | 146 |
) |
175 |
- for audio_bytes in audio_generator |
|
176 |
- ) |
|
147 |
+ if body.response_format != "pcm": |
|
148 |
+ audio_generator = ( |
|
149 |
+ convert_audio_format( |
|
150 |
+ audio_bytes, body.sample_rate or kokoro_utils.SAMPLE_RATE, body.response_format |
|
151 |
+ ) |
|
152 |
+ async for audio_bytes in audio_generator |
|
153 |
+ ) |
|
154 |
+ return StreamingResponse(audio_generator, media_type=f"audio/{body.response_format}") |
|
155 |
+ case "rhasspy/piper-voices": |
|
156 |
+ from speaches import piper_utils |
|
177 | 157 |
|
178 |
- return StreamingResponse(audio_generator, media_type=f"audio/{body.response_format}") |
|
158 |
+ with piper_model_manager.load_model(body.voice) as piper_tts: |
|
159 |
+ # TODO: async generator |
|
160 |
+ audio_generator = piper_utils.generate_audio( |
|
161 |
+ piper_tts, body.input, speed=body.speed, sample_rate=body.sample_rate |
|
162 |
+ ) |
|
163 |
+ if body.response_format != "pcm": |
|
164 |
+ audio_generator = ( |
|
165 |
+ convert_audio_format( |
|
166 |
+ audio_bytes, body.sample_rate or piper_tts.config.sample_rate, body.response_format |
|
167 |
+ ) |
|
168 |
+ for audio_bytes in audio_generator |
|
169 |
+ ) |
|
170 |
+ return StreamingResponse(audio_generator, media_type=f"audio/{body.response_format}") |
|
179 | 171 |
|
180 | 172 |
|
181 | 173 |
@router.get("/v1/audio/speech/voices") |
182 |
-def list_voices() -> list[PiperModel]: |
|
183 |
- return list(list_piper_models()) |
|
174 |
+def list_voices(model_id: ModelId | None = None) -> list[Voice]: |
|
175 |
+ voices: list[Voice] = [] |
|
176 |
+ if model_id == "hexgrad/Kokoro-82M" or model_id is None: |
|
177 |
+ kokoro_model_path = get_kokoro_model_path() |
|
178 |
+ for voice_id in kokoro_utils.VOICE_IDS: |
|
179 |
+ voice = Voice( |
|
180 |
+ created=0, |
|
181 |
+ model_path=kokoro_model_path, |
|
182 |
+ model_id="hexgrad/Kokoro-82M", |
|
183 |
+ owned_by="hexgrad", |
|
184 |
+ sample_rate=kokoro_utils.SAMPLE_RATE, |
|
185 |
+ voice_id=voice_id, |
|
186 |
+ ) |
|
187 |
+ voices.append(voice) |
|
188 |
+ elif model_id == "rhasspy/piper-voices" or model_id is None: |
|
189 |
+ voices.extend(list(list_piper_models())) |
|
190 |
+ |
|
191 |
+ return voices |
--- src/speaches/routers/stt.py
+++ src/speaches/routers/stt.py
... | ... | @@ -27,7 +27,7 @@ |
27 | 27 |
from numpy.typing import NDArray |
28 | 28 |
from pydantic import AfterValidator, Field |
29 | 29 |
|
30 |
-from speaches.api_models import ( |
|
30 |
+from speaches.api_types import ( |
|
31 | 31 |
DEFAULT_TIMESTAMP_GRANULARITIES, |
32 | 32 |
TIMESTAMP_GRANULARITIES_COMBINATIONS, |
33 | 33 |
CreateTranscriptionResponseJson, |
... | ... | @@ -211,9 +211,9 @@ |
211 | 211 |
if form.get("timestamp_granularities[]") is None: |
212 | 212 |
return DEFAULT_TIMESTAMP_GRANULARITIES |
213 | 213 |
timestamp_granularities = form.getlist("timestamp_granularities[]") |
214 |
- assert ( |
|
215 |
- timestamp_granularities in TIMESTAMP_GRANULARITIES_COMBINATIONS |
|
216 |
- ), f"{timestamp_granularities} is not a valid value for `timestamp_granularities[]`." |
|
214 |
+ assert timestamp_granularities in TIMESTAMP_GRANULARITIES_COMBINATIONS, ( |
|
215 |
+ f"{timestamp_granularities} is not a valid value for `timestamp_granularities[]`." |
|
216 |
+ ) |
|
217 | 217 |
return timestamp_granularities |
218 | 218 |
|
219 | 219 |
|
--- src/speaches/text_utils.py
+++ src/speaches/text_utils.py
... | ... | @@ -6,7 +6,7 @@ |
6 | 6 |
if TYPE_CHECKING: |
7 | 7 |
from collections.abc import Iterable |
8 | 8 |
|
9 |
- from speaches.api_models import TranscriptionSegment, TranscriptionWord |
|
9 |
+ from speaches.api_types import TranscriptionSegment, TranscriptionWord |
|
10 | 10 |
|
11 | 11 |
|
12 | 12 |
class Transcription: |
--- src/speaches/text_utils_test.py
+++ src/speaches/text_utils_test.py
... | ... | @@ -1,4 +1,4 @@ |
1 |
-from speaches.api_models import TranscriptionWord |
|
1 |
+from speaches.api_types import TranscriptionWord |
|
2 | 2 |
from speaches.text_utils import ( |
3 | 3 |
canonicalize_word, |
4 | 4 |
common_prefix, |
--- src/speaches/transcriber.py
+++ src/speaches/transcriber.py
... | ... | @@ -9,7 +9,7 @@ |
9 | 9 |
if TYPE_CHECKING: |
10 | 10 |
from collections.abc import AsyncGenerator |
11 | 11 |
|
12 |
- from speaches.api_models import TranscriptionWord |
|
12 |
+ from speaches.api_types import TranscriptionWord |
|
13 | 13 |
from speaches.asr import FasterWhisperASR |
14 | 14 |
|
15 | 15 |
logger = logging.getLogger(__name__) |
--- tests/api_timestamp_granularities_test.py
+++ tests/api_timestamp_granularities_test.py
... | ... | @@ -5,7 +5,7 @@ |
5 | 5 |
from openai import AsyncOpenAI |
6 | 6 |
import pytest |
7 | 7 |
|
8 |
-from speaches.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities |
|
8 |
+from speaches.api_types import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities |
|
9 | 9 |
|
10 | 10 |
|
11 | 11 |
@pytest.mark.asyncio |
--- tests/openai_timestamp_granularities_test.py
+++ tests/openai_timestamp_granularities_test.py
... | ... | @@ -5,7 +5,7 @@ |
5 | 5 |
from openai import AsyncOpenAI, BadRequestError |
6 | 6 |
import pytest |
7 | 7 |
|
8 |
-from speaches.api_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities |
|
8 |
+from speaches.api_types import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities |
|
9 | 9 |
|
10 | 10 |
|
11 | 11 |
@pytest.mark.asyncio |
--- tests/speech_test.py
+++ tests/speech_test.py
... | ... | @@ -10,9 +10,9 @@ |
10 | 10 |
pytest.skip("Only supported on x86_64", allow_module_level=True) |
11 | 11 |
|
12 | 12 |
from speaches.routers.speech import ( # noqa: E402 |
13 |
- DEFAULT_MODEL, |
|
13 |
+ DEFAULT_MODEL_ID, |
|
14 | 14 |
DEFAULT_RESPONSE_FORMAT, |
15 |
- DEFAULT_VOICE, |
|
15 |
+ DEFAULT_VOICE_ID, |
|
16 | 16 |
SUPPORTED_RESPONSE_FORMATS, |
17 | 17 |
ResponseFormat, |
18 | 18 |
) |
... | ... | @@ -25,8 +25,8 @@ |
25 | 25 |
@pytest.mark.parametrize("response_format", SUPPORTED_RESPONSE_FORMATS) |
26 | 26 |
async def test_create_speech_formats(openai_client: AsyncOpenAI, response_format: ResponseFormat) -> None: |
27 | 27 |
await openai_client.audio.speech.create( |
28 |
- model=DEFAULT_MODEL, |
|
29 |
- voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003 |
|
28 |
+ model=DEFAULT_MODEL_ID, |
|
29 |
+ voice=DEFAULT_VOICE_ID, # type: ignore # noqa: PGH003 |
|
30 | 30 |
input=DEFAULT_INPUT, |
31 | 31 |
response_format=response_format, |
32 | 32 |
) |
... | ... | @@ -35,9 +35,9 @@ |
35 | 35 |
GOOD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [ |
36 | 36 |
("tts-1", "alloy"), # OpenAI and OpenAI |
37 | 37 |
("tts-1-hd", "echo"), # OpenAI and OpenAI |
38 |
- ("tts-1", DEFAULT_VOICE), # OpenAI and Piper |
|
39 |
- (DEFAULT_MODEL, "echo"), # Piper and OpenAI |
|
40 |
- (DEFAULT_MODEL, DEFAULT_VOICE), # Piper and Piper |
|
38 |
+ ("tts-1", DEFAULT_VOICE_ID), # OpenAI and Piper |
|
39 |
+ (DEFAULT_MODEL_ID, "echo"), # Piper and OpenAI |
|
40 |
+ (DEFAULT_MODEL_ID, DEFAULT_VOICE_ID), # Piper and Piper |
|
41 | 41 |
] |
42 | 42 |
|
43 | 43 |
|
... | ... | @@ -56,8 +56,8 @@ |
56 | 56 |
BAD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [ |
57 | 57 |
("tts-1", "invalid"), # OpenAI and invalid |
58 | 58 |
("invalid", "echo"), # Invalid and OpenAI |
59 |
- (DEFAULT_MODEL, "invalid"), # Piper and invalid |
|
60 |
- ("invalid", DEFAULT_VOICE), # Invalid and Piper |
|
59 |
+ (DEFAULT_MODEL_ID, "invalid"), # Piper and invalid |
|
60 |
+ ("invalid", DEFAULT_VOICE_ID), # Invalid and Piper |
|
61 | 61 |
("invalid", "invalid"), # Invalid and invalid |
62 | 62 |
] |
63 | 63 |
|
... | ... | @@ -85,8 +85,8 @@ |
85 | 85 |
previous_size: int | None = None |
86 | 86 |
for speed in SUPPORTED_SPEEDS: |
87 | 87 |
res = await openai_client.audio.speech.create( |
88 |
- model=DEFAULT_MODEL, |
|
89 |
- voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003 |
|
88 |
+ model=DEFAULT_MODEL_ID, |
|
89 |
+ voice=DEFAULT_VOICE_ID, # type: ignore # noqa: PGH003 |
|
90 | 90 |
input=DEFAULT_INPUT, |
91 | 91 |
response_format="pcm", |
92 | 92 |
speed=speed, |
... | ... | @@ -106,8 +106,8 @@ |
106 | 106 |
async def test_create_speech_with_unsupported_speed(openai_client: AsyncOpenAI, speed: float) -> None: |
107 | 107 |
with pytest.raises(UnprocessableEntityError): |
108 | 108 |
await openai_client.audio.speech.create( |
109 |
- model=DEFAULT_MODEL, |
|
110 |
- voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003 |
|
109 |
+ model=DEFAULT_MODEL_ID, |
|
110 |
+ voice=DEFAULT_VOICE_ID, # type: ignore # noqa: PGH003 |
|
111 | 111 |
input=DEFAULT_INPUT, |
112 | 112 |
response_format="pcm", |
113 | 113 |
speed=speed, |
... | ... | @@ -122,8 +122,8 @@ |
122 | 122 |
@pytest.mark.parametrize("sample_rate", VALID_SAMPLE_RATES) |
123 | 123 |
async def test_speech_valid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None: |
124 | 124 |
res = await openai_client.audio.speech.create( |
125 |
- model=DEFAULT_MODEL, |
|
126 |
- voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003 |
|
125 |
+ model=DEFAULT_MODEL_ID, |
|
126 |
+ voice=DEFAULT_VOICE_ID, # type: ignore # noqa: PGH003 |
|
127 | 127 |
input=DEFAULT_INPUT, |
128 | 128 |
response_format="wav", |
129 | 129 |
extra_body={"sample_rate": sample_rate}, |
... | ... | @@ -141,8 +141,8 @@ |
141 | 141 |
async def test_speech_invalid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None: |
142 | 142 |
with pytest.raises(UnprocessableEntityError): |
143 | 143 |
await openai_client.audio.speech.create( |
144 |
- model=DEFAULT_MODEL, |
|
145 |
- voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003 |
|
144 |
+ model=DEFAULT_MODEL_ID, |
|
145 |
+ voice=DEFAULT_VOICE_ID, # type: ignore # noqa: PGH003 |
|
146 | 146 |
input=DEFAULT_INPUT, |
147 | 147 |
response_format="wav", |
148 | 148 |
extra_body={"sample_rate": sample_rate}, |
--- tests/sse_test.py
+++ tests/sse_test.py
... | ... | @@ -9,7 +9,7 @@ |
9 | 9 |
import webvtt |
10 | 10 |
import webvtt.vtt |
11 | 11 |
|
12 |
-from speaches.api_models import ( |
|
12 |
+from speaches.api_types import ( |
|
13 | 13 |
CreateTranscriptionResponseJson, |
14 | 14 |
CreateTranscriptionResponseVerboseJson, |
15 | 15 |
) |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?