Spaces:
Build error
Build error
Upload 2 files
Browse files- app.py +119 -179
- caption_anything.py +27 -45
app.py
CHANGED
@@ -5,6 +5,7 @@ import requests
|
|
5 |
from caption_anything import CaptionAnything
|
6 |
import torch
|
7 |
import json
|
|
|
8 |
import sys
|
9 |
import argparse
|
10 |
from caption_anything import parse_augment
|
@@ -15,10 +16,7 @@ import copy
|
|
15 |
from tools import mask_painter
|
16 |
from PIL import Image
|
17 |
import os
|
18 |
-
|
19 |
-
from segment_anything import sam_model_registry
|
20 |
-
from text_refiner import build_text_refiner
|
21 |
-
from segmenter import build_segmenter
|
22 |
|
23 |
def download_checkpoint(url, folder, filename):
|
24 |
os.makedirs(folder, exist_ok=True)
|
@@ -39,8 +37,8 @@ filename = "sam_vit_h_4b8939.pth"
|
|
39 |
download_checkpoint(checkpoint_url, folder, filename)
|
40 |
|
41 |
|
42 |
-
title = """<h1 align="center">
|
43 |
-
description = """Gradio demo for
|
44 |
"""
|
45 |
|
46 |
examples = [
|
@@ -55,108 +53,62 @@ examples = [
|
|
55 |
|
56 |
args = parse_augment()
|
57 |
# args.device = 'cuda:5'
|
58 |
-
# args.disable_gpt =
|
59 |
-
# args.enable_reduce_tokens =
|
60 |
# args.port=20322
|
61 |
-
|
62 |
-
# args.regular_box = True
|
63 |
-
shared_captioner = build_captioner(args.captioner, args.device, args)
|
64 |
-
shared_sam_model = sam_model_registry['vit_h'](checkpoint=args.segmenter_checkpoint).to(args.device)
|
65 |
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
def
|
68 |
-
|
69 |
-
|
70 |
-
if session_id is not None:
|
71 |
-
print('Init caption anything for session {}'.format(session_id))
|
72 |
-
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
|
73 |
-
|
74 |
-
|
75 |
-
def init_openai_api_key(api_key=""):
|
76 |
-
text_refiner = None
|
77 |
-
if api_key and len(api_key) > 30:
|
78 |
-
try:
|
79 |
-
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
80 |
-
text_refiner.llm('hi') # test
|
81 |
-
except:
|
82 |
-
text_refiner = None
|
83 |
-
openai_available = text_refiner is not None
|
84 |
-
return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True), text_refiner
|
85 |
-
|
86 |
-
|
87 |
-
def get_prompt(chat_input, click_state, click_mode):
|
88 |
inputs = json.loads(chat_input)
|
89 |
-
|
90 |
-
points
|
91 |
-
labels
|
92 |
-
for input in inputs:
|
93 |
-
points.append(input[:2])
|
94 |
-
labels.append(input[2])
|
95 |
-
elif click_mode == 'Single':
|
96 |
-
points = []
|
97 |
-
labels = []
|
98 |
-
for input in inputs:
|
99 |
-
points.append(input[:2])
|
100 |
-
labels.append(input[2])
|
101 |
-
click_state[0] = points
|
102 |
-
click_state[1] = labels
|
103 |
-
else:
|
104 |
-
raise NotImplementedError
|
105 |
|
106 |
prompt = {
|
107 |
"prompt_type":["click"],
|
108 |
-
"input_point":
|
109 |
-
"input_label":
|
110 |
"multimask_output":"True",
|
111 |
}
|
112 |
return prompt
|
113 |
|
114 |
-
def
|
115 |
-
if click_mode == 'Continuous':
|
116 |
-
click_state[2].append(caption)
|
117 |
-
elif click_mode == 'Single':
|
118 |
-
click_state[2] = [caption]
|
119 |
-
else:
|
120 |
-
raise NotImplementedError
|
121 |
-
|
122 |
-
|
123 |
-
def chat_with_points(chat_input, click_state, state, text_refiner):
|
124 |
-
if text_refiner is None:
|
125 |
-
response = "Text refiner is not initilzed, please input openai api key."
|
126 |
-
state = state + [(chat_input, response)]
|
127 |
-
return state, state
|
128 |
|
129 |
points, labels, captions = click_state
|
130 |
-
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
|
131 |
-
# # "The image is of width {width} and height {height}."
|
132 |
-
point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
|
133 |
-
prev_visual_context = ""
|
134 |
-
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
135 |
-
if len(captions):
|
136 |
-
prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
|
137 |
-
else:
|
138 |
-
prev_visual_context = 'no point exists.'
|
139 |
-
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
|
140 |
-
response = text_refiner.llm(chat_prompt)
|
141 |
-
state = state + [(chat_input, response)]
|
142 |
-
return state, state
|
143 |
-
|
144 |
-
def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment, factuality,
|
145 |
-
length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
text_refiner=text_refiner,
|
153 |
-
session_id=iface.app_id
|
154 |
)
|
|
|
|
|
|
|
|
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
if point_prompt == 'Positive':
|
162 |
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
@@ -170,32 +122,33 @@ def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment
|
|
170 |
|
171 |
# click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
172 |
# chat_input = click_coordinate
|
173 |
-
prompt = get_prompt(coordinate, click_state
|
174 |
print('prompt: ', prompt, 'controls: ', controls)
|
175 |
|
176 |
-
out = model.inference(image_input, prompt, controls
|
177 |
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
|
178 |
# for k, v in out['generated_captions'].items():
|
179 |
# state = state + [(f'{k}: {v}', None)]
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
# draw = ImageDraw.Draw(image_input)
|
186 |
# draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
|
187 |
input_mask = np.array(out['mask'].convert('P'))
|
188 |
image_input = mask_painter(np.array(image_input), input_mask)
|
189 |
origin_image_input = image_input
|
|
|
190 |
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
191 |
|
192 |
-
yield state, state, click_state,
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
|
200 |
|
201 |
def upload_callback(image_input, state):
|
@@ -207,19 +160,10 @@ def upload_callback(image_input, state):
|
|
207 |
if ratio < 1.0:
|
208 |
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
209 |
print('Scaling input image to {}'.format(image_input.size))
|
210 |
-
|
211 |
-
model =
|
212 |
-
args,
|
213 |
-
api_key="",
|
214 |
-
captioner=shared_captioner,
|
215 |
-
sam_model=shared_sam_model,
|
216 |
-
session_id=iface.app_id
|
217 |
-
)
|
218 |
model.segmenter.set_image(image_input)
|
219 |
-
|
220 |
-
original_size = model.segmenter.predictor.original_size
|
221 |
-
input_size = model.segmenter.predictor.input_size
|
222 |
-
return state, state, image_input, click_state, image_input, image_embedding, original_size, input_size
|
223 |
|
224 |
with gr.Blocks(
|
225 |
css='''
|
@@ -230,38 +174,28 @@ with gr.Blocks(
|
|
230 |
state = gr.State([])
|
231 |
click_state = gr.State([[],[],[]])
|
232 |
origin_image = gr.State(None)
|
233 |
-
|
234 |
-
text_refiner = gr.State(None)
|
235 |
-
original_size = gr.State(None)
|
236 |
-
input_size = gr.State(None)
|
237 |
|
238 |
gr.Markdown(title)
|
239 |
gr.Markdown(description)
|
240 |
|
241 |
with gr.Row():
|
242 |
with gr.Column(scale=1.0):
|
243 |
-
with gr.Column(visible=
|
244 |
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
245 |
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
246 |
with gr.Row(scale=1.0):
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
value="Continuous",
|
256 |
-
label="Clicking Mode",
|
257 |
-
interactive=True)
|
258 |
-
with gr.Row(scale=0.4):
|
259 |
-
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
|
260 |
-
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
261 |
-
with gr.Column(visible=False) as modules_need_gpt:
|
262 |
with gr.Row(scale=1.0):
|
263 |
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
264 |
-
|
265 |
sentiment = gr.Radio(
|
266 |
choices=["Positive", "Natural", "Negative"],
|
267 |
value="Natural",
|
@@ -282,47 +216,40 @@ with gr.Blocks(
|
|
282 |
step=1,
|
283 |
interactive=True,
|
284 |
label="Length",
|
285 |
-
)
|
286 |
-
|
287 |
-
gr.Examples(
|
288 |
-
examples=examples,
|
289 |
-
inputs=[example_image],
|
290 |
-
)
|
291 |
with gr.Column(scale=0.5):
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
with gr.Column(visible=
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
with gr.Column(visible=False) as modules_need_gpt3:
|
306 |
-
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
307 |
with gr.Row():
|
308 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
309 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
310 |
-
|
311 |
-
|
312 |
-
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
313 |
-
disable_chatGPT_button.click(init_openai_api_key, outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
314 |
|
315 |
clear_button_clike.click(
|
316 |
lambda x: ([[], [], []], x, ""),
|
317 |
[origin_image],
|
318 |
-
[click_state, image_input
|
319 |
queue=False,
|
320 |
show_progress=False
|
321 |
)
|
|
|
322 |
clear_button_image.click(
|
323 |
lambda: (None, [], [], [[], [], []], "", ""),
|
324 |
[],
|
325 |
-
[image_input, chatbot, state, click_state,
|
326 |
queue=False,
|
327 |
show_progress=False
|
328 |
)
|
@@ -333,37 +260,50 @@ with gr.Blocks(
|
|
333 |
queue=False,
|
334 |
show_progress=False
|
335 |
)
|
|
|
|
|
336 |
image_input.clear(
|
337 |
lambda: (None, [], [], [[], [], []], "", ""),
|
338 |
[],
|
339 |
-
[image_input, chatbot, state, click_state,
|
340 |
queue=False,
|
341 |
show_progress=False
|
342 |
)
|
343 |
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
# select coordinate
|
349 |
image_input.select(inference_seg_cap,
|
350 |
inputs=[
|
351 |
origin_image,
|
352 |
point_prompt,
|
353 |
-
click_mode,
|
354 |
language,
|
355 |
sentiment,
|
356 |
factuality,
|
357 |
length,
|
358 |
-
image_embedding,
|
359 |
state,
|
360 |
-
click_state
|
361 |
-
original_size,
|
362 |
-
input_size,
|
363 |
-
text_refiner
|
364 |
],
|
365 |
-
outputs=[chatbot, state, click_state,
|
366 |
show_progress=False, queue=True)
|
367 |
|
368 |
-
iface.queue(concurrency_count=
|
369 |
-
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
|
|
5 |
from caption_anything import CaptionAnything
|
6 |
import torch
|
7 |
import json
|
8 |
+
from diffusers import StableDiffusionInpaintPipeline
|
9 |
import sys
|
10 |
import argparse
|
11 |
from caption_anything import parse_augment
|
|
|
16 |
from tools import mask_painter
|
17 |
from PIL import Image
|
18 |
import os
|
19 |
+
import cv2
|
|
|
|
|
|
|
20 |
|
21 |
def download_checkpoint(url, folder, filename):
|
22 |
os.makedirs(folder, exist_ok=True)
|
|
|
37 |
download_checkpoint(checkpoint_url, folder, filename)
|
38 |
|
39 |
|
40 |
+
title = """<h1 align="center">Edit Anything</h1>"""
|
41 |
+
description = """Gradio demo for Segment Anything, image to dense Segment generation with various language styles. To use it, simply upload your image, or click one of the examples to load them.
|
42 |
"""
|
43 |
|
44 |
examples = [
|
|
|
53 |
|
54 |
args = parse_augment()
|
55 |
# args.device = 'cuda:5'
|
56 |
+
# args.disable_gpt = False
|
57 |
+
# args.enable_reduce_tokens = True
|
58 |
# args.port=20322
|
59 |
+
model = CaptionAnything(args)
|
|
|
|
|
|
|
60 |
|
61 |
+
def init_openai_api_key(api_key):
|
62 |
+
# os.environ['OPENAI_API_KEY'] = api_key
|
63 |
+
model.init_refiner(api_key)
|
64 |
+
openai_available = model.text_refiner is not None
|
65 |
+
return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True)
|
66 |
|
67 |
+
def get_prompt(chat_input, click_state):
|
68 |
+
points = click_state[0]
|
69 |
+
labels = click_state[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
inputs = json.loads(chat_input)
|
71 |
+
for input in inputs:
|
72 |
+
points.append(input[:2])
|
73 |
+
labels.append(input[2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
prompt = {
|
76 |
"prompt_type":["click"],
|
77 |
+
"input_point":points,
|
78 |
+
"input_label":labels,
|
79 |
"multimask_output":"True",
|
80 |
}
|
81 |
return prompt
|
82 |
|
83 |
+
def chat_with_points(chat_input, click_state, state, mask_save_path,image_input):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
points, labels, captions = click_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
+
|
88 |
+
# inpainting
|
89 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
90 |
+
"stabilityai/stable-diffusion-2-inpainting",
|
91 |
+
torch_dtype=torch.float32,
|
|
|
|
|
92 |
)
|
93 |
+
|
94 |
+
|
95 |
+
pipe = pipe.to("cuda")
|
96 |
+
mask = cv2.imread(mask_save_path)
|
97 |
|
98 |
+
image_input = np.array(image_input)
|
99 |
+
h,w = image_input.shape[:2]
|
100 |
+
|
101 |
+
image = cv2.resize(image_input,(512,512))
|
102 |
+
mask = cv2.resize(mask,(512,512)).astype(np.uint8)[:,:,0]
|
103 |
+
print(image.shape,mask.shape)
|
104 |
+
print("chat_input:",chat_input)
|
105 |
+
image = pipe(prompt=chat_input, image=image, mask_image=mask).images[0]
|
106 |
+
image = image.resize((w,h))
|
107 |
+
|
108 |
+
# image = Image.fromarray(image, mode='RGB')
|
109 |
+
return state, state, image
|
110 |
+
|
111 |
+
def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData):
|
112 |
|
113 |
if point_prompt == 'Positive':
|
114 |
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
|
|
122 |
|
123 |
# click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
124 |
# chat_input = click_coordinate
|
125 |
+
prompt = get_prompt(coordinate, click_state)
|
126 |
print('prompt: ', prompt, 'controls: ', controls)
|
127 |
|
128 |
+
out = model.inference(image_input, prompt, controls)
|
129 |
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
|
130 |
# for k, v in out['generated_captions'].items():
|
131 |
# state = state + [(f'{k}: {v}', None)]
|
132 |
+
# state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
|
133 |
+
# wiki = out['generated_captions'].get('wiki', "")
|
134 |
+
# click_state[2].append(out['generated_captions']['raw_caption'])
|
135 |
+
|
136 |
+
# text = out['generated_captions']['raw_caption']
|
137 |
# draw = ImageDraw.Draw(image_input)
|
138 |
# draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
|
139 |
input_mask = np.array(out['mask'].convert('P'))
|
140 |
image_input = mask_painter(np.array(image_input), input_mask)
|
141 |
origin_image_input = image_input
|
142 |
+
text = "edit"
|
143 |
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
144 |
|
145 |
+
yield state, state, click_state, image_input, out["mask_save_path"]
|
146 |
+
# if not args.disable_gpt and model.text_refiner:
|
147 |
+
# refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
148 |
+
# # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
149 |
+
# new_cap = refined_caption['caption']
|
150 |
+
# refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
151 |
+
# yield state, state, click_state, chat_input, refined_image_input, wiki
|
152 |
|
153 |
|
154 |
def upload_callback(image_input, state):
|
|
|
160 |
if ratio < 1.0:
|
161 |
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
162 |
print('Scaling input image to {}'.format(image_input.size))
|
163 |
+
model.segmenter.image = None
|
164 |
+
model.segmenter.image_embedding = None
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
model.segmenter.set_image(image_input)
|
166 |
+
return state, image_input, click_state, image_input
|
|
|
|
|
|
|
167 |
|
168 |
with gr.Blocks(
|
169 |
css='''
|
|
|
174 |
state = gr.State([])
|
175 |
click_state = gr.State([[],[],[]])
|
176 |
origin_image = gr.State(None)
|
177 |
+
mask_save_path = gr.State(None)
|
|
|
|
|
|
|
178 |
|
179 |
gr.Markdown(title)
|
180 |
gr.Markdown(description)
|
181 |
|
182 |
with gr.Row():
|
183 |
with gr.Column(scale=1.0):
|
184 |
+
with gr.Column(visible=True) as modules_not_need_gpt:
|
185 |
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
186 |
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
187 |
with gr.Row(scale=1.0):
|
188 |
+
point_prompt = gr.Radio(
|
189 |
+
choices=["Positive", "Negative"],
|
190 |
+
value="Positive",
|
191 |
+
label="Point Prompt",
|
192 |
+
interactive=True)
|
193 |
+
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
|
194 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
195 |
+
with gr.Column(visible=True) as modules_need_gpt:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
with gr.Row(scale=1.0):
|
197 |
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
198 |
+
|
199 |
sentiment = gr.Radio(
|
200 |
choices=["Positive", "Natural", "Negative"],
|
201 |
value="Natural",
|
|
|
216 |
step=1,
|
217 |
interactive=True,
|
218 |
label="Length",
|
219 |
+
)
|
220 |
+
|
|
|
|
|
|
|
|
|
221 |
with gr.Column(scale=0.5):
|
222 |
+
# openai_api_key = gr.Textbox(
|
223 |
+
# placeholder="Input openAI API key and press Enter (Input blank will disable GPT)",
|
224 |
+
# show_label=False,
|
225 |
+
# label = "OpenAI API Key",
|
226 |
+
# lines=1,
|
227 |
+
# type="password"
|
228 |
+
# )
|
229 |
+
# with gr.Column(visible=True) as modules_need_gpt2:
|
230 |
+
# wiki_output = gr.Textbox(lines=6, label="Wiki")
|
231 |
+
with gr.Column(visible=True) as modules_not_need_gpt2:
|
232 |
+
chatbot = gr.Chatbot(label="History",).style(height=450,scale=0.5)
|
233 |
+
with gr.Column(visible=True) as modules_need_gpt3:
|
234 |
+
chat_input = gr.Textbox(lines=1, label="Edit Prompt")
|
|
|
|
|
235 |
with gr.Row():
|
236 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
237 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
238 |
+
|
239 |
+
# openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2])
|
|
|
|
|
240 |
|
241 |
clear_button_clike.click(
|
242 |
lambda x: ([[], [], []], x, ""),
|
243 |
[origin_image],
|
244 |
+
[click_state, image_input],
|
245 |
queue=False,
|
246 |
show_progress=False
|
247 |
)
|
248 |
+
|
249 |
clear_button_image.click(
|
250 |
lambda: (None, [], [], [[], [], []], "", ""),
|
251 |
[],
|
252 |
+
[image_input, chatbot, state, click_state, origin_image],
|
253 |
queue=False,
|
254 |
show_progress=False
|
255 |
)
|
|
|
260 |
queue=False,
|
261 |
show_progress=False
|
262 |
)
|
263 |
+
|
264 |
+
|
265 |
image_input.clear(
|
266 |
lambda: (None, [], [], [[], [], []], "", ""),
|
267 |
[],
|
268 |
+
[image_input, chatbot, state, click_state, origin_image],
|
269 |
queue=False,
|
270 |
show_progress=False
|
271 |
)
|
272 |
|
273 |
+
def example_callback(x):
|
274 |
+
model.image_embedding = None
|
275 |
+
return x
|
276 |
+
|
277 |
+
gr.Examples(
|
278 |
+
examples=examples,
|
279 |
+
inputs=[example_image],
|
280 |
+
)
|
281 |
+
|
282 |
+
submit_button_text.click(
|
283 |
+
chat_with_points,
|
284 |
+
[chat_input, click_state, state, mask_save_path,image_input],
|
285 |
+
[chatbot, state, image_input]
|
286 |
+
)
|
287 |
+
|
288 |
+
|
289 |
+
image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state, image_input])
|
290 |
+
chat_input.submit(chat_with_points, [chat_input, click_state, state, mask_save_path,image_input], [chatbot, state, image_input])
|
291 |
+
example_image.change(upload_callback,[example_image, state], [state, origin_image, click_state, image_input])
|
292 |
|
293 |
# select coordinate
|
294 |
image_input.select(inference_seg_cap,
|
295 |
inputs=[
|
296 |
origin_image,
|
297 |
point_prompt,
|
|
|
298 |
language,
|
299 |
sentiment,
|
300 |
factuality,
|
301 |
length,
|
|
|
302 |
state,
|
303 |
+
click_state
|
|
|
|
|
|
|
304 |
],
|
305 |
+
outputs=[chatbot, state, click_state, image_input, mask_save_path],
|
306 |
show_progress=False, queue=True)
|
307 |
|
308 |
+
iface.queue(concurrency_count=1, api_open=False, max_size=10)
|
309 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=True)
|
caption_anything.py
CHANGED
@@ -1,45 +1,26 @@
|
|
1 |
-
|
2 |
from segmenter import build_segmenter
|
3 |
-
from text_refiner import build_text_refiner
|
4 |
import os
|
5 |
import argparse
|
6 |
import pdb
|
7 |
import time
|
8 |
from PIL import Image
|
9 |
-
|
10 |
-
|
11 |
|
12 |
class CaptionAnything():
|
13 |
-
def __init__(self, args, api_key=""
|
14 |
self.args = args
|
15 |
-
|
16 |
-
self.segmenter = build_segmenter(args.segmenter, args.device, args)
|
17 |
-
|
18 |
self.text_refiner = None
|
19 |
-
|
20 |
-
if text_refiner is not None:
|
21 |
-
self.text_refiner = text_refiner
|
22 |
-
else:
|
23 |
-
self.init_refiner(api_key)
|
24 |
-
|
25 |
-
def init_refiner(self, api_key):
|
26 |
-
try:
|
27 |
-
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
|
28 |
-
self.text_refiner.llm('hi') # test
|
29 |
-
except:
|
30 |
-
self.text_refiner = None
|
31 |
-
print('OpenAI GPT is not available')
|
32 |
|
33 |
def inference(self, image, prompt, controls, disable_gpt=False):
|
34 |
# segment with prompt
|
35 |
print("CA prompt: ", prompt, "CA controls",controls)
|
|
|
36 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
37 |
-
if self.args.enable_morphologyex:
|
38 |
-
seg_mask = 255 * seg_mask.astype(np.uint8)
|
39 |
-
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis = -1)
|
40 |
-
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel = np.ones((6, 6), np.uint8))
|
41 |
-
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel = np.ones((6, 6), np.uint8))
|
42 |
-
seg_mask = seg_mask[:,:,0] > 0
|
43 |
mask_save_path = f'result/mask_{time.time()}.png'
|
44 |
if not os.path.exists(os.path.dirname(mask_save_path)):
|
45 |
os.makedirs(os.path.dirname(mask_save_path))
|
@@ -49,24 +30,26 @@ class CaptionAnything():
|
|
49 |
seg_mask_img.save(mask_save_path)
|
50 |
print('seg_mask path: ', mask_save_path)
|
51 |
print("seg_mask.shape: ", seg_mask.shape)
|
|
|
|
|
|
|
52 |
# captioning with mask
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
'mask_save_path': mask_save_path,
|
68 |
-
'mask': seg_mask_img
|
69 |
-
'context_captions': context_captions}
|
70 |
return out
|
71 |
|
72 |
def parse_augment():
|
@@ -86,7 +69,6 @@ def parse_augment():
|
|
86 |
parser.add_argument('--disable_gpt', action="store_true")
|
87 |
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
88 |
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
89 |
-
parser.add_argument('--enable_morphologyex', action="store_true", default=False)
|
90 |
args = parser.parse_args()
|
91 |
|
92 |
if args.debug:
|
@@ -129,4 +111,4 @@ if __name__ == "__main__":
|
|
129 |
print('Language controls:\n', controls)
|
130 |
out = model.inference(image_path, prompt, controls)
|
131 |
|
132 |
-
|
|
|
1 |
+
|
2 |
from segmenter import build_segmenter
|
|
|
3 |
import os
|
4 |
import argparse
|
5 |
import pdb
|
6 |
import time
|
7 |
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
|
11 |
class CaptionAnything():
|
12 |
+
def __init__(self, args, api_key=""):
|
13 |
self.args = args
|
14 |
+
# self.captioner = build_captioner(args.captioner, args.device, args)
|
15 |
+
self.segmenter = build_segmenter(args.segmenter, args.device, args)
|
|
|
16 |
self.text_refiner = None
|
17 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def inference(self, image, prompt, controls, disable_gpt=False):
|
20 |
# segment with prompt
|
21 |
print("CA prompt: ", prompt, "CA controls",controls)
|
22 |
+
print(image)
|
23 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
mask_save_path = f'result/mask_{time.time()}.png'
|
25 |
if not os.path.exists(os.path.dirname(mask_save_path)):
|
26 |
os.makedirs(os.path.dirname(mask_save_path))
|
|
|
30 |
seg_mask_img.save(mask_save_path)
|
31 |
print('seg_mask path: ', mask_save_path)
|
32 |
print("seg_mask.shape: ", seg_mask.shape)
|
33 |
+
|
34 |
+
# mask_image = mask_image(image,np.array(seg_mask_img))
|
35 |
+
# cv2.imwrite(f'result/mask_vis.png',mask_image)
|
36 |
# captioning with mask
|
37 |
+
# if self.args.enable_reduce_tokens:
|
38 |
+
# caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
|
39 |
+
# else:
|
40 |
+
# caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
|
41 |
+
|
42 |
+
# # refining with TextRefiner
|
43 |
+
# context_captions = []
|
44 |
+
# if self.args.context_captions:
|
45 |
+
# context_captions.append(self.captioner.inference(image))
|
46 |
+
# if not disable_gpt and self.text_refiner is not None:
|
47 |
+
# refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
|
48 |
+
# else:
|
49 |
+
# refined_caption = {'raw_caption': caption}
|
50 |
+
out = {
|
51 |
'mask_save_path': mask_save_path,
|
52 |
+
'mask': seg_mask_img}
|
|
|
53 |
return out
|
54 |
|
55 |
def parse_augment():
|
|
|
69 |
parser.add_argument('--disable_gpt', action="store_true")
|
70 |
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
71 |
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
|
|
72 |
args = parser.parse_args()
|
73 |
|
74 |
if args.debug:
|
|
|
111 |
print('Language controls:\n', controls)
|
112 |
out = model.inference(image_path, prompt, controls)
|
113 |
|
114 |
+
|