awacke1 commited on
Commit
eeebcf5
1 Parent(s): c91a3e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -1,21 +1,21 @@
1
  import os
2
  import gradio as gr
3
  import torch
4
- import PIL
5
- from transformers import AutoProcessor, AutoModelForCausalLM # Using AutoModel classes
6
 
7
  EXAMPLES_DIR = 'examples'
8
  DEFAULT_PROMPT = "<image>"
9
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
- # Load model using AutoModel with trust_remote_code=True
13
- model = AutoModelForCausalLM.from_pretrained('dhansmair/flamingo-mini', trust_remote_code=True)
14
  model.to(device)
15
  model.eval()
16
 
17
- # Initialize processor without the `device` argument
18
- processor = AutoProcessor.from_pretrained('dhansmair/flamingo-mini')
19
 
20
  # Setup some example images
21
  examples = []
@@ -28,14 +28,12 @@ if os.path.isdir(EXAMPLES_DIR):
28
  def predict_caption(image, prompt):
29
  assert isinstance(prompt, str)
30
 
31
- # Process the image using the model
32
- caption = model.generate(
33
- processor(images=image, prompt=prompt), # Pass processed inputs to the model
34
- max_length=50
35
- )
36
-
37
- if isinstance(caption, list):
38
- caption = caption[0]
39
 
40
  return caption
41
 
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
5
+ from PIL import Image # PIL should be imported separately for image handling
6
 
7
  EXAMPLES_DIR = 'examples'
8
  DEFAULT_PROMPT = "<image>"
9
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
+ # Load the BLIP2 model using the AutoModel with trust_remote_code=True
13
+ model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-flan-t5-xl', device_map="auto", torch_dtype=torch.float16)
14
  model.to(device)
15
  model.eval()
16
 
17
+ # Initialize processor
18
+ processor = Blip2Processor.from_pretrained('Salesforce/blip2-flan-t5-xl')
19
 
20
  # Setup some example images
21
  examples = []
 
28
  def predict_caption(image, prompt):
29
  assert isinstance(prompt, str)
30
 
31
+ # Convert the PIL image to the format expected by the processor
32
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
33
+
34
+ # Generate the caption
35
+ generated_ids = model.generate(**inputs, max_length=50)
36
+ caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
37
 
38
  return caption
39