

Merge branch 'main' into ayo-logging-fixes
@6db3f876147c0c9571c8f6d547c593fb3e234428
--- README.md
+++ README.md
... | ... | @@ -3,41 +3,49 @@ |
3 | 3 |
|
4 | 4 |
**Turning Whisper into Real-Time Transcription System** |
5 | 5 |
|
6 |
-Demonstration paper, by Dominik Macháček, Raj Dabre, Ondřej Bojar, 2023 |
|
6 |
+Demonstration paper, by [Dominik Macháček](https://ufal.mff.cuni.cz/dominik-machacek), [Raj Dabre](https://prajdabre.github.io/), [Ondřej Bojar](https://ufal.mff.cuni.cz/ondrej-bojar), 2023 |
|
7 | 7 |
|
8 |
-Abstract: Whisper is one of the recent state-of-the-art multilingual speech recognition and translation models, however, it is not designed for real time transcription. In this paper, we build on top of Whisper and create Whisper-Streaming, an implementation of real-time speech transcription and translation of Whisper-like models. Whisper-Streaming uses local agreement policy with self-adaptive latency to enable streaming transcription. We show that Whisper-Streaming achieves high quality and 3.3 seconds latency on unsegmented long-form speech transcription test set, and we demonstrate its robustness and practical usability as a component in live transcription service at a multilingual conference. |
|
8 |
+Abstract: Whisper is one of the recent state-of-the-art multilingual speech recognition and translation models, however, it is not designed for real-time transcription. In this paper, we build on top of Whisper and create Whisper-Streaming, an implementation of real-time speech transcription and translation of Whisper-like models. Whisper-Streaming uses local agreement policy with self-adaptive latency to enable streaming transcription. We show that Whisper-Streaming achieves high quality and 3.3 seconds latency on unsegmented long-form speech transcription test set, and we demonstrate its robustness and practical usability as a component in live transcription service at a multilingual conference. |
|
9 | 9 |
|
10 | 10 |
|
11 |
-Paper in proceedings: http://www.afnlp.org/conferences/ijcnlp2023/proceedings/main-demo/cdrom/pdf/2023.ijcnlp-demo.3.pdf |
|
12 |
- |
|
13 |
-Demo video: https://player.vimeo.com/video/840442741 |
|
11 |
+[Paper PDF](https://aclanthology.org/2023.ijcnlp-demo.3.pdf), [Demo video](https://player.vimeo.com/video/840442741) |
|
14 | 12 |
|
15 | 13 |
[Slides](http://ufallab.ms.mff.cuni.cz/~machacek/pre-prints/AACL23-2.11.2023-Turning-Whisper-oral.pdf) -- 15 minutes oral presentation at IJCNLP-AACL 2023 |
16 | 14 |
|
17 |
-Please, cite us. [Bibtex citation](http://www.afnlp.org/conferences/ijcnlp2023/proceedings/main-demo/cdrom/bib/2023.ijcnlp-demo.3.bib): |
|
15 |
+Please, cite us. [ACL Anthology](https://aclanthology.org/2023.ijcnlp-demo.3/), [Bibtex citation](https://aclanthology.org/2023.ijcnlp-demo.3.bib): |
|
18 | 16 |
|
19 | 17 |
``` |
20 |
-@InProceedings{machacek-dabre-bojar:2023:ijcnlp, |
|
21 |
- author = {Macháček, Dominik and Dabre, Raj and Bojar, Ondřej}, |
|
22 |
- title = {Turning Whisper into Real-Time Transcription System}, |
|
23 |
- booktitle = {System Demonstrations}, |
|
24 |
- month = {November}, |
|
25 |
- year = {2023}, |
|
26 |
- address = {Bali, Indonesia}, |
|
27 |
- publisher = {Asian Federation of Natural Language Processing}, |
|
28 |
- pages = {17--24}, |
|
18 |
+@inproceedings{machacek-etal-2023-turning, |
|
19 |
+ title = "Turning Whisper into Real-Time Transcription System", |
|
20 |
+ author = "Mach{\'a}{\v{c}}ek, Dominik and |
|
21 |
+ Dabre, Raj and |
|
22 |
+ Bojar, Ond{\v{r}}ej", |
|
23 |
+ editor = "Saha, Sriparna and |
|
24 |
+ Sujaini, Herry", |
|
25 |
+ booktitle = "Proceedings of the 13th International Joint Conference on Natural Language Processing and the 3rd Conference of the Asia-Pacific Chapter of the Association for Computational Linguistics: System Demonstrations", |
|
26 |
+ month = nov, |
|
27 |
+ year = "2023", |
|
28 |
+ address = "Bali, Indonesia", |
|
29 |
+ publisher = "Association for Computational Linguistics", |
|
30 |
+ url = "https://aclanthology.org/2023.ijcnlp-demo.3", |
|
31 |
+ pages = "17--24", |
|
29 | 32 |
} |
30 | 33 |
``` |
31 | 34 |
|
32 | 35 |
## Installation |
33 | 36 |
|
34 |
-1) ``pip install librosa`` -- audio processing library |
|
37 |
+1) ``pip install librosa soundfile`` -- audio processing library |
|
35 | 38 |
|
36 | 39 |
2) Whisper backend. |
37 | 40 |
|
38 |
-Two alternative backends are integrated. The most recommended one is [faster-whisper](https://github.com/guillaumekln/faster-whisper) with GPU support. Follow their instructions for NVIDIA libraries -- we succeeded with CUDNN 8.5.0 and CUDA 11.7. Install with `pip install faster-whisper`. |
|
41 |
+ Several alternative backends are integrated. The most recommended one is [faster-whisper](https://github.com/guillaumekln/faster-whisper) with GPU support. Follow their instructions for NVIDIA libraries -- we succeeded with CUDNN 8.5.0 and CUDA 11.7. Install with `pip install faster-whisper`. |
|
39 | 42 |
|
40 | 43 |
Alternative, less restrictive, but slower backend is [whisper-timestamped](https://github.com/linto-ai/whisper-timestamped): `pip install git+https://github.com/linto-ai/whisper-timestamped` |
44 |
+ |
|
45 |
+Thirdly, it's also possible to run this software from the [OpenAI Whisper API](https://platform.openai.com/docs/api-reference/audio/createTranscription). This solution is fast and requires no GPU, just a small VM will suffice, but you will need to pay OpenAI for api access. Also note that, since each audio fragment is processed multiple times, the [price](https://openai.com/pricing) will be higher than obvious from the pricing page, so keep an eye on costs while using. Setting a higher chunk-size will reduce costs significantly. |
|
46 |
+Install with: `pip install openai` |
|
47 |
+ |
|
48 |
+For running with the openai-api backend, make sure that your [OpenAI api key](https://platform.openai.com/api-keys) is set in the `OPENAI_API_KEY` environment variable. For example, before running, do: `export OPENAI_API_KEY=sk-xxx` with *sk-xxx* replaced with your api key. |
|
41 | 49 |
|
42 | 50 |
The backend is loaded only when chosen. The unused one does not have to be installed. |
43 | 51 |
|
... | ... | @@ -69,7 +77,7 @@ |
69 | 77 |
|
70 | 78 |
``` |
71 | 79 |
usage: whisper_online.py [-h] [--min-chunk-size MIN_CHUNK_SIZE] [--model {tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large}] [--model_cache_dir MODEL_CACHE_DIR] [--model_dir MODEL_DIR] [--lan LAN] [--task {transcribe,translate}] |
72 |
- [--backend {faster-whisper,whisper_timestamped}] [--vad] [--buffer_trimming {sentence,segment}] [--buffer_trimming_sec BUFFER_TRIMMING_SEC] [--start_at START_AT] [--offline] [--comp_unaware] |
|
80 |
+ [--backend {faster-whisper,whisper_timestamped,openai-api}] [--vad] [--buffer_trimming {sentence,segment}] [--buffer_trimming_sec BUFFER_TRIMMING_SEC] [--start_at START_AT] [--offline] [--comp_unaware] |
|
73 | 81 |
audio_path |
74 | 82 |
|
75 | 83 |
positional arguments: |
... | ... | @@ -86,10 +94,10 @@ |
86 | 94 |
--model_dir MODEL_DIR |
87 | 95 |
Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter. |
88 | 96 |
--lan LAN, --language LAN |
89 |
- Language code for transcription, e.g. en,de,cs. |
|
97 |
+ Source language code, e.g. en,de,cs, or 'auto' for language detection. |
|
90 | 98 |
--task {transcribe,translate} |
91 | 99 |
Transcribe or translate. |
92 |
- --backend {faster-whisper,whisper_timestamped} |
|
100 |
+ --backend {faster-whisper,whisper_timestamped,openai-api} |
|
93 | 101 |
Load only this backend for Whisper processing. |
94 | 102 |
--vad Use VAD = voice activity detection, with the default parameters. |
95 | 103 |
--buffer_trimming {sentence,segment} |
... | ... | @@ -147,7 +155,7 @@ |
147 | 155 |
|
148 | 156 |
This pseudocode describes the interface that we suggest for your implementation. You can implement any features that you need for your application. |
149 | 157 |
|
150 |
-``` |
|
158 |
+```python |
|
151 | 159 |
from whisper_online import * |
152 | 160 |
|
153 | 161 |
src_lan = "en" # source language |
... | ... | @@ -216,12 +224,20 @@ |
216 | 224 |
re-process confirmed sentence prefixes and skip them, making sure they don't |
217 | 225 |
overlap, and we limit the processing buffer window. |
218 | 226 |
|
219 |
-Contributions are welcome. |
|
220 |
- |
|
221 | 227 |
### Performance evaluation |
222 | 228 |
|
223 | 229 |
[See the paper.](http://www.afnlp.org/conferences/ijcnlp2023/proceedings/main-demo/cdrom/pdf/2023.ijcnlp-demo.3.pdf) |
224 | 230 |
|
231 |
+### Contributions |
|
232 |
+ |
|
233 |
+Contributions are welcome. We acknowledge especially: |
|
234 |
+ |
|
235 |
+- [The GitHub contributors](https://github.com/ufal/whisper_streaming/graphs/contributors) for their pull requests with new features and bugfixes. |
|
236 |
+- [The translation of this repo into Chinese.](https://github.com/Gloridust/whisper_streaming_CN) |
|
237 |
+- [Ondřej Plátek](https://opla.cz/) for the paper pre-review. |
|
238 |
+- [Peter Polák](https://ufal.mff.cuni.cz/peter-polak) for the original idea. |
|
239 |
+- The UEDIN team of the [ELITR project](https://elitr.eu) for the original line_packet.py. |
|
240 |
+ |
|
225 | 241 |
|
226 | 242 |
## Contact |
227 | 243 |
|
--- whisper_online.py
+++ whisper_online.py
... | ... | @@ -7,10 +7,13 @@ |
7 | 7 |
import logging |
8 | 8 |
|
9 | 9 |
|
10 |
+import io |
|
11 |
+import soundfile as sf |
|
12 |
+import math |
|
10 | 13 |
|
11 | 14 |
@lru_cache |
12 | 15 |
def load_audio(fname): |
13 |
- a, _ = librosa.load(fname, sr=16000) |
|
16 |
+ a, _ = librosa.load(fname, sr=16000, dtype=np.float32) |
|
14 | 17 |
return a |
15 | 18 |
|
16 | 19 |
def load_audio_chunk(fname, beg, end): |
... | ... | @@ -31,7 +34,10 @@ |
31 | 34 |
self.logfile = logfile |
32 | 35 |
|
33 | 36 |
self.transcribe_kargs = {} |
34 |
- self.original_language = lan |
|
37 |
+ if lan == "auto": |
|
38 |
+ self.original_language = None |
|
39 |
+ else: |
|
40 |
+ self.original_language = lan |
|
35 | 41 |
|
36 | 42 |
self.model = self.load_model(modelsize, cache_dir, model_dir) |
37 | 43 |
|
... | ... | @@ -55,6 +61,7 @@ |
55 | 61 |
|
56 | 62 |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None): |
57 | 63 |
import whisper |
64 |
+ import whisper_timestamped |
|
58 | 65 |
from whisper_timestamped import transcribe_timestamped |
59 | 66 |
self.transcribe_timestamped = transcribe_timestamped |
60 | 67 |
if model_dir is not None: |
... | ... | @@ -119,8 +126,11 @@ |
119 | 126 |
return model |
120 | 127 |
|
121 | 128 |
def transcribe(self, audio, init_prompt=""): |
129 |
+ |
|
122 | 130 |
# tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01) |
123 | 131 |
segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs) |
132 |
+ #print(info) # info contains language detection result |
|
133 |
+ |
|
124 | 134 |
return list(segments) |
125 | 135 |
|
126 | 136 |
def ts_words(self, segments): |
... | ... | @@ -141,6 +151,93 @@ |
141 | 151 |
|
142 | 152 |
def set_translate_task(self): |
143 | 153 |
self.transcribe_kargs["task"] = "translate" |
154 |
+ |
|
155 |
+ |
|
156 |
+class OpenaiApiASR(ASRBase): |
|
157 |
+ """Uses OpenAI's Whisper API for audio transcription.""" |
|
158 |
+ |
|
159 |
+ def __init__(self, lan=None, temperature=0, logfile=sys.stderr): |
|
160 |
+ self.logfile = logfile |
|
161 |
+ |
|
162 |
+ self.modelname = "whisper-1" |
|
163 |
+ self.original_language = None if lan == "auto" else lan # ISO-639-1 language code |
|
164 |
+ self.response_format = "verbose_json" |
|
165 |
+ self.temperature = temperature |
|
166 |
+ |
|
167 |
+ self.load_model() |
|
168 |
+ |
|
169 |
+ self.use_vad_opt = False |
|
170 |
+ |
|
171 |
+ # reset the task in set_translate_task |
|
172 |
+ self.task = "transcribe" |
|
173 |
+ |
|
174 |
+ def load_model(self, *args, **kwargs): |
|
175 |
+ from openai import OpenAI |
|
176 |
+ self.client = OpenAI() |
|
177 |
+ |
|
178 |
+ self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost |
|
179 |
+ |
|
180 |
+ |
|
181 |
+ def ts_words(self, segments): |
|
182 |
+ no_speech_segments = [] |
|
183 |
+ if self.use_vad_opt: |
|
184 |
+ for segment in segments.segments: |
|
185 |
+ # TODO: threshold can be set from outside |
|
186 |
+ if segment["no_speech_prob"] > 0.8: |
|
187 |
+ no_speech_segments.append((segment.get("start"), segment.get("end"))) |
|
188 |
+ |
|
189 |
+ o = [] |
|
190 |
+ for word in segments.words: |
|
191 |
+ start = word.get("start") |
|
192 |
+ end = word.get("end") |
|
193 |
+ if any(s[0] <= start <= s[1] for s in no_speech_segments): |
|
194 |
+ # print("Skipping word", word.get("word"), "because it's in a no-speech segment") |
|
195 |
+ continue |
|
196 |
+ o.append((start, end, word.get("word"))) |
|
197 |
+ return o |
|
198 |
+ |
|
199 |
+ |
|
200 |
+ def segments_end_ts(self, res): |
|
201 |
+ return [s["end"] for s in res.words] |
|
202 |
+ |
|
203 |
+ def transcribe(self, audio_data, prompt=None, *args, **kwargs): |
|
204 |
+ # Write the audio data to a buffer |
|
205 |
+ buffer = io.BytesIO() |
|
206 |
+ buffer.name = "temp.wav" |
|
207 |
+ sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16') |
|
208 |
+ buffer.seek(0) # Reset buffer's position to the beginning |
|
209 |
+ |
|
210 |
+ self.transcribed_seconds += math.ceil(len(audio_data)/16000) # it rounds up to the whole seconds |
|
211 |
+ |
|
212 |
+ params = { |
|
213 |
+ "model": self.modelname, |
|
214 |
+ "file": buffer, |
|
215 |
+ "response_format": self.response_format, |
|
216 |
+ "temperature": self.temperature, |
|
217 |
+ "timestamp_granularities": ["word", "segment"] |
|
218 |
+ } |
|
219 |
+ if self.task != "translate" and self.original_language: |
|
220 |
+ params["language"] = self.original_language |
|
221 |
+ if prompt: |
|
222 |
+ params["prompt"] = prompt |
|
223 |
+ |
|
224 |
+ if self.task == "translate": |
|
225 |
+ proc = self.client.audio.translations |
|
226 |
+ else: |
|
227 |
+ proc = self.client.audio.transcriptions |
|
228 |
+ |
|
229 |
+ # Process transcription/translation |
|
230 |
+ transcript = proc.create(**params) |
|
231 |
+ logging.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") |
|
232 |
+ |
|
233 |
+ return transcript |
|
234 |
+ |
|
235 |
+ def use_vad(self): |
|
236 |
+ self.use_vad_opt = True |
|
237 |
+ |
|
238 |
+ def set_translate_task(self): |
|
239 |
+ self.task = "translate" |
|
240 |
+ |
|
144 | 241 |
|
145 | 242 |
|
146 | 243 |
|
... | ... | @@ -237,9 +334,6 @@ |
237 | 334 |
|
238 | 335 |
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile) |
239 | 336 |
self.commited = [] |
240 |
- self.last_chunked_at = 0 |
|
241 |
- |
|
242 |
- self.silence_iters = 0 |
|
243 | 337 |
|
244 | 338 |
def insert_audio_chunk(self, audio): |
245 | 339 |
self.audio_buffer = np.append(self.audio_buffer, audio) |
... | ... | @@ -249,7 +343,7 @@ |
249 | 343 |
"context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons. |
250 | 344 |
""" |
251 | 345 |
k = max(0,len(self.commited)-1) |
252 |
- while k > 0 and self.commited[k-1][1] > self.last_chunked_at: |
|
346 |
+ while k > 0 and self.commited[k-1][1] > self.buffer_time_offset: |
|
253 | 347 |
k -= 1 |
254 | 348 |
|
255 | 349 |
p = self.commited[:k] |
... | ... | @@ -362,7 +456,6 @@ |
362 | 456 |
cut_seconds = time - self.buffer_time_offset |
363 | 457 |
self.audio_buffer = self.audio_buffer[int(cut_seconds*self.SAMPLING_RATE):] |
364 | 458 |
self.buffer_time_offset = time |
365 |
- self.last_chunked_at = time |
|
366 | 459 |
|
367 | 460 |
def words_to_sentences(self, words): |
368 | 461 |
"""Uses self.tokenizer for sentence segmentation of words. |
... | ... | @@ -456,12 +549,41 @@ |
456 | 549 |
parser.add_argument('--model', type=str, default='large-v2', choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large".split(","),help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.") |
457 | 550 |
parser.add_argument('--model_cache_dir', type=str, default=None, help="Overriding the default model cache dir where models downloaded from the hub are saved") |
458 | 551 |
parser.add_argument('--model_dir', type=str, default=None, help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.") |
459 |
- parser.add_argument('--lan', '--language', type=str, default='en', help="Language code for transcription, e.g. en,de,cs.") |
|
552 |
+ parser.add_argument('--lan', '--language', type=str, default='auto', help="Source language code, e.g. en,de,cs, or 'auto' for language detection.") |
|
460 | 553 |
parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe","translate"],help="Transcribe or translate.") |
461 |
- parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped"],help='Load only this backend for Whisper processing.') |
|
554 |
+ parser.add_argument('--backend', type=str, default="faster-whisper", choices=["faster-whisper", "whisper_timestamped", "openai-api"],help='Load only this backend for Whisper processing.') |
|
462 | 555 |
parser.add_argument('--vad', action="store_true", default=False, help='Use VAD = voice activity detection, with the default parameters.') |
463 | 556 |
parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.') |
464 | 557 |
parser.add_argument('--buffer_trimming_sec', type=float, default=15, help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.') |
558 |
+ |
|
559 |
+def asr_factory(args, logfile=sys.stderr): |
|
560 |
+ """ |
|
561 |
+ Creates and configures an ASR instance based on the specified backend and arguments. |
|
562 |
+ """ |
|
563 |
+ backend = args.backend |
|
564 |
+ if backend == "openai-api": |
|
565 |
+ logging.debug("Using OpenAI API.") |
|
566 |
+ asr = OpenaiApiASR(lan=args.lan) |
|
567 |
+ else: |
|
568 |
+ if backend == "faster-whisper": |
|
569 |
+ asr_cls = FasterWhisperASR |
|
570 |
+ else: |
|
571 |
+ asr_cls = WhisperTimestampedASR |
|
572 |
+ |
|
573 |
+ # Only for FasterWhisperASR and WhisperTimestampedASR |
|
574 |
+ size = args.model |
|
575 |
+ t = time.time() |
|
576 |
+ logging.debug(f"Loading Whisper {size} model for {args.lan}...") |
|
577 |
+ asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir) |
|
578 |
+ e = time.time() |
|
579 |
+ logging.debug(f"done. It took {round(e-t,2)} seconds.") |
|
580 |
+ |
|
581 |
+ # Apply common configurations |
|
582 |
+ if getattr(args, 'vad', False): # Checks if VAD argument is present and True |
|
583 |
+ logging.info("Setting VAD filter") |
|
584 |
+ asr.use_vad() |
|
585 |
+ |
|
586 |
+ return asr |
|
465 | 587 |
|
466 | 588 |
## main: |
467 | 589 |
|
... | ... | @@ -490,18 +612,8 @@ |
490 | 612 |
duration = len(load_audio(audio_path))/SAMPLING_RATE |
491 | 613 |
logging.info("Audio duration is: %2.2f seconds" % duration) |
492 | 614 |
|
493 |
- size = args.model |
|
615 |
+ asr = asr_factory(args, logfile=logfile) |
|
494 | 616 |
language = args.lan |
495 |
- |
|
496 |
- t = time.time() |
|
497 |
- logging.info(f"Loading Whisper {size} model for {language}...") |
|
498 |
- |
|
499 |
- if args.backend == "faster-whisper": |
|
500 |
- asr_cls = FasterWhisperASR |
|
501 |
- else: |
|
502 |
- asr_cls = WhisperTimestampedASR |
|
503 |
- |
|
504 |
- asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir) |
|
505 | 617 |
|
506 | 618 |
if args.task == "translate": |
507 | 619 |
asr.set_translate_task() |
... | ... | @@ -509,15 +621,6 @@ |
509 | 621 |
else: |
510 | 622 |
tgt_language = language # Whisper transcribes in this language |
511 | 623 |
|
512 |
- |
|
513 |
- e = time.time() |
|
514 |
- logging.info(f"done. It took {round(e-t,2)} seconds.") |
|
515 |
- |
|
516 |
- if args.vad: |
|
517 |
- logging.info("setting VAD filter") |
|
518 |
- asr.use_vad() |
|
519 |
- |
|
520 |
- |
|
521 | 624 |
min_chunk = args.min_chunk_size |
522 | 625 |
if args.buffer_trimming == "sentence": |
523 | 626 |
tokenizer = create_tokenizer(tgt_language) |
... | ... | @@ -548,7 +651,8 @@ |
548 | 651 |
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),file=logfile,flush=True) |
549 | 652 |
print("%1.4f %1.0f %1.0f %s" % (now*1000, o[0]*1000,o[1]*1000,o[2]),flush=True) |
550 | 653 |
else: |
551 |
- print("here?", o,file=logfile,flush=True) |
|
654 |
+ # No text, so no output |
|
655 |
+ pass |
|
552 | 656 |
|
553 | 657 |
if args.offline: ## offline mode processing (for testing/debugging) |
554 | 658 |
a = load_audio(audio_path) |
--- whisper_online_server.py
+++ whisper_online_server.py
... | ... | @@ -5,6 +5,7 @@ |
5 | 5 |
import argparse |
6 | 6 |
import os |
7 | 7 |
import logging |
8 |
+import numpy as np |
|
8 | 9 |
|
9 | 10 |
parser = argparse.ArgumentParser() |
10 | 11 |
|
... | ... | @@ -33,34 +34,13 @@ |
33 | 34 |
size = args.model |
34 | 35 |
language = args.lan |
35 | 36 |
|
36 |
-t = time.time() |
|
37 |
-logging.debug(f"Loading Whisper {size} model for {language}...") |
|
38 |
- |
|
39 |
-if args.backend == "faster-whisper": |
|
40 |
- from faster_whisper import WhisperModel |
|
41 |
- asr_cls = FasterWhisperASR |
|
42 |
- logging.getLogger("faster_whisper").setLevel(logging.WARNING) |
|
43 |
-else: |
|
44 |
- import whisper |
|
45 |
- import whisper_timestamped |
|
46 |
-# from whisper_timestamped_model import WhisperTimestampedASR |
|
47 |
- asr_cls = WhisperTimestampedASR |
|
48 |
- |
|
49 |
-asr = asr_cls(modelsize=size, lan=language, cache_dir=args.model_cache_dir, model_dir=args.model_dir) |
|
37 |
+asr = asr_factory(args) |
|
50 | 38 |
|
51 | 39 |
if args.task == "translate": |
52 | 40 |
asr.set_translate_task() |
53 | 41 |
tgt_language = "en" |
54 | 42 |
else: |
55 | 43 |
tgt_language = language |
56 |
- |
|
57 |
-e = time.time() |
|
58 |
-logging.debug(f"done. It took {round(e-t,2)} seconds.") |
|
59 |
- |
|
60 |
-if args.vad: |
|
61 |
- logging.debug("setting VAD filter") |
|
62 |
- asr.use_vad() |
|
63 |
- |
|
64 | 44 |
|
65 | 45 |
min_chunk = args.min_chunk_size |
66 | 46 |
|
... | ... | @@ -141,7 +121,7 @@ |
141 | 121 |
if not raw_bytes: |
142 | 122 |
break |
143 | 123 |
sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW") |
144 |
- audio, _ = librosa.load(sf,sr=SAMPLING_RATE) |
|
124 |
+ audio, _ = librosa.load(sf,sr=SAMPLING_RATE,dtype=np.float32) |
|
145 | 125 |
out.append(audio) |
146 | 126 |
if not out: |
147 | 127 |
return None |
Add a comment
Delete comment
Once you delete this comment, you won't be able to recover it. Are you sure you want to delete this comment?