# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang) # # See LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import subprocess from dataclasses import dataclass from datetime import timedelta from typing import Optional from transformers import pipeline, MarianMTModel, MarianTokenizer import numpy as np import sherpa_onnx from model import sample_rate @dataclass class Segment: start: float duration: float text: str = "" cn_text: str = "" @property def end(self): return self.start + self.duration def __str__(self): s = f"0{timedelta(seconds=self.start)}"[:-3] s += " --> " s += f"0{timedelta(seconds=self.end)}"[:-3] s = s.replace(".", ",") s += "\n" s += self.text s += "\n" s += self.cn_text return s def decode( recognizer: sherpa_onnx.OfflineRecognizer, vad: sherpa_onnx.VoiceActivityDetector, punct: Optional[sherpa_onnx.OfflinePunctuation], filename: str, ) -> str: ffmpeg_cmd = [ "ffmpeg", "-i", filename, "-f", "s16le", "-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate), "-", ] process = subprocess.Popen( ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL ) frames_per_read = int(sample_rate * 100) # 100 second window_size = 512 buffer = [] segment_list = [] logging.info("Started!") all_text = [] is_last = False while True: # *2 because int16_t has two bytes data = process.stdout.read(frames_per_read * 2) if not data: if is_last: break is_last = True data = np.zeros(sample_rate, dtype=np.int16) samples = np.frombuffer(data, dtype=np.int16) samples = samples.astype(np.float32) / 32768 buffer = np.concatenate([buffer, samples]) while len(buffer) > window_size: vad.accept_waveform(buffer[:window_size]) buffer = buffer[window_size:] streams = [] segments = [] while not vad.empty(): segment = Segment( start=vad.front.start / sample_rate, duration=len(vad.front.samples) / sample_rate, ) segments.append(segment) stream = recognizer.create_stream() stream.accept_waveform(sample_rate, vad.front.samples) streams.append(stream) vad.pop() for s in streams: recognizer.decode_stream(s) for seg, stream in zip(segments, streams): en_text = stream.result.text.strip() seg.text = en_text if len(seg.text) == 0: logging.info("Skip empty segment") continue seg.cn_text = _llm_translator.translate(en_text) if len(all_text) == 0: all_text.append(seg.text) elif len(all_text[-1][0].encode()) == 1 and len(seg.text[0].encode()) == 1: all_text.append(" ") all_text.append(seg.text) else: all_text.append(seg.text) if punct is not None: seg.text = punct.add_punctuation(seg.text) segment_list.append(seg) all_text = "".join(all_text) if punct is not None: all_text = punct.add_punctuation(all_text) return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1)), all_text def translate_en_to_cn(src_text: str, ) -> str: model_name = "Helsinki-NLP/opus-mt-en-zh" tokenizer = MarianTokenizer.from_pretrained(model_name) model = MarianMTModel.from_pretrained(model_name) translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True)) res = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] return res class LLMTranslator: _tokenizer: MarianTokenizer _model: MarianMTModel def __init__(self): model_name = "Helsinki-NLP/opus-mt-en-zh" self._tokenizer = MarianTokenizer.from_pretrained(model_name) self._model = MarianMTModel.from_pretrained(model_name) def translate(self, src_text: str) -> str: translated = self._model.generate(**self._tokenizer(src_text, return_tensors="pt", padding=True)) res = [self._tokenizer.decode(t, skip_special_tokens=True) for t in translated] return "".join(str(itemText) for itemText in res) _llm_translator = LLMTranslator()