Spaces:
Runtime error
Runtime error
File size: 9,337 Bytes
4c022fe d0745b6 4c022fe d0745b6 4c022fe d0745b6 4c022fe d0745b6 4c022fe d0745b6 4c022fe d0745b6 4c022fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import os
import json
import torch
import random
import numpy as np
COLORS = {
'brown': [165, 42, 42],
'red': [255, 0, 0],
'pink': [253, 108, 158],
'orange': [255, 165, 0],
'yellow': [255, 255, 0],
'purple': [128, 0, 128],
'green': [0, 128, 0],
'blue': [0, 0, 255],
'white': [255, 255, 255],
'gray': [128, 128, 128],
'black': [0, 0, 0],
}
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def hex_to_rgb(hex_string, return_nearest_color=False, device='cuda'):
r"""
Covert Hex triplet to RGB triplet.
"""
# Remove '#' symbol if present
hex_string = hex_string.lstrip('#')
# Convert hex values to integers
red = int(hex_string[0:2], 16)
green = int(hex_string[2:4], 16)
blue = int(hex_string[4:6], 16)
rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
if return_nearest_color:
nearest_color = find_nearest_color(rgb)
return rgb.to(device), nearest_color
return rgb.to(device)
def find_nearest_color(rgb):
r"""
Find the nearest neighbor color given the RGB value.
"""
if isinstance(rgb, list) or isinstance(rgb, tuple):
rgb = torch.FloatTensor(rgb)[None, :, None, None]/255.
color_distance = torch.FloatTensor([np.linalg.norm(
rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()])
nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()]
return nearest_color
def font2style(font, device='cuda'):
r"""
Convert the font name to the style name.
"""
return {'mirza': 'Claud Monet, impressionism, oil on canvas',
'roboto': 'Ukiyoe',
'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq',
'sofia': 'Pop Art, masterpiece, andy warhol',
'slabo': 'Vincent Van Gogh',
'inconsolata': 'Pixel Art, 8 bits, 16 bits',
'ubuntu': 'Rembrandt',
'Monoton': 'neon art, colorful light, highly details, octane render',
'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
def parse_json(json_str, device):
r"""
Convert the JSON string to attributes.
"""
# initialze region-base attributes.
base_text_prompt = ''
style_text_prompts = []
footnote_text_prompts = []
footnote_target_tokens = []
color_text_prompts = []
color_rgbs = []
color_names = []
size_text_prompts_and_sizes = []
# parse the attributes from JSON.
prev_style = None
prev_color_rgb = None
use_grad_guidance = False
for span in json_str['ops']:
text_prompt = span['insert'].rstrip('\n')
base_text_prompt += span['insert'].rstrip('\n')
if text_prompt == ' ':
continue
if 'attributes' in span:
if 'font' in span['attributes']:
style = font2style(span['attributes']['font'])
if prev_style == style:
prev_text_prompt = style_text_prompts[-1].split('in the style of')[
0]
style_text_prompts[-1] = prev_text_prompt + \
' ' + text_prompt + f' in the style of {style}'
else:
style_text_prompts.append(
text_prompt + f' in the style of {style}')
prev_style = style
else:
prev_style = None
if 'link' in span['attributes']:
footnote_text_prompts.append(span['attributes']['link'])
footnote_target_tokens.append(text_prompt)
font_size = 1
if 'size' in span['attributes'] and 'strike' not in span['attributes']:
font_size = float(span['attributes']['size'][:-2])/3.
elif 'size' in span['attributes'] and 'strike' in span['attributes']:
font_size = -float(span['attributes']['size'][:-2])/3.
elif 'size' not in span['attributes'] and 'strike' not in span['attributes']:
font_size = 1
if 'color' in span['attributes']:
use_grad_guidance = True
color_rgb, nearest_color = hex_to_rgb(
span['attributes']['color'], True, device=device)
if prev_color_rgb == color_rgb:
prev_text_prompt = color_text_prompts[-1]
color_text_prompts[-1] = prev_text_prompt + \
' ' + text_prompt
else:
color_rgbs.append(color_rgb)
color_names.append(nearest_color)
color_text_prompts.append(text_prompt)
if font_size != 1:
size_text_prompts_and_sizes.append([text_prompt, font_size])
return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance
def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts,
footnote_target_tokens, color_text_prompts, color_names):
r"""
Algorithm 1 in the paper.
"""
region_text_prompts = []
region_target_token_ids = []
base_tokens = model.tokenizer._tokenize(base_text_prompt)
# process the style text prompt
for text_prompt in style_text_prompts:
region_text_prompts.append(text_prompt)
region_target_token_ids.append([])
style_tokens = model.tokenizer._tokenize(
text_prompt.split('in the style of')[0])
for style_token in style_tokens:
region_target_token_ids[-1].append(
base_tokens.index(style_token)+1)
# process the complementary text prompt
for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens):
region_target_token_ids.append([])
region_text_prompts.append(footnote_text_prompt)
style_tokens = model.tokenizer._tokenize(text_prompt)
for style_token in style_tokens:
region_target_token_ids[-1].append(
base_tokens.index(style_token)+1)
# process the color text prompt
for color_text_prompt, color_name in zip(color_text_prompts, color_names):
region_target_token_ids.append([])
region_text_prompts.append(color_name+' '+color_text_prompt)
style_tokens = model.tokenizer._tokenize(color_text_prompt)
for style_token in style_tokens:
region_target_token_ids[-1].append(
base_tokens.index(style_token)+1)
# process the remaining tokens without any attributes
region_text_prompts.append(base_text_prompt)
region_target_token_ids_all = [
id for ids in region_target_token_ids for id in ids]
target_token_ids_rest = [id for id in range(
1, len(base_tokens)+1) if id not in region_target_token_ids_all]
region_target_token_ids.append(target_token_ids_rest)
region_target_token_ids = [torch.LongTensor(
obj_token_id) for obj_token_id in region_target_token_ids]
return region_text_prompts, region_target_token_ids, base_tokens
def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes):
r"""
Control the token impact using font sizes.
"""
word_pos = []
font_sizes = []
for text_prompt, font_size in size_text_prompts_and_sizes:
size_tokens = model.tokenizer._tokenize(text_prompt)
for size_token in size_tokens:
word_pos.append(base_tokens.index(size_token)+1)
font_sizes.append(font_size)
if len(word_pos) > 0:
word_pos = torch.LongTensor(word_pos).to(model.device)
font_sizes = torch.FloatTensor(font_sizes).to(model.device)
else:
word_pos = None
font_sizes = None
text_format_dict = {
'word_pos': word_pos,
'font_size': font_sizes,
}
return text_format_dict
def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict,
guidance_start_step=999, color_guidance_weight=1):
r"""
Control the token impact using font sizes.
"""
color_target_token_ids = []
for text_prompt in color_text_prompts:
color_target_token_ids.append([])
color_tokens = model.tokenizer._tokenize(text_prompt)
for color_token in color_tokens:
color_target_token_ids[-1].append(base_tokens.index(color_token)+1)
color_target_token_ids_all = [
id for ids in color_target_token_ids for id in ids]
color_target_token_ids_rest = [id for id in range(
1, len(base_tokens)+1) if id not in color_target_token_ids_all]
color_target_token_ids.append(color_target_token_ids_rest)
color_target_token_ids = [torch.LongTensor(
obj_token_id) for obj_token_id in color_target_token_ids]
text_format_dict['target_RGB'] = color_rgbs
text_format_dict['guidance_start_step'] = guidance_start_step
text_format_dict['color_guidance_weight'] = color_guidance_weight
return text_format_dict, color_target_token_ids
|