Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import torch | |
import re | |
from PIL import Image | |
from tqdm import tqdm | |
from train.scripts.generate_lm_multiple import gen_sequence, build_visorgpt | |
from utils.seq2coord import gen_cond_mask | |
from visor_gligen.gligen_inference_box import gligen_infer, build_gligen_model | |
from visor_controlnet.gradio_pose2image_v2 import control_infer, build_control_model, build_controlv11_model | |
# init models | |
visorgpt_config_path = 'train/models/gpt2/config.json' | |
visorgpt_model_path = 'demo/ckpts/visorgpt/visorgpt_dagger_ta_tb.pt' | |
visorgpt_vocab_path = 'train/models/google_uncased_en_coord_vocab.txt' | |
# control_model_path = 'demo/ckpts/controlnet/control_sd15_openpose.pth' | |
control_model_path = 'demo/ckpts/controlnet/control_v11p_sd15_openpose.pth' # v1.1 | |
control_sd_path = 'demo/ckpts/controlnet/v1-5-pruned-emaonly.safetensors' | |
control_model_config = 'demo/ckpts/controlnet/cldm_v15.yaml' | |
gligen_model_path = 'demo/ckpts/gligen/diffusion_pytorch_model_box.bin' | |
visorgpt_args, visorgpt_model = build_visorgpt(model_config=visorgpt_config_path, | |
model_path=visorgpt_model_path, | |
vocab_path=visorgpt_vocab_path) | |
control_model, ddim_sampler = build_controlv11_model(model_path=control_model_path, | |
sd_path=control_sd_path, | |
config_path=control_model_config) | |
# build gligen model | |
g_model, g_autoencoder, g_text_encoder, g_diffusion, \ | |
g_config, g_grounding_tokenizer_input = build_gligen_model(ckpt=gligen_model_path) | |
# maximum number of instances | |
max_num_keypoint = 16 | |
max_num_bbox = 16 | |
max_num_mask = 8 | |
def generate_sequence(gen_type, | |
data_type, | |
instance_size, | |
num_instance, | |
object_name_inbox): | |
ctn = True | |
if gen_type == 'key point': | |
num_keypoint = 18 | |
if num_instance > max_num_keypoint: | |
num_instance = max_num_keypoint | |
seq_prompt = '; '.join([gen_type, data_type, instance_size, str(num_instance), str(num_keypoint)]) + ' ; [person' | |
elif gen_type == 'box' or gen_type == 'mask': | |
if not object_name_inbox.strip(): | |
if gen_type == 'mask': | |
object_name_inbox = "bottle; cup" | |
else: | |
if data_type == 'object centric': | |
object_name_inbox = "great white shark" | |
else: | |
object_name_inbox = "person; frisbee" | |
num_keypoint = 0 | |
if gen_type == 'mask': | |
if num_instance > max_num_mask: | |
num_instance = max_num_mask | |
if gen_type == 'box': | |
if num_instance > max_num_bbox: | |
num_instance = max_num_bbox | |
if data_type == 'object centric': | |
num_instance = 1 | |
objects = ', '.join(object_name_inbox.strip().split(";")) | |
seq_prompt = '; '.join([gen_type, data_type, instance_size, | |
str(num_instance), str(num_keypoint)]) + '; ' + objects | |
if len(object_name_inbox.split(';')) > num_instance: | |
return { | |
raw_sequence: gr.update( | |
value="The umber of category names should be less than the number of instances, please try again :)", | |
visible=True) | |
} | |
print("input prompt: \n", seq_prompt) | |
sequence = gen_sequence(visorgpt_args, visorgpt_model, seq_prompt) | |
assert isinstance(sequence, list) | |
try: | |
cond_mask, cond_json = gen_cond_mask(sequence, ctn) | |
if gen_type == 'key point': | |
ori_sequence = cond_json[2]['sequences'][0][0] + '[SEP]' | |
elif gen_type == 'box': | |
ori_sequence = cond_json[0]['sequences'][0][0] + '[SEP]' | |
elif gen_type == 'mask': | |
ori_sequence = cond_json[1]['sequences'][0][0] + '[SEP]' | |
except: | |
cond_mask, cond_json = gen_cond_mask(sequence, not ctn) | |
if gen_type == 'key point': | |
ori_sequence = cond_json[2]['sequences'][0][0] + '[SEP]' | |
elif gen_type == 'box': | |
ori_sequence = cond_json[0]['sequences'][0][0] + '[SEP]' | |
elif gen_type == 'mask': | |
ori_sequence = cond_json[1]['sequences'][0][0] + '[SEP]' | |
ret_img = Image.fromarray(cond_mask) | |
if not gen_type == 'mask': | |
return { | |
result_gallery: [ret_img], | |
raw_sequence: gr.update(value=ori_sequence, visible=True), | |
images_button: gr.update(visible=True), | |
text_container: cond_json, | |
sequence_container: ori_sequence | |
} | |
else: | |
return { | |
result_gallery: [ret_img], | |
raw_sequence: gr.update(value=ori_sequence, visible=True), | |
images_button: gr.update(visible=False), | |
text_container: cond_json, | |
sequence_container: ori_sequence | |
} | |
def add_contents(gen_type, | |
data_type, | |
instance_size, | |
num_instance, | |
object_name_inbox, | |
num_continuous_gen, | |
global_seq): | |
ctn = True | |
if gen_type == 'key point': | |
num_keypoint = 18 | |
seq_prompt = '; '.join([gen_type, data_type, instance_size, str(num_instance), str(num_keypoint)]) + ' ; [person' | |
if num_continuous_gen: | |
ctn = True | |
cur_instance = int(global_seq.split(';')[3].strip()) | |
new_number = cur_instance + num_continuous_gen | |
if new_number > max_num_keypoint: | |
new_number = max_num_keypoint | |
# prompt type a | |
if global_seq.split(';')[5].find('[') == -1: | |
global_seq = global_seq.replace('[CLS]', '').replace('[SEP]', '') | |
objects = re.findall(re.compile(r'[\[](.*?)[]]', re.S), global_seq) | |
objects = ' '.join(['[ person' + x + ']' for x in objects]) | |
seq_prompt = '; '.join([gen_type, data_type, instance_size, str(new_number), str(num_keypoint), objects]) | |
# prompt type b | |
else: | |
global_seq = global_seq.replace('[CLS]', '').replace('[SEP]', '') | |
seq_list = global_seq.split(';') | |
seq_list[3] = str(new_number) | |
seq_prompt = ';'.join(seq_list) | |
elif gen_type == 'box' or gen_type == 'mask': | |
num_keypoint = 0 | |
if data_type == 'object centric': | |
num_instance = 1 | |
objects = ', '.join(object_name_inbox.strip().split(";")) | |
seq_prompt = '; '.join([gen_type, data_type, instance_size, | |
str(num_instance), str(num_keypoint)]) + '; ' + objects | |
if len(object_name_inbox.split(';')) > num_instance: | |
return { | |
raw_sequence: gr.update(value=f"The umber of category names should be less than the number of instances, please try again :)", visible=True) | |
} | |
if num_continuous_gen: | |
cur_instance = int(global_seq.split(';')[3].strip()) | |
new_number = cur_instance + num_continuous_gen | |
if gen_type == 'mask': | |
if new_number > max_num_mask: | |
new_number = max_num_mask | |
if gen_type == 'box': | |
if new_number > max_num_bbox: | |
new_number = max_num_bbox | |
# prompt type a | |
if global_seq.split(';')[5].find('[') == -1: | |
global_seq = global_seq.replace('[CLS]', '').replace('[SEP]', '') | |
coords = re.findall(re.compile(r'[\[](.*?)[]]', re.S), global_seq) | |
objects = global_seq.split(';')[5].split(',') | |
objects = ' '.join(['[ ' + objects[i] + coords[i] + ']' for i in range(len(coords))]) | |
seq_prompt = '; '.join([gen_type, data_type, instance_size, str(new_number), str(num_keypoint), objects]) | |
# prompt type b | |
else: | |
global_seq = global_seq.replace('[CLS]', '').replace('[SEP]', '') | |
seq_list = global_seq.split(';') | |
seq_list[3] = str(new_number) | |
seq_prompt = ';'.join(seq_list) | |
# import ipdb;ipdb.set_trace() | |
print("input prompt: \n", seq_prompt) | |
with torch.no_grad(): | |
sequence = gen_sequence(visorgpt_args, visorgpt_model, seq_prompt) | |
torch.cuda.empty_cache() | |
assert isinstance(sequence, list) | |
try: | |
cond_mask, cond_json = gen_cond_mask(sequence, ctn) | |
if gen_type == 'key point': | |
ori_sequence = cond_json[2]['sequences'][0][0] + '[SEP]' | |
elif gen_type == 'box': | |
ori_sequence = cond_json[0]['sequences'][0][0] + '[SEP]' | |
elif gen_type == 'mask': | |
ori_sequence = cond_json[1]['sequences'][0][0] + '[SEP]' | |
except: | |
cond_mask, cond_json = gen_cond_mask(sequence, not ctn) | |
if gen_type == 'key point': | |
ori_sequence = cond_json[2]['sequences'][0][0] + '[SEP]' | |
elif gen_type == 'box': | |
ori_sequence = cond_json[0]['sequences'][0][0] + '[SEP]' | |
elif gen_type == 'mask': | |
ori_sequence = cond_json[1]['sequences'][0][0] + '[SEP]' | |
ret_img = Image.fromarray(cond_mask) | |
if not gen_type == 'mask': | |
return { | |
result_gallery: [ret_img], | |
raw_sequence: gr.update(value=ori_sequence, visible=True), | |
images_button: gr.update(visible=True), | |
text_container: cond_json, | |
sequence_container: ori_sequence | |
} | |
else: | |
return { | |
result_gallery: [ret_img], | |
raw_sequence: gr.update(value=ori_sequence, visible=True), | |
images_button: gr.update(visible=False), | |
text_container: cond_json, | |
sequence_container: ori_sequence | |
} | |
def generate_images(gen_type, | |
num_samples, | |
ddim_steps, | |
object_prompt, | |
seed, | |
global_text, | |
global_seq): | |
if gen_type == 'key point': | |
data = global_text[2]['keypoints'] | |
idx = np.arange(len(data)) | |
split_idx = list(np.array_split(idx, 1)[0]) | |
for idx in tqdm(split_idx): | |
item = data[idx] | |
keypoint_list = [] | |
for ins in item: | |
kv = list(ins.items())[0] | |
keypoint = (np.array(kv[1])).tolist() | |
keypoint_list.append(keypoint) | |
with torch.no_grad(): | |
ret_img = control_infer(model=control_model, | |
ddim_sampler=ddim_sampler, | |
keypoint_list=keypoint_list, | |
prompt=object_prompt.strip(), | |
num_samples=num_samples, | |
ddim_steps=ddim_steps, | |
seed=seed) | |
torch.cuda.empty_cache() | |
elif gen_type == 'box': | |
data = global_text[0]['bboxes'] | |
with torch.no_grad(): | |
ret_img = gligen_infer(model=g_model, | |
autoencoder=g_autoencoder, | |
text_encoder=g_text_encoder, | |
diffusion=g_diffusion, | |
config=g_config, | |
grounding_tokenizer_input=g_grounding_tokenizer_input, | |
context_prompt=object_prompt.strip(), | |
bbox_lists=data, | |
ddim_steps=ddim_steps, | |
batch_size=num_samples, | |
seed=seed) | |
torch.cuda.empty_cache() | |
if not gen_type == 'mask': | |
return { | |
result_gallery: ret_img, | |
text_container: global_text, | |
sequence_container: global_seq | |
} | |
else: | |
return { | |
raw_sequence: "sequence to mask is not supported yet :)", | |
text_container: global_text, | |
sequence_container: global_seq | |
} | |
def object_name_inbox_fn(gen_type): | |
if gen_type == 'key point': | |
return { | |
object_name_inbox: gr.update(visible=False), | |
data_type: gr.update(choices=['multiple instances']), | |
images_button: gr.update(value='Synthesize images using ControlNet'), | |
ddim_steps: gr.update(value=20), | |
object_prompt: gr.update(placeholder='in suit'), | |
num_instance: gr.update(visible=True, minimum=1, maximum=16, value=2, step=1), | |
sequence_container: None | |
} | |
elif gen_type == 'box': | |
return { | |
object_name_inbox: gr.update(visible=True, value='person; frisbee'), | |
data_type: gr.update(choices=['multiple instances', 'object centric']), | |
images_button: gr.update(value='Synthesize images using GLIGEN'), | |
ddim_steps: gr.update(value=50), | |
object_prompt: gr.update(placeholder='man and frisbee'), | |
num_instance: gr.update(visible=True, minimum=1, maximum=16, value=2, step=1), | |
sequence_container: None | |
} | |
elif gen_type == 'mask': | |
return { | |
object_name_inbox: gr.update(visible=True, | |
label="MS COCO categories to be generated (separated by semicolon)", value='bottle; cup'), | |
data_type: gr.update(choices=['multiple instances']), | |
images_button: gr.update(value='Synthesize images using GLIGEN'), | |
ddim_steps: gr.update(value=50), | |
object_prompt: gr.update(placeholder='bottle and cup'), | |
num_instance: gr.update(visible=True, minimum=1, maximum=8, value=2, step=1), | |
sequence_container: None | |
} | |
def instance_type_change_fn(data_type): | |
if data_type == 'multiple instances': | |
return { | |
md_title: gr.update(visible=True), | |
num_continuous_gen: gr.update(visible=True), | |
continuous_btn: gr.update(visible=True), | |
object_name_inbox: gr.update(label="MS COCO categories to be generated (separated by semicolon)", value='person; frisbee'), | |
object_prompt: gr.update(placeholder='man and frisbee'), | |
num_instance: gr.update(visible=True, minimum=1, maximum=16, value=2, step=1), | |
} | |
elif data_type == 'object centric': | |
return { | |
md_title: gr.update(visible=False), | |
num_continuous_gen: gr.update(visible=False), | |
continuous_btn: gr.update(visible=False), | |
object_name_inbox: gr.update(label="ImageNet-1K categories to be generated", value='great white shark'), | |
object_prompt: gr.update(placeholder='great white shark'), | |
num_instance: gr.update(visible=False, value=1), | |
} | |
block = gr.Blocks() | |
with block: | |
text_container = gr.State() | |
sequence_container = gr.State() | |
gr.Markdown('<div align=center> <img src="file/visorgpt_title_all.jpg" width = "100%" height = "100%" /> </div>') | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Params to generate sequences") | |
gen_type = gr.inputs.Dropdown(choices=['key point', 'box', 'mask'], type='value', default='key point', label='Anotation Type') | |
data_type = gr.inputs.Dropdown(choices=['multiple instances'], type='value', default='multiple instances', label='Data Type') | |
instance_size = gr.inputs.Dropdown(choices=['small', 'medium', 'large'], type='value', default='large', label='Instance Size') | |
num_instance = gr.Slider(label="Number of instances per image", minimum=1, maximum=16, value=2, step=1) | |
object_name_inbox = gr.Textbox(label="MS COCO categories to be generated (separated by semicolon)", placeholder="person; frisbee", visible=False) | |
sequence_button = gr.Button(value="Customize sequential output") | |
md_title = gr.Markdown("### Continuous generation (Optional)") | |
num_continuous_gen = gr.Slider(label="Add instances to the current scene", minimum=1, maximum=16, value=1, step=1) | |
continuous_btn = gr.Button(value="Add instances to the current scene") | |
gr.Markdown("### Params to synthesize images") | |
object_prompt = gr.Textbox(label="Context Prompt", placeholder="in suit", visible=True) | |
num_samples = gr.Slider(label="Batch Size", minimum=1, maximum=36, value=1, step=1) | |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) | |
images_button = gr.Button(value="Synthesize images using ControlNet", visible=False) | |
with gr.Column(): | |
raw_sequence = gr.Textbox(label="Raw Sequence", visible=False) | |
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto', preview=True) | |
gen_type.change(object_name_inbox_fn, inputs=[gen_type], | |
outputs=[object_name_inbox, data_type, images_button, ddim_steps, object_prompt, num_instance, sequence_container]) | |
data_type.change(instance_type_change_fn, inputs=[data_type], | |
outputs=[md_title, num_continuous_gen, continuous_btn, object_name_inbox, object_prompt, num_instance]) | |
ips = [gen_type, data_type, instance_size, num_instance, object_name_inbox] | |
sequence_button.click(fn=generate_sequence, inputs=ips, outputs=[result_gallery, raw_sequence, images_button, text_container, sequence_container]) | |
ips = [gen_type, data_type, instance_size, num_instance, object_name_inbox, num_continuous_gen, sequence_container] | |
continuous_btn.click(fn=add_contents, inputs=ips, outputs=[result_gallery, raw_sequence, images_button, text_container, sequence_container]) | |
ips = [gen_type, num_samples, ddim_steps, object_prompt, seed, text_container, sequence_container] | |
images_button.click(fn=generate_images, inputs=ips, outputs=[result_gallery, raw_sequence, text_container, sequence_container]) | |
block.launch(server_name='0.0.0.0', server_port=10086, debug=False, share=False) | |