Tijs Zwinkels 2024-02-10
Make --vad work with --backend openai-api
@5da3267add56cdd63aaef11cee53b508ec95a4be
whisper_online.py
--- whisper_online.py
+++ whisper_online.py
@@ -162,7 +162,7 @@
 
         self.load_model()
 
-        self.use_vad = False
+        self.use_vad_opt = False
 
         # reset the task in set_translate_task
         self.task = "transcribe"
@@ -175,21 +175,27 @@
         
 
     def ts_words(self, segments):
-        o = []
-        # If VAD on, skip segments containing no speech. 
-        # TODO: threshold can be set from outside
-        # TODO: Make VAD work again with word-level timestamps
-        #if self.use_vad and segment["no_speech_prob"] > 0.8:
-        #    continue
+        no_speech_segments = []
+        if self.use_vad_opt:
+            for segment in segments.segments:
+                # TODO: threshold can be set from outside
+                if segment["no_speech_prob"] > 0.8:
+                    no_speech_segments.append((segment.get("start"), segment.get("end")))
 
-        for word in segments:
-            o.append((word.get("start"), word.get("end"), word.get("word")))
+        o = []
+        for word in segments.words:
+            start = word.get("start")
+            end = word.get("end")
+            if any(s[0] <= start <= s[1] for s in no_speech_segments):
+                # print("Skipping word", word.get("word"), "because it's in a no-speech segment")
+                continue
+            o.append((start, end, word.get("word")))
 
         return o
 
 
     def segments_end_ts(self, res):
-        return [s["end"] for s in res]
+        return [s["end"] for s in res.words]
 
     def transcribe(self, audio_data, prompt=None, *args, **kwargs):
         # Write the audio data to a buffer
@@ -205,7 +211,7 @@
             "file": buffer,
             "response_format": self.response_format,
             "temperature": self.temperature,
-            "timestamp_granularities": ["word"]
+            "timestamp_granularities": ["word", "segment"]
         }
         if self.task != "translate" and self.language:
             params["language"] = self.language
@@ -221,10 +227,10 @@
         transcript = proc.create(**params)
         print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds",file=self.logfile)
 
-        return transcript.words
+        return transcript
 
     def use_vad(self):
-        self.use_vad = True
+        self.use_vad_opt = True
 
     def set_translate_task(self):
         self.task = "translate"
@@ -592,9 +598,9 @@
         e = time.time()
         print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
 
-        if args.vad:
-            print("setting VAD filter",file=logfile)
-            asr.use_vad()
+    if args.vad:
+        print("setting VAD filter",file=logfile)
+        asr.use_vad()
 
     if args.task == "translate":
         asr.set_translate_task()
Add a comment
List