gaspar-avit commited on
Commit
ac9d9ab
·
1 Parent(s): 5807d17

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -1
app.py CHANGED
@@ -1,3 +1,297 @@
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- st.title("""Movie Poster Generator""")
 
1
+ ## Alternative movie poster generator
2
+
3
+
4
+
5
  import streamlit as st
6
+ import pandas as pd
7
+ import numpy as np
8
+ import json
9
+ import requests
10
+ import os
11
+ import io
12
+
13
+ from streamlit import session_state as session
14
+ from datetime import time, datetime
15
+ from zipfile import ZipFile
16
+ from htbuilder import HtmlElement, div, ul, li, br, hr, a, p, img, styles, classes, fonts
17
+ from htbuilder.units import percent, px
18
+ from htbuilder.funcs import rgba, rgb
19
+ from PIL import Image
20
+
21
+ ###############################
22
+ ## --- GLOBAL VARIABLES ---- ##
23
+ ###############################
24
+
25
+
26
+ PATH_JSON = '/home/user/.kaggle/kaggle.json'
27
+
28
+
29
+
30
+ # Environment variables to authenticate Kaggle account
31
+ os.environ['KAGGLE_USERNAME'] = st.secrets['username']
32
+ os.environ['KAGGLE_KEY'] = st.secrets['key']
33
+ os.environ['KAGGLE_CONFIG_DIR'] = PATH_JSON
34
+
35
+ from kaggle.api.kaggle_api_extended import KaggleApi
36
+
37
+
38
+
39
+ ###############################
40
+ ## ------- FUNCTIONS ------- ##
41
+ ###############################
42
+
43
+ def link(link, text, **style):
44
+ return a(_href=link, _target="_blank", style=styles(**style))(text)
45
+
46
+ def image(src_as_string, **style):
47
+ return img(src=src_as_string, style=styles(**style))
48
+
49
+ def layout(*args):
50
+
51
+ style = """
52
+ <style>
53
+ # MainMenu {visibility: hidden;}
54
+ footer {visibility: hidden;}
55
+ .stApp { bottom: 105px; }
56
+ </style>
57
+ """
58
+
59
+ style_div = styles(
60
+ position="fixed",
61
+ left=0,
62
+ bottom=0,
63
+ margin=px(0, 0, 0, 0),
64
+ width=percent(100),
65
+ color="black",
66
+ text_align="center",
67
+ height="auto",
68
+ opacity=1
69
+ )
70
+
71
+ style_hr = styles(
72
+ display="block",
73
+ margin=px(8, 8, "auto", "auto"),
74
+ border_style="inset",
75
+ border_width=px(2)
76
+ )
77
+
78
+ body = p()
79
+ foot = div(
80
+ style=style_div
81
+ )(
82
+ hr(
83
+ style=style_hr
84
+ ),
85
+ body
86
+ )
87
+
88
+ st.markdown(style, unsafe_allow_html=True)
89
+
90
+ for arg in args:
91
+ if isinstance(arg, str):
92
+ body(arg)
93
+
94
+ elif isinstance(arg, HtmlElement):
95
+ body(arg)
96
+
97
+ st.markdown(str(foot), unsafe_allow_html=True)
98
+
99
+ def footer():
100
+ myargs = [
101
+ #"Made in ",
102
+ #image('https://avatars3.githubusercontent.com/u/45109972?s=400&v=4',
103
+ # width=px(25), height=px(25)),
104
+ #" with ❤️ by ",
105
+ "Made with ❤️ by ",
106
+ link("https://www.linkedin.com/in/gaspar-avit/", "Gaspar Avit"),
107
+ ]
108
+ layout(*myargs)
109
+
110
+ def authenticate_kaggle():
111
+ # Connect to kaggle API
112
+
113
+ # Save credentials to json file
114
+ if not os.path.exists(PATH_JSON):
115
+ api_token = {"username":st.secrets['username'],"key":st.secrets['key']}
116
+ with open(PATH_JSON, 'w') as file:
117
+ json.dump(api_token, file)
118
+
119
+ # Activate Kaggle API
120
+ global api
121
+ api = KaggleApi()
122
+ api.authenticate()
123
+
124
+
125
+ @st.experimental_memo(persist=True, show_spinner=False, suppress_st_warning=True)
126
+ def load_dataset():
127
+ """
128
+ Load Dataset from Kaggle
129
+ -return: dataframe containing dataset
130
+ """
131
+
132
+ ## --- Connect to kaggle API --- ##
133
+ # Save credentials to json file
134
+ if not os.path.exists(PATH_JSON):
135
+ api_token = {"username":st.secrets['username'],"key":st.secrets['key']}
136
+ with open(PATH_JSON, 'w') as file:
137
+ json.dump(api_token, file)
138
+
139
+ # Activate Kaggle API
140
+ global api
141
+ api = KaggleApi()
142
+ api.authenticate()
143
+ ## ----------------------------- ##
144
+
145
+ # Downloading Movies dataset
146
+ api.dataset_download_file('rounakbanik/the-movies-dataset', 'movies_metadata.csv')
147
+
148
+ # Extract data
149
+ zf = ZipFile('movies_metadata.csv.zip')
150
+ zf.extractall()
151
+ zf.close()
152
+
153
+ # Create dataframe
154
+ data = pd.read_csv('movies_metadata.csv', low_memory=False)
155
+
156
+ return data
157
+
158
+
159
+ def query_summary(text):
160
+ """
161
+ Get summarization from HuggingFace Inference API
162
+ -param text: text to be summarized
163
+ -return: summarized text
164
+ """
165
+ API_URL = "https://api-inference.huggingface.co/models/facebook/bart-large-cnn"
166
+ headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"}
167
+ payload = {"inputs": f"{text}",}
168
+
169
+ response = requests.request("POST", API_URL, headers=headers, json=payload).json()
170
+
171
+ try:
172
+ text = response[0].get('summary_text')
173
+ except:
174
+ text = response[0]
175
+ return text
176
+
177
+
178
+ def query_generate(text):
179
+ """
180
+ Get image from HuggingFace Inference API
181
+ -param text: text to generate image
182
+ -return: generated image
183
+ """
184
+ API_URL = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5"
185
+ headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"}
186
+ payload = {"inputs": f"{text}",}
187
+
188
+ response = requests.post(API_URL, headers=headers, json=payload)
189
+ return response.content
190
+
191
+
192
+ def generate_poster(movie_data):
193
+ """
194
+ Function for recommending movies
195
+ -param movie_data: metadata of movie selected by user
196
+ -return: image of generated alternative poster
197
+ """
198
+
199
+
200
+ # Get summarization of movie synopsis
201
+ with st.spinner("Please wait while the synopsis is being summarized..."):
202
+ synopsis_sum = query_summary(movie_data.overview.values[0])
203
+
204
+ st.text("")
205
+ st.text("")
206
+ st.title("""
207
+ #### Synopsis: """)
208
+ st.text("Synopsis summary: " + synopsis_sum)
209
+ st.text("")
210
+
211
+
212
+ # Get image based on synopsis
213
+ with st.spinner("Generating poster image..."):
214
+ poster_image = query_generate(synopsis_sum)
215
+
216
+ # Show image
217
+ try:
218
+ image = Image.open(io.BytesIO(poster_image))
219
+
220
+ st.text("")
221
+ st.text("")
222
+ st.title("""
223
+ #### Resulting poster: """)
224
+ col1, col2, col3 = st.columns(3)
225
+
226
+ with col1:
227
+ st.write(' ')
228
+
229
+ with col2:
230
+ st.text("")
231
+ st.image(image, caption="Movie: \"" + movie_data.title.values[0] + "\"")
232
+
233
+ with col3:
234
+ st.write(' ')
235
+
236
+ except:
237
+ st.text(poster_image)
238
+
239
+
240
+ return poster_image
241
+ # ------------------------------------------------------- #
242
+
243
+
244
+ ###############################
245
+ ## --------- MAIN ---------- ##
246
+ ###############################
247
+
248
+
249
+ if __name__ == "__main__":
250
+
251
+
252
+ # Initialize image variable
253
+ poster = None
254
+
255
+ ## --- Page config ------------ ##
256
+ # Set page title
257
+ st.title("""
258
+ Movie Poster Generator :film_frames:
259
+
260
+ #### This is a movie poster generator based on movie's synopsis :sunglasses:
261
+
262
+ #### Just select the title of a movie to generate an alternative poster.
263
+ """)
264
+
265
+ # Set page footer
266
+ footer()
267
+ ## ---------------------------- ##
268
+
269
+
270
+ ## Create dataset
271
+ data = load_dataset()
272
+
273
+ st.text("")
274
+ st.text("")
275
+ st.text("")
276
+ st.text("")
277
+
278
+ selected_movie = st.selectbox(label="Select a movie to generate alternative poster", options=data.title)
279
+
280
+ st.text("")
281
+ st.text("")
282
+
283
+ buffer1, col1, buffer2 = st.columns([1.3, 1, 1])
284
+
285
+ is_clicked = col1.button(label="Generate poster!")
286
+ is_clicked_rerun = None
287
+
288
+ if is_clicked:
289
+ poster = generate_poster(data[data.title==selected_movie])
290
+
291
+ if poster is not None:
292
+ buffer1, col1, buffer2 = st.columns([1.3, 1, 1])
293
+ is_clicked_rerun = col1.button(label="Rerun with same movie!")
294
+
295
+ if is_clicked_rerun:
296
+ poster = generate_poster(data[data.title==selected_movie])
297