Pablinho commited on
Commit
f8fcf48
1 Parent(s): a754dac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -1,28 +1,28 @@
1
  import gradio as gr
2
- from brain import img2txt, generate_story
3
- import os
4
 
 
5
 
6
- def generate_story_from_image(image):
7
- """ Generate a story from an image."""
8
- temp_image_path = "assets/image.jpg"
9
- image.save(temp_image_path)
10
 
11
- scenario = img2txt(temp_image_path)
12
- story = generate_story(scenario)
13
-
14
- os.remove(temp_image_path)
15
-
16
- return story
17
 
18
 
19
  iface = gr.Interface(
20
  fn=generate_story_from_image,
21
- inputs=gr.Image(type="pil"),
 
 
 
 
 
 
 
22
  outputs="text",
23
  title="Kids Story Generator",
24
- description="Upload an image and get a kids story based on it!",
25
- examples=[["assets/image.jpg"]],
26
  )
27
 
28
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from brain import StoryGenerator
 
3
 
4
+ story_generator = StoryGenerator()
5
 
 
 
 
 
6
 
7
+ def generate_story_from_image(image, model_name):
8
+ """Wrapper function to use with Gradio interface"""
9
+ return story_generator.generate_story_from_image(image, model_name)
 
 
 
10
 
11
 
12
  iface = gr.Interface(
13
  fn=generate_story_from_image,
14
+ inputs=[
15
+ gr.Image(type="pil"),
16
+ gr.Dropdown(
17
+ choices=list(story_generator.text_models.keys()),
18
+ label="Choose a model",
19
+ value="Mistral-7B"
20
+ )
21
+ ],
22
  outputs="text",
23
  title="Kids Story Generator",
24
+ description="Upload an image, choose a model, and get a kids story based on it!",
25
+ examples=[["assets/image.jpg", "Mistral-7B"]],
26
  )
27
 
28
  if __name__ == "__main__":