color_demo_v1 / app.py
busraasan's picture
add font
aa8c4d9
from color_palette.regressor.config import config_to_use
import numpy as np
from PIL import Image, ImageFont, ImageDraw
from sklearn.linear_model import LinearRegression
from color_palette.model.GNN import ColorAttentionClassification
from color_palette.regressor.model import Color2CubeDataset
from color_palette.regressor.config import *
from torch.utils.data import Dataset, DataLoader
from color_palette.dataset import GraphDestijlDataset
from color_palette.config import DataConfig
import random
import os
import torch.nn.functional as F
import torch
import gradio as gr
config = DataConfig()
model_name = config.model_name
dataset_root = config.dataset
feature_size = config.feature_size
device = config.device
image_folder = "img_folder"
if not os.path.exists(image_folder):
os.mkdir(image_folder)
def train_regressor(train_loader):
X = []
y = []
for i, (input_data, target) in enumerate(train_loader):
input_data = np.squeeze(input_data)
target = np.squeeze(target)
X.append(input_data)
y.append(target)
X = np.stack(X, axis=0)
y = np.squeeze(np.stack(y, axis=0))
print("Before regressor train!\n")
reg = LinearRegression().fit(X, y)
return reg
model_weight_path = "models/" + model_name + "/weights/best.pth"
# palettes = np.load(config_to_use.save_folder+'/new_palettes_purple.npy')
# original_palettes = np.load(config_to_use.save_folder+'/original_palettes_purple.npy')
graph_test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True)
model = ColorAttentionClassification(feature_size).to(device)
model.load_state_dict(torch.load(model_weight_path)["state_dict"])
dataset = Color2CubeDataset(config=config_to_use)
train_loader = DataLoader(dataset, batch_size=1, shuffle=False)
regressor = train_regressor(train_loader=train_loader)
palette_of_the_design = [[0, 0, 0] for i in range(5)]
all_node_colors = None
class Demo:
def __init__(self, graph_dataset):
self.dataset = graph_dataset
first_sample_idx = random.randint(0, len(self.dataset)-1)
self.input_data, self.target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx)
global all_node_colors
all_node_colors = also_normal_values
self.same_indices = None
self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True)
def demo_reset(self):
first_sample_idx = random.randint(0, len(self.dataset))
self.input_data, self.target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx)
global all_node_colors
all_node_colors = also_normal_values
self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True)
def generate_img_from_palette(self, palette, canvas_size=512, is_first=False):
palette = np.array(palette).astype('int')
rgb_bg, rgb_text, rgb_text, rgb_circle, rgb_main_img, rgb_img1, rgb_img2, rgb_img3 = [tuple(color) for color in palette]
if is_first:
self.same_indices, unique_colors, _ = self.return_all_same_colors(palette=palette)
else:
_, unique_colors, _ = self.return_all_same_colors(palette=palette)
# assign the current palette using global keyword
global palette_of_the_design
palette_of_the_design = unique_colors
# Set the background color and create an empty PIL Image to fill with shapes and text
image = Image.new("RGB", (canvas_size, canvas_size), color=rgb_bg)
# Save background image
title = "Lorem Ipsum Dolor"
undertitle = "Neque porro quisquam est qui dolorem ipsum quia dolor sit amet, \n consectetur, adipisci velit..."
draw = ImageDraw.Draw(image)
# Set settings for the fonts
font_title = ImageFont.truetype("Arial.ttf", 32)
title_width, title_height = draw.textsize(title, font=font_title)
title_x = (canvas_size - title_width) // 2
title_y = (canvas_size - title_height) // 2 - 100
font_undertitle = ImageFont.truetype("Arial.ttf", 15)
text_width, text_height = draw.textsize(undertitle, font=font_undertitle)
undertitle_x = (canvas_size - text_width) // 2
undertitle_y = (canvas_size - text_height) // 2 - 50
# Draw titles
draw.text((title_x, title_y), title, fill=rgb_text, font=font_title)
draw.text((undertitle_x, undertitle_y), undertitle, fill=rgb_text, font=font_undertitle)
# Draw the circle
rad = random.randint(30, 70)
x = random.randint(400, 512-(rad+10))
y = random.randint(10, title_y-(rad+10))
draw.ellipse((x, y, x+rad, y+rad), fill=rgb_circle)
# Draw the image
for j, color in enumerate([rgb_main_img, rgb_img1, rgb_img2, rgb_img3]):
x = 512-((j+1)*60)
y = 512-((j+1)*60)
if j == 0:
rad = 80
draw.rectangle((x, y, x+rad, y+rad), fill=color)
else:
rad = 40
draw.rectangle((x, y, x+rad, y+rad), fill=color)
image.save(os.path.join("deneme.png"))
def run_model(self, input_data, target_color, node_to_mask, updated_color):
global all_node_colors
palette = np.array([color.detach().numpy()*255 for color in all_node_colors]).astype('int')
same_indices_list, unique_colors, first_indices = self.return_all_same_colors(palette)
unique_colors = unique_colors/255
selected_color = torch.Tensor(updated_color)/255
map_node_to_mask = -1
print("same indices list")
print(self.same_indices)
print("node to mask")
print(node_to_mask)
for i, idxs in enumerate(self.same_indices):
if node_to_mask in idxs:
map_node_to_mask = i
print("map node to mask: ", i)
for i, indices in enumerate(self.same_indices):
if i == 0:
# update the color [0.15, 0.4908, 0.73]
cube_num_of_selected = self.rgb2cube(selected_color*255)
one_hot = np.zeros((64,))
one_hot[int(cube_num_of_selected)] = 1.0
node_to_recommend = map_node_to_mask
input_data.x[same_indices_list[map_node_to_mask], 4:] = torch.Tensor(one_hot)
unique_colors[map_node_to_mask] = selected_color
else:
if i == map_node_to_mask:
zeroth_bin = 0
indices = same_indices_list[zeroth_bin]
node_to_recommend = 0
input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4))
node_to_mask = indices[0]
else:
node_to_recommend = i
input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4))
node_to_mask = indices[0]
out = self.forward_pass(model, input_data) # input data has one-hot color features
if torch.is_tensor(node_to_mask):
node_to_mask = node_to_mask.item()
values, values_indices = torch.topk(F.softmax(out[node_to_mask, :], dim=0), k=3, dim=0) # predict the color cube of the recommendation
prediction = values_indices.detach().numpy()[2]
# construct a palette using unique RGB palette and one-hot representation of the prediction cube.
feature_vector = self.create_rgb_and_one_hot_cube_vector(unique_colors, prediction, node_to_recommend)
# map cube to rgb color space using the regressor
recommendation = regressor.predict(feature_vector)[0]
# we now have the first set of recommendations. Now, we need to update the colors and input_data to propagate information.
# update the color in the palette and run the algorithm for rest of the palette.
# for that, first map the color to cube and convert to one_hot
input_data, unique_colors = self.update_palette(input_data, unique_colors, recommendation, same_indices_list, node_to_recommend)
# recursively do this here.
# save the results.
return np.array(unique_colors*255).astype(int)
def rgb2cube(self, color):
intervals = np.arange(0, 256, 256//4)
cube_coordinates = []
for channel in color:
i = 0
for j, value in enumerate(intervals):
if value < channel:
i = j
cube_coordinates.append(i)
cube_num = cube_coordinates[0]*1 + cube_coordinates[1]*4 + cube_coordinates[2]*4*4
return cube_num
def cube2rgb(self, cube_num):
"""
Return the start of the ranges
"""
cube_num = int(cube_num)
intervals = np.arange(0, 256, 256//4)
coor2 = cube_num // 16
coor1 = (cube_num - coor2*4*4) // 4
coor0 = cube_num - coor2*4*4 - coor1*4
return [intervals[coor0], intervals[coor1], intervals[coor2]]
def return_all_same_colors(self, palette):
indices_list = [[],[],[],[],[]]
unique_colors, first_indices = np.unique(palette, axis=0, return_index=True)
unique_colors = np.array(unique_colors)
all_colors = np.array(palette)
for idx, color in enumerate(unique_colors):
for node_num, element in enumerate(all_colors):
if np.equal(color, element).all():
indices_list[idx].append(node_num)
# these palettes and indices also include the masked color
return indices_list, unique_colors, first_indices
def update_palette(self, input_data, unique_rgb_palette, recommendation, indices_list, idx_to_idxs):
# convert prediction to one-hot vector
cube_num_of_the_changed_color = self.rgb2cube(recommendation*255)
one_hot = np.zeros((64,))
one_hot[int(cube_num_of_the_changed_color)] = 1.0
# update the feature vector accordingly for all the same colors
for idx in indices_list[idx_to_idxs]:
input_data.x[idx, 4:] = torch.Tensor(one_hot)
# update the unique color vector
unique_rgb_palette[idx_to_idxs] = recommendation
return input_data, unique_rgb_palette
def create_rgb_and_one_hot_cube_vector(self, rgb_palette, cube_num, node_to_mask):
one_hot = np.zeros((64,))
one_hot[int(cube_num)] = 1.0
removed_palette = np.delete(rgb_palette, node_to_mask, axis=0)
feature_vector = np.concatenate((removed_palette.flatten(), one_hot), axis=0)
return feature_vector.reshape(1, -1)
def create_all_one_hot_vector(self, rgb_palette, cube_num, node_to_mask):
one_hot = np.zeros((64,))
one_hot[int(cube_num)] = 1.0
removed_palette = np.delete(rgb_palette, node_to_mask, axis=0)
new_input_data = []
for color in removed_palette:
color_cube_num = self.rgb2cube(color*255)
empty_arr = np.zeros((64,))
empty_arr[int(color_cube_num)] = 1.0
new_input_data.append(empty_arr)
feature_vector = np.concatenate((np.array(new_input_data).flatten(), one_hot), axis=0)
return feature_vector.reshape(1, -1)
def forward_pass(self, model, data):
model.eval()
out = model(data.x, data.edge_index.long(), data.edge_weight)
return out
def rearrange_indices_list(self, indices_list, node_to_mask, unique_rgb_palette):
# take the node_to_mask indices to the beginning of the list
for i in range(len(indices_list)):
if node_to_mask in indices_list[i]:
index_to_pop = i
idxs = indices_list.pop(index_to_pop)
palette = unique_rgb_palette[index_to_pop]
temp_palette = np.delete(unique_rgb_palette, index_to_pop, axis=0)
unique_rgb_palette = np.concatenate(([palette], temp_palette), axis=0)
return [idxs] + indices_list, unique_rgb_palette
def update_color(self, updated_color, idx):
"""
Takes a color and assigns it to the palette and the image.
"""
idx = int(idx)
color = updated_color[1:-1].split(",")
color = [int(num) for num in color]
index_list = self.same_indices[idx]
which_one = random.randint(0, len(index_list)-1)
idx_to_update = index_list[which_one]
unique_colors = self.run_model(self.input_data, self.target_color, idx_to_update, color)
global palette_of_the_design
palette_of_the_design = unique_colors
global all_node_colors
if torch.is_tensor(all_node_colors):
all_node_colors = all_node_colors.detach().numpy()
for i, index_list in enumerate(self.same_indices):
for index in index_list:
all_node_colors[index] = unique_colors[i]
self.generate_img_from_palette(palette=[color for color in all_node_colors])
main_image = Image.open("deneme.png")
gradio_elements = []
gradio_elements.append(gr.Image(main_image, height=256, width=256))
for i in range(len(self.same_indices)):
color = unique_colors[i]
image = Image.new("RGB", (512, 512), color=tuple(color))
gradio_elements.append(gr.Image(image, height=64, width=64))
string_version = "["+str(color[0])+", "+ str(color[1])+", " + str(color[2])+"]"
gradio_elements.append(gr.Textbox(value=string_version, min_width=64))
all_node_colors = torch.Tensor(all_node_colors) / 255
return tuple(gradio_elements)
def perform_reset(button_input):
global demo
global all_node_colors
gradio_elements = []
demo.demo_reset()
main_image = Image.open("deneme.png")
gradio_elements = []
gradio_elements.append(gr.Image(main_image, height=256, width=256))
for color in palette_of_the_design:
image = Image.new("RGB", (512, 512), color=tuple(color))
gradio_elements.append(gr.Image(image, height=64, width=64))
string_version = "["+str(color[0])+", "+ str(color[1])+", " + str(color[2])+"]"
gradio_elements.append(gr.Textbox(value=string_version, min_width=64))
return tuple(gradio_elements)
demo = Demo(graph_dataset=graph_test_dataset)
# Form a gradio template to display images and update the colors.
with gr.Blocks() as project_demo:
with gr.Row():
image = Image.open("deneme.png")
design = gr.Image(image, height=256, width=256)
with gr.Row():
with gr.Column(min_width=100):
image1 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[0]))
image1_gr = gr.Image(image1, height=64, width=64)
string1 = "["+str(palette_of_the_design[0][0])+", "+ str(palette_of_the_design[0][1])+", " + str(palette_of_the_design[0][2])+"]"
color1_update = gr.Textbox(value=string1, min_width=64)
color1_button = gr.Button(value="Update Color 1", min_width=64)
with gr.Column(min_width=100):
image2 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[1]))
image2_gr = gr.Image(image2, height=64, width=64)
string2 = "["+str(palette_of_the_design[1][0])+", "+ str(palette_of_the_design[1][1])+", " + str(palette_of_the_design[1][2])+"]"
color2_update = gr.Textbox(value=string2, min_width=64)
color2_button = gr.Button(value="Update Color 2", min_width=64)
with gr.Column(min_width=100):
image3 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[2]))
image3_gr = gr.Image(image3, height=64, width=64)
string3 = "["+str(palette_of_the_design[2][0])+", "+ str(palette_of_the_design[2][1])+", " + str(palette_of_the_design[2][2])+"]"
color3_update = gr.Textbox(value=string3, min_width=64)
color3_button = gr.Button(value="Update Color 3", min_width=64)
with gr.Column(min_width=100):
image4 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[3]))
image4_gr = gr.Image(image4, height=64, width=64)
string4 = "["+str(palette_of_the_design[3][0])+", "+ str(palette_of_the_design[3][1])+", " + str(palette_of_the_design[3][2])+"]"
color4_update = gr.Textbox(value=string4, min_width=64)
color4_button = gr.Button(value="Update Color 4", min_width=64)
with gr.Column(min_width=100):
image5 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[4]))
image5_gr = gr.Image(image5, height=64, width=64)
string5 = "["+str(palette_of_the_design[4][0])+", "+ str(palette_of_the_design[4][1])+", " + str(palette_of_the_design[4][2])+"]"
color5_update = gr.Textbox(value=string5, min_width=64)
color5_button = gr.Button(value="Update Color 5", min_width=64)
with gr.Row():
reset_button = gr.Button(value="Reset the palette", min_width=64)
zero = gr.Number(value=0, visible=False)
one = gr.Number(value=1, visible=False)
two = gr.Number(value=2, visible=False)
three = gr.Number(value=3, visible=False)
four = gr.Number(value=4, visible=False)
color1_button.click(fn=demo.update_color, inputs=[color1_update, zero], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
color2_button.click(fn=demo.update_color, inputs=[color2_update, one], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
color3_button.click(fn=demo.update_color, inputs=[color3_update, two], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
color4_button.click(fn=demo.update_color, inputs=[color4_update, three], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
color5_button.click(fn=demo.update_color, inputs=[color5_update, four], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
reset_button.click(fn=perform_reset, inputs=[reset_button], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update])
project_demo.launch()