

Merge branch 'whisper-mlx'
@3d91a2aa9808c8d43dcf084d40e5e9647afe8e83
--- whisper_online.py
+++ whisper_online.py
... | ... | @@ -160,27 +160,71 @@ |
160 | 160 |
""" |
161 | 161 |
Uses MPX Whisper library as the backend, optimized for Apple Silicon. |
162 | 162 |
Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc |
163 |
- Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx |
|
163 |
+ Significantly faster than faster-whisper (without CUDA) on Apple M1. |
|
164 | 164 |
""" |
165 | 165 |
|
166 | 166 |
sep = " " |
167 | 167 |
|
168 |
- def load_model(self, modelsize=None, model_dir=None): |
|
168 |
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None): |
|
169 |
+ """ |
|
170 |
+ Loads the MLX-compatible Whisper model. |
|
171 |
+ |
|
172 |
+ Args: |
|
173 |
+ modelsize (str, optional): The size or name of the Whisper model to load. |
|
174 |
+ If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method. |
|
175 |
+ Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo". |
|
176 |
+ cache_dir (str, optional): Path to the directory for caching models. |
|
177 |
+ **Note**: This is not supported by MLX Whisper and will be ignored. |
|
178 |
+ model_dir (str, optional): Direct path to a custom model directory. |
|
179 |
+ If specified, it overrides the `modelsize` parameter. |
|
180 |
+ """ |
|
169 | 181 |
from mlx_whisper import transcribe |
170 | 182 |
|
171 | 183 |
if model_dir is not None: |
172 | 184 |
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") |
173 | 185 |
model_size_or_path = model_dir |
174 | 186 |
elif modelsize is not None: |
175 |
- logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.") |
|
176 |
- model_size_or_path = modelsize |
|
177 |
- elif modelsize == None: |
|
178 |
- logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.") |
|
179 |
- model_size_or_path = "mlx-community/whisper-large-v3-mlx" |
|
187 |
+ model_size_or_path = self.translate_model_name(modelsize) |
|
188 |
+ logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.") |
|
180 | 189 |
|
181 | 190 |
self.model_size_or_path = model_size_or_path |
182 | 191 |
return transcribe |
183 | 192 |
|
193 |
+ def translate_model_name(self, model_name): |
|
194 |
+ """ |
|
195 |
+ Translates a given model name to its corresponding MLX-compatible model path. |
|
196 |
+ |
|
197 |
+ Args: |
|
198 |
+ model_name (str): The name of the model to translate. |
|
199 |
+ |
|
200 |
+ Returns: |
|
201 |
+ str: The MLX-compatible model path. |
|
202 |
+ """ |
|
203 |
+ # Dictionary mapping model names to MLX-compatible paths |
|
204 |
+ model_mapping = { |
|
205 |
+ "tiny.en": "mlx-community/whisper-tiny.en-mlx", |
|
206 |
+ "tiny": "mlx-community/whisper-tiny-mlx", |
|
207 |
+ "base.en": "mlx-community/whisper-base.en-mlx", |
|
208 |
+ "base": "mlx-community/whisper-base-mlx", |
|
209 |
+ "small.en": "mlx-community/whisper-small.en-mlx", |
|
210 |
+ "small": "mlx-community/whisper-small-mlx", |
|
211 |
+ "medium.en": "mlx-community/whisper-medium.en-mlx", |
|
212 |
+ "medium": "mlx-community/whisper-medium-mlx", |
|
213 |
+ "large-v1": "mlx-community/whisper-large-v1-mlx", |
|
214 |
+ "large-v2": "mlx-community/whisper-large-v2-mlx", |
|
215 |
+ "large-v3": "mlx-community/whisper-large-v3-mlx", |
|
216 |
+ "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", |
|
217 |
+ "large": "mlx-community/whisper-large-mlx" |
|
218 |
+ } |
|
219 |
+ |
|
220 |
+ # Retrieve the corresponding MLX model path |
|
221 |
+ mlx_model_path = model_mapping.get(model_name) |
|
222 |
+ |
|
223 |
+ if mlx_model_path: |
|
224 |
+ return mlx_model_path |
|
225 |
+ else: |
|
226 |
+ raise ValueError(f"Model name '{model_name}' is not recognized or not supported.") |
|
227 |
+ |
|
184 | 228 |
def transcribe(self, audio, init_prompt=""): |
185 | 229 |
segments = self.model( |
186 | 230 |
audio, |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?