awacke1 commited on
Commit
8e1683e
1 Parent(s): 973f818

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -5,12 +5,17 @@ import PIL
5
 
6
  from flamingo_mini import FlamingoConfig, FlamingoModel, FlamingoProcessor
7
 
 
 
8
  EXAMPLES_DIR = 'examples'
9
  DEFAULT_PROMPT = "<image>"
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
11
  model = FlamingoModel.from_pretrained('dhansmair/flamingo-mini')
12
  model.to(device)
13
  model.eval()
 
14
  processor = FlamingoProcessor(model.config, load_vision_processor=True)
15
 
16
  # setup some example images
@@ -20,16 +25,21 @@ if os.path.isdir(EXAMPLES_DIR):
20
  path = EXAMPLES_DIR + "/" + file
21
  examples.append([path, DEFAULT_PROMPT])
22
 
 
23
  def predict_caption(image, prompt):
24
  assert isinstance(prompt, str)
 
25
  features = processor.extract_features(image).to(device)
26
  caption = model.generate_captions(processor,
27
  visual_features=features,
28
  prompt=prompt)
 
29
  if isinstance(caption, list):
30
  caption = caption[0]
31
- return caption
32
 
 
 
 
33
  iface = gr.Interface(fn=predict_caption,
34
  inputs=[gr.Image(type="pil"), gr.Textbox(value=DEFAULT_PROMPT, label="Prompt")],
35
  examples=examples,
 
5
 
6
  from flamingo_mini import FlamingoConfig, FlamingoModel, FlamingoProcessor
7
 
8
+
9
+
10
  EXAMPLES_DIR = 'examples'
11
  DEFAULT_PROMPT = "<image>"
12
+
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
  model = FlamingoModel.from_pretrained('dhansmair/flamingo-mini')
16
  model.to(device)
17
  model.eval()
18
+
19
  processor = FlamingoProcessor(model.config, load_vision_processor=True)
20
 
21
  # setup some example images
 
25
  path = EXAMPLES_DIR + "/" + file
26
  examples.append([path, DEFAULT_PROMPT])
27
 
28
+
29
  def predict_caption(image, prompt):
30
  assert isinstance(prompt, str)
31
+
32
  features = processor.extract_features(image).to(device)
33
  caption = model.generate_captions(processor,
34
  visual_features=features,
35
  prompt=prompt)
36
+
37
  if isinstance(caption, list):
38
  caption = caption[0]
 
39
 
40
+ return caption
41
+
42
+
43
  iface = gr.Interface(fn=predict_caption,
44
  inputs=[gr.Image(type="pil"), gr.Textbox(value=DEFAULT_PROMPT, label="Prompt")],
45
  examples=examples,