saefro991 commited on
Commit
b6af22a
1 Parent(s): bc80791

eval mode for inference

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -22,7 +22,7 @@ class ChangeSampleRate(nn.Module):
22
  output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
23
  return output
24
 
25
- model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt")
26
  def calc_mos(audio_path):
27
  wav, sr = torchaudio.load(audio_path)
28
  osr = 16_000
@@ -34,7 +34,8 @@ def calc_mos(audio_path):
34
  'domains': torch.tensor([0]),
35
  'judge_id': torch.tensor([288])
36
  }
37
- output = model(batch)
 
38
  return output.mean(dim=1).squeeze().detach().numpy()*2 + 3
39
 
40
 
 
22
  output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
23
  return output
24
 
25
+ model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
26
  def calc_mos(audio_path):
27
  wav, sr = torchaudio.load(audio_path)
28
  osr = 16_000
 
34
  'domains': torch.tensor([0]),
35
  'judge_id': torch.tensor([288])
36
  }
37
+ with torch.no_grad():
38
+ output = model(batch)
39
  return output.mean(dim=1).squeeze().detach().numpy()*2 + 3
40
 
41