|
import json |
|
import math |
|
import random |
|
import os |
|
import streamlit as st |
|
import lyricsgenius |
|
import transformers |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
|
|
st.set_page_config(page_title="HuggingArtists") |
|
|
|
|
|
st.title("HuggingArtists") |
|
st.sidebar.markdown( |
|
""" |
|
<style> |
|
.aligncenter { |
|
text-align: center; |
|
} |
|
</style> |
|
<p class="aligncenter"> |
|
<img src="https://raw.githubusercontent.com/AlekseyKorshuk/huggingartists/master/img/logo.jpg" width="420" /> |
|
</p> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
st.sidebar.markdown( |
|
""" |
|
<style> |
|
.aligncenter { |
|
text-align: center; |
|
} |
|
</style> |
|
|
|
<p style='text-align: center'> |
|
<a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank">GitHub</a> | <a href="https://wandb.ai/huggingartists/huggingartists/reportlist" target="_blank">Project Report</a> |
|
</p> |
|
|
|
<p class="aligncenter"> |
|
<a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank"> |
|
<img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingartists?style=social"/> |
|
</a> |
|
</p> |
|
<p class="aligncenter"> |
|
<a href="https://t.me/joinchat/_CQ04KjcJ-4yZTky" target="_blank"> |
|
<img src="https://img.shields.io/badge/dynamic/json?color=blue&label=Telegram%20Channel&query=%24.result&url=https%3A%2F%2Fapi.telegram.org%2Fbot1929545866%3AAAFGhV-KKnegEcLiyYJxsc4zV6C-bdPEBtQ%2FgetChatMemberCount%3Fchat_id%3D-1001253621662&style=social&logo=telegram"/> |
|
</a> |
|
</p> |
|
<p class="aligncenter"> |
|
<a href="https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb" target="_blank"> |
|
<img src="https://colab.research.google.com/assets/colab-badge.svg"/> |
|
</a> |
|
</p> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
|
|
st.sidebar.header("Generation settings:") |
|
num_sequences = st.sidebar.number_input( |
|
"Number of sequences to generate", |
|
min_value=1, |
|
value=5, |
|
help="The amount of generated texts", |
|
) |
|
min_length = st.sidebar.number_input( |
|
"Minimum length of the sequence", |
|
min_value=1, |
|
value=100, |
|
help="The minimum length of the sequence to be generated", |
|
) |
|
max_length= st.sidebar.number_input( |
|
"Maximum length of the sequence", |
|
min_value=1, |
|
value=160, |
|
help="The maximum length of the sequence to be generated", |
|
) |
|
temperature = st.sidebar.slider( |
|
"Temperature", |
|
min_value=0.0, |
|
max_value=3.0, |
|
step=0.01, |
|
value=1.0, |
|
help="The value used to module the next token probabilities", |
|
) |
|
top_p = st.sidebar.slider( |
|
"Top-P", |
|
min_value=0.0, |
|
max_value=1.0, |
|
step=0.01, |
|
value=0.95, |
|
help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.", |
|
) |
|
|
|
top_k= st.sidebar.number_input( |
|
"Top-K", |
|
min_value=0, |
|
value=50, |
|
step=1, |
|
help="The number of highest probability vocabulary tokens to keep for top-k-filtering.", |
|
) |
|
|
|
caption = ( |
|
"In [HuggingArtists](https://github.com/AlekseyKorshuk/huggingartist), we can generate lyrics by a specific artist. This was made by fine-tuning a pre-trained [HuggingFace Transformer](https://huggingface.co) on parsed datasets from [Genius](https://genius.com)." |
|
) |
|
st.markdown("[HuggingArtists](https://github.com/AlekseyKorshuk/huggingartist) - Train a model to generate lyrics π΅") |
|
st.markdown(caption) |
|
|
|
st.subheader("Settings:") |
|
artist_name = st.text_input("Artist name:", "Headie One") |
|
start = st.text_input("Beginning of the song:", "Bad B come to the niz") |
|
|
|
TOKEN = "q_JK_BFy9OMiG7fGTzL-nUto9JDv3iXI24aYRrQnkOvjSCSbY4BuFIindweRsr5I" |
|
genius = lyricsgenius.Genius(TOKEN) |
|
|
|
model_html = """ |
|
|
|
<div class="inline-flex flex-col" style="line-height: 1.5;"> |
|
<div class="flex"> |
|
<div |
|
\t\t\tstyle="display:DISPLAY_1; margin-left: auto; margin-right: auto; width: 92px; height:92px; border-radius: 50%; background-size: cover; background-image: url('USER_PROFILE')"> |
|
</div> |
|
</div> |
|
<div style="text-align: center; margin-top: 3px; font-size: 16px; font-weight: 800">π€ HuggingArtists Model π€</div> |
|
<div style="text-align: center; font-size: 16px; font-weight: 800">USER_NAME</div> |
|
<a href="https://genius.com/artists/USER_HANDLE"> |
|
\t<div style="text-align: center; font-size: 14px;">@USER_HANDLE</div> |
|
</a> |
|
</div> |
|
""" |
|
|
|
|
|
def post_process(output_sequences): |
|
predictions = [] |
|
generated_sequences = [] |
|
|
|
max_repeat = 2 |
|
|
|
|
|
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): |
|
generated_sequence = generated_sequence.tolist() |
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
|
generated_sequences.append(text.strip()) |
|
|
|
for i, g in enumerate(generated_sequences): |
|
res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n') |
|
lines = res.split('\n') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predictions.append('\n'.join(lines)) |
|
|
|
return predictions |
|
|
|
if st.button("Run"): |
|
model_name = None |
|
with st.spinner(text=f"Searching for {artist_name } in Genius..."): |
|
artist = genius.search_artist(artist_name, max_songs=0, get_full_info=False) |
|
if artist is not None: |
|
artist_dict = genius.artist(artist.id)['artist'] |
|
artist_url = str(artist_dict['url']) |
|
model_name = artist_url[artist_url.rfind('/') + 1:].lower() |
|
st.markdown(model_html.replace("USER_PROFILE",artist.image_url).replace("USER_NAME",artist.name).replace("USER_HANDLE",model_name), unsafe_allow_html=True) |
|
else: |
|
st.markdown(f"Could not find {artist_name}! Be sure that he/she exists in [Genius](https://genius.com/).") |
|
if model_name is not None: |
|
with st.spinner(text=f"Downloading the model of {artist_name }..."): |
|
model = None |
|
tokenizer = None |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(f"huggingartists/{model_name}") |
|
model = AutoModelForCausalLM.from_pretrained(f"huggingartists/{model_name}") |
|
except Exception as ex: |
|
|
|
st.markdown(f"Model for this artist does not exist yet. Create it in just 5 min with [Colab Notebook](https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb):") |
|
st.markdown( |
|
""" |
|
<style> |
|
.aligncenter { |
|
text-align: center; |
|
} |
|
</style> |
|
<p class="aligncenter"> |
|
<a href="https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb" target="_blank"> |
|
<img src="https://colab.research.google.com/assets/colab-badge.svg"/> |
|
</a> |
|
</p> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
if model is not None: |
|
with st.spinner(text=f"Generating lyrics..."): |
|
encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids |
|
encoded_prompt = encoded_prompt.to(model.device) |
|
|
|
output_sequences = model.generate( |
|
input_ids=encoded_prompt, |
|
max_length=max_length, |
|
min_length=min_length, |
|
temperature=float(temperature), |
|
top_p=float(top_p), |
|
top_k=int(top_k), |
|
do_sample=True, |
|
repetition_penalty=1.0, |
|
num_return_sequences=num_sequences |
|
) |
|
|
|
predictions = post_process(output_sequences) |
|
st.subheader("Results") |
|
for prediction in predictions: |
|
st.text(prediction) |
|
st.subheader("Link to the original repository:") |
|
st.markdown( |
|
""" |
|
<style> |
|
.aligncenter { |
|
text-align: center; |
|
} |
|
</style> |
|
<p class="aligncenter"> |
|
<a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank"> |
|
<img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingartists?style=social"/> |
|
</a> |
|
""", |
|
unsafe_allow_html=True, |
|
) |