

add whisper mlx backend
@2721b3cc035b16e58d3b09ebfaeda1a16250b6c7
--- whisper_online.py
+++ whisper_online.py
... | ... | @@ -156,6 +156,63 @@ |
156 | 156 |
def set_translate_task(self): |
157 | 157 |
self.transcribe_kargs["task"] = "translate" |
158 | 158 |
|
159 |
+class MLXWhisper(ASRBase): |
|
160 |
+ """ |
|
161 |
+ Uses MPX Whisper library as the backend, optimized for Apple Silicon. |
|
162 |
+ Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc |
|
163 |
+ Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx |
|
164 |
+ """ |
|
165 |
+ |
|
166 |
+ sep = " " |
|
167 |
+ |
|
168 |
+ def load_model(self, modelsize=None, model_dir=None): |
|
169 |
+ from mlx_whisper import transcribe |
|
170 |
+ |
|
171 |
+ if model_dir is not None: |
|
172 |
+ logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") |
|
173 |
+ model_size_or_path = model_dir |
|
174 |
+ elif modelsize is not None: |
|
175 |
+ logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.") |
|
176 |
+ model_size_or_path = modelsize |
|
177 |
+ elif modelsize == None: |
|
178 |
+ logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.") |
|
179 |
+ model_size_or_path = "mlx-community/whisper-large-v3-mlx" |
|
180 |
+ |
|
181 |
+ self.model_size_or_path = model_size_or_path |
|
182 |
+ return transcribe |
|
183 |
+ |
|
184 |
+ def transcribe(self, audio, init_prompt=""): |
|
185 |
+ segments = self.model( |
|
186 |
+ audio, |
|
187 |
+ language=self.original_language, |
|
188 |
+ initial_prompt=init_prompt, |
|
189 |
+ word_timestamps=True, |
|
190 |
+ condition_on_previous_text=True, |
|
191 |
+ path_or_hf_repo=self.model_size_or_path, |
|
192 |
+ **self.transcribe_kargs |
|
193 |
+ ) |
|
194 |
+ return segments.get("segments", []) |
|
195 |
+ |
|
196 |
+ |
|
197 |
+ def ts_words(self, segments): |
|
198 |
+ """ |
|
199 |
+ Extract timestamped words from transcription segments and skips words with high no-speech probability. |
|
200 |
+ """ |
|
201 |
+ return [ |
|
202 |
+ (word["start"], word["end"], word["word"]) |
|
203 |
+ for segment in segments |
|
204 |
+ for word in segment.get("words", []) |
|
205 |
+ if segment.get("no_speech_prob", 0) <= 0.9 |
|
206 |
+ ] |
|
207 |
+ |
|
208 |
+ def segments_end_ts(self, res): |
|
209 |
+ return [s['end'] for s in res] |
|
210 |
+ |
|
211 |
+ def use_vad(self): |
|
212 |
+ self.transcribe_kargs["vad_filter"] = True |
|
213 |
+ |
|
214 |
+ def set_translate_task(self): |
|
215 |
+ self.transcribe_kargs["task"] = "translate" |
|
159 | 216 |
|
160 | 217 |
class OpenaiApiASR(ASRBase): |
161 | 218 |
"""Uses OpenAI's Whisper API for audio transcription.""" |
... | ... | @@ -660,7 +717,7 @@ |
660 | 717 |
parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.") |
661 | 718 |
parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.") |
662 | 719 |
parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.") |
663 |
- parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.') |
|
720 |
+ parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],help='Load only this backend for Whisper processing.') |
|
664 | 721 |
parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.') |
665 | 722 |
parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.') |
666 | 723 |
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.') |
... | ... | @@ -679,6 +736,8 @@ |
679 | 736 |
else: |
680 | 737 |
if backend == "faster-whisper": |
681 | 738 |
asr_cls = FasterWhisperASR |
739 |
+ elif backend == "mlx-whisper": |
|
740 |
+ asr_cls = MLXWhisper |
|
682 | 741 |
else: |
683 | 742 |
asr_cls = WhisperTimestampedASR |
684 | 743 |
|
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?