Luca 2023-11-03
backend import in child load_model method and expose logfile arg
@cc7e524fc4c5be39875538e4b944c2a846b72316
whisper_online.py
--- whisper_online.py
+++ whisper_online.py
@@ -30,11 +30,7 @@
         self.transcribe_kargs = {}
         self.original_language = lan 
 
-        self.import_backend()
         self.model = self.load_model(modelsize, cache_dir, model_dir)
-
-    def import_backend(self):
-        raise NotImplemented("must be implemented in the child class")
 
     def load_model(self, modelsize, cache_dir):
         raise NotImplemented("must be implemented in the child class")
@@ -52,15 +48,13 @@
     """
 
     sep = " "
-    
-    def import_backend(self):
-        global whisper, whisper_timestamped
-        import whisper
-        import whisper_timestamped
 
     def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
+        global whisper_timestamped  # has to be global as it is used at each `transcribe` call
+        import whisper
+        import whisper_timestamped
         if model_dir is not None:
-            print("ignoring model_dir, not implemented",file=self.output)
+            print("ignoring model_dir, not implemented",file=self.logfile)
         return whisper.load_model(modelsize, download_root=cache_dir)
 
     def transcribe(self, audio, init_prompt=""):
@@ -89,13 +83,10 @@
 
     sep = ""
 
-    def import_backend(self):
-        global faster_whisper
-        import faster_whisper
-
     def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
+        from faster_whisper import WhisperModel
         if model_dir is not None:
-            print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.output)
+            print(f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.",file=self.logfile)
             model_size_or_path = model_dir
         elif modelsize is not None:
             model_size_or_path = modelsize
@@ -143,7 +134,7 @@
 
 class HypothesisBuffer:
 
-    def __init__(self, output=sys.stderr):
+    def __init__(self, logfile=sys.stderr):
         """output: where to store the log. Leave it unchanged to print to terminal."""
         self.commited_in_buffer = []
         self.buffer = []
@@ -152,7 +143,7 @@
         self.last_commited_time = 0
         self.last_commited_word = None
 
-        self.output = output
+        self.logfile = logfile
 
     def insert(self, new, offset):
         # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
@@ -172,9 +163,9 @@
                         c = " ".join([self.commited_in_buffer[-j][2] for j in range(1,i+1)][::-1])
                         tail = " ".join(self.new[j-1][2] for j in range(1,i+1))
                         if c == tail:
-                            print("removing last",i,"words:",file=self.output)
+                            print("removing last",i,"words:",file=self.logfile)
                             for j in range(i):
-                                print("\t",self.new.pop(0),file=self.output)
+                                print("\t",self.new.pop(0),file=self.logfile)
                             break
 
     def flush(self):
@@ -211,14 +202,14 @@
 
     SAMPLING_RATE = 16000
 
-    def __init__(self, asr, tokenizer, output=sys.stderr):
+    def __init__(self, asr, tokenizer, logfile=sys.stderr):
         """asr: WhisperASR object
         tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer.
         output: where to store the log. Leave it unchanged to print to terminal.
         """
         self.asr = asr
         self.tokenizer = tokenizer
-        self.output = output
+        self.logfile = logfile
 
         self.init()
 
@@ -227,7 +218,7 @@
         self.audio_buffer = np.array([],dtype=np.float32)
         self.buffer_time_offset = 0
 
-        self.transcript_buffer = HypothesisBuffer(output=self.output)
+        self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
         self.commited = []
         self.last_chunked_at = 0
 
@@ -262,9 +253,9 @@
         """
 
         prompt, non_prompt = self.prompt()
-        print("PROMPT:", prompt, file=self.output)
-        print("CONTEXT:", non_prompt, file=self.output)
-        print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.output)
+        print("PROMPT:", prompt, file=self.logfile)
+        print("CONTEXT:", non_prompt, file=self.logfile)
+        print(f"transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}",file=self.logfile)
         res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
 
         # transform to [(beg,end,"word1"), ...]
@@ -273,8 +264,8 @@
         self.transcript_buffer.insert(tsw, self.buffer_time_offset)
         o = self.transcript_buffer.flush()
         self.commited.extend(o)
-        print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.output,flush=True)
-        print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.output,flush=True)
+        print(">>>>COMPLETE NOW:",self.to_flush(o),file=self.logfile,flush=True)
+        print("INCOMPLETE:",self.to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True)
 
         # there is a newly confirmed text
         if o:
@@ -293,14 +284,14 @@
 #        elif self.transcript_buffer.complete():
 #            self.silence_iters = 0
 #        elif not self.transcript_buffer.complete():
-#        #    print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=self.output,flush=True)
+#        #    print("NOT COMPLETE:",to_flush(self.transcript_buffer.complete()),file=self.logfile,flush=True)
 #            self.silence_iters += 1
 #            if self.silence_iters >= 3:
 #                n = self.last_chunked_at
 ##                self.chunk_completed_sentence()
 ##                if n == self.last_chunked_at:
 #                self.chunk_at(self.last_chunked_at+self.chunk)
-#                print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=self.output)
+#                print(f"\tCHUNK: 3-times silence! chunk_at {n}+{self.chunk}",file=self.logfile)
 ##                self.silence_iters = 0
 
 
@@ -316,18 +307,18 @@
             #while k>0 and self.commited[k][1] > l:
             #    k -= 1
             #t = self.commited[k][1] 
-            print(f"chunking because of len",file=self.output)
+            print(f"chunking because of len",file=self.logfile)
             #self.chunk_at(t)
 
-        print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.output)
+        print(f"len of buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:2.2f}",file=self.logfile)
         return self.to_flush(o)
 
     def chunk_completed_sentence(self):
         if self.commited == []: return
-        print(self.commited,file=self.output)
+        print(self.commited,file=self.logfile)
         sents = self.words_to_sentences(self.commited)
         for s in sents:
-            print("\t\tSENT:",s,file=self.output)
+            print("\t\tSENT:",s,file=self.logfile)
         if len(sents) < 2:
             return
         while len(sents) > 2:
@@ -335,7 +326,7 @@
         # we will continue with audio processing at this timestamp
         chunk_at = sents[-2][1]
 
-        print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.output)
+        print(f"--- sentence chunked at {chunk_at:2.2f}",file=self.logfile)
         self.chunk_at(chunk_at)
 
     def chunk_completed_segment(self, res):
@@ -352,12 +343,12 @@
                 ends.pop(-1)
                 e = ends[-2]+self.buffer_time_offset
             if e <= t:
-                print(f"--- segment chunked at {e:2.2f}",file=self.output)
+                print(f"--- segment chunked at {e:2.2f}",file=self.logfile)
                 self.chunk_at(e)
             else:
-                print(f"--- last segment not within commited area",file=self.output)
+                print(f"--- last segment not within commited area",file=self.logfile)
         else:
-            print(f"--- not enough segments to chunk",file=self.output)
+            print(f"--- not enough segments to chunk",file=self.logfile)
 
 
 
@@ -403,7 +394,7 @@
         """
         o = self.transcript_buffer.complete()
         f = self.to_flush(o)
-        print("last, noncommited:",f,file=self.output)
+        print("last, noncommited:",f,file=self.logfile)
         return f
 
 
Add a comment
List