Dominik Macháček 2024-08-19
clean code with VAC
@c300a03e47eab86e5de54d816f55fb05c3a15d02
 
silero_vad.py (added)
+++ silero_vad.py
@@ -0,0 +1,95 @@
+import torch
+
+# this is copypasted from silero-vad's vad_utils.py:
+# https://github.com/snakers4/silero-vad/blob/f6b1294cb27590fb2452899df98fb234dfef1134/utils_vad.py#L340
+
+class VADIterator:
+    def __init__(self,
+                 model,
+                 threshold: float = 0.5,
+                 sampling_rate: int = 16000,
+                 min_silence_duration_ms: int = 100,
+                 speech_pad_ms: int = 30
+                 ):
+
+        """
+        Class for stream imitation
+
+        Parameters
+        ----------
+        model: preloaded .jit silero VAD model
+
+        threshold: float (default - 0.5)
+            Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
+            It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
+
+        sampling_rate: int (default - 16000)
+            Currently silero VAD models support 8000 and 16000 sample rates
+
+        min_silence_duration_ms: int (default - 100 milliseconds)
+            In the end of each speech chunk wait for min_silence_duration_ms before separating it
+
+        speech_pad_ms: int (default - 30 milliseconds)
+            Final speech chunks are padded by speech_pad_ms each side
+        """
+
+        self.model = model
+        self.threshold = threshold
+        self.sampling_rate = sampling_rate
+
+        if sampling_rate not in [8000, 16000]:
+            raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
+
+        self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
+        self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
+        self.reset_states()
+
+    def reset_states(self):
+
+        self.model.reset_states()
+        self.triggered = False
+        self.temp_end = 0
+        self.current_sample = 0
+
+    def __call__(self, x, return_seconds=False):
+        """
+        x: torch.Tensor
+            audio chunk (see examples in repo)
+
+        return_seconds: bool (default - False)
+            whether return timestamps in seconds (default - samples)
+        """
+
+        if not torch.is_tensor(x):
+            try:
+                x = torch.Tensor(x)
+            except:
+                raise TypeError("Audio cannot be casted to tensor. Cast it manually")
+
+        window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
+        self.current_sample += window_size_samples
+
+        speech_prob = self.model(x, self.sampling_rate).item()
+
+        if (speech_prob >= self.threshold) and self.temp_end:
+            self.temp_end = 0
+
+        if (speech_prob >= self.threshold) and not self.triggered:
+            self.triggered = True
+            speech_start = self.current_sample - self.speech_pad_samples
+            return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
+
+        if (speech_prob < self.threshold - 0.15) and self.triggered:
+            if not self.temp_end:
+                self.temp_end = self.current_sample
+            if self.current_sample - self.temp_end < self.min_silence_samples:
+                return None
+            else:
+                speech_end = self.temp_end + self.speech_pad_samples
+                self.temp_end = 0
+                self.triggered = False
+                return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
+
+        return None
+
+
voice_activity_controller.py
--- voice_activity_controller.py
+++ voice_activity_controller.py
@@ -1,111 +1,35 @@
 import torch
-import numpy as np
+from silero_vad import VADIterator
+import time
 
 class VoiceActivityController:
