Quentin Fuxa 2024-12-19
add whisper mlx backend
@2721b3cc035b16e58d3b09ebfaeda1a16250b6c7
whisper_online.py
--- whisper_online.py
+++ whisper_online.py
@@ -156,6 +156,63 @@
     def set_translate_task(self):
         self.transcribe_kargs["task"] = "translate"
 
+class MLXWhisper(ASRBase):
+    """
+    Uses MPX Whisper library as the backend, optimized for Apple Silicon.
+    Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
+    Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx
+    """
+
+    sep = " "
+
+    def load_model(self, modelsize=None, model_dir=None):
+        from mlx_whisper import transcribe
+
+        if model_dir is not None:
+            logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
+            model_size_or_path = model_dir
+        elif modelsize is not None:
+            logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.")
+            model_size_or_path = modelsize
+        elif modelsize == None:
+            logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.")
+            model_size_or_path = "mlx-community/whisper-large-v3-mlx"
+        
+        self.model_size_or_path = model_size_or_path
+        return transcribe
+    
+    def transcribe(self, audio, init_prompt=""):
+        segments = self.model(
+            audio,
+            language=self.original_language,
+            initial_prompt=init_prompt,
+            word_timestamps=True,
+            condition_on_previous_text=True,
+            path_or_hf_repo=self.model_size_or_path,
+            **self.transcribe_kargs
+        )
+        return segments.get("segments", [])
+
+
+    def ts_words(self, segments):
+        """
+        Extract timestamped words from transcription segments and skips words with high no-speech probability.
+        """
+        return [
+            (word["start"], word["end"], word["word"])
+            for segment in segments
+            for word in segment.get("words", [])
+            if segment.get("no_speech_prob", 0) <= 0.9
+        ]
+    
+    def segments_end_ts(self, res):
+        return [s['end'] for s in res]
+
+    def use_vad(self):
+        self.transcribe_kargs["vad_filter"] = True
+
+    def set_translate_task(self):
+        self.transcribe_kargs["task"] = "translate"
 
 class OpenaiApiASR(ASRBase):
     """Uses OpenAI's Whisper API for audio transcription."""
@@ -660,7 +717,7 @@
     parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
     parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
     parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.")
-    parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.')
+    parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],help='Load only this backend for Whisper processing.')
     parser.add_argument('--vac', action="store_true", default=False, help='Use VAC = voice activity controller. Recommended. Requires torch.')
     parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.')
     parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.')
@@ -679,6 +736,8 @@
     else:
         if backend == "faster-whisper":
             asr_cls = FasterWhisperASR
+        elif backend == "mlx-whisper":
+            asr_cls = MLXWhisper
         else:
             asr_cls = WhisperTimestampedASR
 
Add a comment
List