Spaces:
Runtime error
Runtime error
Songwei Ge
commited on
Commit
•
61c1bd4
1
Parent(s):
ab7db7f
demo
Browse files- requirements.txt +6 -0
- sample.py +0 -109
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.11.0
|
2 |
+
torchvision==0.12.0
|
3 |
+
diffusers==0.12.1
|
4 |
+
transformers==4.25.1
|
5 |
+
numpy==1.24.2
|
6 |
+
seaborn==0.12.2
|
sample.py
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import time
|
4 |
-
import argparse
|
5 |
-
import imageio
|
6 |
-
import torch
|
7 |
-
import numpy as np
|
8 |
-
from torchvision import transforms
|
9 |
-
|
10 |
-
from models.region_diffusion import RegionDiffusion
|
11 |
-
from utils.attention_utils import get_token_maps
|
12 |
-
from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
|
13 |
-
get_attention_control_input, get_gradient_guidance_input
|
14 |
-
|
15 |
-
|
16 |
-
def main(args, param):
|
17 |
-
|
18 |
-
# Create the folder to store outputs.
|
19 |
-
run_dir = args.run_dir
|
20 |
-
os.makedirs(args.run_dir, exist_ok=True)
|
21 |
-
|
22 |
-
# Load region diffusion model.
|
23 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
24 |
-
model = RegionDiffusion(device)
|
25 |
-
|
26 |
-
# parse json to span attributes
|
27 |
-
base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
|
28 |
-
color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
|
29 |
-
param['text_input'])
|
30 |
-
|
31 |
-
# create control input for region diffusion
|
32 |
-
region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
|
33 |
-
model, base_text_prompt, style_text_prompts, footnote_text_prompts,
|
34 |
-
footnote_target_tokens, color_text_prompts, color_names)
|
35 |
-
|
36 |
-
# create control input for cross attention
|
37 |
-
text_format_dict = get_attention_control_input(
|
38 |
-
model, base_tokens, size_text_prompts_and_sizes)
|
39 |
-
|
40 |
-
# create control input for region guidance
|
41 |
-
text_format_dict, color_target_token_ids = get_gradient_guidance_input(
|
42 |
-
model, base_tokens, color_text_prompts, color_rgbs, text_format_dict)
|
43 |
-
|
44 |
-
height = param['height']
|
45 |
-
width = param['width']
|
46 |
-
seed = param['noise_index']
|
47 |
-
negative_text = param['negative_prompt']
|
48 |
-
seed_everything(seed)
|
49 |
-
|
50 |
-
# get token maps from plain text to image generation.
|
51 |
-
begin_time = time.time()
|
52 |
-
if model.attention_maps is None:
|
53 |
-
model.register_evaluation_hooks()
|
54 |
-
else:
|
55 |
-
model.reset_attention_maps()
|
56 |
-
plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
|
57 |
-
height=height, width=width, num_inference_steps=param['steps'],
|
58 |
-
guidance_scale=param['guidance_weight'])
|
59 |
-
fn_base = os.path.join(run_dir, 'seed%d_plain.png' % (seed))
|
60 |
-
imageio.imwrite(fn_base, plain_img[0])
|
61 |
-
print('time lapses to get attention maps: %.4f' % (time.time()-begin_time))
|
62 |
-
color_obj_masks = get_token_maps(
|
63 |
-
model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
|
64 |
-
model.masks = get_token_maps(
|
65 |
-
model.attention_maps, run_dir, width//8, height//8, region_target_token_ids, seed, base_tokens)
|
66 |
-
color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
|
67 |
-
interpolation=transforms.InterpolationMode.BICUBIC,
|
68 |
-
antialias=True)
|
69 |
-
for color_obj_mask in color_obj_masks]
|
70 |
-
text_format_dict['color_obj_atten'] = color_obj_masks
|
71 |
-
model.remove_evaluation_hooks()
|
72 |
-
|
73 |
-
# generate image from rich text
|
74 |
-
begin_time = time.time()
|
75 |
-
seed_everything(seed)
|
76 |
-
rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
|
77 |
-
height=height, width=width, num_inference_steps=param['steps'],
|
78 |
-
guidance_scale=param['guidance_weight'], use_grad_guidance=use_grad_guidance,
|
79 |
-
text_format_dict=text_format_dict)
|
80 |
-
print('time lapses to generate image from rich text: %.4f' %
|
81 |
-
(time.time()-begin_time))
|
82 |
-
fn_style = os.path.join(run_dir, 'seed%d_rich.png' % (seed))
|
83 |
-
imageio.imwrite(fn_style, rich_img[0])
|
84 |
-
# imageio.imwrite(fn_cat, np.concatenate([img[0], rich_img[0]], 1))
|
85 |
-
|
86 |
-
|
87 |
-
if __name__ == '__main__':
|
88 |
-
parser = argparse.ArgumentParser()
|
89 |
-
parser.add_argument('--run_dir', type=str, default='results/release/debug')
|
90 |
-
parser.add_argument('--height', type=int, default=512)
|
91 |
-
parser.add_argument('--width', type=int, default=512)
|
92 |
-
parser.add_argument('--seed', type=int, default=6)
|
93 |
-
parser.add_argument('--sample_steps', type=int, default=41)
|
94 |
-
parser.add_argument('--rich_text_json', type=str,
|
95 |
-
default='{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. There are palm trees in the background."}]}')
|
96 |
-
parser.add_argument('--negative_prompt', type=str, default='')
|
97 |
-
parser.add_argument('--guidance_weight', type=float, default=8.5)
|
98 |
-
args = parser.parse_args()
|
99 |
-
param = {
|
100 |
-
'text_input': json.loads(args.rich_text_json),
|
101 |
-
'height': args.height,
|
102 |
-
'width': args.width,
|
103 |
-
'guidance_weight': args.guidance_weight,
|
104 |
-
'steps': args.sample_steps,
|
105 |
-
'noise_index': args.seed,
|
106 |
-
'negative_prompt': args.negative_prompt,
|
107 |
-
}
|
108 |
-
|
109 |
-
main(args, param)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|