Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ import gradio as gr
|
|
5 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
6 |
from transformers.dynamic_module_utils import get_imports
|
7 |
import torch
|
8 |
-
import requests
|
9 |
from PIL import Image, ImageDraw
|
10 |
import random
|
11 |
import numpy as np
|
@@ -13,6 +12,8 @@ import matplotlib.pyplot as plt
|
|
13 |
import matplotlib.patches as patches
|
14 |
import cv2
|
15 |
import io
|
|
|
|
|
16 |
|
17 |
def workaround_fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
18 |
if not str(filename).endswith("/modeling_florence2.py"):
|
@@ -21,118 +22,69 @@ def workaround_fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
|
21 |
imports.remove("flash_attn")
|
22 |
return imports
|
23 |
|
|
|
|
|
24 |
with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
|
25 |
-
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to(
|
26 |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
|
27 |
|
28 |
colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red',
|
29 |
'lime', 'indigo', 'violet', 'aqua', 'magenta', 'coral', 'gold', 'tan', 'skyblue']
|
30 |
|
31 |
-
def fig_to_pil(fig):
|
32 |
-
buf = io.BytesIO()
|
33 |
-
fig.savefig(buf, format='png')
|
34 |
-
buf.seek(0)
|
35 |
-
return Image.open(buf)
|
36 |
-
|
37 |
-
@spaces.GPU
|
38 |
def run_example(task_prompt, image, text_input=None):
|
39 |
-
if text_input is None
|
40 |
-
|
41 |
-
else:
|
42 |
-
prompt = task_prompt + text_input
|
43 |
-
inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
|
44 |
with torch.inference_mode():
|
45 |
-
generated_ids = model.generate(
|
46 |
-
input_ids=inputs["input_ids"],
|
47 |
-
pixel_values=inputs["pixel_values"],
|
48 |
-
max_new_tokens=1024,
|
49 |
-
early_stopping=False,
|
50 |
-
do_sample=False,
|
51 |
-
num_beams=3,
|
52 |
-
)
|
53 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
54 |
-
|
55 |
-
generated_text,
|
56 |
-
task=task_prompt,
|
57 |
-
image_size=(image.size[0], image.size[1])
|
58 |
-
)
|
59 |
-
return parsed_answer
|
60 |
|
61 |
def plot_bbox(image, data):
|
62 |
-
|
63 |
-
|
64 |
for bbox, label in zip(data['bboxes'], data['labels']):
|
65 |
x1, y1, x2, y2 = bbox
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
ax.axis('off')
|
70 |
-
return fig_to_pil(fig)
|
71 |
|
72 |
def draw_polygons(image, prediction, fill_mask=False):
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
for polygons, label in zip(prediction['polygons'], prediction['labels']):
|
77 |
-
color = random.choice(colormap)
|
78 |
-
fill_color = random.choice(colormap) if fill_mask else None
|
79 |
-
for _polygon in polygons:
|
80 |
-
_polygon = np.array(_polygon).reshape(-1, 2)
|
81 |
-
if _polygon.shape[0] < 3:
|
82 |
-
continue
|
83 |
-
_polygon = (_polygon * scale).reshape(-1).tolist()
|
84 |
-
if len(_polygon) % 2 != 0:
|
85 |
-
continue
|
86 |
-
polygon_points = np.array(_polygon).reshape(-1, 2)
|
87 |
-
if fill_mask:
|
88 |
-
polygon = patches.Polygon(polygon_points, edgecolor=color, facecolor=fill_color, linewidth=2)
|
89 |
-
else:
|
90 |
-
polygon = patches.Polygon(polygon_points, edgecolor=color, fill=False, linewidth=2)
|
91 |
-
ax.add_patch(polygon)
|
92 |
-
plt.text(polygon_points[0, 0], polygon_points[0, 1], label, color='white', fontsize=8, bbox=dict(facecolor=color, alpha=0.5))
|
93 |
-
ax.axis('off')
|
94 |
-
return fig_to_pil(fig)
|
95 |
-
|
96 |
-
def draw_ocr_bboxes(image, prediction):
|
97 |
-
fig, ax = plt.subplots()
|
98 |
-
ax.imshow(image)
|
99 |
-
scale = 1
|
100 |
-
bboxes, labels = prediction['quad_boxes'], prediction['labels']
|
101 |
-
for box, label in zip(bboxes, labels):
|
102 |
color = random.choice(colormap)
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
return fig_to_pil(fig)
|
111 |
-
|
112 |
|
113 |
@spaces.GPU(duration=120)
|
114 |
def process_video(input_video_path, task_prompt):
|
115 |
cap = cv2.VideoCapture(input_video_path)
|
116 |
if not cap.isOpened():
|
117 |
-
return None
|
118 |
|
119 |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
120 |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
121 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
122 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
123 |
|
124 |
-
|
125 |
-
cap.release()
|
126 |
-
return None
|
127 |
-
|
128 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
129 |
-
out = cv2.VideoWriter(
|
130 |
-
|
131 |
-
if not out.isOpened():
|
132 |
-
cap.release()
|
133 |
-
return None
|
134 |
|
135 |
processed_frames = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
while cap.isOpened():
|
137 |
ret, frame = cap.read()
|
138 |
if not ret:
|
@@ -141,39 +93,105 @@ def process_video(input_video_path, task_prompt):
|
|
141 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
142 |
pil_image = Image.fromarray(frame_rgb)
|
143 |
|
144 |
-
|
|
|
145 |
|
146 |
-
|
147 |
-
if task_prompt == "<OD>":
|
148 |
-
if "<OD>" in result and "bboxes" in result["<OD>"] and "labels" in result["<OD>"]:
|
149 |
processed_image = plot_bbox(pil_image, result['<OD>'])
|
150 |
-
|
151 |
-
|
152 |
-
processed_image =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
157 |
|
158 |
cap.release()
|
159 |
out.release()
|
160 |
cv2.destroyAllWindows()
|
161 |
|
162 |
if processed_frames == 0:
|
163 |
-
return None
|
164 |
|
165 |
-
return
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
gr.HTML("<h1><center>Microsoft Florence-2-large-ft</center></h1>")
|
|
|
177 |
with gr.Tab(label="Image"):
|
178 |
with gr.Row():
|
179 |
with gr.Column():
|
@@ -221,12 +239,14 @@ with gr.Blocks(css=css) as demo:
|
|
221 |
with gr.Column():
|
222 |
input_video = gr.Video(label="Video")
|
223 |
video_task_dropdown = gr.Dropdown(
|
224 |
-
choices=["Object Detection", "Dense Region Caption"],
|
225 |
label="Video Task", value="Object Detection"
|
226 |
)
|
|
|
227 |
video_submit_btn = gr.Button(value="Process Video")
|
228 |
with gr.Column():
|
229 |
-
output_video = gr.Video(label="Video")
|
|
|
230 |
|
231 |
def update_text_input(task):
|
232 |
return gr.update(visible=task in ["Caption to Phrase Grounding", "Referring Expression Segmentation",
|
@@ -235,32 +255,12 @@ with gr.Blocks(css=css) as demo:
|
|
235 |
|
236 |
task_dropdown.change(fn=update_text_input, inputs=task_dropdown, outputs=text_input)
|
237 |
|
238 |
-
def
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
"More Detailed Caption": ("<MORE_DETAILED_CAPTION>", lambda result: (result['<MORE_DETAILED_CAPTION>'], image)),
|
243 |
-
"Caption to Phrase Grounding": ("<CAPTION_TO_PHRASE_GROUNDING>", lambda result: (str(result['<CAPTION_TO_PHRASE_GROUNDING>']), plot_bbox(image, result['<CAPTION_TO_PHRASE_GROUNDING>']))),
|
244 |
-
"Object Detection": ("<OD>", lambda result: (str(result['<OD>']), plot_bbox(image, result['<OD>']))),
|
245 |
-
"Dense Region Caption": ("<DENSE_REGION_CAPTION>", lambda result: (str(result['<DENSE_REGION_CAPTION>']), plot_bbox(image, result['<DENSE_REGION_CAPTION>']))),
|
246 |
-
"Region Proposal": ("<REGION_PROPOSAL>", lambda result: (str(result['<REGION_PROPOSAL>']), plot_bbox(image, result['<REGION_PROPOSAL>']))),
|
247 |
-
"Referring Expression Segmentation": ("<REFERRING_EXPRESSION_SEGMENTATION>", lambda result: (str(result['<REFERRING_EXPRESSION_SEGMENTATION>']), draw_polygons(image, result['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True))),
|
248 |
-
"Region to Segmentation": ("<REGION_TO_SEGMENTATION>", lambda result: (str(result['<REGION_TO_SEGMENTATION>']), draw_polygons(image, result['<REGION_TO_SEGMENTATION>'], fill_mask=True))),
|
249 |
-
"Open Vocabulary Detection": ("<OPEN_VOCABULARY_DETECTION>", lambda result: (str(convert_to_od_format(result['<OPEN_VOCABULARY_DETECTION>'])), plot_bbox(image, convert_to_od_format(result['<OPEN_VOCABULARY_DETECTION>'])))),
|
250 |
-
"Region to Category": ("<REGION_TO_CATEGORY>", lambda result: (result['<REGION_TO_CATEGORY>'], image)),
|
251 |
-
"Region to Description": ("<REGION_TO_DESCRIPTION>", lambda result: (result['<REGION_TO_DESCRIPTION>'], image)),
|
252 |
-
"OCR": ("<OCR>", lambda result: (result['<OCR>'], image)),
|
253 |
-
"OCR with Region": ("<OCR_WITH_REGION>", lambda result: (str(result['<OCR_WITH_REGION>']), draw_ocr_bboxes(image, result['<OCR_WITH_REGION>']))),
|
254 |
-
}
|
255 |
-
|
256 |
-
if task in task_mapping:
|
257 |
-
prompt, process_func = task_mapping[task]
|
258 |
-
result = run_example(prompt, image, text)
|
259 |
-
return process_func(result)
|
260 |
-
else:
|
261 |
-
return "", image
|
262 |
|
263 |
submit_btn.click(fn=process_image, inputs=[input_img, task_dropdown, text_input], outputs=[output_text, output_image])
|
264 |
-
video_submit_btn.click(fn=
|
265 |
|
266 |
demo.launch()
|
|
|
5 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
6 |
from transformers.dynamic_module_utils import get_imports
|
7 |
import torch
|
|
|
8 |
from PIL import Image, ImageDraw
|
9 |
import random
|
10 |
import numpy as np
|
|
|
12 |
import matplotlib.patches as patches
|
13 |
import cv2
|
14 |
import io
|
15 |
+
import uuid
|
16 |
+
|
17 |
|
18 |
def workaround_fixed_get_imports(filename: str | os.PathLike) -> list[str]:
|
19 |
if not str(filename).endswith("/modeling_florence2.py"):
|
|
|
22 |
imports.remove("flash_attn")
|
23 |
return imports
|
24 |
|
25 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
26 |
+
|
27 |
with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
|
28 |
+
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to(device).eval()
|
29 |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
|
30 |
|
31 |
colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red',
|
32 |
'lime', 'indigo', 'violet', 'aqua', 'magenta', 'coral', 'gold', 'tan', 'skyblue']
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def run_example(task_prompt, image, text_input=None):
|
35 |
+
prompt = task_prompt if text_input is None else task_prompt + text_input
|
36 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
|
|
|
|
|
|
|
37 |
with torch.inference_mode():
|
38 |
+
generated_ids = model.generate(**inputs, max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
40 |
+
return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
def plot_bbox(image, data):
|
43 |
+
img_draw = image.copy()
|
44 |
+
draw = ImageDraw.Draw(img_draw)
|
45 |
for bbox, label in zip(data['bboxes'], data['labels']):
|
46 |
x1, y1, x2, y2 = bbox
|
47 |
+
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
|
48 |
+
draw.text((x1, y1), label, fill="white")
|
49 |
+
return np.array(img_draw)
|
|
|
|
|
50 |
|
51 |
def draw_polygons(image, prediction, fill_mask=False):
|
52 |
+
img_draw = image.copy()
|
53 |
+
draw = ImageDraw.Draw(img_draw)
|
54 |
+
for polygons, label in zip(prediction.get('polygons', []), prediction.get('labels', [])):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
color = random.choice(colormap)
|
56 |
+
for polygon in polygons:
|
57 |
+
if isinstance(polygon[0], (int, float)):
|
58 |
+
polygon = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)]
|
59 |
+
draw.polygon(polygon, outline=color, fill=color if fill_mask else None)
|
60 |
+
if polygon:
|
61 |
+
draw.text(polygon[0], label, fill="white")
|
62 |
+
return np.array(img_draw)
|
|
|
|
|
63 |
|
64 |
@spaces.GPU(duration=120)
|
65 |
def process_video(input_video_path, task_prompt):
|
66 |
cap = cv2.VideoCapture(input_video_path)
|
67 |
if not cap.isOpened():
|
68 |
+
return None, []
|
69 |
|
70 |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
71 |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
72 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
73 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
74 |
|
75 |
+
result_file_name = f"{uuid.uuid4()}.mp4"
|
|
|
|
|
|
|
76 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
77 |
+
out = cv2.VideoWriter(result_file_name, fourcc, fps, (frame_width, frame_height))
|
|
|
|
|
|
|
|
|
78 |
|
79 |
processed_frames = 0
|
80 |
+
frame_results = []
|
81 |
+
color_map = {} #consistency for chromakey possibility
|
82 |
+
|
83 |
+
def get_color(label):
|
84 |
+
if label not in color_map:
|
85 |
+
color_map[label] = random.choice(colormap)
|
86 |
+
return color_map[label]
|
87 |
+
|
88 |
while cap.isOpened():
|
89 |
ret, frame = cap.read()
|
90 |
if not ret:
|
|
|
93 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
94 |
pil_image = Image.fromarray(frame_rgb)
|
95 |
|
96 |
+
try:
|
97 |
+
result = run_example(task_prompt, pil_image)
|
98 |
|
99 |
+
if task_prompt == "<OD>":
|
|
|
|
|
100 |
processed_image = plot_bbox(pil_image, result['<OD>'])
|
101 |
+
frame_results.append((processed_frames + 1, result['<OD>']))
|
102 |
+
elif task_prompt == "<DENSE_REGION_CAPTION>":
|
103 |
+
processed_image = pil_image.copy()
|
104 |
+
draw = ImageDraw.Draw(processed_image)
|
105 |
+
for i, label in enumerate(result['<DENSE_REGION_CAPTION>'].get('labels', [])):
|
106 |
+
draw.text((10, 10 + i*20), label, fill="white")
|
107 |
+
processed_image = np.array(processed_image)
|
108 |
+
frame_results.append((processed_frames + 1, result['<DENSE_REGION_CAPTION>']))
|
109 |
+
elif task_prompt in ["<REFERRING_EXPRESSION_SEGMENTATION>", "<REGION_TO_SEGMENTATION>"]:
|
110 |
+
if isinstance(result[task_prompt], dict) and 'polygons' in result[task_prompt]:
|
111 |
+
processed_image = draw_vid_polygons(pil_image, result[task_prompt], get_color)
|
112 |
+
else:
|
113 |
+
processed_image = np.array(pil_image)
|
114 |
+
frame_results.append((processed_frames + 1, result[task_prompt]))
|
115 |
+
else:
|
116 |
+
processed_image = np.array(pil_image)
|
117 |
+
|
118 |
+
out.write(cv2.cvtColor(processed_image, cv2.COLOR_RGB2BGR))
|
119 |
+
processed_frames += 1
|
120 |
|
121 |
+
except Exception as e:
|
122 |
+
print(f"Error processing frame {processed_frames + 1}: {str(e)}")
|
123 |
+
processed_image = np.array(pil_image)
|
124 |
+
out.write(cv2.cvtColor(processed_image, cv2.COLOR_RGB2BGR))
|
125 |
+
processed_frames += 1
|
126 |
|
127 |
cap.release()
|
128 |
out.release()
|
129 |
cv2.destroyAllWindows()
|
130 |
|
131 |
if processed_frames == 0:
|
132 |
+
return None, frame_results
|
133 |
|
134 |
+
return result_file_name, frame_results
|
135 |
|
136 |
+
def draw_vid_polygons(image, prediction, get_color):
|
137 |
+
img_draw = image.copy()
|
138 |
+
draw = ImageDraw.Draw(img_draw)
|
139 |
+
for polygons, label in zip(prediction.get('polygons', []), prediction.get('labels', [])):
|
140 |
+
color = get_color(label)
|
141 |
+
for polygon in polygons:
|
142 |
+
if isinstance(polygon[0], (int, float)):
|
143 |
+
polygon = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)]
|
144 |
+
draw.polygon(polygon, outline=color, fill=color)
|
145 |
+
if polygon:
|
146 |
+
draw.text(polygon[0], label, fill="white")
|
147 |
+
return np.array(img_draw)
|
148 |
|
149 |
+
def process_image(image, task, text):
|
150 |
+
task_mapping = {
|
151 |
+
"Caption": ("<CAPTION>", lambda result: (result['<CAPTION>'], image)),
|
152 |
+
"Detailed Caption": ("<DETAILED_CAPTION>", lambda result: (result['<DETAILED_CAPTION>'], image)),
|
153 |
+
"More Detailed Caption": ("<MORE_DETAILED_CAPTION>", lambda result: (result['<MORE_DETAILED_CAPTION>'], image)),
|
154 |
+
"Caption to Phrase Grounding": ("<CAPTION_TO_PHRASE_GROUNDING>", lambda result: (str(result['<CAPTION_TO_PHRASE_GROUNDING>']), Image.fromarray(plot_bbox(image, result['<CAPTION_TO_PHRASE_GROUNDING>'])))),
|
155 |
+
"Object Detection": ("<OD>", lambda result: (str(result['<OD>']), Image.fromarray(plot_bbox(image, result['<OD>'])))),
|
156 |
+
"Dense Region Caption": ("<DENSE_REGION_CAPTION>", lambda result: (str(result['<DENSE_REGION_CAPTION>']), Image.fromarray(draw_polygons(image, result['<DENSE_REGION_CAPTION>'], fill_mask=True)))),
|
157 |
+
"Region Proposal": ("<REGION_PROPOSAL>", lambda result: (str(result['<REGION_PROPOSAL>']), Image.fromarray(plot_bbox(image, result['<REGION_PROPOSAL>'])))),
|
158 |
+
"Referring Expression Segmentation": ("<REFERRING_EXPRESSION_SEGMENTATION>", lambda result: (str(result['<REFERRING_EXPRESSION_SEGMENTATION>']), Image.fromarray(draw_polygons(image, result['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)))),
|
159 |
+
"Region to Segmentation": ("<REGION_TO_SEGMENTATION>", lambda result: (str(result['<REGION_TO_SEGMENTATION>']), Image.fromarray(draw_polygons(image, result['<REGION_TO_SEGMENTATION>'], fill_mask=True)))),
|
160 |
+
"Open Vocabulary Detection": ("<OPEN_VOCABULARY_DETECTION>", lambda result: (str(result['<OPEN_VOCABULARY_DETECTION>']), Image.fromarray(plot_bbox(image, result['<OPEN_VOCABULARY_DETECTION>'])))),
|
161 |
+
"Region to Category": ("<REGION_TO_CATEGORY>", lambda result: (result['<REGION_TO_CATEGORY>'], image)),
|
162 |
+
"Region to Description": ("<REGION_TO_DESCRIPTION>", lambda result: (result['<REGION_TO_DESCRIPTION>'], image)),
|
163 |
+
"OCR": ("<OCR>", lambda result: (result['<OCR>'], image)),
|
164 |
+
"OCR with Region": ("<OCR_WITH_REGION>", lambda result: (str(result['<OCR_WITH_REGION>']), Image.fromarray(plot_bbox(image, result['<OCR_WITH_REGION>'])))),
|
165 |
+
}
|
166 |
+
|
167 |
+
if task in task_mapping:
|
168 |
+
prompt, process_func = task_mapping[task]
|
169 |
+
result = run_example(prompt, image, text)
|
170 |
+
return process_func(result)
|
171 |
+
else:
|
172 |
+
return "", image
|
173 |
+
|
174 |
+
def map_task_to_prompt(task):
|
175 |
+
task_mapping = {
|
176 |
+
"Object Detection": "<OD>",
|
177 |
+
"Dense Region Caption": "<DENSE_REGION_CAPTION>",
|
178 |
+
"Referring Expression Segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
|
179 |
+
"Region to Segmentation": "<REGION_TO_SEGMENTATION>"
|
180 |
+
}
|
181 |
+
return task_mapping.get(task, "")
|
182 |
+
|
183 |
+
def process_video_p(input_video, task, text_input):
|
184 |
+
prompt = map_task_to_prompt(task)
|
185 |
+
if task == "Referring Expression Segmentation" and text_input:
|
186 |
+
prompt += text_input
|
187 |
+
result, frame_results = process_video(input_video, prompt)
|
188 |
+
if result is None:
|
189 |
+
return None, "Error: Video processing failed. Check logs above for info.", str(frame_results)
|
190 |
+
return result, result, str(frame_results)
|
191 |
+
|
192 |
+
with gr.Blocks() as demo:
|
193 |
gr.HTML("<h1><center>Microsoft Florence-2-large-ft</center></h1>")
|
194 |
+
|
195 |
with gr.Tab(label="Image"):
|
196 |
with gr.Row():
|
197 |
with gr.Column():
|
|
|
239 |
with gr.Column():
|
240 |
input_video = gr.Video(label="Video")
|
241 |
video_task_dropdown = gr.Dropdown(
|
242 |
+
choices=["Object Detection", "Dense Region Caption", "Referring Expression Segmentation", "Region to Segmentation"],
|
243 |
label="Video Task", value="Object Detection"
|
244 |
)
|
245 |
+
video_text_input = gr.Textbox(label="Text Input (for Referring Expression Segmentation)", visible=False)
|
246 |
video_submit_btn = gr.Button(value="Process Video")
|
247 |
with gr.Column():
|
248 |
+
output_video = gr.Video(label="Processed Video")
|
249 |
+
frame_results_output = gr.Textbox(label="Frame Results")
|
250 |
|
251 |
def update_text_input(task):
|
252 |
return gr.update(visible=task in ["Caption to Phrase Grounding", "Referring Expression Segmentation",
|
|
|
255 |
|
256 |
task_dropdown.change(fn=update_text_input, inputs=task_dropdown, outputs=text_input)
|
257 |
|
258 |
+
def update_video_text_input(task):
|
259 |
+
return gr.update(visible=task == "Referring Expression Segmentation")
|
260 |
+
|
261 |
+
video_task_dropdown.change(fn=update_video_text_input, inputs=video_task_dropdown, outputs=video_text_input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
submit_btn.click(fn=process_image, inputs=[input_img, task_dropdown, text_input], outputs=[output_text, output_image])
|
264 |
+
video_submit_btn.click(fn=process_video_p, inputs=[input_video, video_task_dropdown, video_text_input], outputs=[output_video, output_video, frame_results_output])
|
265 |
|
266 |
demo.launch()
|