Tijs Zwinkels 2024-03-20
Fix crash when using openai-api with whisper_online_server + refactored creation of the ASR into a factory method
+ refactored creation of the ASR into a factory method
@2d56ec87bf3a3b85cbf1f5a815fa68d7f49cf264
whisper_online.py
--- whisper_online.py
+++ whisper_online.py
@@ -548,6 +548,37 @@
     parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
     parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
 
+def asr_factory(args, logfile=sys.stderr):
+    """
+    Creates and configures an ASR instance based on the specified backend and arguments.
+    """
+    backend = args.backend
+    if backend == "openai-api":
+        print("Using OpenAI API.", file=logfile)
+        asr = OpenaiApiASR(lan=args.lan)
+    else:
+        if backend == "faster-whisper":
+            from faster_whisper import FasterWhisperASR
+            asr_cls = FasterWhisperASR
+        else:
+            from whisper_timestamped import WhisperTimestampedASR
+            asr_cls = WhisperTimestampedASR
+
+        # Only for FasterWhisperASR and WhisperTimestampedASR
+        size = args.model
+        t = time.time()
+        print(f"Loading Whisper {size} model for {args.lan}...", file=logfile, end=" ", flush=True)
+        asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
+        e = time.time()
+        print(f"done. It took {round(e-t,2)} seconds.", file=logfile)
+
+    # Apply common configurations
+    if getattr(args, 'vad', False):  # Checks if VAD argument is present and True
+        print("Setting VAD filter", file=logfile)
+        asr.use_vad()
+
+    return asr
+
 ## main:
 
 if __name__ == "__main__":
@@ -575,28 +606,8 @@
     duration = len(load_audio(audio_path))/SAMPLING_RATE
     print("Audio duration is: %2.2f seconds" % duration, file=logfile)
 
+    asr = asr_factory(args, logfile=logfile)
     language = args.lan
-
-    if args.backend == "openai-api":
-        print("Using OpenAI API.",file=logfile)
-        asr = OpenaiApiASR(lan=language)
-    else:
-        if args.backend == "faster-whisper":
-            asr_cls = FasterWhisperASR
-        else:
-            asr_cls = WhisperTimestampedASR
-
-        size = args.model
-        t = time.time()
-        print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
-        asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
-        e = time.time()
-        print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
-
-    if args.vad:
-        print("setting VAD filter",file=logfile)
-        asr.use_vad()
-
     if args.task == "translate":
         asr.set_translate_task()
         tgt_language = "en"  # Whisper translates into English
whisper_online_server.py
--- whisper_online_server.py
+++ whisper_online_server.py
@@ -24,35 +24,12 @@
 size = args.model
 language = args.lan
 
-t = time.time()
-print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
-
-if args.backend == "faster-whisper":
-    from faster_whisper import WhisperModel
-    asr_cls = FasterWhisperASR
-elif args.backend == "openai-api":
-    asr_cls = OpenaiApiASR
-else:
-    import whisper
-    import whisper_timestamped
-#    from whisper_timestamped_model import WhisperTimestampedASR
-    asr_cls = WhisperTimestampedASR
-
-asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
-
+asr = asr_factory(args)
 if args.task == "translate":
     asr.set_translate_task()
     tgt_language = "en"
 else:
     tgt_language = language
-
-e = time.time()
-print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
-
-if args.vad:
-    print("setting VAD filter",file=sys.stderr)
-    asr.use_vad()
-
 
 min_chunk = args.min_chunk_size
 
Add a comment
List