mrfakename commited on
Commit
b315dd9
1 Parent(s): b4fc33b

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

pyproject.toml CHANGED
@@ -15,7 +15,7 @@ classifiers = [
15
  ]
16
  dependencies = [
17
  "accelerate>=0.33.0",
18
- "bitsandbytes>0.37.0",
19
  "cached_path",
20
  "click",
21
  "datasets",
 
15
  ]
16
  dependencies = [
17
  "accelerate>=0.33.0",
18
+ "bitsandbytes>0.37.0; platform_machine != 'arm64' and platform_system != 'Darwin'",
19
  "cached_path",
20
  "click",
21
  "datasets",
src/f5_tts/api.py CHANGED
@@ -3,7 +3,6 @@ import sys
3
  from importlib.resources import files
4
 
5
  import soundfile as sf
6
- import torch
7
  import tqdm
8
  from cached_path import cached_path
9
 
@@ -43,9 +42,12 @@ class F5TTS:
43
  self.mel_spec_type = vocoder_name
44
 
45
  # Set device
46
- self.device = device or (
47
- "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
48
- )
 
 
 
49
 
50
  # Load models
51
  self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
 
3
  from importlib.resources import files
4
 
5
  import soundfile as sf
 
6
  import tqdm
7
  from cached_path import cached_path
8
 
 
42
  self.mel_spec_type = vocoder_name
43
 
44
  # Set device
45
+ if device is not None:
46
+ self.device = device
47
+ else:
48
+ import torch
49
+
50
+ self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
51
 
52
  # Load models
53
  self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
src/f5_tts/infer/speech_edit.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
 
 
 
3
  import torch
4
  import torch.nn.functional as F
5
  import torchaudio
 
1
  import os
2
 
3
+ os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
4
+
5
  import torch
6
  import torch.nn.functional as F
7
  import torchaudio
src/f5_tts/infer/utils_infer.py CHANGED
@@ -3,6 +3,7 @@
3
  import os
4
  import sys
5
 
 
6
  sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
7
 
8
  import hashlib
@@ -33,8 +34,6 @@ from f5_tts.model.utils import (
33
  _ref_audio_cache = {}
34
 
35
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
36
- if device == "mps":
37
- os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1"
38
 
39
  # -----------------------------------------
40
 
 
3
  import os
4
  import sys
5
 
6
+ os.environ["PYTOCH_ENABLE_MPS_FALLBACK"] = "1" # for MPS device compatibility
7
  sys.path.append(f"../../{os.path.dirname(os.path.abspath(__file__))}/third_party/BigVGAN/")
8
 
9
  import hashlib
 
34
  _ref_audio_cache = {}
35
 
36
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
 
37
 
38
  # -----------------------------------------
39