mrfakename commited on
Commit
f6a409f
1 Parent(s): b315dd9

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/infer/utils_infer.py +4 -8
src/f5_tts/infer/utils_infer.py CHANGED
@@ -135,12 +135,10 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
135
  asr_pipe = None
136
 
137
 
138
- def initialize_asr_pipeline(device=device, dtype=None):
139
  if dtype is None:
140
  dtype = (
141
- torch.float16
142
- if torch.cuda.is_available() and torch.cuda.get_device_properties(device).major >= 6
143
- else torch.float32
144
  )
145
  global asr_pipe
146
  asr_pipe = pipeline(
@@ -170,12 +168,10 @@ def transcribe(ref_audio, language=None):
170
  # load model checkpoint for inference
171
 
172
 
173
- def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
174
  if dtype is None:
175
  dtype = (
176
- torch.float16
177
- if torch.cuda.is_available() and torch.cuda.get_device_properties(device).major >= 6
178
- else torch.float32
179
  )
180
  model = model.to(dtype)
181
 
 
135
  asr_pipe = None
136
 
137
 
138
+ def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
  dtype = (
141
+ torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
 
 
142
  )
143
  global asr_pipe
144
  asr_pipe = pipeline(
 
168
  # load model checkpoint for inference
169
 
170
 
171
+ def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
172
  if dtype is None:
173
  dtype = (
174
+ torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
 
 
175
  )
176
  model = model.to(dtype)
177