maxiw commited on
Commit
de6a4a2
1 Parent(s): 7fa80e9

initial setup

Browse files
Files changed (2) hide show
  1. app.py +140 -4
  2. requirements.txt +8 -0
app.py CHANGED
@@ -1,7 +1,143 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
4
+ from qwen_vl_utils import process_vision_info
5
+ import torch
6
+ import base64
7
+ from PIL import Image, ImageDraw
8
+ from io import BytesIO
9
+ import re
10
 
 
 
11
 
12
+ models = {
13
+ "OS-Copilot/OS-Atlas-Base-7B": Qwen2VLForConditionalGeneration.from_pretrained("OS-Copilot/OS-Atlas-Base-7B", torch_dtype="auto", device_map="auto"),
14
+ }
15
+
16
+ processors = {
17
+ "OS-Copilot/OS-Atlas-Base-7B": AutoProcessor.from_pretrained("OS-Copilot/OS-Atlas-Base-7B")
18
+ }
19
+
20
+
21
+ def image_to_base64(image):
22
+ buffered = BytesIO()
23
+ image.save(buffered, format="PNG")
24
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
25
+ return img_str
26
+
27
+
28
+ def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2):
29
+ draw = ImageDraw.Draw(image)
30
+ for box in bounding_boxes:
31
+ xmin, ymin, xmax, ymax = box
32
+ draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
33
+ return image
34
+
35
+
36
+ def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000):
37
+ x_scale = original_width / scaled_width
38
+ y_scale = original_height / scaled_height
39
+ rescaled_boxes = []
40
+ for box in bounding_boxes:
41
+ xmin, ymin, xmax, ymax = box
42
+ rescaled_box = [
43
+ xmin * x_scale,
44
+ ymin * y_scale,
45
+ xmax * x_scale,
46
+ ymax * y_scale
47
+ ]
48
+ rescaled_boxes.append(rescaled_box)
49
+ return rescaled_boxes
50
+
51
+
52
+ @spaces.GPU
53
+ def run_example(image, text_input, system_prompt, model_id="OS-Copilot/OS-Atlas-Base-7B"):
54
+ model = models[model_id].eval()
55
+ processor = processors[model_id]
56
+
57
+ messages = [
58
+ {
59
+ "role": "user",
60
+ "content": [
61
+ {"type": "image", "image": f"data:image;base64,{image_to_base64(image)}"},
62
+ {"type": "text", "text": text_input},
63
+ ],
64
+ }
65
+ ]
66
+
67
+ text = processor.apply_chat_template(
68
+ messages, tokenize=False, add_generation_prompt=True
69
+ )
70
+ image_inputs, video_inputs = process_vision_info(messages)
71
+ inputs = processor(
72
+ text=[text],
73
+ images=image_inputs,
74
+ videos=video_inputs,
75
+ padding=True,
76
+ return_tensors="pt",
77
+ )
78
+ inputs = inputs.to("cuda")
79
+
80
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
81
+ generated_ids_trimmed = [
82
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
83
+ ]
84
+ output_text = processor.batch_decode(
85
+ generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
86
+ )
87
+ print(output_text)
88
+
89
+ object_ref_pattern = r"<\|object_ref_start\|>(.*?)<\|object_ref_end\|>"
90
+ box_pattern = r"<\|box_start\|>(.*?)<\|box_end\|>"
91
+
92
+ object_ref = re.search(object_ref_pattern, text).group(1)
93
+ box_content = re.search(box_pattern, text).group(1)
94
+
95
+ boxes = [tuple(map(int, pair.strip("()").split(','))) for pair in box_content.split("),(")]
96
+ boxes = [boxes[0][0], boxes[0][1], boxes[1][0], boxes[1][1]]
97
+
98
+ scaled_boxes = rescale_bounding_boxes(boxes, image.width, image.height)
99
+ return output_text, boxes, draw_bounding_boxes(image, scaled_boxes)
100
+
101
+ css = """
102
+ #output {
103
+ height: 500px;
104
+ overflow: auto;
105
+ border: 1px solid #ccc;
106
+ }
107
+ """
108
+ default_system_prompt = "You are a helpfull assistant to detect objects in images. When asked to detect elements based on a description you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] whith the values beeing scaled to 1000 by 1000 pixels. When there are more than one result, answer with a list of bounding boxes in the form of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."
109
+
110
+ with gr.Blocks(css=css) as demo:
111
+ gr.Markdown(
112
+ """
113
+ # OS-Atlas Demo
114
+ """)
115
+ with gr.Tab(label="OS-Atlas Input"):
116
+ with gr.Row():
117
+ with gr.Column():
118
+ input_img = gr.Image(label="Input Image", type="pil")
119
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="OS-Copilot/OS-Atlas-Base-7B")
120
+ system_prompt = gr.Textbox(label="System Prompt", value=default_system_prompt)
121
+ text_input = gr.Textbox(label="User Prompt")
122
+ submit_btn = gr.Button(value="Submit")
123
+ with gr.Column():
124
+ model_output_text = gr.Textbox(label="Model Output Text")
125
+ parsed_boxes = gr.Textbox(label="Parsed Boxes")
126
+ annotated_image = gr.Image(label="Annotated Image")
127
+
128
+ gr.Examples(
129
+ examples=[
130
+ ["assets/image1.jpg", "detect goats", default_system_prompt],
131
+ ["assets/image2.jpg", "detect blue button", default_system_prompt],
132
+ ["assets/image3.jpg", "detect person on bike", default_system_prompt],
133
+ ],
134
+ inputs=[input_img, text_input, system_prompt],
135
+ outputs=[model_output_text, parsed_boxes, annotated_image],
136
+ fn=run_example,
137
+ cache_examples=True,
138
+ label="Try examples"
139
+ )
140
+
141
+ submit_btn.click(run_example, [input_img, text_input, system_prompt, model_selector], [model_output_text, parsed_boxes, annotated_image])
142
+
143
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy==1.24.4
2
+ Pillow==10.3.0
3
+ Requests==2.31.0
4
+ torch
5
+ torchvision
6
+ transformers
7
+ accelerate==0.30.0
8
+ qwen-vl-utils