Spaces:
Running
Running
from color_palette.regressor.config import config_to_use | |
import numpy as np | |
from PIL import Image, ImageFont, ImageDraw | |
from sklearn.linear_model import LinearRegression | |
from color_palette.model.GNN import ColorAttentionClassification | |
from color_palette.regressor.model import Color2CubeDataset | |
from color_palette.regressor.config import * | |
from torch.utils.data import Dataset, DataLoader | |
from color_palette.dataset import GraphDestijlDataset | |
from color_palette.config import DataConfig | |
import random | |
import os | |
import torch.nn.functional as F | |
import torch | |
import gradio as gr | |
config = DataConfig() | |
model_name = config.model_name | |
dataset_root = config.dataset | |
feature_size = config.feature_size | |
device = config.device | |
image_folder = "img_folder" | |
if not os.path.exists(image_folder): | |
os.mkdir(image_folder) | |
def train_regressor(train_loader): | |
X = [] | |
y = [] | |
for i, (input_data, target) in enumerate(train_loader): | |
input_data = np.squeeze(input_data) | |
target = np.squeeze(target) | |
X.append(input_data) | |
y.append(target) | |
X = np.stack(X, axis=0) | |
y = np.squeeze(np.stack(y, axis=0)) | |
print("Before regressor train!\n") | |
reg = LinearRegression().fit(X, y) | |
return reg | |
model_weight_path = "models/" + model_name + "/weights/best.pth" | |
# palettes = np.load(config_to_use.save_folder+'/new_palettes_purple.npy') | |
# original_palettes = np.load(config_to_use.save_folder+'/original_palettes_purple.npy') | |
graph_test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True) | |
model = ColorAttentionClassification(feature_size).to(device) | |
model.load_state_dict(torch.load(model_weight_path)["state_dict"]) | |
dataset = Color2CubeDataset(config=config_to_use) | |
train_loader = DataLoader(dataset, batch_size=1, shuffle=False) | |
regressor = train_regressor(train_loader=train_loader) | |
palette_of_the_design = [[0, 0, 0] for i in range(5)] | |
all_node_colors = None | |
class Demo: | |
def __init__(self, graph_dataset): | |
self.dataset = graph_dataset | |
first_sample_idx = random.randint(0, len(self.dataset)-1) | |
self.input_data, self.target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx) | |
global all_node_colors | |
all_node_colors = also_normal_values | |
self.same_indices = None | |
self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True) | |
def demo_reset(self): | |
first_sample_idx = random.randint(0, len(self.dataset)) | |
self.input_data, self.target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx) | |
global all_node_colors | |
all_node_colors = also_normal_values | |
self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True) | |
def generate_img_from_palette(self, palette, canvas_size=512, is_first=False): | |
palette = np.array(palette).astype('int') | |
rgb_bg, rgb_text, rgb_text, rgb_circle, rgb_main_img, rgb_img1, rgb_img2, rgb_img3 = [tuple(color) for color in palette] | |
if is_first: | |
self.same_indices, unique_colors, _ = self.return_all_same_colors(palette=palette) | |
else: | |
_, unique_colors, _ = self.return_all_same_colors(palette=palette) | |
# assign the current palette using global keyword | |
global palette_of_the_design | |
palette_of_the_design = unique_colors | |
# Set the background color and create an empty PIL Image to fill with shapes and text | |
image = Image.new("RGB", (canvas_size, canvas_size), color=rgb_bg) | |
# Save background image | |
title = "Lorem Ipsum Dolor" | |
undertitle = "Neque porro quisquam est qui dolorem ipsum quia dolor sit amet, \n consectetur, adipisci velit..." | |
draw = ImageDraw.Draw(image) | |
# Set settings for the fonts | |
font_title = ImageFont.truetype("Arial.ttf", 32) | |
title_width, title_height = draw.textsize(title, font=font_title) | |
title_x = (canvas_size - title_width) // 2 | |
title_y = (canvas_size - title_height) // 2 - 100 | |
font_undertitle = ImageFont.truetype("Arial.ttf", 15) | |
text_width, text_height = draw.textsize(undertitle, font=font_undertitle) | |
undertitle_x = (canvas_size - text_width) // 2 | |
undertitle_y = (canvas_size - text_height) // 2 - 50 | |
# Draw titles | |
draw.text((title_x, title_y), title, fill=rgb_text, font=font_title) | |
draw.text((undertitle_x, undertitle_y), undertitle, fill=rgb_text, font=font_undertitle) | |
# Draw the circle | |
rad = random.randint(30, 70) | |
x = random.randint(400, 512-(rad+10)) | |
y = random.randint(10, title_y-(rad+10)) | |
draw.ellipse((x, y, x+rad, y+rad), fill=rgb_circle) | |
# Draw the image | |
for j, color in enumerate([rgb_main_img, rgb_img1, rgb_img2, rgb_img3]): | |
x = 512-((j+1)*60) | |
y = 512-((j+1)*60) | |
if j == 0: | |
rad = 80 | |
draw.rectangle((x, y, x+rad, y+rad), fill=color) | |
else: | |
rad = 40 | |
draw.rectangle((x, y, x+rad, y+rad), fill=color) | |
image.save(os.path.join("deneme.png")) | |
def run_model(self, input_data, target_color, node_to_mask, updated_color): | |
global all_node_colors | |
palette = np.array([color.detach().numpy()*255 for color in all_node_colors]).astype('int') | |
same_indices_list, unique_colors, first_indices = self.return_all_same_colors(palette) | |
unique_colors = unique_colors/255 | |
selected_color = torch.Tensor(updated_color)/255 | |
map_node_to_mask = -1 | |
print("same indices list") | |
print(self.same_indices) | |
print("node to mask") | |
print(node_to_mask) | |
for i, idxs in enumerate(self.same_indices): | |
if node_to_mask in idxs: | |
map_node_to_mask = i | |
print("map node to mask: ", i) | |
for i, indices in enumerate(self.same_indices): | |
if i == 0: | |
# update the color [0.15, 0.4908, 0.73] | |
cube_num_of_selected = self.rgb2cube(selected_color*255) | |
one_hot = np.zeros((64,)) | |
one_hot[int(cube_num_of_selected)] = 1.0 | |
node_to_recommend = map_node_to_mask | |
input_data.x[same_indices_list[map_node_to_mask], 4:] = torch.Tensor(one_hot) | |
unique_colors[map_node_to_mask] = selected_color | |
else: | |
if i == map_node_to_mask: | |
zeroth_bin = 0 | |
indices = same_indices_list[zeroth_bin] | |
node_to_recommend = 0 | |
input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4)) | |
node_to_mask = indices[0] | |
else: | |
node_to_recommend = i | |
input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4)) | |
node_to_mask = indices[0] | |
out = self.forward_pass(model, input_data) # input data has one-hot color features | |
if torch.is_tensor(node_to_mask): | |
node_to_mask = node_to_mask.item() | |
values, values_indices = torch.topk(F.softmax(out[node_to_mask, :], dim=0), k=3, dim=0) # predict the color cube of the recommendation | |
prediction = values_indices.detach().numpy()[2] | |
# construct a palette using unique RGB palette and one-hot representation of the prediction cube. | |
feature_vector = self.create_rgb_and_one_hot_cube_vector(unique_colors, prediction, node_to_recommend) | |
# map cube to rgb color space using the regressor | |
recommendation = regressor.predict(feature_vector)[0] | |
# we now have the first set of recommendations. Now, we need to update the colors and input_data to propagate information. | |
# update the color in the palette and run the algorithm for rest of the palette. | |
# for that, first map the color to cube and convert to one_hot | |
input_data, unique_colors = self.update_palette(input_data, unique_colors, recommendation, same_indices_list, node_to_recommend) | |
# recursively do this here. | |
# save the results. | |
return np.array(unique_colors*255).astype(int) | |
def rgb2cube(self, color): | |
intervals = np.arange(0, 256, 256//4) | |
cube_coordinates = [] | |
for channel in color: | |
i = 0 | |
for j, value in enumerate(intervals): | |
if value < channel: | |
i = j | |
cube_coordinates.append(i) | |
cube_num = cube_coordinates[0]*1 + cube_coordinates[1]*4 + cube_coordinates[2]*4*4 | |
return cube_num | |
def cube2rgb(self, cube_num): | |
""" | |
Return the start of the ranges | |
""" | |
cube_num = int(cube_num) | |
intervals = np.arange(0, 256, 256//4) | |
coor2 = cube_num // 16 | |
coor1 = (cube_num - coor2*4*4) // 4 | |
coor0 = cube_num - coor2*4*4 - coor1*4 | |
return [intervals[coor0], intervals[coor1], intervals[coor2]] | |
def return_all_same_colors(self, palette): | |
indices_list = [[],[],[],[],[]] | |
unique_colors, first_indices = np.unique(palette, axis=0, return_index=True) | |
unique_colors = np.array(unique_colors) | |
all_colors = np.array(palette) | |
for idx, color in enumerate(unique_colors): | |
for node_num, element in enumerate(all_colors): | |
if np.equal(color, element).all(): | |
indices_list[idx].append(node_num) | |
# these palettes and indices also include the masked color | |
return indices_list, unique_colors, first_indices | |
def update_palette(self, input_data, unique_rgb_palette, recommendation, indices_list, idx_to_idxs): | |
# convert prediction to one-hot vector | |
cube_num_of_the_changed_color = self.rgb2cube(recommendation*255) | |
one_hot = np.zeros((64,)) | |
one_hot[int(cube_num_of_the_changed_color)] = 1.0 | |
# update the feature vector accordingly for all the same colors | |
for idx in indices_list[idx_to_idxs]: | |
input_data.x[idx, 4:] = torch.Tensor(one_hot) | |
# update the unique color vector | |
unique_rgb_palette[idx_to_idxs] = recommendation | |
return input_data, unique_rgb_palette | |
def create_rgb_and_one_hot_cube_vector(self, rgb_palette, cube_num, node_to_mask): | |
one_hot = np.zeros((64,)) | |
one_hot[int(cube_num)] = 1.0 | |
removed_palette = np.delete(rgb_palette, node_to_mask, axis=0) | |
feature_vector = np.concatenate((removed_palette.flatten(), one_hot), axis=0) | |
return feature_vector.reshape(1, -1) | |
def create_all_one_hot_vector(self, rgb_palette, cube_num, node_to_mask): | |
one_hot = np.zeros((64,)) | |
one_hot[int(cube_num)] = 1.0 | |
removed_palette = np.delete(rgb_palette, node_to_mask, axis=0) | |
new_input_data = [] | |
for color in removed_palette: | |
color_cube_num = self.rgb2cube(color*255) | |
empty_arr = np.zeros((64,)) | |
empty_arr[int(color_cube_num)] = 1.0 | |
new_input_data.append(empty_arr) | |
feature_vector = np.concatenate((np.array(new_input_data).flatten(), one_hot), axis=0) | |
return feature_vector.reshape(1, -1) | |
def forward_pass(self, model, data): | |
model.eval() | |
out = model(data.x, data.edge_index.long(), data.edge_weight) | |
return out | |
def rearrange_indices_list(self, indices_list, node_to_mask, unique_rgb_palette): | |
# take the node_to_mask indices to the beginning of the list | |
for i in range(len(indices_list)): | |
if node_to_mask in indices_list[i]: | |
index_to_pop = i | |
idxs = indices_list.pop(index_to_pop) | |
palette = unique_rgb_palette[index_to_pop] | |
temp_palette = np.delete(unique_rgb_palette, index_to_pop, axis=0) | |
unique_rgb_palette = np.concatenate(([palette], temp_palette), axis=0) | |
return [idxs] + indices_list, unique_rgb_palette | |
def update_color(self, updated_color, idx): | |
""" | |
Takes a color and assigns it to the palette and the image. | |
""" | |
idx = int(idx) | |
color = updated_color[1:-1].split(",") | |
color = [int(num) for num in color] | |
index_list = self.same_indices[idx] | |
which_one = random.randint(0, len(index_list)-1) | |
idx_to_update = index_list[which_one] | |
unique_colors = self.run_model(self.input_data, self.target_color, idx_to_update, color) | |
global palette_of_the_design | |
palette_of_the_design = unique_colors | |
global all_node_colors | |
if torch.is_tensor(all_node_colors): | |
all_node_colors = all_node_colors.detach().numpy() | |
for i, index_list in enumerate(self.same_indices): | |
for index in index_list: | |
all_node_colors[index] = unique_colors[i] | |
self.generate_img_from_palette(palette=[color for color in all_node_colors]) | |
main_image = Image.open("deneme.png") | |
gradio_elements = [] | |
gradio_elements.append(gr.Image(main_image, height=256, width=256)) | |
for i in range(len(self.same_indices)): | |
color = unique_colors[i] | |
image = Image.new("RGB", (512, 512), color=tuple(color)) | |
gradio_elements.append(gr.Image(image, height=64, width=64)) | |
string_version = "["+str(color[0])+", "+ str(color[1])+", " + str(color[2])+"]" | |
gradio_elements.append(gr.Textbox(value=string_version, min_width=64)) | |
all_node_colors = torch.Tensor(all_node_colors) / 255 | |
return tuple(gradio_elements) | |
def perform_reset(button_input): | |
global demo | |
global all_node_colors | |
gradio_elements = [] | |
demo.demo_reset() | |
main_image = Image.open("deneme.png") | |
gradio_elements = [] | |
gradio_elements.append(gr.Image(main_image, height=256, width=256)) | |
for color in palette_of_the_design: | |
image = Image.new("RGB", (512, 512), color=tuple(color)) | |
gradio_elements.append(gr.Image(image, height=64, width=64)) | |
string_version = "["+str(color[0])+", "+ str(color[1])+", " + str(color[2])+"]" | |
gradio_elements.append(gr.Textbox(value=string_version, min_width=64)) | |
return tuple(gradio_elements) | |
demo = Demo(graph_dataset=graph_test_dataset) | |
# Form a gradio template to display images and update the colors. | |
with gr.Blocks() as project_demo: | |
with gr.Row(): | |
image = Image.open("deneme.png") | |
design = gr.Image(image, height=256, width=256) | |
with gr.Row(): | |
with gr.Column(min_width=100): | |
image1 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[0])) | |
image1_gr = gr.Image(image1, height=64, width=64) | |
string1 = "["+str(palette_of_the_design[0][0])+", "+ str(palette_of_the_design[0][1])+", " + str(palette_of_the_design[0][2])+"]" | |
color1_update = gr.Textbox(value=string1, min_width=64) | |
color1_button = gr.Button(value="Update Color 1", min_width=64) | |
with gr.Column(min_width=100): | |
image2 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[1])) | |
image2_gr = gr.Image(image2, height=64, width=64) | |
string2 = "["+str(palette_of_the_design[1][0])+", "+ str(palette_of_the_design[1][1])+", " + str(palette_of_the_design[1][2])+"]" | |
color2_update = gr.Textbox(value=string2, min_width=64) | |
color2_button = gr.Button(value="Update Color 2", min_width=64) | |
with gr.Column(min_width=100): | |
image3 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[2])) | |
image3_gr = gr.Image(image3, height=64, width=64) | |
string3 = "["+str(palette_of_the_design[2][0])+", "+ str(palette_of_the_design[2][1])+", " + str(palette_of_the_design[2][2])+"]" | |
color3_update = gr.Textbox(value=string3, min_width=64) | |
color3_button = gr.Button(value="Update Color 3", min_width=64) | |
with gr.Column(min_width=100): | |
image4 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[3])) | |
image4_gr = gr.Image(image4, height=64, width=64) | |
string4 = "["+str(palette_of_the_design[3][0])+", "+ str(palette_of_the_design[3][1])+", " + str(palette_of_the_design[3][2])+"]" | |
color4_update = gr.Textbox(value=string4, min_width=64) | |
color4_button = gr.Button(value="Update Color 4", min_width=64) | |
with gr.Column(min_width=100): | |
image5 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[4])) | |
image5_gr = gr.Image(image5, height=64, width=64) | |
string5 = "["+str(palette_of_the_design[4][0])+", "+ str(palette_of_the_design[4][1])+", " + str(palette_of_the_design[4][2])+"]" | |
color5_update = gr.Textbox(value=string5, min_width=64) | |
color5_button = gr.Button(value="Update Color 5", min_width=64) | |
with gr.Row(): | |
reset_button = gr.Button(value="Reset the palette", min_width=64) | |
zero = gr.Number(value=0, visible=False) | |
one = gr.Number(value=1, visible=False) | |
two = gr.Number(value=2, visible=False) | |
three = gr.Number(value=3, visible=False) | |
four = gr.Number(value=4, visible=False) | |
color1_button.click(fn=demo.update_color, inputs=[color1_update, zero], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) | |
color2_button.click(fn=demo.update_color, inputs=[color2_update, one], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) | |
color3_button.click(fn=demo.update_color, inputs=[color3_update, two], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) | |
color4_button.click(fn=demo.update_color, inputs=[color4_update, three], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) | |
color5_button.click(fn=demo.update_color, inputs=[color5_update, four], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) | |
reset_button.click(fn=perform_reset, inputs=[reset_button], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) | |
project_demo.launch() | |