Tijs Zwinkels 2024-03-21
Move creation of OnlineASRProcessor inside the factory method
Preventing more code duplication between whisper_online.py and whisper_online_server.py
@812aefa40f55094e2c0d2df68ea4c7113a79ef0f
whisper_online.py
--- whisper_online.py
+++ whisper_online.py
@@ -551,7 +551,7 @@
 
 def asr_factory(args, logfile=sys.stderr):
     """
-    Creates and configures an ASR instance based on the specified backend and arguments.
+    Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
     """
     backend = args.backend
     if backend == "openai-api":
@@ -576,8 +576,23 @@
         print("Setting VAD filter", file=logfile)
         asr.use_vad()
 
-    return asr
+    language = args.lan
+    if args.task == "translate":
+        asr.set_translate_task()
+        tgt_language = "en"  # Whisper translates into English
+    else:
+        tgt_language = language  # Whisper transcribes in this language
 
+    # Create the tokenizer
+    if args.buffer_trimming == "sentence":
+        tokenizer = create_tokenizer(tgt_language)
+    else:
+        tokenizer = None
+
+    # Create the OnlineASRProcessor
+    online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
+
+    return asr, online
 ## main:
 
 if __name__ == "__main__":
@@ -605,22 +620,8 @@
     duration = len(load_audio(audio_path))/SAMPLING_RATE
     print("Audio duration is: %2.2f seconds" % duration, file=logfile)
 
-    asr = asr_factory(args, logfile=logfile)
-    language = args.lan
-    if args.task == "translate":
-        asr.set_translate_task()
-        tgt_language = "en"  # Whisper translates into English
-    else:
-        tgt_language = language  # Whisper transcribes in this language
-
-    
+    asr, online = asr_factory(args, logfile=logfile)
     min_chunk = args.min_chunk_size
-    if args.buffer_trimming == "sentence":
-        tokenizer = create_tokenizer(tgt_language)
-    else:
-        tokenizer = None
-    online = OnlineASRProcessor(asr,tokenizer,logfile=logfile,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
-
 
     # load the audio into the LRU cache before we start the timer
     a = load_audio_chunk(audio_path,0,1)
whisper_online_server.py
--- whisper_online_server.py
+++ whisper_online_server.py
@@ -23,23 +23,8 @@
 
 size = args.model
 language = args.lan
-
-asr = asr_factory(args)
-if args.task == "translate":
-    asr.set_translate_task()
-    tgt_language = "en"
-else:
-    tgt_language = language
-
+asr, online = asr_factory(args)
 min_chunk = args.min_chunk_size
-
-if args.buffer_trimming == "sentence":
-    tokenizer = create_tokenizer(tgt_language)
-else:
-    tokenizer = None
-online = OnlineASRProcessor(asr,tokenizer,buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
-
-
 
 demo_audio_path = "cs-maji-2.16k.wav"
 if os.path.exists(demo_audio_path):
Add a comment
List