Spaces:
Runtime error
Runtime error
## Alternative movie poster generator | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import json | |
import requests | |
import os | |
import io | |
import string | |
import random | |
from streamlit import session_state as session | |
from datetime import time, datetime | |
from zipfile import ZipFile | |
from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts | |
from htbuilder.units import percent, px | |
from htbuilder.funcs import rgba, rgb | |
from PIL import Image | |
############################### | |
## --- GLOBAL VARIABLES ---- ## | |
############################### | |
PATH_JSON = '/home/user/.kaggle/kaggle.json' | |
# Environment variables to authenticate Kaggle account | |
os.environ['KAGGLE_USERNAME'] = st.secrets['username'] | |
os.environ['KAGGLE_KEY'] = st.secrets['key'] | |
os.environ['KAGGLE_CONFIG_DIR'] = PATH_JSON | |
from kaggle.api.kaggle_api_extended import KaggleApi | |
############################### | |
## ------- FUNCTIONS ------- ## | |
############################### | |
def link(link, text, **style): | |
return a(_href=link, _target="_blank", style=styles(**style))(text) | |
def layout(*args): | |
style = """ | |
<style> | |
# MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
.stApp { bottom: 105px; } | |
</style> | |
""" | |
style_div = styles( | |
position="fixed", | |
left=0, | |
bottom=0, | |
margin=px(0, 0, 0, 0), | |
width=percent(100), | |
color="black", | |
text_align="center", | |
height="auto", | |
opacity=1 | |
) | |
style_hr = styles( | |
display="block", | |
margin=px(4, 4, "auto", "auto"), | |
border_style="inset", | |
border_width=px(0) | |
) | |
body = p() | |
foot = div( | |
style=style_div | |
)( | |
hr( | |
style=style_hr | |
), | |
body | |
) | |
st.markdown(style, unsafe_allow_html=True) | |
for arg in args: | |
if isinstance(arg, str): | |
body(arg) | |
elif isinstance(arg, HtmlElement): | |
body(arg) | |
st.markdown(str(foot), unsafe_allow_html=True) | |
def footer(): | |
myargs = [ | |
"Made with ❤️ by ", | |
link("https://www.linkedin.com/in/gaspar-avit/?locale=en_US", "Gaspar Avit"), | |
] | |
layout(*myargs) | |
def authenticate_kaggle(): | |
# Connect to kaggle API | |
# Save credentials to json file | |
if not os.path.exists(PATH_JSON): | |
api_token = {"username":st.secrets['username'],"key":st.secrets['key']} | |
with open(PATH_JSON, 'w') as file: | |
json.dump(api_token, file) | |
# Activate Kaggle API | |
global api | |
api = KaggleApi() | |
api.authenticate() | |
def load_dataset(): | |
""" | |
Load Dataset from Kaggle | |
-return: dataframe containing dataset | |
""" | |
## --- Connect to kaggle API --- ## | |
# Save credentials to json file | |
if not os.path.exists(PATH_JSON): | |
api_token = {"username":st.secrets['username'],"key":st.secrets['key']} | |
with open(PATH_JSON, 'w') as file: | |
json.dump(api_token, file) | |
# Activate Kaggle API | |
global api | |
api = KaggleApi() | |
api.authenticate() | |
## ----------------------------- ## | |
# Downloading Movies dataset | |
api.dataset_download_file('rounakbanik/the-movies-dataset', 'movies_metadata.csv') | |
# Extract data | |
zf = ZipFile('movies_metadata.csv.zip') | |
zf.extractall() | |
zf.close() | |
# Create dataframe | |
data = pd.read_csv('movies_metadata.csv', low_memory=False) | |
data['year'] = data["release_date"].map(lambda x: x.split('-')[0] if isinstance(x, str) else '0') | |
data['title_year'] = data['title'] + ' (' + data['year'] + ')' | |
return data | |
def query_summary(text): | |
""" | |
Get summarization from HuggingFace Inference API | |
-param text: text to be summarized | |
-return: summarized text | |
""" | |
API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn" | |
headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"} | |
payload = {"inputs": f"{text}",} | |
response = requests.request("POST", API_URL, headers=headers, json=payload).json() | |
try: | |
text = response[0].get('summary_text') | |
except: | |
text = response[0] | |
return text | |
def query_generate(text, title, genres, year, selected_model='Stable Diffusion v1.5'): | |
""" | |
Get image from HuggingFace Inference API | |
-param text: text to generate image | |
-param title: title of the movie | |
-param genres: genres of the movie | |
-param year: year of the movie | |
-return: generated image | |
""" | |
if selected_model=='Stable Diffusion XL': | |
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0" | |
elif selected_model=='Stable Diffusion v2.1': | |
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-2-1" | |
elif selected_model=='Stable Diffusion v1.5': | |
API_URL = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5" | |
else: | |
raise ValueError("Value not valid for argument 'selected_model'.") | |
headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"} | |
text = 'A Poster for the movie ' + title.split('(')[0] + 'in portrait mode based on the following synopsis: \"' + text + '\". Style: ' + genres + '. Year ' + year + \ | |
'. Ignore ' + ''.join(random.choices(string.ascii_letters, k=10)) | |
payload = {"inputs": f"{text}", "options": {"use_cache": "false"},} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
try: | |
response_str = response.content.decode("utf-8") | |
if 'error' in response_str: | |
payload = {"inputs": f"{text}", | |
"options": {"wait_for_model": True}, | |
} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
except: | |
pass | |
return response.content | |
def generate_poster(movie_data, selected_model): | |
""" | |
Function for recommending movies | |
-param movie_data: metadata of movie selected by user | |
-return: image of generated alternative poster | |
""" | |
# Get movie metadata | |
genres = [i['name'] for i in eval(movie_data['genres'].values[0])] | |
genres_string = ', '.join(genres) | |
year = movie_data['year'].values[0] | |
title = movie_data['title'].values[0] | |
# Get summarization of movie synopsis | |
st.text("") | |
with st.spinner("Summarizing synopsis..."): | |
synopsis_sum = query_summary(movie_data.overview.values[0]) | |
# Print summarized synopsis | |
st.text("") | |
synopsis_expander = st.expander("Show synopsis", expanded=False) | |
with synopsis_expander: | |
st.subheader("Summarized synopsis:") | |
col1, col2 = st.columns([5, 1]) | |
with col1: | |
st.write(synopsis_sum) | |
st.text("") | |
st.text("") | |
st.text("") | |
st.text("") | |
# Get image based on synopsis | |
with st.spinner("Generating poster..."): | |
response_content = query_generate(synopsis_sum, title, genres_string, year, selected_model) | |
# Show image | |
try: | |
image = Image.open(io.BytesIO(response_content)) | |
st.text("") | |
st.text("") | |
st.subheader("Resulting poster:") | |
st.text("") | |
col1, col2, col3 = st.columns([1, 5, 1]) | |
with col2: | |
st.image(image, caption="Movie: \"" + movie_data.title.values[0] + "\"") | |
del image | |
st.text("") | |
st.text("") | |
st.text("") | |
st.text("") | |
except: | |
col1, col2 = st.columns([5, 1]) | |
with col1: | |
st.write(response_content) | |
return response_content | |
# ------------------------------------------------------- # | |
############################### | |
## --------- MAIN ---------- ## | |
############################### | |
if __name__ == "__main__": | |
# Initialize image variable | |
poster = None | |
## --- Page config ------------ ## | |
# Set page title | |
st.title(""" | |
Movie Poster Generator :film_frames: | |
#### This is a movie poster generator based on movie's synopsis :sunglasses: | |
#### Just select the title of a movie to generate an alternative poster. | |
""") | |
# Set page footer | |
footer() | |
# Set sidebar with info | |
st.sidebar.markdown("## Generating movie posters using Stable Diffusion") | |
st.sidebar.markdown("This streamlit space aims to generate movie posters based on synopsis.") | |
st.sidebar.markdown("Firstly, the synopsis of the selected movie is extracted from the dataset and then summarized using Facebook's BART model.") | |
st.sidebar.markdown("Once the movie's summary is ready, it is passed to the Stable Diffusion v1.5 model using HF's Inference API, with some prompt tuning.") | |
## ---------------------------- ## | |
## Create dataset | |
data = load_dataset() | |
st.text("") | |
st.text("") | |
st.text("") | |
st.text("") | |
## Select box with all the movies as choices | |
session.selected_movie = st.selectbox(label="Select a movie to generate alternative poster", options=data.title_year) | |
st.text("") | |
st.text("") | |
## Create button to trigger poster generation | |
sd_options = ['Stable Diffusion v1.5', 'Stable Diffusion v2.1', 'Stable Diffusion XL'] | |
buffer1, col1, col2, buffer2 = st.columns([0.3, 1, 1, 1]) | |
session.selected_model = col1.selectbox(label="Select SD model version", options=sd_options, label_visibility="collapsed") | |
is_clicked = col2.button(label="Generate poster!") | |
st.text("") | |
st.text("") | |
## Clear cache between runs | |
st.runtime.legacy_caching.clear_cache() | |
generate_poster.clear() | |
## Generate poster | |
if is_clicked: | |
poster = generate_poster(data[data.title_year==session.selected_movie], session.selected_model) | |
generate_poster.clear() | |
st.runtime.legacy_caching.clear_cache() | |
st.text("") | |
st.text("") | |
st.text("") | |
st.text("") | |