Dominik Macháček 2024-01-26
missing features in openai-api, PR #52
@997a653f425fae2f4ada5664fffbd61376ae386f
whisper_online.py
--- whisper_online.py
+++ whisper_online.py
@@ -6,8 +6,7 @@
 import time
 import io
 import soundfile as sf
-
-
+import math
 
 @lru_cache
 def load_audio(fname):
@@ -153,24 +152,34 @@
 class OpenaiApiASR(ASRBase):
     """Uses OpenAI's Whisper API for audio transcription."""
 
-    def __init__(self, modelsize=None, lan=None, cache_dir=None, model_dir=None, response_format="verbose_json", temperature=0):
-        self.modelname = "whisper-1"  # modelsize is not used but kept for interface consistency
+    def __init__(self, lan=None, response_format="verbose_json", temperature=0, logfile=sys.stderr):
+        self.logfile = logfile
+
+        self.modelname = "whisper-1"  
         self.language = lan  # ISO-639-1 language code
         self.response_format = response_format
         self.temperature = temperature
-        self.model = self.load_model(modelsize, cache_dir, model_dir)
+
+        self.load_model()
+
+        self.use_vad = False
+
+        # reset the task in set_translate_task
+        self.task = "transcribe"
 
     def load_model(self, *args, **kwargs):
         from openai import OpenAI
         self.client = OpenAI()
-        # Since we're using the OpenAI API, there's no model to load locally.
-        print("Model configuration is set to use the OpenAI Whisper API.")
+
+        self.transcribed_seconds = 0  # for logging how many seconds were processed by API, to know the cost
+        
 
     def ts_words(self, segments):
         o = []
         for segment in segments:
-            # Skip segments containing no speech
-            if segment["no_speech_prob"] > 0.8:
+            # If VAD on, skip segments containing no speech. 
+            # TODO: threshold can be set from outside
+            if self.use_vad and segment["no_speech_prob"] > 0.8:
                 continue
 
             # Splitting the text into words and filtering out empty strings
@@ -203,22 +212,38 @@
         sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
         buffer.seek(0)  # Reset buffer's position to the beginning
 
-        # Prepare transcription parameters
-        transcription_params = {
+        self.transcribed_seconds += math.ceil(len(audio_data)/16000)  # it rounds up to the whole seconds
+
+        params = {
             "model": self.modelname,
             "file": buffer,
             "response_format": self.response_format,
             "temperature": self.temperature
         }
-        if self.language:
+        if self.task != "translate" and self.language:
             transcription_params["language"] = self.language
         if prompt:
             transcription_params["prompt"] = prompt
 
-        # Perform the transcription
-        transcript = self.client.audio.transcriptions.create(**transcription_params)
+        if self.task == "translate":
+            proc = self.client.audio.translations
+        else:
+            proc = self.client.audio.transcriptions
+
+        # Process transcription/translation
+
+        transcript = proc.create(**params)
+        print(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds ",file=self.logfile)
 
         return transcript.segments
+
+    def use_vad(self):
+        self.use_vad = True
+
+    def set_translate_task(self):
+        self.task = "translate"
+
+
 
 
 class HypothesisBuffer:
@@ -563,34 +588,33 @@
     duration = len(load_audio(audio_path))/SAMPLING_RATE
     print("Audio duration is: %2.2f seconds" % duration, file=logfile)
 
-    size = args.model
     language = args.lan
 
-    t = time.time()
-    print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
-
-    if args.backend == "faster-whisper":
-        asr_cls = FasterWhisperASR
-    elif args.backend == "openai-api":
-        asr_cls = OpenaiApiASR
+    if args.backend == "openai-api":
+        print("Using OpenAI API.",file=logfile)
+        asr = OpenaiApiASR(lan=language)
     else:
-        asr_cls = WhisperTimestampedASR
+        if args.backend == "faster-whisper":
+            asr_cls = FasterWhisperASR
+        else:
+            asr_cls = WhisperTimestampedASR
 
-    asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
+        size = args.model
+        t = time.time()
+        print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True)
+        asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
+        e = time.time()
+        print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
+
+        if args.vad:
+            print("setting VAD filter",file=logfile)
+            asr.use_vad()
 
     if args.task == "translate":
         asr.set_translate_task()
         tgt_language = "en"  # Whisper translates into English
     else:
         tgt_language = language  # Whisper transcribes in this language
-
-
-    e = time.time()
-    print(f"done. It took {round(e-t,2)} seconds.",file=logfile)
-
-    if args.vad:
-        print("setting VAD filter",file=logfile)
-        asr.use_vad()
 
     
     min_chunk = args.min_chunk_size
Add a comment
List