|
from audiocraft.models import MusicGen |
|
import streamlit as st |
|
import torch |
|
import torchaudio |
|
import io |
|
import base64 |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = MusicGen.get_pretrained('facebook/musicgen-small') |
|
return model |
|
|
|
def generate_music_tensors(description, duration: int): |
|
st.write(f"Generating music for: '{description}' (Duration: {duration}s)") |
|
model = load_model() |
|
model.set_generation_params(use_sampling=True, top_k=250, duration=duration) |
|
output = model.generate(descriptions=[description], progress=True) |
|
return output[0] |
|
|
|
def create_audio_buffer(samples: torch.Tensor): |
|
"""Generate an in-memory audio buffer.""" |
|
sample_rate = 32000 |
|
samples = samples.detach().cpu() |
|
|
|
if samples.dim() == 2: |
|
samples = samples[None, ...] |
|
|
|
|
|
buffer = io.BytesIO() |
|
torchaudio.save(buffer, samples[0], sample_rate, format="wav") |
|
buffer.seek(0) |
|
return buffer |
|
|
|
def generate_download_link(buffer, file_label="Download Music"): |
|
"""Create a download link for the generated audio.""" |
|
data = buffer.read() |
|
b64 = base64.b64encode(data).decode() |
|
href = f'<a href="data:audio/wav;base64,{b64}" download="generated_music.wav">{file_label}</a>' |
|
return href |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
.title { |
|
font-size: 3em; |
|
text-align: center; |
|
color: #4A90E2; |
|
margin-top: 0; |
|
} |
|
.footer { |
|
position: fixed; |
|
left: 0; |
|
bottom: 0; |
|
width: 100%; |
|
background-color: #f1f1f1; |
|
text-align: center; |
|
padding: 10px; |
|
font-size: 0.8em; |
|
color: #555; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True |
|
) |
|
|
|
|
|
|
|
def main(): |
|
st.markdown('<h1 class="title">Theaimart: Music Generator 🎵</h1>', unsafe_allow_html=True) |
|
st.write("Generate music based on your text input using Meta's Audiocraft library!") |
|
|
|
description = st.text_area("Enter a description:") |
|
duration = st.slider("Select duration (seconds)", 1, 20, 10) |
|
|
|
if description and duration: |
|
music_tensors = generate_music_tensors(description, duration) |
|
audio_buffer = create_audio_buffer(music_tensors) |
|
|
|
st.audio(audio_buffer, format="audio/wav") |
|
st.markdown(generate_download_link(audio_buffer), unsafe_allow_html=True) |
|
|
|
|
|
st.markdown('<div class="footer">Made with ❤️ by Theaimart</div>', unsafe_allow_html=True) |
|
|
|
if __name__ == "__main__": |
|
main() |