Quentin Fuxa 2024-12-19
add translate_model_name function
@6b6a3419d5410e589b3aa6481842c322c75ec1ff
whisper_online.py
--- whisper_online.py
+++ whisper_online.py
@@ -160,27 +160,71 @@
     """
     Uses MPX Whisper library as the backend, optimized for Apple Silicon.
     Models available: https://huggingface.co/collections/mlx-community/whisper-663256f9964fbb1177db93dc
-    Significantly faster than faster-whisper (without CUDA) on Apple M1. Model used by default: mlx-community/whisper-large-v3-mlx
+    Significantly faster than faster-whisper (without CUDA) on Apple M1. 
     """
 
     sep = " "
 
-    def load_model(self, modelsize=None, model_dir=None):
+    def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
+        """
+            Loads the MLX-compatible Whisper model.
+
+            Args:
+                modelsize (str, optional): The size or name of the Whisper model to load. 
+                    If provided, it will be translated to an MLX-compatible model path using the `translate_model_name` method.
+                    Example: "large-v3-turbo" -> "mlx-community/whisper-large-v3-turbo".
+                cache_dir (str, optional): Path to the directory for caching models. 
+                    **Note**: This is not supported by MLX Whisper and will be ignored.
+                model_dir (str, optional): Direct path to a custom model directory. 
+                    If specified, it overrides the `modelsize` parameter.
+        """
         from mlx_whisper import transcribe
 
         if model_dir is not None:
             logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
             model_size_or_path = model_dir
         elif modelsize is not None:
-            logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so make sure you use a mlx-compatible model.")
-            model_size_or_path = modelsize
-        elif modelsize == None:
-            logger.debug("No model size or path specified. Using mlx-community/whisper-large-v3-mlx.")
-            model_size_or_path = "mlx-community/whisper-large-v3-mlx"
+            model_size_or_path = self.translate_model_name(modelsize)
+            logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
         
         self.model_size_or_path = model_size_or_path
         return transcribe
     
+    def translate_model_name(self, model_name):
+        """
+        Translates a given model name to its corresponding MLX-compatible model path.
+
+        Args:
+            model_name (str): The name of the model to translate.
+
+        Returns:
+            str: The MLX-compatible model path.
+        """
+        # Dictionary mapping model names to MLX-compatible paths
+        model_mapping = {
+            "tiny.en": "mlx-community/whisper-tiny.en-mlx",
+            "tiny": "mlx-community/whisper-tiny-mlx",
+            "base.en": "mlx-community/whisper-base.en-mlx",
+            "base": "mlx-community/whisper-base-mlx",
+            "small.en": "mlx-community/whisper-small.en-mlx",
+            "small": "mlx-community/whisper-small-mlx",
+            "medium.en": "mlx-community/whisper-medium.en-mlx",
+            "medium": "mlx-community/whisper-medium-mlx",
+            "large-v1": "mlx-community/whisper-large-v1-mlx",
+            "large-v2": "mlx-community/whisper-large-v2-mlx",
+            "large-v3": "mlx-community/whisper-large-v3-mlx",
+            "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
+            "large": "mlx-community/whisper-large-mlx"
+        }
+
+        # Retrieve the corresponding MLX model path
+        mlx_model_path = model_mapping.get(model_name)
+
+        if mlx_model_path:
+            return mlx_model_path
+        else:
+            raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
+    
     def transcribe(self, audio, init_prompt=""):
         segments = self.model(
             audio,
Add a comment
List