Spaces:
Runtime error
Runtime error
## Alternative movie poster generator | |
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import json | |
import requests | |
import os | |
from streamlit import session_state as session | |
from datetime import time, datetime | |
from zipfile import ZipFile | |
from sentence_transformers import SentenceTransformer | |
from diffusers import DiffusionPipeline | |
from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts | |
from htbuilder.units import percent, px | |
from htbuilder.funcs import rgba, rgb | |
############################### | |
## --- GLOBAL VARIABLES ---- ## | |
############################### | |
IS_MODEL_LOADED = False | |
PATH_JSON = '/home/user/app/.kaggle/kaggle.json' | |
# Environment variables to authenticate Kaggle account | |
os.environ['KAGGLE_USERNAME'] = st.secrets['username'] | |
os.environ['KAGGLE_KEY'] = st.secrets['key'] | |
os.environ['KAGGLE_CONFIG_DIR'] = PATH_JSON | |
from kaggle.api.kaggle_api_extended import KaggleApi | |
############################### | |
## ------- FUNCTIONS ------- ## | |
############################### | |
def link(link, text, **style): | |
return a(_href=link, _target="_blank", style=styles(**style))(text) | |
def image(src_as_string, **style): | |
return img(src=src_as_string, style=styles(**style)) | |
def layout(*args): | |
style = """ | |
<style> | |
# MainMenu {visibility: hidden;} | |
footer {visibility: hidden;} | |
.stApp { bottom: 105px; } | |
</style> | |
""" | |
style_div = styles( | |
position="fixed", | |
left=0, | |
bottom=0, | |
margin=px(0, 0, 0, 0), | |
width=percent(100), | |
color="black", | |
text_align="center", | |
height="auto", | |
opacity=1 | |
) | |
style_hr = styles( | |
display="block", | |
margin=px(8, 8, "auto", "auto"), | |
border_style="inset", | |
border_width=px(2) | |
) | |
body = p() | |
foot = div( | |
style=style_div | |
)( | |
hr( | |
style=style_hr | |
), | |
body | |
) | |
st.markdown(style, unsafe_allow_html=True) | |
for arg in args: | |
if isinstance(arg, str): | |
body(arg) | |
elif isinstance(arg, HtmlElement): | |
body(arg) | |
st.markdown(str(foot), unsafe_allow_html=True) | |
def footer(): | |
myargs = [ | |
"Made in ", | |
image('https://avatars3.githubusercontent.com/u/45109972?s=400&v=4', | |
width=px(25), height=px(25)), | |
" with ❤️ by ", | |
link("https://www.linkedin.com/in/gaspar-avit/", "Gaspar Avit"), | |
] | |
layout(*myargs) | |
def authenticate_kaggle(): | |
# Connect to kaggle API | |
if not os.path.exists(PATH_JSON): | |
api_token = {"username":st.secrets['username'],"key":st.secrets['key']} | |
with open(PATH_JSON, 'w') as file: | |
json.dump(api_token, file) | |
# Activate Kaggle API | |
api = KaggleApi() | |
api.authenticate() | |
try: | |
api.authenticate() | |
except: | |
with open('/home/appuser/.kaggle/kaggle.json', 'w') as file: | |
json.dump(api_token, file) | |
api.authenticate() | |
def load_dataset(): | |
""" | |
Load Dataset from Kaggle | |
-return: dataframe containing dataset | |
""" | |
# Connect to kaggle API | |
authenticate_kaggle() | |
# Downloading Movies dataset | |
api.dataset_download_file('rounakbanik/the-movies-dataset', 'movies_metadata.csv') | |
# Extract data | |
zf = ZipFile('movies_metadata.csv.zip') | |
zf.extractall() | |
zf.close() | |
# Create dataframe | |
data = pd.read_csv('movies_metadata.csv', low_memory=False) | |
return data | |
def load_model(): | |
model = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") | |
IS_MODEL_LOADED = True | |
return model | |
#return DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") | |
#return DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-2") | |
def query_summarization(text): | |
""" | |
Get summarization from HuggingFace Inference API | |
-param text: text to be summarized | |
-return: summarized text | |
""" | |
API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn" | |
headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"} | |
payload = {"inputs": f"{text}",} | |
response = requests.request("POST", API_URL, headers=headers, json=payload).json() | |
return response[0].get('summary_text') | |
def generate_poster(movie_data): | |
""" | |
Function for recommending movies | |
-param movie_data: metadata of movie selected by user | |
-return: image of generated alternative poster | |
""" | |
# Get summarization of movie synopsis | |
with st.spinner("Please wait while the synopsis is being summarized..."): | |
synopsis_sum = query_summarization(movie_data.overview.values[0]) | |
st.text("") | |
st.text("") | |
st.text(synopsis_sum) | |
# Load text-to-image model | |
if not IS_MODEL_LOADED: | |
with st.spinner("Loading Text to Image model..."): | |
pipeline = load_model() | |
# Get image based on synopsis | |
poster_image = pipeline(synopsis_sum).images[0] | |
st.image(poster_image, caption=movie_data.title) | |
return poster_image | |
# ------------------------------------------------------- # | |
############################### | |
## --------- MAIN ---------- ## | |
############################### | |
if __name__ == "__main__": | |
# Initialize image variable | |
image = None | |
## Create dataset | |
data = load_dataset() | |
## --- Page config --- ## | |
# Set page title | |
st.title(""" | |
Alternative Movie Poster Generator :film_frames: | |
This is a movie poster generator based on movie's synopsis :sunglasses: | |
Just select the title of a movie to generate an alternative poster. | |
""") | |
## Set page footer | |
footer() | |
## ------------------- ## | |
st.text("") | |
st.text("") | |
st.text("") | |
st.text("") | |
session.selected_movie = st.selectbox(label="Select a movie to generate alternative poster", options=data.title) | |
st.text("") | |
st.text("") | |
buffer1, col1, buffer2 = st.columns([1.3, 1, 1]) | |
is_clicked = col1.button(label="Generate poster!") | |
if is_clicked: | |
image = generate_poster(data[data.title==session.selected_movie]) | |
st.text("") | |
st.text("") | |
st.text("") | |
st.text("") | |
if image is not None: | |
st.image(image, caption=session.selected_movie.title.values[0]) | |