cmagganas commited on
Commit
7585d18
1 Parent(s): 74f66bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -14,37 +14,43 @@ processor = preloaded["processor"]
14
 
15
  st.title("Audio Inversion with HuggingFace & Streamlit")
16
 
 
 
 
 
 
17
  uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "flac"])
18
 
19
- if uploaded_file:
 
 
 
 
20
  # Play the uploaded audio
21
- st.audio(uploaded_file, format="audio/wav")
 
22
 
23
  # Read the audio file
24
- audio, sr = sf.read(io.BytesIO(uploaded_file.getvalue()))
25
 
26
  # Convert audio to tensor
27
  audio_tensor = torch.tensor(audio).float()
28
 
29
- # Use Streamlit's session state to prevent re-inversion
30
- if "inverted_audio" not in st.session_state:
31
- with st.spinner("Inverting audio..."):
32
- # Invert the audio using the modified function
33
- inverted_audio_tensor = invert_audio(model, processor, audio_tensor, sr)
34
-
35
- # Convert tensor back to numpy
36
- inverted_audio_np = inverted_audio_tensor.numpy()
37
 
38
- # Store inverted audio in session state
39
- st.session_state.inverted_audio = inverted_audio_np
 
40
 
41
- # Play inverted audio from session state
42
  with io.BytesIO() as out_io:
43
- sf.write(out_io, st.session_state.inverted_audio, sr, format="wav")
44
  st.audio(out_io.getvalue(), format="audio/wav")
45
 
46
  # Offer a download button for the inverted audio
47
  if st.button("Download Inverted Audio"):
48
  with io.BytesIO() as out_io:
49
- sf.write(out_io, st.session_state.inverted_audio, sr, format="wav")
50
  st.download_button("Download Inverted Audio", data=out_io.getvalue(), file_name="inverted_output.wav", mime="audio/wav")
 
14
 
15
  st.title("Audio Inversion with HuggingFace & Streamlit")
16
 
17
+ # If this is the first run, create a new session state attribute for uploaded file
18
+ if 'uploaded_file' not in st.session_state:
19
+ st.session_state.uploaded_file = None
20
+
21
+ # Get the uploaded file
22
  uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "flac"])
23
 
24
+ # Update the session state only if a new file is uploaded
25
+ if uploaded_file is not None:
26
+ st.session_state.uploaded_file = uploaded_file.getvalue() # store content, not the file object
27
+
28
+ if st.session_state.uploaded_file:
29
  # Play the uploaded audio
30
+ audio_byte_content = st.session_state.uploaded_file
31
+ st.audio(audio_byte_content, format="audio/wav")
32
 
33
  # Read the audio file
34
+ audio, sr = sf.read(io.BytesIO(audio_byte_content))
35
 
36
  # Convert audio to tensor
37
  audio_tensor = torch.tensor(audio).float()
38
 
39
+ @st.cache(allow_output_mutation=True, suppress_st_warning=True)
40
+ def cache_inverted_audio(audio_tensor):
41
+ return invert_audio(model, processor, audio_tensor, sr)
 
 
 
 
 
42
 
43
+ # Use cached result
44
+ inverted_audio_tensor = cache_inverted_audio(audio_tensor)
45
+ inverted_audio_np = inverted_audio_tensor.numpy()
46
 
47
+ # Play inverted audio
48
  with io.BytesIO() as out_io:
49
+ sf.write(out_io, inverted_audio_np, sr, format="wav")
50
  st.audio(out_io.getvalue(), format="audio/wav")
51
 
52
  # Offer a download button for the inverted audio
53
  if st.button("Download Inverted Audio"):
54
  with io.BytesIO() as out_io:
55
+ sf.write(out_io, inverted_audio_np, sr, format="wav")
56
  st.download_button("Download Inverted Audio", data=out_io.getvalue(), file_name="inverted_output.wav", mime="audio/wav")