

Fix crash when using openai-api with whisper_online_server + refactored creation of the ASR into a factory method
+ refactored creation of the ASR into a factory method
@2d56ec87bf3a3b85cbf1f5a815fa68d7f49cf264
--- whisper_online.py
+++ whisper_online.py
... | ... | @@ -548,6 +548,37 @@ |
548 | 548 |
parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.') |
549 | 549 |
parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.') |
550 | 550 |
|
551 |
+def asr_factory(args, logfile=sys.stderr): |
|
552 |
+ """ |
|
553 |
+ Creates and configures an ASR instance based on the specified backend and arguments. |
|
554 |
+ """ |
|
555 |
+ backend = args.backend |
|
556 |
+ if backend == "openai-api": |
|
557 |
+ print("Using OpenAI API.", file=logfile) |
|
558 |
+ asr = OpenaiApiASR(lan=args.lan) |
|
559 |
+ else: |
|
560 |
+ if backend == "faster-whisper": |
|
561 |
+ from faster_whisper import FasterWhisperASR |
|
562 |
+ asr_cls = FasterWhisperASR |
|
563 |
+ else: |
|
564 |
+ from whisper_timestamped import WhisperTimestampedASR |
|
565 |
+ asr_cls = WhisperTimestampedASR |
|
566 |
+ |
|
567 |
+ # Only for FasterWhisperASR and WhisperTimestampedASR |
|
568 |
+ size = args.model |
|
569 |
+ t = time.time() |
|
570 |
+ print(f"Loading Whisper {size} model for {args.lan}...", file=logfile, end=" ", flush=True) |
|
571 |
+ asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir) |
|
572 |
+ e = time.time() |
|
573 |
+ print(f"done. It took {round(e-t,2)} seconds.", file=logfile) |
|
574 |
+ |
|
575 |
+ # Apply common configurations |
|
576 |
+ if getattr(args, 'vad', False): # Checks if VAD argument is present and True |
|
577 |
+ print("Setting VAD filter", file=logfile) |
|
578 |
+ asr.use_vad() |
|
579 |
+ |
|
580 |
+ return asr |
|
581 |
+ |
|
551 | 582 |
## main: |
552 | 583 |
|
553 | 584 |
if __name__ == "__main__": |
... | ... | @@ -575,28 +606,8 @@ |
575 | 606 |
duration = len(load_audio(audio_path))/SAMPLING_RATE |
576 | 607 |
print("Audio duration is: %2.2f seconds" % duration, file=logfile) |
577 | 608 |
|
609 |
+ asr = asr_factory(args, logfile=logfile) |
|
578 | 610 |
language = args.lan |
579 |
- |
|
580 |
- if args.backend == "openai-api": |
|
581 |
- print("Using OpenAI API.",file=logfile) |
|
582 |
- asr = OpenaiApiASR(lan=language) |
|
583 |
- else: |
|
584 |
- if args.backend == "faster-whisper": |
|
585 |
- asr_cls = FasterWhisperASR |
|
586 |
- else: |
|
587 |
- asr_cls = WhisperTimestampedASR |
|
588 |
- |
|
589 |
- size = args.model |
|
590 |
- t = time.time() |
|
591 |
- print(f"Loading Whisper {size} model for {language}...",file=logfile,end=" ",flush=True) |
|
592 |
- asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir) |
|
593 |
- e = time.time() |
|
594 |
- print(f"done. It took {round(e-t,2)} seconds.",file=logfile) |
|
595 |
- |
|
596 |
- if args.vad: |
|
597 |
- print("setting VAD filter",file=logfile) |
|
598 |
- asr.use_vad() |
|
599 |
- |
|
600 | 611 |
if args.task == "translate": |
601 | 612 |
asr.set_translate_task() |
602 | 613 |
tgt_language = "en" # Whisper translates into English |
--- whisper_online_server.py
+++ whisper_online_server.py
... | ... | @@ -24,35 +24,12 @@ |
24 | 24 |
size = args.model |
25 | 25 |
language = args.lan |
26 | 26 |
|
27 |
-t = time.time() |
|
28 |
-print(f"Loading Whisper {size} model for {language}...",file=sys.stderr,end=" ",flush=True) |
|
29 |
- |
|
30 |
-if args.backend == "faster-whisper": |
|
31 |
- from faster_whisper import WhisperModel |
|
32 |
- asr_cls = FasterWhisperASR |
|
33 |
-elif args.backend == "openai-api": |
|
34 |
- asr_cls = OpenaiApiASR |
|
35 |
-else: |
|
36 |
- import whisper |
|
37 |
- import whisper_timestamped |
|
38 |
-# from whisper_timestamped_model import WhisperTimestampedASR |
|
39 |
- asr_cls = WhisperTimestampedASR |
|
40 |
- |
|
41 |
-asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir) |
|
42 |
- |
|
27 |
+asr = asr_factory(args) |
|
43 | 28 |
if args.task == "translate": |
44 | 29 |
asr.set_translate_task() |
45 | 30 |
tgt_language = "en" |
46 | 31 |
else: |
47 | 32 |
tgt_language = language |
48 |
- |
|
49 |
-e = time.time() |
|
50 |
-print(f"done. It took {round(e-t,2)} seconds.",file=sys.stderr) |
|
51 |
- |
|
52 |
-if args.vad: |
|
53 |
- print("setting VAD filter",file=sys.stderr) |
|
54 |
- asr.use_vad() |
|
55 |
- |
|
56 | 33 |
|
57 | 34 |
min_chunk = args.min_chunk_size |
58 | 35 |
|
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?