SixOpen commited on
Commit
6f31b98
·
verified ·
1 Parent(s): c09715b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -130
app.py CHANGED
@@ -5,7 +5,6 @@ import gradio as gr
5
  from transformers import AutoProcessor, AutoModelForCausalLM
6
  from transformers.dynamic_module_utils import get_imports
7
  import torch
8
- import requests
9
  from PIL import Image, ImageDraw
10
  import random
11
  import numpy as np
@@ -13,6 +12,8 @@ import matplotlib.pyplot as plt
13
  import matplotlib.patches as patches
14
  import cv2
15
  import io
 
 
16
 
17
  def workaround_fixed_get_imports(filename: str | os.PathLike) -> list[str]:
18
  if not str(filename).endswith("/modeling_florence2.py"):
@@ -21,118 +22,69 @@ def workaround_fixed_get_imports(filename: str | os.PathLike) -> list[str]:
21
  imports.remove("flash_attn")
22
  return imports
23
 
 
 
24
  with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
25
- model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to("cuda").eval()
26
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
27
 
28
  colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red',
29
  'lime', 'indigo', 'violet', 'aqua', 'magenta', 'coral', 'gold', 'tan', 'skyblue']
30
 
31
- def fig_to_pil(fig):
32
- buf = io.BytesIO()
33
- fig.savefig(buf, format='png')
34
- buf.seek(0)
35
- return Image.open(buf)
36
-
37
- @spaces.GPU
38
  def run_example(task_prompt, image, text_input=None):
39
- if text_input is None:
40
- prompt = task_prompt
41
- else:
42
- prompt = task_prompt + text_input
43
- inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda")
44
  with torch.inference_mode():
45
- generated_ids = model.generate(
46
- input_ids=inputs["input_ids"],
47
- pixel_values=inputs["pixel_values"],
48
- max_new_tokens=1024,
49
- early_stopping=False,
50
- do_sample=False,
51
- num_beams=3,
52
- )
53
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
54
- parsed_answer = processor.post_process_generation(
55
- generated_text,
56
- task=task_prompt,
57
- image_size=(image.size[0], image.size[1])
58
- )
59
- return parsed_answer
60
 
61
  def plot_bbox(image, data):
62
- fig, ax = plt.subplots()
63
- ax.imshow(image)
64
  for bbox, label in zip(data['bboxes'], data['labels']):
65
  x1, y1, x2, y2 = bbox
66
- rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
67
- ax.add_patch(rect)
68
- plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='indigo', alpha=0.5))
69
- ax.axis('off')
70
- return fig_to_pil(fig)
71
 
72
  def draw_polygons(image, prediction, fill_mask=False):
73
- fig, ax = plt.subplots()
74
- ax.imshow(image)
75
- scale = 1
76
- for polygons, label in zip(prediction['polygons'], prediction['labels']):
77
- color = random.choice(colormap)
78
- fill_color = random.choice(colormap) if fill_mask else None
79
- for _polygon in polygons:
80
- _polygon = np.array(_polygon).reshape(-1, 2)
81
- if _polygon.shape[0] < 3:
82
- continue
83
- _polygon = (_polygon * scale).reshape(-1).tolist()
84
- if len(_polygon) % 2 != 0:
85
- continue
86
- polygon_points = np.array(_polygon).reshape(-1, 2)
87
- if fill_mask:
88
- polygon = patches.Polygon(polygon_points, edgecolor=color, facecolor=fill_color, linewidth=2)
89
- else:
90
- polygon = patches.Polygon(polygon_points, edgecolor=color, fill=False, linewidth=2)
91
- ax.add_patch(polygon)
92
- plt.text(polygon_points[0, 0], polygon_points[0, 1], label, color='white', fontsize=8, bbox=dict(facecolor=color, alpha=0.5))
93
- ax.axis('off')
94
- return fig_to_pil(fig)
95
-
96
- def draw_ocr_bboxes(image, prediction):
97
- fig, ax = plt.subplots()
98
- ax.imshow(image)
99
- scale = 1
100
- bboxes, labels = prediction['quad_boxes'], prediction['labels']
101
- for box, label in zip(bboxes, labels):
102
  color = random.choice(colormap)
