gaspar-avit commited on
Commit
4a8f50c
1 Parent(s): ff60cfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -175,16 +175,29 @@ def query_summary(text):
175
  return text
176
 
177
 
178
- def query_generate(text, title, genres, year):
179
  """
180
  Get image from HuggingFace Inference API
181
  -param text: text to generate image
 
 
 
 
182
  -return: generated image
183
  """
184
 
185
- #API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
186
- API_URL = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5"
187
- #API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-2-1"
 
 
 
 
 
 
 
 
 
188
  headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"}
189
  text = 'A Poster for the movie ' + title.split('(')[0] + 'in portrait mode based on the following synopsis: \"' + text + '\". Style: ' + genres + '. Year ' + year + \
190
  '. Ignore ' + ''.join(random.choices(string.ascii_letters, k=10))
@@ -205,7 +218,7 @@ def query_generate(text, title, genres, year):
205
  return response.content
206
 
207
  @st.experimental_memo(persist=False, show_spinner=False, suppress_st_warning=True)
208
- def generate_poster(movie_data):
209
  """
210
  Function for recommending movies
211
  -param movie_data: metadata of movie selected by user
@@ -240,7 +253,7 @@ def generate_poster(movie_data):
240
 
241
  # Get image based on synopsis
242
  with st.spinner("Generating poster..."):
243
- response_content = query_generate(synopsis_sum, title, genres_string, year)
244
 
245
  # Show image
246
  try:
@@ -315,8 +328,10 @@ if __name__ == "__main__":
315
  st.text("")
316
 
317
  ## Create button to trigger poster generation
318
- buffer1, col1, buffer2 = st.columns([1.3, 1, 1])
319
- is_clicked = col1.button(label="Generate poster!")
 
 
320
 
321
  st.text("")
322
  st.text("")
@@ -327,7 +342,7 @@ if __name__ == "__main__":
327
 
328
  ## Generate poster
329
  if is_clicked:
330
- poster = generate_poster(data[data.title_year==session.selected_movie])
331
  generate_poster.clear()
332
  st.runtime.legacy_caching.clear_cache()
333
 
 
175
  return text
176
 
177
 
178
+ def query_generate(text, title, genres, year, selected_model='Stable Diffusion v1.5'):
179
  """
180
  Get image from HuggingFace Inference API
181
  -param text: text to generate image
182
+ -param title: title of the movie
183
+ -param genres: genres of the movie
184
+ -param year: year of the movie
185
+
186
  -return: generated image
187
  """
188
 
189
+ if selected_model=='Stable Diffusion XL':
190
+ API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
191
+
192
+ elif selected_model=='Stable Diffusion v2.1':
193
+ API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-2-1"
194
+
195
+ elif selected_model=='Stable Diffusion v1.5':
196
+ API_URL = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5"
197
+
198
+ else:
199
+ raise ValueError("Value not valid for argument 'selected_model'.")
200
+
201
  headers = {"Authorization": f"Bearer {st.secrets['hf_token']}"}
202
  text = 'A Poster for the movie ' + title.split('(')[0] + 'in portrait mode based on the following synopsis: \"' + text + '\". Style: ' + genres + '. Year ' + year + \
203
  '. Ignore ' + ''.join(random.choices(string.ascii_letters, k=10))
 
218
  return response.content
219
 
220
  @st.experimental_memo(persist=False, show_spinner=False, suppress_st_warning=True)
221
+ def generate_poster(movie_data, selected_model):
222
  """
223
  Function for recommending movies
224
  -param movie_data: metadata of movie selected by user
 
253
 
254
  # Get image based on synopsis
255
  with st.spinner("Generating poster..."):
256
+ response_content = query_generate(synopsis_sum, title, genres_string, year, selected_model)
257
 
258
  # Show image
259
  try:
 
328
  st.text("")
329
 
330
  ## Create button to trigger poster generation
331
+ sd_options = ['Stable Diffusion XL', 'Stable Diffusion v2.1', 'Stable Diffusion v1.5']
332
+ buffer1, col1, col2, buffer2 = st.columns([0.3, 1, 1, 1])
333
+ session.selected_model = col1.selectbox(label="Select SD model version", options=sd_options)
334
+ is_clicked = col2.button(label="Generate poster!")
335
 
336
  st.text("")
337
  st.text("")
 
342
 
343
  ## Generate poster
344
  if is_clicked:
345
+ poster = generate_poster(data[data.title_year==session.selected_movie], session.selected_model)
346
  generate_poster.clear()
347
  st.runtime.legacy_caching.clear_cache()
348