Spaces:
Runtime error
Runtime error
## Alternative movie poster generator | |
# Install newer version of streamlit | |
import subprocess | |
import sys | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "streamlit==1.19.0"]) | |
import streamlit as st; print(st.__version__) | |
import pandas as pd | |
import numpy as np | |
import json | |
import requests | |
import os | |
import io | |
from streamlit import session_state as session | |
from datetime import time, datetime | |
from zipfile import ZipFile | |
from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts | |
from htbuilder.units import percent, px | |
from htbuilder.funcs import rgba, rgb | |
from PIL import Image | |
############################### | |
## --- GLOBAL VARIABLES ---- ## | |
############################### | |
PATH_JSON = '/home/user/.kaggle/kaggle.json' | |
# Environment variables to authenticate Kaggle account | |
os.environ['KAGGLE_USERNAME'] = st.secrets['username'] | |
os.environ['KAGGLE_KEY'] = st.secrets['key'] | |
os.environ['KAGGLE_CONFIG_DIR'] = PATH_JSON | |
from kaggle.api.kaggle_api_extended import KaggleApi | |
############################### | |
## ------- FUNCTIONS ------- ## | |
############################### | |
def link(link, text, **style): | |
return a(_href=link, _target="_blank", style=styles(**style))(text) | |
def image(src_as_string, **style): | |
return img(src=src_as_string, style=styles(**style)) | |
def layout(*args): | |
style = """ | |
<style> | |
# MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
.stApp { bottom: 105px; } | |
</style> | |
""" | |
style_div = styles( | |
position="fixed", | |
left=0, | |
bottom=0, | |
margin=px(0, 0, 0, 0), | |
width=percent(100), | |
color="black", | |
text_align="center", | |
height="auto", | |
opacity=1 | |
) | |
style_hr = styles( | |
display="block", | |
margin=px(8, 8, "auto", "auto"), | |
border_style="inset", | |
border_width=px(2) | |
) | |
body = p() | |
foot = div( | |
style=style_div | |
)( | |
hr( | |
style=style_hr | |
), | |
body | |
) | |
st.markdown(style, unsafe_allow_html=True) | |
for arg in args: | |
if isinstance(arg, str): | |
body(arg) | |
elif isinstance(arg, HtmlElement): | |
body(arg) | |
st.markdown(str(foot), unsafe_allow_html=True) | |
def footer(): | |
myargs = [ | |
#"Made in ", | |
#image('https://avatars3.githubusercontent.com/u/45109972?s=400&v=4', | |
# width=px(25), height=px(25)), | |
#" with ❤️ by ", | |
"Made with ❤️ by ", | |
link("https://www.linkedin.com/in/gaspar-avit/", "Gaspar Avit"), | |
] | |
layout(*myargs) | |
def authenticate_kaggle(): | |
# Connect to kaggle API | |
# Save credentials to json file | |
if not os.path.exists(PATH_JSON): | |
api_token = {"username":st.secrets['username'],"key":st.secrets['key']} | |
with open(PATH_JSON, 'w') as file: | |
json.dump(api_token, file) | |
# Activate Kaggle API | |
global api | |
api = KaggleApi() | |
api.authenticate() | |
def load_dataset(): | |
""" | |
Load Dataset from Kaggle | |
-return: dataframe containing dataset | |
""" | |
## --- Connect to kaggle API --- ## | |
# Save credentials to json file | |
if not os.path.exists(PATH_JSON): | |
api_token = {"username":st.secrets['username'],"key":st.secrets['key']} | |
with open(PATH_JSON, 'w') as file: | |
json.dump(api_token, file) | |
# Activate Kaggle API | |
global api | |
api = KaggleApi() | |
api.authenticate() | |
## ----------------------------- ## | |
# Downloading Movies dataset | |
api.dataset_download_file('rounakbanik/the-movies-dataset', 'movies_metadata.csv') | |
# Extract data | |
zf = ZipFile('movies_metadata.csv.zip') | |
zf.extractall() | |
zf.close() | |
# Create dataframe | |
data = pd.read_csv('movies_metadata.csv', low_memory=False) | |
data['year'] = data["release_date"].map(lambda x: x.split('-')[0] if isinstance(x, str) else '0') | |
data['title_year'] = data['title'] + ' (' + data['year'] + ')' | |
return data | |
def query_summary(text): | |
""" | |
Get summarization from HuggingFace Inference API | |
-param text: text to be summarized | |
-return: summarized text | |
""" | |
API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn" | |
headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"} | |
payload = {"inputs": f"{text}",} | |
response = requests.request("POST", API_URL, headers=headers, json=payload).json() | |
try: | |
text = response[0].get('summary_text') | |
except: | |
text = response[0] | |
return text | |
def query_generate(text): | |
""" | |
Get image from HuggingFace Inference API | |
-param text: text to generate image | |
-return: generated image | |
""" | |
API_URL = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5" | |
headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"} | |
text = "Poster of movie. " + text | |
payload = {"inputs": f"{text}",} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.content | |
#@st.cache_data | |
def generate_poster(movie_data): | |
""" | |
Function for recommending movies | |
-param movie_data: metadata of movie selected by user | |
-return: image of generated alternative poster | |
""" | |
# Get summarization of movie synopsis | |
with st.spinner("Please wait while the synopsis is being summarized..."): | |
synopsis_sum = query_summary(movie_data.overview.values[0]) | |
st.text("") | |
st.text("") | |
st.subheader("Synopsis:") | |
st.text("Synopsis summary: " + synopsis_sum) | |
st.text("") | |
# Get image based on synopsis | |
with st.spinner("Generating poster image..."): | |
poster_image = query_generate(synopsis_sum) | |
# Show image | |
try: | |
image = Image.open(io.BytesIO(poster_image)) | |
st.text("") | |
st.text("") | |
st.subheader("Resulting poster:") | |
col1, col2, col3 = st.columns([1, 10, 1]) | |
with col1: | |
st.write("") | |
with col2: | |
st.text("") | |
st.image(image, caption="Movie: \"" + movie_data.title.values[0] + "\"") | |
with col3: | |
st.write("") | |
except: | |
st.text(poster_image) | |
return poster_image | |
# ------------------------------------------------------- # | |
############################### | |
## --------- MAIN ---------- ## | |
############################### | |
if __name__ == "__main__": | |
# Initialize image variable | |
poster = None | |
## --- Page config ------------ ## | |
# Set page title | |
st.title(""" | |
Movie Poster Generator :film_frames: | |
#### This is a movie poster generator based on movie's synopsis :sunglasses: | |
#### Just select the title of a movie to generate an alternative poster. | |
""") | |
# Set page footer | |
footer() | |
## ---------------------------- ## | |
## Create dataset | |
data = load_dataset() | |
st.text("") | |
st.text("") | |
st.text("") | |
st.text("") | |
selected_movie = st.selectbox(label="Select a movie to generate alternative poster", options=data.title_year) | |
st.text("") | |
st.text("") | |
buffer1, col1, buffer2 = st.columns([1.3, 1, 1]) | |
is_clicked = col1.button(label="Generate poster!") | |
is_clicked_rerun = None | |
if is_clicked: | |
poster = generate_poster(data[data.title_year==selected_movie]) | |
#st.cache_data.clear() | |
_= """ | |
if poster is not None: | |
buffer1, col1, buffer2 = st.columns([1.3, 1, 1]) | |
is_clicked_rerun = col1.button(label="Rerun with same movie!") | |
if is_clicked_rerun: | |
poster = generate_poster(data[data.title_year==selected_movie]) | |
""" | |