JacobLinCool commited on
Commit
0293e4c
1 Parent(s): 8fdcaa3

fix: infer when no index file

Browse files
Files changed (1) hide show
  1. app/infer.py +47 -20
app/infer.py CHANGED
@@ -12,15 +12,19 @@ from model import device
12
 
13
 
14
  @zero(duration=120)
15
- def infer(exp_dir: str, original_audio: str, f0add: int) -> Tuple[int, np.ndarray]:
 
 
16
  name = os.path.basename(exp_dir)
17
  model = os.path.join(exp_dir, f"{name}.pth")
18
  if not os.path.exists(model):
19
  raise gr.Error("Model not found")
20
 
21
  index = glob(f"{exp_dir}/added_*.index")
22
- if not index:
23
- raise gr.Error("Index not found")
 
 
24
 
25
  base = os.path.basename(original_audio)
26
  base = os.path.splitext(base)[0]
@@ -40,11 +44,11 @@ def infer(exp_dir: str, original_audio: str, f0add: int) -> Tuple[int, np.ndarra
40
  "rmvpe",
41
  index,
42
  None,
43
- 0.5,
44
- 3,
45
  0,
46
  1,
47
- 0.33,
48
  )
49
 
50
  sr = wav_opt[0]
@@ -80,19 +84,36 @@ class InferenceTab:
80
  )
81
 
82
  with gr.Row():
83
- self.original_audio = gr.Audio(
84
- label="Upload original audio",
85
- type="filepath",
86
- show_download_button=True,
87
- )
88
- self.f0add = gr.Slider(
89
- label="F0 add",
90
- minimum=-16,
91
- maximum=16,
92
- step=1,
93
- value=0,
94
- )
95
- self.infer_btn = gr.Button(value="Infer", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  with gr.Row():
97
  self.infer_output = gr.Audio(label="Inferred audio")
98
  with gr.Row():
@@ -101,7 +122,13 @@ class InferenceTab:
101
  def build(self, exp_dir: gr.Textbox):
102
  self.infer_btn.click(
103
  fn=infer,
104
- inputs=[exp_dir, self.original_audio, self.f0add],
 
 
 
 
 
 
105
  outputs=[self.infer_output],
106
  ).success(
107
  fn=merge,
 
12
 
13
 
14
  @zero(duration=120)
15
+ def infer(
16
+ exp_dir: str, original_audio: str, f0add: int, index_rate: float, protect: float
17
+ ) -> Tuple[int, np.ndarray]:
18
  name = os.path.basename(exp_dir)
19
  model = os.path.join(exp_dir, f"{name}.pth")
20
  if not os.path.exists(model):
21
  raise gr.Error("Model not found")
22
 
23
  index = glob(f"{exp_dir}/added_*.index")
24
+ if index:
25
+ index = index[0]
26
+ else:
27
+ index = None
28
 
29
  base = os.path.basename(original_audio)
30
  base = os.path.splitext(base)[0]
 
44
  "rmvpe",
45
  index,
46
  None,
47
+ index_rate,
48
+ 3, # this only has effect when f0_method is "harvest"
49
  0,
50
  1,
51
+ protect,
52
  )
53
 
54
  sr = wav_opt[0]
 
84
  )
85
 
86
  with gr.Row():
87
+ with gr.Column():
88
+ self.original_audio = gr.Audio(
89
+ label="Upload original audio",
90
+ type="filepath",
91
+ show_download_button=True,
92
+ )
93
+ with gr.Column():
94
+ self.f0add = gr.Slider(
95
+ label="F0 add",
96
+ minimum=-16,
97
+ maximum=16,
98
+ step=1,
99
+ value=0,
100
+ )
101
+ self.index_rate = gr.Slider(
102
+ label="Index rate",
103
+ minimum=-0,
104
+ maximum=1,
105
+ step=0.01,
106
+ value=0.5,
107
+ )
108
+ self.protect = gr.Slider(
109
+ label="Protect",
110
+ minimum=0,
111
+ maximum=1,
112
+ step=0.01,
113
+ value=0.33,
114
+ )
115
+ with gr.Column():
116
+ self.infer_btn = gr.Button(value="Infer", variant="primary")
117
  with gr.Row():
118
  self.infer_output = gr.Audio(label="Inferred audio")
119
  with gr.Row():
 
122
  def build(self, exp_dir: gr.Textbox):
123
  self.infer_btn.click(
124
  fn=infer,
125
+ inputs=[
126
+ exp_dir,
127
+ self.original_audio,
128
+ self.f0add,
129
+ self.index_rate,
130
+ self.protect,
131
+ ],
132
  outputs=[self.infer_output],
133
  ).success(
134
  fn=merge,