B-K commited on
Commit
a149a56
·
verified ·
1 Parent(s): bb63f88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -261,7 +261,7 @@ def load_input(song_path, arranger_id):
261
 
262
  def download_piano(youtube_link):
263
  yt = pytube.YouTube(youtube_link)
264
- download_path = os.path.join(yt_dir, yt.title + ".mp4")
265
  yt.streams.filter(only_audio=True).first().download(filename=download_path)
266
 
267
  # convert to mp3
@@ -290,11 +290,12 @@ def inference(yt_link, arranger_id):
290
 
291
  def post_process(generated):
292
  print("post processing")
293
- midi = tokenizer.decode(generated.argmax(dim=-1).cpu())
 
294
 
295
  # random name
296
  output_midi_path = os.path.join(midi_dir, f"{binascii.hexlify(os.urandom(8)).decode()}.mid")
297
- midi.dump_midi(os.path.join(midi_dir, output_midi_path))
298
 
299
  print("exporting")
300
  return output_midi_path
 
261
 
262
  def download_piano(youtube_link):
263
  yt = pytube.YouTube(youtube_link)
264
+ download_path = os.path.join(yt_dir, f"{binascii.hexlify(os.urandom(8)).decode()}.mp4")
265
  yt.streams.filter(only_audio=True).first().download(filename=download_path)
266
 
267
  # convert to mp3
 
290
 
291
  def post_process(generated):
292
  print("post processing")
293
+ print(generated.argmax(dim=-1).shape)
294
+ midi = tokenizer.decode(generated.argmax(dim=-1).unsqueeze(0).cpu())
295
 
296
  # random name
297
  output_midi_path = os.path.join(midi_dir, f"{binascii.hexlify(os.urandom(8)).decode()}.mid")
298
+ midi.dump_midi(output_midi_path)
299
 
300
  print("exporting")
301
  return output_midi_path