Spaces:
Build error
Build error
Rongjiehuang
commited on
Commit
•
d9a4587
1
Parent(s):
90d9164
update
Browse files
inference/base_tts_infer.py
CHANGED
@@ -78,7 +78,7 @@ class BaseTTSInfer:
|
|
78 |
# processed ref audio
|
79 |
ref_audio = inp['ref_audio']
|
80 |
processed_ref_audio = 'example/temp.wav'
|
81 |
-
voice_encoder = VoiceEncoder().
|
82 |
encoder = [self.ph_encoder, self.word_encoder]
|
83 |
EmotionEncoder.load_model(self.hparams['emotion_encoder_path'])
|
84 |
binarizer_cls = self.hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer')
|
@@ -185,7 +185,7 @@ class BaseTTSInfer:
|
|
185 |
input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
|
186 |
|
187 |
# retrieve logits & take argmax
|
188 |
-
logits = self.asr_model(input_values.
|
189 |
predicted_ids = torch.argmax(logits, dim=-1)
|
190 |
|
191 |
# transcribe
|
|
|
78 |
# processed ref audio
|
79 |
ref_audio = inp['ref_audio']
|
80 |
processed_ref_audio = 'example/temp.wav'
|
81 |
+
voice_encoder = VoiceEncoder().to(self.device)
|
82 |
encoder = [self.ph_encoder, self.word_encoder]
|
83 |
EmotionEncoder.load_model(self.hparams['emotion_encoder_path'])
|
84 |
binarizer_cls = self.hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer')
|
|
|
185 |
input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
|
186 |
|
187 |
# retrieve logits & take argmax
|
188 |
+
logits = self.asr_model(input_values.to(self.device)).logits
|
189 |
predicted_ids = torch.argmax(logits, dim=-1)
|
190 |
|
191 |
# transcribe
|