mrfakename
commited on
Commit
•
734bf0e
1
Parent(s):
7f66593
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- model/cfm.py +1 -1
- model/utils.py +1 -1
model/cfm.py
CHANGED
@@ -96,7 +96,7 @@ class CFM(nn.Module):
|
|
96 |
):
|
97 |
self.eval()
|
98 |
|
99 |
-
if cond.device
|
100 |
cond = cond.half()
|
101 |
|
102 |
# raw wave
|
|
|
96 |
):
|
97 |
self.eval()
|
98 |
|
99 |
+
if cond.device == torch.device('cuda'):
|
100 |
cond = cond.half()
|
101 |
|
102 |
# raw wave
|
model/utils.py
CHANGED
@@ -555,7 +555,7 @@ def repetition_found(text, length = 2, tolerance = 10):
|
|
555 |
# load model checkpoint for inference
|
556 |
|
557 |
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
558 |
-
if device
|
559 |
model = model.half()
|
560 |
|
561 |
ckpt_type = ckpt_path.split(".")[-1]
|
|
|
555 |
# load model checkpoint for inference
|
556 |
|
557 |
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
558 |
+
if device == "cuda":
|
559 |
model = model.half()
|
560 |
|
561 |
ckpt_type = ckpt_path.split(".")[-1]
|