Dominik Macháček 2023-04-19
faster-whisper support
@3605f32ffc97618d1bfb34062571462b4acc607b
README.md
--- README.md
+++ README.md
@@ -3,19 +3,24 @@
 
 ## Installation
 
+This code work with two kinds of backends. Both require
+
 ```
-pip install git+https://github.com/linto-ai/whisper-timestamped
-XDG_CACHE_HOME=$(pwd)/pip-cache pip install git+https://github.com/linto-ai/whisper-timestamped
 pip install librosa
 pip install opus-fast-mosestokenizer
-pip install torch
 ```
+
+The most recommended backend is [faster-whisper](https://github.com/guillaumekln/faster-whisper) with GPU support. Follow their instructions for NVIDIA libraries -- we succeeded with CUDNN 8.5.0 and CUDA 11.7. Install with `pip install faster-whisper`.
+
+Alternative, less restrictive, but slowe backend is [whisper-timestamped](https://github.com/linto-ai/whisper-timestamped): `pip install git+https://github.com/linto-ai/whisper-timestamped`
+
+The backend is loaded only when chosen. The unused one does not have to be installed.
 
 ## Usage
 
 ```
 (p3) $ python3 whisper_online.py -h
-usage: whisper_online.py [-h] [--min-chunk-size MIN_CHUNK_SIZE] [--model MODEL] [--model_dir MODEL_DIR] [--lan LAN] [--start_at START_AT] audio_path
+usage: whisper_online.py [-h] [--min-chunk-size MIN_CHUNK_SIZE] [--model MODEL] [--model_dir MODEL_DIR] [--lan LAN] [--start_at START_AT] [--backend {faster-whisper,whisper_timestamped}] audio_path
 
 positional arguments:
   audio_path
@@ -30,6 +35,8 @@
   --lan LAN, --language LAN
                         Language code for transcription, e.g. en,de,cs.
   --start_at START_AT   Start processing audio at this time.
+  --backend {faster-whisper,whisper_timestamped}
+                        Load only this backend for Whisper processing.
 ```
 
 Example:
whisper_online.py
--- whisper_online.py
+++ whisper_online.py
@@ -1,15 +1,10 @@
 #!/usr/bin/env python3
 import sys
 import numpy as np
-import whisper
-import whisper_timestamped
-import librosa
+import librosa  
 from functools import lru_cache
-import torch
 import time
 from mosestokenizer import MosesTokenizer
-import json
-
 
 @lru_cache
 def load_audio(fname):
@@ -22,10 +17,38 @@
     end_s = int(end*16000)
     return audio[beg_s:end_s]
 
-class WhisperASR:
-    def __init__(self, modelsize="small", lan="en", cache_dir="disk-cache-dir"):
+
+# Whisper backend
+
+class ASRBase:
+
+    def __init__(self, modelsize, lan, cache_dir):
         self.original_language = lan 
-        self.model = whisper.load_model(modelsize, download_root=cache_dir)
+
+        self.model = self.load_model(modelsize, cache_dir)
+
+    def load_model(self, modelsize, cache_dir):
+        raise NotImplemented("mus be implemented in the child class")
+
+    def transcribe(self, audio, init_prompt=""):
+        raise NotImplemented("mus be implemented in the child class")
+
+
+## requires imports:
+#      import whisper
+#      import whisper_timestamped
+
+class WhisperTimestampedASR(ASRBase):
+    """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
+    On the other hand, the installation for GPU could be easier.
+
+    If used, requires imports:
+        import whisper
+        import whisper_timestamped
+    """
+
+    def load_model(self, modelsize, cache_dir):
+        return whisper.load_model(modelsize, download_root=cache_dir)
 
     def transcribe(self, audio, init_prompt=""):
         result = whisper_timestamped.transcribe_timestamped(self.model, audio, language=self.original_language, initial_prompt=init_prompt, verbose=None, condition_on_previous_text=True)
@@ -39,6 +62,52 @@
                 t = (w["start"],w["end"],w["text"])
                 o.append(t)
         return o
+
+    def segments_end_ts(self, res):
+        return [s["end"] for s in res["segments"]]
+
+
+class FasterWhisperASR(ASRBase):
+    """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
+
+    Requires imports, if used:
+        import faster_whisper
+    """
+
+    def load_model(self, modelsize, cache_dir):
+        # cache_dir is not set, it seemed not working. Default ~/.cache/huggingface/hub is used.
+
+        # this worked fast and reliably on NVIDIA L40
+        model = WhisperModel(modelsize, device="cuda", compute_type="float16")
+
+        # or run on GPU with INT8
+        # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
+        #model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
+
+        # or run on CPU with INT8
+        # tested: works, but slow, appx 10-times than cuda FP16
+        #model = WhisperModel(model_size, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
+        return model
+
+    def transcribe(self, audio, init_prompt=""):
+        wt = False
+        segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True)
+        return list(segments)
+
+    def ts_words(self, segments):
+        o = []
+        for segment in segments:
+            for word in segment.words:
+                # stripping the spaces
+                w = word.word.strip()
+                t = (word.start, word.end, w)
+                o.append(t)
+        return o
+
+    def segments_end_ts(self, res):
+        return [s.end for s in res]
+
+
 
 def to_flush(sents, offset=0):
     # concatenates the timestamped words or sentences into one sequence that is flushed in one line
@@ -253,7 +322,7 @@
     def chunk_completed_segment(self, res):
         if self.commited == []: return
 
-        ends = [s["end"] for s in res["segments"]]
+        ends = self.asr.segments_end_ts(res)
 
         t = self.commited[-1][1]
 
@@ -320,6 +389,7 @@
 
 
 
+
 ## main:
 
 import argparse
@@ -330,6 +400,7 @@
 parser.add_argument('--model_dir', type=str, default='disk-cache-dir', help="the path where Whisper models are saved (or downloaded to). Default: ./disk-cache-dir")
 parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.")
 parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
+parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.')
 args = parser.parse_args()
 
 audio_path = args.audio_path
@@ -343,7 +414,18 @@
 
 t = time.time()
 print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True)
-asr = WhisperASR(lan=language, modelsize=size)
+#asr = WhisperASR(lan=language, modelsize=size)
+
+if args.backend == "faster-whisper":
+    from faster_whisper import WhisperModel
+    asr_cls = FasterWhisperASR
+else:
+    import whisper
+    import whisper_timestamped
+#    from whisper_timestamped_model import WhisperTimestampedASR
+    asr_cls = WhisperTimestampedASR
+
+asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_dir)
 e = time.time()
 print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr)
 
Add a comment
List