Dominik Macháček 2024-10-05
FixedSileroVADIterator to support other than 512-sized chunks with v5
isssue #116
@c5798e1ca8b6558b1453cc7b0f2e203f756ac2d0
silero_vad.py
--- silero_vad.py
+++ silero_vad.py
@@ -94,4 +94,41 @@
 
         return None
 
+#######################
+# this is our workaround for Silero v5 requiring at least 512-sized audio chunks 
+# (see https://github.com/ufal/whisper_streaming/issues/116 )
 
+import numpy as np
+class FixedVADIterator(VADIterator):
+
+    def reset_states(self):
+        super().reset_states()
+        self.buffer = np.array([],dtype=np.float32)
+
+    def __call__(self, x, return_seconds=False):
+        self.buffer = np.append(self.buffer, x) 
+        if len(self.buffer) >= 512:
+            ret = super().__call__(self.buffer, return_seconds=return_seconds)
+            self.buffer = np.array([],dtype=np.float32)
+            return ret
+        return None
+
+if __name__ == "__main__":
+    # test/demonstrate the need for FixedVADIterator:
+
+    import torch
+    model, _ = torch.hub.load(
+        repo_or_dir='snakers4/silero-vad',
+        model='silero_vad'
+    )
+    vac = FixedVADIterator(model)
+#   vac = VADIterator(model)  # the second case crashes with this
+
+    # this works: for both
+    audio_buffer = np.array([0]*(512),dtype=np.float32)
+    vac(audio_buffer)
+
+    # this crashes on the non FixedVADIterator with 
+    # ops.prim.RaiseException("Input audio chunk is too short", "builtins.ValueError")
+    audio_buffer = np.array([0]*(512-1),dtype=np.float32)
+    vac(audio_buffer)
whisper_online.py
--- whisper_online.py
+++ whisper_online.py
@@ -531,7 +531,7 @@
         # VAC:
         import torch
         model, _ = torch.hub.load(
-            repo_or_dir='snakers4/silero-vad:v4.0',
+            repo_or_dir='snakers4/silero-vad',
             model='silero_vad'
         )
         from silero_vad import VADIterator
Add a comment
List