teganmosi commited on
Commit
588bca7
1 Parent(s): 6c985bd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from llava.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
4
+ from llava.conversation import SeparatorStyle, conv_templates
5
+ from llava.mm_utils import (
6
+ KeywordsStoppingCriteria,
7
+ get_model_name_from_path,
8
+ process_images,
9
+ tokenizer_image_token,
10
+ )
11
+ from llava.model.builder import load_pretrained_model
12
+ from llava.utils import disable_torch_init
13
+ from PIL import Image
14
+ import torch
15
+
16
+ # Disable PyTorch initialization
17
+ disable_torch_init()
18
+
19
+ # Load the pretrained model
20
+ MODEL = "4bit/llava-v1.5-13b-3GB"
21
+ model_name = get_model_name_from_path(MODEL)
22
+
23
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
24
+ model_path=MODEL, model_base=None, model_name=model_name, load_4bit=True
25
+ )
26
+
27
+ # Define the prompt creation function
28
+ def create_prompt(prompt: str):
29
+ conv = conv_templates["llava_v0"].copy()
30
+ roles = conv.roles
31
+ prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
32
+ conv.append_message(roles[0], prompt)
33
+ conv.append_message(roles[1], None)
34
+ return conv.get_prompt(), conv
35
+
36
+ # Define the image processing function
37
+ def process_image(image):
38
+ args = {"image_aspect_ratio": "pad"}
39
+ image_tensor = process_images([image], image_processor, args)
40
+ return image_tensor.to(model.device, dtype=torch.float16)
41
+
42
+ # Define the image description function
43
+ def describe_image(image_file):
44
+ image = Image.open(image_file)
45
+ image.resize((500, 500))
46
+ processed_image = process_image(image)
47
+ prompt, _ = create_prompt("Describe the image")
48
+ input_ids = (
49
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
50
+ .unsqueeze(0)
51
+ .to(model.device)
52
+ )
53
+
54
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
55
+ stopping_criteria = KeywordsStoppingCriteria(
56
+ keywords=[stop_str], tokenizer=tokenizer, input_ids=input_ids
57
+ )
58
+
59
+ with torch.inference_mode():
60
+ output_ids = model.generate(
61
+ input_ids,
62
+ images=processed_image,
63
+ do_sample=True,
64
+ temperature=0.01,
65
+ max_new_tokens=512,
66
+ use_cache=True,
67
+ stopping_criteria=[stopping_criteria],
68
+ )
69
+
70
+ description = tokenizer.decode(
71
+ output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
72
+ ).strip()
73
+
74
+ return description
75
+
76
+
77
+ iface = gr.Interface(
78
+ fn=describe_image,
79
+ inputs=gr.Image(type="pil", label="Image"), # Specify the label for the input
80
+ outputs=gr.Textbox(),
81
+ live=True,
82
+ capture_session=True
83
+ )
84
+
85
+ # Launch the Gradio interface
86
+ iface.launch()