103
- new_box = np.array(box) * scale
104
- if new_box.ndim == 1:
105
- new_box = new_box.reshape(-1, 2)
106
- polygon = patches.Polygon(new_box, edgecolor=color, fill=False, linewidth=3)
107
- ax.add_patch(polygon)
108
- plt.text(new_box[0, 0], new_box[0, 1], label, color='white', fontsize=8, bbox=dict(facecolor=color, alpha=0.5))
109
- ax.axis('off')
110
- return fig_to_pil(fig)
111
-
112
 
113
  @spaces.GPU(duration=120)
114
  def process_video(input_video_path, task_prompt):
115
  cap = cv2.VideoCapture(input_video_path)
116
  if not cap.isOpened():
117
- return None
118
 
119
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
120
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
121
  fps = cap.get(cv2.CAP_PROP_FPS)
122
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
123
 
124
- if frame_width <= 0 or frame_height <= 0 or fps <= 0 or total_frames <= 0:
125
- cap.release()
126
- return None
127
-
128
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
129
- out = cv2.VideoWriter("output_vid.mp4", fourcc, fps, (frame_width, frame_height))
130
-
131
- if not out.isOpened():
132
- cap.release()
133
- return None
134
 
135
  processed_frames = 0
 
 
 
 
 
 
 
 
136
  while cap.isOpened():
137
  ret, frame = cap.read()
138
  if not ret:
@@ -141,39 +93,105 @@ def process_video(input_video_path, task_prompt):
141
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
142
  pil_image = Image.fromarray(frame_rgb)
143
 
144
- result = run_example(task_prompt, pil_image)
 
145
 
146
- processed_image = pil_image
147
- if task_prompt == "<OD>":
148
- if "<OD>" in result and "bboxes" in result["<OD>"] and "labels" in result["<OD>"]:
149
  processed_image = plot_bbox(pil_image, result['<OD>'])
