Spaces:
Running
on
Zero
Running
on
Zero
fix some bugs
Browse files- app/run_app.sh +5 -0
- app/src/brushedit_app.py +53 -42
- app/src/vlm_pipeline.py +24 -18
app/run_app.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export PYTHONPATH=.:$PYTHONPATH
|
2 |
+
|
3 |
+
export CUDA_VISIBLE_DEVICES=0
|
4 |
+
|
5 |
+
python app/src/brushedit_app.py
|
app/src/brushedit_app.py
CHANGED
@@ -337,7 +337,7 @@ vlm_type, vlm_local_path, vlm_processor, vlm_model = vlms_template[DEFAULT_VLM_M
|
|
337 |
if vlm_processor != "" and vlm_model != "":
|
338 |
vlm_model.to(device)
|
339 |
else:
|
340 |
-
gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
341 |
|
342 |
|
343 |
## init base model
|
@@ -504,7 +504,7 @@ def random_mask_func(mask, dilation_type='square', dilation_size=20):
|
|
504 |
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
505 |
dilated_mask[ellipse_mask] = True
|
506 |
else:
|
507 |
-
|
508 |
|
509 |
# use binary dilation
|
510 |
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
@@ -637,7 +637,8 @@ def process(input_image,
|
|
637 |
image_pil = input_image["background"].convert("RGB")
|
638 |
original_image = np.array(image_pil)
|
639 |
if prompt is None or prompt == "":
|
640 |
-
|
|
|
641 |
|
642 |
alpha_mask = input_image["layers"][0].split()[3]
|
643 |
input_mask = np.asarray(alpha_mask)
|
@@ -687,17 +688,23 @@ def process(input_image,
|
|
687 |
original_mask = input_mask
|
688 |
|
689 |
|
690 |
-
|
691 |
if category is not None:
|
692 |
-
pass
|
|
|
|
|
693 |
else:
|
694 |
-
|
695 |
-
|
|
|
|
|
696 |
|
|
|
697 |
if original_mask is not None:
|
698 |
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
699 |
else:
|
700 |
-
|
|
|
701 |
vlm_processor,
|
702 |
vlm_model,
|
703 |
original_image,
|
@@ -705,30 +712,37 @@ def process(input_image,
|
|
705 |
prompt,
|
706 |
device)
|
707 |
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
|
|
|
|
|
|
719 |
if original_mask.ndim == 2:
|
720 |
original_mask = original_mask[:,:,None]
|
721 |
|
722 |
|
723 |
-
if len(target_prompt)
|
724 |
-
prompt_after_apply_instruction =
|
|
|
|
|
|
|
|
|
725 |
vlm_processor,
|
726 |
vlm_model,
|
727 |
original_image,
|
728 |
prompt,
|
729 |
device)
|
730 |
-
|
731 |
-
|
732 |
|
733 |
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
734 |
|
@@ -758,7 +772,8 @@ def process(input_image,
|
|
758 |
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
759 |
# mask_image.save(f"outputs/mask_{uuid}.png")
|
760 |
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
761 |
-
|
|
|
762 |
|
763 |
|
764 |
def generate_target_prompt(input_image,
|
@@ -774,7 +789,7 @@ def generate_target_prompt(input_image,
|
|
774 |
original_image,
|
775 |
prompt,
|
776 |
device)
|
777 |
-
return prompt_after_apply_instruction
|
778 |
|
779 |
|
780 |
def process_mask(input_image,
|
@@ -1415,7 +1430,7 @@ def init_img(base,
|
|
1415 |
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1416 |
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "", "Custom resolution", False, False, example_change_times
|
1417 |
else:
|
1418 |
-
return base, original_image, None, "", None, None, None, "", "",
|
1419 |
|
1420 |
|
1421 |
def reset_func(input_image,
|
@@ -1423,7 +1438,7 @@ def reset_func(input_image,
|
|
1423 |
original_mask,
|
1424 |
prompt,
|
1425 |
target_prompt,
|
1426 |
-
|
1427 |
input_image = None
|
1428 |
original_image = None
|
1429 |
original_mask = None
|
@@ -1432,10 +1447,9 @@ def reset_func(input_image,
|
|
1432 |
masked_gallery = []
|
1433 |
result_gallery = []
|
1434 |
target_prompt = ''
|
1435 |
-
target_prompt_output = ''
|
1436 |
if torch.cuda.is_available():
|
1437 |
torch.cuda.empty_cache()
|
1438 |
-
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt,
|
1439 |
|
1440 |
|
1441 |
def update_example(example_type,
|
@@ -1458,7 +1472,8 @@ def update_example(example_type,
|
|
1458 |
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1459 |
aspect_ratio = "Custom resolution"
|
1460 |
example_change_times += 1
|
1461 |
-
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "",
|
|
|
1462 |
|
1463 |
block = gr.Blocks(
|
1464 |
theme=gr.themes.Soft(
|
@@ -1498,6 +1513,8 @@ with block as demo:
|
|
1498 |
sources=["upload"],
|
1499 |
)
|
1500 |
|
|
|
|
|
1501 |
|
1502 |
vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1503 |
with gr.Group():
|
@@ -1510,12 +1527,6 @@ with block as demo:
|
|
1510 |
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1511 |
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
|
1512 |
|
1513 |
-
|
1514 |
-
prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
|
1515 |
-
|
1516 |
-
run_button = gr.Button("💫 Run")
|
1517 |
-
|
1518 |
-
|
1519 |
with gr.Row():
|
1520 |
mask_button = gr.Button("Generate Mask")
|
1521 |
random_mask_button = gr.Button("Square/Circle Mask ")
|
@@ -1603,7 +1614,7 @@ with block as demo:
|
|
1603 |
with gr.Tab(elem_classes="feedback", label="Output"):
|
1604 |
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
|
1605 |
|
1606 |
-
target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
|
1607 |
|
1608 |
reset_button = gr.Button("Reset")
|
1609 |
|
@@ -1634,9 +1645,9 @@ with block as demo:
|
|
1634 |
input_image.upload(
|
1635 |
init_img,
|
1636 |
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1637 |
-
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt,
|
1638 |
)
|
1639 |
-
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt,
|
1640 |
|
1641 |
## vlm and base model dropdown
|
1642 |
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
@@ -1666,7 +1677,7 @@ with block as demo:
|
|
1666 |
invert_mask_state]
|
1667 |
|
1668 |
## run brushedit
|
1669 |
-
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt,
|
1670 |
|
1671 |
## mask func
|
1672 |
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
@@ -1681,10 +1692,10 @@ with block as demo:
|
|
1681 |
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1682 |
|
1683 |
## prompt func
|
1684 |
-
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt
|
1685 |
|
1686 |
## reset func
|
1687 |
-
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt
|
1688 |
|
1689 |
|
1690 |
demo.launch()
|
|
|
337 |
if vlm_processor != "" and vlm_model != "":
|
338 |
vlm_model.to(device)
|
339 |
else:
|
340 |
+
raise gr.Error("Please Download default VLM model "+ DEFAULT_VLM_MODEL_NAME +" first.")
|
341 |
|
342 |
|
343 |
## init base model
|
|
|
504 |
dilated_mask = np.zeros_like(binary_mask, dtype=bool)
|
505 |
dilated_mask[ellipse_mask] = True
|
506 |
else:
|
507 |
+
ValueError("dilation_type must be 'square' or 'ellipse'")
|
508 |
|
509 |
# use binary dilation
|
510 |
dilated_mask = np.uint8(dilated_mask[:,:,np.newaxis]) * 255
|
|
|
637 |
image_pil = input_image["background"].convert("RGB")
|
638 |
original_image = np.array(image_pil)
|
639 |
if prompt is None or prompt == "":
|
640 |
+
if target_prompt is None or target_prompt == "":
|
641 |
+
raise gr.Error("Please input your instructions, e.g., remove the xxx")
|
642 |
|
643 |
alpha_mask = input_image["layers"][0].split()[3]
|
644 |
input_mask = np.asarray(alpha_mask)
|
|
|
688 |
original_mask = input_mask
|
689 |
|
690 |
|
691 |
+
## inpainting directly if target_prompt is not None
|
692 |
if category is not None:
|
693 |
+
pass
|
694 |
+
elif target_prompt is not None and len(target_prompt) >= 1 and original_mask is not None:
|
695 |
+
pass
|
696 |
else:
|
697 |
+
try:
|
698 |
+
category = vlm_response_editing_type(vlm_processor, vlm_model, original_image, prompt, device)
|
699 |
+
except Exception as e:
|
700 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
701 |
|
702 |
+
|
703 |
if original_mask is not None:
|
704 |
original_mask = np.clip(original_mask, 0, 255).astype(np.uint8)
|
705 |
else:
|
706 |
+
try:
|
707 |
+
object_wait_for_edit = vlm_response_object_wait_for_edit(
|
708 |
vlm_processor,
|
709 |
vlm_model,
|
710 |
original_image,
|
|
|
712 |
prompt,
|
713 |
device)
|
714 |
|
715 |
+
original_mask = vlm_response_mask(vlm_processor,
|
716 |
+
vlm_model,
|
717 |
+
category,
|
718 |
+
original_image,
|
719 |
+
prompt,
|
720 |
+
object_wait_for_edit,
|
721 |
+
sam,
|
722 |
+
sam_predictor,
|
723 |
+
sam_automask_generator,
|
724 |
+
groundingdino_model,
|
725 |
+
device)
|
726 |
+
except Exception as e:
|
727 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
728 |
+
|
729 |
if original_mask.ndim == 2:
|
730 |
original_mask = original_mask[:,:,None]
|
731 |
|
732 |
|
733 |
+
if target_prompt is not None and len(target_prompt) >= 1:
|
734 |
+
prompt_after_apply_instruction = target_prompt
|
735 |
+
|
736 |
+
else:
|
737 |
+
try:
|
738 |
+
prompt_after_apply_instruction = vlm_response_prompt_after_apply_instruction(
|
739 |
vlm_processor,
|
740 |
vlm_model,
|
741 |
original_image,
|
742 |
prompt,
|
743 |
device)
|
744 |
+
except Exception as e:
|
745 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
746 |
|
747 |
generator = torch.Generator(device).manual_seed(random.randint(0, 2147483647) if randomize_seed else seed)
|
748 |
|
|
|
772 |
# image[3].save(f"outputs/image_edit_{uuid}_3.png")
|
773 |
# mask_image.save(f"outputs/mask_{uuid}.png")
|
774 |
# masked_image.save(f"outputs/masked_image_{uuid}.png")
|
775 |
+
# gr.Info(f"Target Prompt: {prompt_after_apply_instruction}", duration=16)
|
776 |
+
return image, [mask_image], [masked_image], prompt, '', False
|
777 |
|
778 |
|
779 |
def generate_target_prompt(input_image,
|
|
|
789 |
original_image,
|
790 |
prompt,
|
791 |
device)
|
792 |
+
return prompt_after_apply_instruction
|
793 |
|
794 |
|
795 |
def process_mask(input_image,
|
|
|
1430 |
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1431 |
return base, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, "", "", "", "Custom resolution", False, False, example_change_times
|
1432 |
else:
|
1433 |
+
return base, original_image, None, "", None, None, None, "", "", aspect_ratio, True, False, 0
|
1434 |
|
1435 |
|
1436 |
def reset_func(input_image,
|
|
|
1438 |
original_mask,
|
1439 |
prompt,
|
1440 |
target_prompt,
|
1441 |
+
):
|
1442 |
input_image = None
|
1443 |
original_image = None
|
1444 |
original_mask = None
|
|
|
1447 |
masked_gallery = []
|
1448 |
result_gallery = []
|
1449 |
target_prompt = ''
|
|
|
1450 |
if torch.cuda.is_available():
|
1451 |
torch.cuda.empty_cache()
|
1452 |
+
return input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, True, False
|
1453 |
|
1454 |
|
1455 |
def update_example(example_type,
|
|
|
1472 |
original_mask = np.array(mask_gallery[0]).astype(np.uint8)[:,:,None] # h,w,1
|
1473 |
aspect_ratio = "Custom resolution"
|
1474 |
example_change_times += 1
|
1475 |
+
return input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, "", False, example_change_times
|
1476 |
+
|
1477 |
|
1478 |
block = gr.Blocks(
|
1479 |
theme=gr.themes.Soft(
|
|
|
1513 |
sources=["upload"],
|
1514 |
)
|
1515 |
|
1516 |
+
prompt = gr.Textbox(label="⌨️ Instruction", placeholder="Please input your instruction.", value="",lines=1)
|
1517 |
+
run_button = gr.Button("💫 Run")
|
1518 |
|
1519 |
vlm_model_dropdown = gr.Dropdown(label="VLM model", choices=VLM_MODEL_NAMES, value=DEFAULT_VLM_MODEL_NAME, interactive=True)
|
1520 |
with gr.Group():
|
|
|
1527 |
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
|
1528 |
resize_default = gr.Checkbox(label="Short edge resize to 640px", value=True)
|
1529 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1530 |
with gr.Row():
|
1531 |
mask_button = gr.Button("Generate Mask")
|
1532 |
random_mask_button = gr.Button("Square/Circle Mask ")
|
|
|
1614 |
with gr.Tab(elem_classes="feedback", label="Output"):
|
1615 |
result_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", preview=True, height=400)
|
1616 |
|
1617 |
+
# target_prompt_output = gr.Text(label="Output Target Prompt", value="", lines=1, interactive=False)
|
1618 |
|
1619 |
reset_button = gr.Button("Reset")
|
1620 |
|
|
|
1645 |
input_image.upload(
|
1646 |
init_img,
|
1647 |
[input_image, init_type, prompt, aspect_ratio, example_change_times],
|
1648 |
+
[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, init_type, aspect_ratio, resize_default, invert_mask_state, example_change_times]
|
1649 |
)
|
1650 |
+
example_type.change(fn=update_example, inputs=[example_type, prompt, example_change_times], outputs=[input_image, prompt, original_image, original_mask, mask_gallery, masked_gallery, result_gallery, aspect_ratio, target_prompt, invert_mask_state, example_change_times])
|
1651 |
|
1652 |
## vlm and base model dropdown
|
1653 |
vlm_model_dropdown.change(fn=update_vlm_model, inputs=[vlm_model_dropdown], outputs=[status])
|
|
|
1677 |
invert_mask_state]
|
1678 |
|
1679 |
## run brushedit
|
1680 |
+
run_button.click(fn=process, inputs=ips, outputs=[result_gallery, mask_gallery, masked_gallery, prompt, target_prompt, invert_mask_state])
|
1681 |
|
1682 |
## mask func
|
1683 |
mask_button.click(fn=process_mask, inputs=[input_image, original_image, prompt, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask, category])
|
|
|
1692 |
move_down_button.click(fn=move_mask_down, inputs=[input_image, original_image, original_mask, moving_pixels, resize_default, aspect_ratio], outputs=[masked_gallery, mask_gallery, original_mask])
|
1693 |
|
1694 |
## prompt func
|
1695 |
+
generate_target_prompt_button.click(fn=generate_target_prompt, inputs=[input_image, original_image, prompt], outputs=[target_prompt])
|
1696 |
|
1697 |
## reset func
|
1698 |
+
reset_button.click(fn=reset_func, inputs=[input_image, original_image, original_mask, prompt, target_prompt], outputs=[input_image, original_image, original_mask, prompt, mask_gallery, masked_gallery, result_gallery, target_prompt, resize_default, invert_mask_state])
|
1699 |
|
1700 |
|
1701 |
demo.launch()
|
app/src/vlm_pipeline.py
CHANGED
@@ -98,10 +98,12 @@ def vlm_response_editing_type(vlm_processor,
|
|
98 |
messages = create_editing_category_messages_qwen2(editing_prompt)
|
99 |
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
105 |
|
106 |
|
107 |
### response object to be edited
|
@@ -206,17 +208,21 @@ def vlm_response_prompt_after_apply_instruction(vlm_processor,
|
|
206 |
image,
|
207 |
editing_prompt,
|
208 |
device):
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
|
98 |
messages = create_editing_category_messages_qwen2(editing_prompt)
|
99 |
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device=device)
|
100 |
|
101 |
+
try:
|
102 |
+
for category_name in ["Addition","Remove","Local","Global","Background"]:
|
103 |
+
if category_name.lower() in response_str.lower():
|
104 |
+
return category_name
|
105 |
+
except Exception as e:
|
106 |
+
raise gr.Error("Please input OpenAI API Key. Or please input correct commands, including add, delete, and modify commands. If it still does not work, please switch to a more powerful VLM.")
|
107 |
|
108 |
|
109 |
### response object to be edited
|
|
|
208 |
image,
|
209 |
editing_prompt,
|
210 |
device):
|
211 |
+
|
212 |
+
try:
|
213 |
+
if isinstance(vlm_model, OpenAI):
|
214 |
+
base64_image = encode_image(image)
|
215 |
+
messages = create_apply_editing_messages_gpt4o(editing_prompt, base64_image)
|
216 |
+
response_str = run_gpt4o_vl_inference(vlm_model, messages)
|
217 |
+
elif isinstance(vlm_model, LlavaNextForConditionalGeneration):
|
218 |
+
messages = create_apply_editing_messages_llava(editing_prompt)
|
219 |
+
response_str = run_llava_next_inference(vlm_processor, vlm_model, messages, image, device)
|
220 |
+
elif isinstance(vlm_model, Qwen2VLForConditionalGeneration):
|
221 |
+
base64_image = encode_image(image)
|
222 |
+
messages = create_apply_editing_messages_qwen2(editing_prompt, base64_image)
|
223 |
+
response_str = run_qwen2_vl_inference(vlm_processor, vlm_model, messages, image, device)
|
224 |
+
else:
|
225 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
226 |
+
except Exception as e:
|
227 |
+
raise gr.Error("Please select the correct VLM model and input the correct API Key first!")
|
228 |
+
return response_str
|