-    def __init__(
-            self, 
-            sampling_rate = 16000,
-            min_silence_to_final_ms = 500,
-            min_speech_to_final_ms = 100,
-            min_silence_duration_ms = 100,
-            use_vad_result = True,
-#            activity_detected_callback=None,
-            threshold =0.3
-        ):
-#        self.activity_detected_callback=activity_detected_callback
-        self.model, self.utils = torch.hub.load(
+    SAMPLING_RATE = 16000
+    def __init__(self):
+        self.model, _ = torch.hub.load(
             repo_or_dir='snakers4/silero-vad',
             model='silero_vad'
         )
-        # (self.get_speech_timestamps,
-        # save_audio,
-        # read_audio,
-        # VADIterator,
-        # collect_chunks) = self.utils
+        # we use the default options: 500ms silence, etc.
+        self.iterator = VADIterator(self.model)
 
-        self.sampling_rate = sampling_rate  
-        self.final_silence_limit = min_silence_to_final_ms * self.sampling_rate / 1000 
-        self.final_speech_limit = min_speech_to_final_ms *self.sampling_rate / 1000
-        self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
+    def reset(self):
+        self.iterator.reset_states()
 
-        self.use_vad_result = use_vad_result
-        self.threshold = threshold
-        self.reset_states()
-
-    def reset_states(self):
-        self.model.reset_states()
-        self.temp_end = 0
-        self.current_sample = 0
-
-        self.last_silence_len= 0
-        self.speech_len = 0
-
-    def apply_vad(self, audio):
-        """
-        returns: triple
-            (voice_audio,
-            speech_in_wav,
-            silence_in_wav)
-
-        """
-        print("applying vad here")
+    def __call__(self, audio):
+        '''
+        audio: audio chunk in the current np.array format
+        returns: 
+        - { 'start': time_frame } ... when voice start was detected. time_frame is number of frame (can be converted to seconds)
+        - { 'end': time_frame }   ... when voice end is detected
+        - None                    ... when no change detected by current chunk 
+        '''
         x = audio
-        if not torch.is_tensor(x):
-            try:
-                x = torch.Tensor(x)
-            except:
-                raise TypeError("Audio cannot be casted to tensor. Cast it manually")
-
-        speech_prob = self.model(x, self.sampling_rate).item()
-        print("speech_prob",speech_prob)
-        
-        window_size_samples = len(x[0]) if x.dim() == 2 else len(x)
-        self.current_sample += window_size_samples 
-
-        if speech_prob >= self.threshold:  # speech is detected
-            self.temp_end = 0
-            return audio, window_size_samples, 0
-
-        else:  # silence detected, counting w
-            if not self.temp_end:
-                self.temp_end = self.current_sample
-
-            if self.current_sample - self.temp_end < self.min_silence_samples:
-                return audio, 0, window_size_samples
-            else:
-                return np.array([], dtype=np.float16) if self.use_vad_result else audio, 0, window_size_samples
-
-
-    def detect_speech_iter(self, data, audio_in_int16 = False):
-        audio_block = data
-        wav = audio_block
-
-        is_final = False
-        voice_audio, speech_in_wav, last_silent_in_wav = self.apply_vad(wav)
-        print("speech, last silence",speech_in_wav, last_silent_in_wav)
-
-
-        if speech_in_wav > 0 :
-            self.last_silence_len= 0                
-            self.speech_len += speech_in_wav
-#            if self.activity_detected_callback is not None:
-#                self.activity_detected_callback()
-
-        self.last_silence_len +=  last_silent_in_wav
-        print("self.last_silence_len",self.last_silence_len, self.final_silence_limit,self.last_silence_len>= self.final_silence_limit)
-        print("self.speech_len, final_speech_limit",self.speech_len , self.final_speech_limit,self.speech_len >= self.final_speech_limit)
-        if self.last_silence_len>= self.final_silence_limit and self.speech_len >= self.final_speech_limit:
-            for i in range(10): print("TADY!!!")
-
-            is_final = True
-            self.last_silence_len= 0
-            self.speech_len = 0                
-
-        return voice_audio, is_final
-
-    def detect_user_speech(self, audio_stream, audio_in_int16 = False):
-        self.last_silence_len= 0
-        self.speech_len = 0
-
-        for data in audio_stream:  # replace with your condition of choice
-            yield self.detect_speech_iter(data, audio_in_int16)
+#        if not torch.is_tensor(x):
+#            try:
+#                x = torch.Tensor(x)
+#            except:
+#                raise TypeError("Audio cannot be casted to tensor. Cast it manually")
+        t = time.time()
+        a = self.iterator(x)
+        print("VAD took ",time.time()-t,"seconds")
+        return a
whisper_online.py
--- whisper_online.py
+++ whisper_online.py
@@ -331,16 +331,14 @@
 
         self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
 
-    def init(self, keep_offset=False):
+    def init(self, offset=None):
         """run this when starting or restarting processing"""
         self.audio_buffer = np.array([],dtype=np.float32)
         self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
-        if not keep_offset:
-            self.buffer_time_offset = 0
-            self.transcript_buffer.last_commited_time = 0
-        else:
-            self.transcript_buffer.last_commited_time = self.buffer_time_offset
-
+        self.buffer_time_offset = 0
+        if offset is not None:
+            self.buffer_time_offset = offset
+        self.transcript_buffer.last_commited_time = self.buffer_time_offset
         self.commited = []
 
     def insert_audio_chunk(self, audio):
@@ -529,27 +527,71 @@
         self.online_chunk_size = online_chunk_size
 
         self.online = OnlineASRProcessor(*a, **kw)
-        from voice_activity_controller import VoiceActivityController
-        self.vac = VoiceActivityController(use_vad_result = False)
+
+        # VAC:
+        import torch
+        model, _ = torch.hub.load(
+            repo_or_dir='snakers4/silero-vad',
+            model='silero_vad'
+        )
+        from silero_vad import VADIterator
+        self.vac = VADIterator(model)  # we use all the default options: 500ms silence, etc.  
 
         self.logfile = self.online.logfile
-
         self.init()
 
     def init(self):
         self.online.init()
         self.vac.reset_states()
         self.current_online_chunk_buffer_size = 0
+
         self.is_currently_final = False
+
+        self.status = None  # or "voice" or "nonvoice"
+        self.audio_buffer = np.array([],dtype=np.float32)
+        self.buffer_offset = 0  # in frames
+
+    def clear_buffer(self):
+        self.buffer_offset += len(self.audio_buffer)
+        self.audio_buffer = np.array([],dtype=np.float32)
 
 
     def insert_audio_chunk(self, audio):
-        r = self.vac.detect_speech_iter(audio,audio_in_int16=False)
-        audio, is_final = r
-        print(is_final)
-        self.is_currently_final = is_final
-        self.online.insert_audio_chunk(audio)
-        self.current_online_chunk_buffer_size += len(audio)
+        res = self.vac(audio)
+        print(res)
+        self.audio_buffer = np.append(self.audio_buffer, audio)
+
+        if res is not None:
+            frame = list(res.values())[0]
+            if 'start' in res and 'end' not in res:
+                self.status = 'voice'
+                send_audio = self.audio_buffer[frame-self.buffer_offset:]
+                self.online.init(offset=frame/self.SAMPLING_RATE)
+                self.online.insert_audio_chunk(send_audio)
+                self.current_online_chunk_buffer_size += len(send_audio)
+                self.clear_buffer()
+            elif 'end' in res and 'start' not in res:
+                self.status = 'nonvoice'
+                send_audio = self.audio_buffer[:frame-self.buffer_offset]
+                self.online.insert_audio_chunk(send_audio)
+                self.current_online_chunk_buffer_size += len(send_audio)
+                self.is_currently_final = True
+                self.clear_buffer()
+            else:
+                # It doesn't happen in the current code.
+                raise NotImplemented("both start and end of voice in one chunk!!!")
+        else:
+            if self.status == 'voice':
+                self.online.insert_audio_chunk(self.audio_buffer)
+                self.current_online_chunk_buffer_size += len(self.audio_buffer)
+            if self.status is not None:
+                self.clear_buffer()
+            else:  # we are at the beginning of process, no voice has ever been detected
+                # We keep the 1s because VAD may later find start of voice in it.
+                # But trimming it to prevent OOM. 
+                self.buffer_offset += max(0,len(self.audio_buffer)-self.SAMPLING_RATE)
+                self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
+
 
     def process_iter(self):
         if self.is_currently_final:
@@ -559,13 +601,13 @@
             ret = self.online.process_iter()
             return ret
         else:
-            print("no online update, only VAD", file=self.logfile)
+            print("no online update, only VAD", self.status, file=self.logfile)
             return (None, None, "")
 
     def finish(self):
         ret = self.online.finish()
-        self.online.init(keep_offset=True)
         self.current_online_chunk_buffer_size = 0
+        self.is_currently_final = False
         return ret
 
 
Add a comment
List