150
- elif task_prompt == "<DENSE_REGION_CAPTION>":
151
- if "<DENSE_REGION_CAPTION>" in result and "polygons" in result["<DENSE_REGION_CAPTION>"] and "labels" in result["<DENSE_REGION_CAPTION>"]:
152
- processed_image = draw_polygons(pil_image, result['<DENSE_REGION_CAPTION>'], fill_mask=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- processed_frame = cv2.cvtColor(np.array(processed_image), cv2.COLOR_RGB2BGR)
155
- out.write(processed_frame)
156
- processed_frames += 1
 
 
157
 
158
  cap.release()
159
  out.release()
160
  cv2.destroyAllWindows()
161
 
162
  if processed_frames == 0:
163
- return None
164
 
165
- return "output_vid.mp4"
166
 
167
- css = """
168
- #output {
169
- min-height: 100px;
170
- overflow: auto;
171
- border: 1px solid #ccc;
172
- }
173
- """
 
 
 
 
 
174
 
175
- with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  gr.HTML("<h1><center>Microsoft Florence-2-large-ft</center></h1>")
 
177
  with gr.Tab(label="Image"):
178
  with gr.Row():
179
  with gr.Column():
@@ -221,12 +239,14 @@ with gr.Blocks(css=css) as demo:
221
  with gr.Column():
222
  input_video = gr.Video(label="Video")
223
  video_task_dropdown = gr.Dropdown(
224
- choices=["Object Detection", "Dense Region Caption"],
225
  label="Video Task", value="Object Detection"
226
  )
 
227
  video_submit_btn = gr.Button(value="Process Video")
228
  with gr.Column():
229
- output_video = gr.Video(label="Video")
 
230
 
231
  def update_text_input(task):
232
  return gr.update(visible=task in ["Caption to Phrase Grounding", "Referring Expression Segmentation",
@@ -235,32 +255,12 @@ with gr.Blocks(css=css) as demo:
235
 
236
  task_dropdown.change(fn=update_text_input, inputs=task_dropdown, outputs=text_input)
237
 
238
- def process_image(image, task, text):
239
- task_mapping = {
240
- "Caption": ("<CAPTION>", lambda result: (result['<CAPTION>'], image)),
241
- "Detailed Caption": ("<DETAILED_CAPTION>", lambda result: (result['<DETAILED_CAPTION>'], image)),
242
- "More Detailed Caption": ("<MORE_DETAILED_CAPTION>", lambda result: (result['<MORE_DETAILED_CAPTION>'], image)),
243
- "Caption to Phrase Grounding": ("<CAPTION_TO_PHRASE_GROUNDING>", lambda result: (str(result['<CAPTION_TO_PHRASE_GROUNDING>']), plot_bbox(image, result['<CAPTION_TO_PHRASE_GROUNDING>']))),
244
- "Object Detection": ("<OD>", lambda result: (str(result['<OD>']), plot_bbox(image, result['<OD>']))),
245
- "Dense Region Caption": ("<DENSE_REGION_CAPTION>", lambda result: (str(result['<DENSE_REGION_CAPTION>']), plot_bbox(image, result['<DENSE_REGION_CAPTION>']))),
246
- "Region Proposal": ("<REGION_PROPOSAL>", lambda result: (str(result['<REGION_PROPOSAL>']), plot_bbox(image, result['<REGION_PROPOSAL>']))),
247
- "Referring Expression Segmentation": ("<REFERRING_EXPRESSION_SEGMENTATION>", lambda result: (str(result['<REFERRING_EXPRESSION_SEGMENTATION>']), draw_polygons(image, result['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True))),
248
- "Region to Segmentation": ("<REGION_TO_SEGMENTATION>", lambda result: (str(result['<REGION_TO_SEGMENTATION>']), draw_polygons(image, result['<REGION_TO_SEGMENTATION>'], fill_mask=True))),
249
- "Open Vocabulary Detection": ("<OPEN_VOCABULARY_DETECTION>", lambda result: (str(convert_to_od_format(result['<OPEN_VOCABULARY_DETECTION>'])), plot_bbox(image, convert_to_od_format(result['<OPEN_VOCABULARY_DETECTION>'])))),
250
- "Region to Category": ("<REGION_TO_CATEGORY>", lambda result: (result['<REGION_TO_CATEGORY>'], image)),
251
- "Region to Description": ("<REGION_TO_DESCRIPTION>", lambda result: (result['<REGION_TO_DESCRIPTION>'], image)),
252
- "OCR": ("<OCR>", lambda result: (result['<OCR>'], image)),
253
- "OCR with Region": ("<OCR_WITH_REGION>", lambda result: (str(result['<OCR_WITH_REGION>']), draw_ocr_bboxes(image, result['<OCR_WITH_REGION>']))),
254
- }
255
-
256
- if task in task_mapping:
257
- prompt, process_func = task_mapping[task]
258
- result = run_example(prompt, image, text)
259
- return process_func(result)
260
- else:
261
- return "", image
262
 
263
  submit_btn.click(fn=process_image, inputs=[input_img, task_dropdown, text_input], outputs=[output_text, output_image])
264
- video_submit_btn.click(fn=process_video, inputs=[input_video, video_task_dropdown], outputs=output_video)
265
 
266
  demo.launch()
 
5
  from transformers import AutoProcessor, AutoModelForCausalLM
6
  from transformers.dynamic_module_utils import get_imports
7
  import torch
 
8
  from PIL import Image, ImageDraw
9
  import random
10
  import numpy as np
 
12
  import matplotlib.patches as patches
13
  import cv2
14
  import io
15
+ import uuid
16
+
17
 
18
  def workaround_fixed_get_imports(filename: str | os.PathLike) -> list[str]:
19
  if not str(filename).endswith("/modeling_florence2.py"):
 
22
  imports.remove("flash_attn")
23
  return imports
24
 
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
  with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
28
+ model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True).to(device).eval()
29
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", trust_remote_code=True)
30
 
31
  colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red',
32
  'lime', 'indigo', 'violet', 'aqua', 'magenta', 'coral', 'gold', 'tan', 'skyblue']
33
 
 
 
 
 
 
 
 
34
  def run_example(task_prompt, image, text_input=None):
35
+ prompt = task_prompt if text_input is None else task_prompt + text_input
36
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
 
 
 
37
  with torch.inference_mode():
38
+ generated_ids = model.generate(**inputs, max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3)
 
 
 
 
 
 
 
39
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
40
+ return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))
 
 
 
 
 
41
 
42
  def plot_bbox(image, data):
43
+ img_draw = image.copy()
44
+ draw = ImageDraw.Draw(img_draw)
45
  for bbox, label in zip(data['bboxes'], data['labels']):
46
  x1, y1, x2, y2 = bbox
47
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
48
+ draw.text((x1, y1), label, fill="white")
49
+ return np.array(img_draw)
 
 
50
 
51
  def draw_polygons(image, prediction, fill_mask=False):
52
+ img_draw = image.copy()
53
+ draw = ImageDraw.Draw(img_draw)
54
+ for polygons, label in zip(prediction.get('polygons', []), prediction.get('labels', [])):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  color = random.choice(colormap)
56
+ for polygon in polygons:
57
+ if isinstance(polygon[0], (int, float)):
58
+ polygon = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)]
59
+ draw.polygon(polygon, outline=color, fill=color if fill_mask else None)
60
+ if polygon:
61
+ draw.text(polygon[0], label, fill="white")
62
+ return np.array(img_draw)
 
 
63
 
64
  @spaces.GPU(duration=120)
65
  def process_video(input_video_path, task_prompt):
66
  cap = cv2.VideoCapture(input_video_path)
67
  if not cap.isOpened():
68
+ return None, []
69
 
70
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
71
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
72
  fps = cap.get(cv2.CAP_PROP_FPS)
73
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
74
 
75
+ result_file_name = f"{uuid.uuid4()}.mp4"
 
 
 
76
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
77
+ out = cv2.VideoWriter(result_file_name, fourcc, fps, (frame_width, frame_height))
 
 
 
 
78
 
79
  processed_frames = 0
80
+ frame_results = []
81
+ color_map = {} #consistency for chromakey possibility
82
+
83
+ def get_color(label):
84
+ if label not in color_map:
85
+ color_map[label] = random.choice(colormap)
86
+ return color_map[label]
87
+
88
  while cap.isOpened():
89
  ret, frame = cap.read()
90
  if not ret:
 
93
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
94
  pil_image = Image.fromarray(frame_rgb)
95
 
96
+ try:
97
+ result = run_example(task_prompt, pil_image)
98
 
99
+ if task_prompt == "<OD>":
 
 
100
  processed_image = plot_bbox(pil_image, result['<OD>'])
101
+ frame_results.append((processed_frames + 1, result['<OD>']))
102
+ elif task_prompt == "<DENSE_REGION_CAPTION>":
103
+ processed_image = pil_image.copy()
104
+ draw = ImageDraw.Draw(processed_image)
105
+ for i, label in enumerate(result['<DENSE_REGION_CAPTION>'].get('labels', [])):
106
+ draw.text((10, 10 + i*20), label, fill="white")
107
+ processed_image = np.array(processed_image)
108
+ frame_results.append((processed_frames + 1, result['<DENSE_REGION_CAPTION>']))
109
+ elif task_prompt in ["<REFERRING_EXPRESSION_SEGMENTATION>", "<REGION_TO_SEGMENTATION>"]:
110
+ if isinstance(result[task_prompt], dict) and 'polygons' in result[task_prompt]:
111
+ processed_image = draw_vid_polygons(pil_image, result[task_prompt], get_color)
112
+ else:
113
+ processed_image = np.array(pil_image)
114
+ frame_results.append((processed_frames + 1, result[task_prompt]))
115
+ else:
116
+ processed_image = np.array(pil_image)
117
+
118
+ out.write(cv2.cvtColor(processed_image, cv2.COLOR_RGB2BGR))
119
+ processed_frames += 1
120
 
121
+ except Exception as e:
122
+ print(f"Error processing frame {processed_frames + 1}: {str(e)}")
123
+ processed_image = np.array(pil_image)
124
+ out.write(cv2.cvtColor(processed_image, cv2.COLOR_RGB2BGR))
125
+ processed_frames += 1
126
 
127
  cap.release()
128
  out.release()
129
  cv2.destroyAllWindows()
130
 
131
  if processed_frames == 0:
132
+ return None, frame_results
133
 
134
+ return result_file_name, frame_results
135
 
136
+ def draw_vid_polygons(image, prediction, get_color):
137
+ img_draw = image.copy()
138
+ draw = ImageDraw.Draw(img_draw)
139
+ for polygons, label in zip(prediction.get('polygons', []), prediction.get('labels', [])):
140
+ color = get_color(label)
141
+ for polygon in polygons:
142
+ if isinstance(polygon[0], (int, float)):
143
+ polygon = [(polygon[i], polygon[i+1]) for i in range(0, len(polygon), 2)]
144
+ draw.polygon(polygon, outline=color, fill=color)
145
+ if polygon:
146
+ draw.text(polygon[0], label, fill="white")
147
+ return np.array(img_draw)
148
 
149
+ def process_image(image, task, text):
150
+ task_mapping = {
151
+ "Caption": ("<CAPTION>", lambda result: (result['<CAPTION>'], image)),
152
+ "Detailed Caption": ("<DETAILED_CAPTION>", lambda result: (result['<DETAILED_CAPTION>'], image)),
153
+ "More Detailed Caption": ("<MORE_DETAILED_CAPTION>", lambda result: (result['<MORE_DETAILED_CAPTION>'], image)),
154
+ "Caption to Phrase Grounding": ("<CAPTION_TO_PHRASE_GROUNDING>", lambda result: (str(result['<CAPTION_TO_PHRASE_GROUNDING>']), Image.fromarray(plot_bbox(image, result['<CAPTION_TO_PHRASE_GROUNDING>'])))),
155
+ "Object Detection": ("<OD>", lambda result: (str(result['<OD>']), Image.fromarray(plot_bbox(image, result['<OD>'])))),
156
+ "Dense Region Caption": ("<DENSE_REGION_CAPTION>", lambda result: (str(result['<DENSE_REGION_CAPTION>']), Image.fromarray(draw_polygons(image, result['<DENSE_REGION_CAPTION>'], fill_mask=True)))),
157
+ "Region Proposal": ("<REGION_PROPOSAL>", lambda result: (str(result['<REGION_PROPOSAL>']), Image.fromarray(plot_bbox(image, result['<REGION_PROPOSAL>'])))),
158
+ "Referring Expression Segmentation": ("<REFERRING_EXPRESSION_SEGMENTATION>", lambda result: (str(result['<REFERRING_EXPRESSION_SEGMENTATION>']), Image.fromarray(draw_polygons(image, result['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)))),
159
+ "Region to Segmentation": ("<REGION_TO_SEGMENTATION>", lambda result: (str(result['<REGION_TO_SEGMENTATION>']), Image.fromarray(draw_polygons(image, result['<REGION_TO_SEGMENTATION>'], fill_mask=True)))),
160
+ "Open Vocabulary Detection": ("<OPEN_VOCABULARY_DETECTION>", lambda result: (str(result['<OPEN_VOCABULARY_DETECTION>']), Image.fromarray(plot_bbox(image, result['<OPEN_VOCABULARY_DETECTION>'])))),
161
+ "Region to Category": ("<REGION_TO_CATEGORY>", lambda result: (result['<REGION_TO_CATEGORY>'], image)),
162
+ "Region to Description": ("<REGION_TO_DESCRIPTION>", lambda result: (result['<REGION_TO_DESCRIPTION>'], image)),
163
+ "OCR": ("<OCR>", lambda result: (result['<OCR>'], image)),
164
+ "OCR with Region": ("<OCR_WITH_REGION>", lambda result: (str(result['<OCR_WITH_REGION>']), Image.fromarray(plot_bbox(image, result['<OCR_WITH_REGION>'])))),
165
+ }
166
+
167
+ if task in task_mapping:
168
+ prompt, process_func = task_mapping[task]
169
+ result = run_example(prompt, image, text)
170
+ return process_func(result)
171
+ else:
172
+ return "", image
173
+
174
+ def map_task_to_prompt(task):
175
+ task_mapping = {
176
+ "Object Detection": "<OD>",
177
+ "Dense Region Caption": "<DENSE_REGION_CAPTION>",
178
+ "Referring Expression Segmentation": "<REFERRING_EXPRESSION_SEGMENTATION>",
179
+ "Region to Segmentation": "<REGION_TO_SEGMENTATION>"
180
+ }
181
+ return task_mapping.get(task, "")
182
+
183
+ def process_video_p(input_video, task, text_input):
184
+ prompt = map_task_to_prompt(task)
185
+ if task == "Referring Expression Segmentation" and text_input:
186
+ prompt += text_input
187
+ result, frame_results = process_video(input_video, prompt)
188
+ if result is None:
189
+ return None, "Error: Video processing failed. Check logs above for info.", str(frame_results)
190
+ return result, result, str(frame_results)
191
+
192
+ with gr.Blocks() as demo:
193
  gr.HTML("<h1><center>Microsoft Florence-2-large-ft</center></h1>")
194
+
195
  with gr.Tab(label="Image"):
196
  with gr.Row():
197
  with gr.Column():
 
239
  with gr.Column():
240
  input_video = gr.Video(label="Video")
241
  video_task_dropdown = gr.Dropdown(
242
+ choices=["Object Detection", "Dense Region Caption", "Referring Expression Segmentation", "Region to Segmentation"],
243
  label="Video Task", value="Object Detection"
244
  )
245
+ video_text_input = gr.Textbox(label="Text Input (for Referring Expression Segmentation)", visible=False)
246
  video_submit_btn = gr.Button(value="Process Video")
247
  with gr.Column():
248
+ output_video = gr.Video(label="Processed Video")
249
+ frame_results_output = gr.Textbox(label="Frame Results")
250
 
251
  def update_text_input(task):
252
  return gr.update(visible=task in ["Caption to Phrase Grounding", "Referring Expression Segmentation",
 
255
 
256
  task_dropdown.change(fn=update_text_input, inputs=task_dropdown, outputs=text_input)
257
 
258
+ def update_video_text_input(task):
259
+ return gr.update(visible=task == "Referring Expression Segmentation")
260
+
261
+ video_task_dropdown.change(fn=update_video_text_input, inputs=video_task_dropdown, outputs=video_text_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  submit_btn.click(fn=process_image, inputs=[input_img, task_dropdown, text_input], outputs=[output_text, output_image])
264
+ video_submit_btn.click(fn=process_video_p, inputs=[input_video, video_task_dropdown, video_text_input], outputs=[output_video, output_video, frame_results_output])
265
 
266
  demo.launch()