from color_palette.regressor.config import config_to_use import numpy as np from PIL import Image, ImageFont, ImageDraw from sklearn.linear_model import LinearRegression from color_palette.model.GNN import ColorAttentionClassification from color_palette.regressor.model import Color2CubeDataset from color_palette.regressor.config import * from torch.utils.data import Dataset, DataLoader from color_palette.dataset import GraphDestijlDataset from color_palette.config import DataConfig import random import os import torch.nn.functional as F import torch import gradio as gr config = DataConfig() model_name = config.model_name dataset_root = config.dataset feature_size = config.feature_size device = config.device image_folder = "img_folder" if not os.path.exists(image_folder): os.mkdir(image_folder) def train_regressor(train_loader): X = [] y = [] for i, (input_data, target) in enumerate(train_loader): input_data = np.squeeze(input_data) target = np.squeeze(target) X.append(input_data) y.append(target) X = np.stack(X, axis=0) y = np.squeeze(np.stack(y, axis=0)) print("Before regressor train!\n") reg = LinearRegression().fit(X, y) return reg model_weight_path = "models/" + model_name + "/weights/best.pth" # palettes = np.load(config_to_use.save_folder+'/new_palettes_purple.npy') # original_palettes = np.load(config_to_use.save_folder+'/original_palettes_purple.npy') graph_test_dataset = GraphDestijlDataset(root=dataset_root, test=True, cube_mapping=True) model = ColorAttentionClassification(feature_size).to(device) model.load_state_dict(torch.load(model_weight_path)["state_dict"]) dataset = Color2CubeDataset(config=config_to_use) train_loader = DataLoader(dataset, batch_size=1, shuffle=False) regressor = train_regressor(train_loader=train_loader) palette_of_the_design = [[0, 0, 0] for i in range(5)] all_node_colors = None class Demo: def __init__(self, graph_dataset): self.dataset = graph_dataset first_sample_idx = random.randint(0, len(self.dataset)-1) self.input_data, self.target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx) global all_node_colors all_node_colors = also_normal_values self.same_indices = None self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True) def demo_reset(self): first_sample_idx = random.randint(0, len(self.dataset)) self.input_data, self.target_color, node_to_mask, also_normal_values = self.dataset.get(first_sample_idx) global all_node_colors all_node_colors = also_normal_values self.generate_img_from_palette([color.detach().numpy()*255 for color in also_normal_values], is_first=True) def generate_img_from_palette(self, palette, canvas_size=512, is_first=False): palette = np.array(palette).astype('int') rgb_bg, rgb_text, rgb_text, rgb_circle, rgb_main_img, rgb_img1, rgb_img2, rgb_img3 = [tuple(color) for color in palette] if is_first: self.same_indices, unique_colors, _ = self.return_all_same_colors(palette=palette) else: _, unique_colors, _ = self.return_all_same_colors(palette=palette) # assign the current palette using global keyword global palette_of_the_design palette_of_the_design = unique_colors # Set the background color and create an empty PIL Image to fill with shapes and text image = Image.new("RGB", (canvas_size, canvas_size), color=rgb_bg) # Save background image title = "Lorem Ipsum Dolor" undertitle = "Neque porro quisquam est qui dolorem ipsum quia dolor sit amet, \n consectetur, adipisci velit..." draw = ImageDraw.Draw(image) # Set settings for the fonts font_title = ImageFont.truetype("Arial.ttf", 32) title_width, title_height = draw.textsize(title, font=font_title) title_x = (canvas_size - title_width) // 2 title_y = (canvas_size - title_height) // 2 - 100 font_undertitle = ImageFont.truetype("Arial.ttf", 15) text_width, text_height = draw.textsize(undertitle, font=font_undertitle) undertitle_x = (canvas_size - text_width) // 2 undertitle_y = (canvas_size - text_height) // 2 - 50 # Draw titles draw.text((title_x, title_y), title, fill=rgb_text, font=font_title) draw.text((undertitle_x, undertitle_y), undertitle, fill=rgb_text, font=font_undertitle) # Draw the circle rad = random.randint(30, 70) x = random.randint(400, 512-(rad+10)) y = random.randint(10, title_y-(rad+10)) draw.ellipse((x, y, x+rad, y+rad), fill=rgb_circle) # Draw the image for j, color in enumerate([rgb_main_img, rgb_img1, rgb_img2, rgb_img3]): x = 512-((j+1)*60) y = 512-((j+1)*60) if j == 0: rad = 80 draw.rectangle((x, y, x+rad, y+rad), fill=color) else: rad = 40 draw.rectangle((x, y, x+rad, y+rad), fill=color) image.save(os.path.join("deneme.png")) def run_model(self, input_data, target_color, node_to_mask, updated_color): global all_node_colors palette = np.array([color.detach().numpy()*255 for color in all_node_colors]).astype('int') same_indices_list, unique_colors, first_indices = self.return_all_same_colors(palette) unique_colors = unique_colors/255 selected_color = torch.Tensor(updated_color)/255 map_node_to_mask = -1 print("same indices list") print(self.same_indices) print("node to mask") print(node_to_mask) for i, idxs in enumerate(self.same_indices): if node_to_mask in idxs: map_node_to_mask = i print("map node to mask: ", i) for i, indices in enumerate(self.same_indices): if i == 0: # update the color [0.15, 0.4908, 0.73] cube_num_of_selected = self.rgb2cube(selected_color*255) one_hot = np.zeros((64,)) one_hot[int(cube_num_of_selected)] = 1.0 node_to_recommend = map_node_to_mask input_data.x[same_indices_list[map_node_to_mask], 4:] = torch.Tensor(one_hot) unique_colors[map_node_to_mask] = selected_color else: if i == map_node_to_mask: zeroth_bin = 0 indices = same_indices_list[zeroth_bin] node_to_recommend = 0 input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4)) node_to_mask = indices[0] else: node_to_recommend = i input_data.x[indices[0], 4:] = torch.zeros((input_data.x.shape[1]-4)) node_to_mask = indices[0] out = self.forward_pass(model, input_data) # input data has one-hot color features if torch.is_tensor(node_to_mask): node_to_mask = node_to_mask.item() values, values_indices = torch.topk(F.softmax(out[node_to_mask, :], dim=0), k=3, dim=0) # predict the color cube of the recommendation prediction = values_indices.detach().numpy()[2] # construct a palette using unique RGB palette and one-hot representation of the prediction cube. feature_vector = self.create_rgb_and_one_hot_cube_vector(unique_colors, prediction, node_to_recommend) # map cube to rgb color space using the regressor recommendation = regressor.predict(feature_vector)[0] # we now have the first set of recommendations. Now, we need to update the colors and input_data to propagate information. # update the color in the palette and run the algorithm for rest of the palette. # for that, first map the color to cube and convert to one_hot input_data, unique_colors = self.update_palette(input_data, unique_colors, recommendation, same_indices_list, node_to_recommend) # recursively do this here. # save the results. return np.array(unique_colors*255).astype(int) def rgb2cube(self, color): intervals = np.arange(0, 256, 256//4) cube_coordinates = [] for channel in color: i = 0 for j, value in enumerate(intervals): if value < channel: i = j cube_coordinates.append(i) cube_num = cube_coordinates[0]*1 + cube_coordinates[1]*4 + cube_coordinates[2]*4*4 return cube_num def cube2rgb(self, cube_num): """ Return the start of the ranges """ cube_num = int(cube_num) intervals = np.arange(0, 256, 256//4) coor2 = cube_num // 16 coor1 = (cube_num - coor2*4*4) // 4 coor0 = cube_num - coor2*4*4 - coor1*4 return [intervals[coor0], intervals[coor1], intervals[coor2]] def return_all_same_colors(self, palette): indices_list = [[],[],[],[],[]] unique_colors, first_indices = np.unique(palette, axis=0, return_index=True) unique_colors = np.array(unique_colors) all_colors = np.array(palette) for idx, color in enumerate(unique_colors): for node_num, element in enumerate(all_colors): if np.equal(color, element).all(): indices_list[idx].append(node_num) # these palettes and indices also include the masked color return indices_list, unique_colors, first_indices def update_palette(self, input_data, unique_rgb_palette, recommendation, indices_list, idx_to_idxs): # convert prediction to one-hot vector cube_num_of_the_changed_color = self.rgb2cube(recommendation*255) one_hot = np.zeros((64,)) one_hot[int(cube_num_of_the_changed_color)] = 1.0 # update the feature vector accordingly for all the same colors for idx in indices_list[idx_to_idxs]: input_data.x[idx, 4:] = torch.Tensor(one_hot) # update the unique color vector unique_rgb_palette[idx_to_idxs] = recommendation return input_data, unique_rgb_palette def create_rgb_and_one_hot_cube_vector(self, rgb_palette, cube_num, node_to_mask): one_hot = np.zeros((64,)) one_hot[int(cube_num)] = 1.0 removed_palette = np.delete(rgb_palette, node_to_mask, axis=0) feature_vector = np.concatenate((removed_palette.flatten(), one_hot), axis=0) return feature_vector.reshape(1, -1) def create_all_one_hot_vector(self, rgb_palette, cube_num, node_to_mask): one_hot = np.zeros((64,)) one_hot[int(cube_num)] = 1.0 removed_palette = np.delete(rgb_palette, node_to_mask, axis=0) new_input_data = [] for color in removed_palette: color_cube_num = self.rgb2cube(color*255) empty_arr = np.zeros((64,)) empty_arr[int(color_cube_num)] = 1.0 new_input_data.append(empty_arr) feature_vector = np.concatenate((np.array(new_input_data).flatten(), one_hot), axis=0) return feature_vector.reshape(1, -1) def forward_pass(self, model, data): model.eval() out = model(data.x, data.edge_index.long(), data.edge_weight) return out def rearrange_indices_list(self, indices_list, node_to_mask, unique_rgb_palette): # take the node_to_mask indices to the beginning of the list for i in range(len(indices_list)): if node_to_mask in indices_list[i]: index_to_pop = i idxs = indices_list.pop(index_to_pop) palette = unique_rgb_palette[index_to_pop] temp_palette = np.delete(unique_rgb_palette, index_to_pop, axis=0) unique_rgb_palette = np.concatenate(([palette], temp_palette), axis=0) return [idxs] + indices_list, unique_rgb_palette def update_color(self, updated_color, idx): """ Takes a color and assigns it to the palette and the image. """ idx = int(idx) color = updated_color[1:-1].split(",") color = [int(num) for num in color] index_list = self.same_indices[idx] which_one = random.randint(0, len(index_list)-1) idx_to_update = index_list[which_one] unique_colors = self.run_model(self.input_data, self.target_color, idx_to_update, color) global palette_of_the_design palette_of_the_design = unique_colors global all_node_colors if torch.is_tensor(all_node_colors): all_node_colors = all_node_colors.detach().numpy() for i, index_list in enumerate(self.same_indices): for index in index_list: all_node_colors[index] = unique_colors[i] self.generate_img_from_palette(palette=[color for color in all_node_colors]) main_image = Image.open("deneme.png") gradio_elements = [] gradio_elements.append(gr.Image(main_image, height=256, width=256)) for i in range(len(self.same_indices)): color = unique_colors[i] image = Image.new("RGB", (512, 512), color=tuple(color)) gradio_elements.append(gr.Image(image, height=64, width=64)) string_version = "["+str(color[0])+", "+ str(color[1])+", " + str(color[2])+"]" gradio_elements.append(gr.Textbox(value=string_version, min_width=64)) all_node_colors = torch.Tensor(all_node_colors) / 255 return tuple(gradio_elements) def perform_reset(button_input): global demo global all_node_colors gradio_elements = [] demo.demo_reset() main_image = Image.open("deneme.png") gradio_elements = [] gradio_elements.append(gr.Image(main_image, height=256, width=256)) for color in palette_of_the_design: image = Image.new("RGB", (512, 512), color=tuple(color)) gradio_elements.append(gr.Image(image, height=64, width=64)) string_version = "["+str(color[0])+", "+ str(color[1])+", " + str(color[2])+"]" gradio_elements.append(gr.Textbox(value=string_version, min_width=64)) return tuple(gradio_elements) demo = Demo(graph_dataset=graph_test_dataset) # Form a gradio template to display images and update the colors. with gr.Blocks() as project_demo: with gr.Row(): image = Image.open("deneme.png") design = gr.Image(image, height=256, width=256) with gr.Row(): with gr.Column(min_width=100): image1 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[0])) image1_gr = gr.Image(image1, height=64, width=64) string1 = "["+str(palette_of_the_design[0][0])+", "+ str(palette_of_the_design[0][1])+", " + str(palette_of_the_design[0][2])+"]" color1_update = gr.Textbox(value=string1, min_width=64) color1_button = gr.Button(value="Update Color 1", min_width=64) with gr.Column(min_width=100): image2 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[1])) image2_gr = gr.Image(image2, height=64, width=64) string2 = "["+str(palette_of_the_design[1][0])+", "+ str(palette_of_the_design[1][1])+", " + str(palette_of_the_design[1][2])+"]" color2_update = gr.Textbox(value=string2, min_width=64) color2_button = gr.Button(value="Update Color 2", min_width=64) with gr.Column(min_width=100): image3 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[2])) image3_gr = gr.Image(image3, height=64, width=64) string3 = "["+str(palette_of_the_design[2][0])+", "+ str(palette_of_the_design[2][1])+", " + str(palette_of_the_design[2][2])+"]" color3_update = gr.Textbox(value=string3, min_width=64) color3_button = gr.Button(value="Update Color 3", min_width=64) with gr.Column(min_width=100): image4 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[3])) image4_gr = gr.Image(image4, height=64, width=64) string4 = "["+str(palette_of_the_design[3][0])+", "+ str(palette_of_the_design[3][1])+", " + str(palette_of_the_design[3][2])+"]" color4_update = gr.Textbox(value=string4, min_width=64) color4_button = gr.Button(value="Update Color 4", min_width=64) with gr.Column(min_width=100): image5 = Image.new("RGB", (512, 512), color=tuple(palette_of_the_design[4])) image5_gr = gr.Image(image5, height=64, width=64) string5 = "["+str(palette_of_the_design[4][0])+", "+ str(palette_of_the_design[4][1])+", " + str(palette_of_the_design[4][2])+"]" color5_update = gr.Textbox(value=string5, min_width=64) color5_button = gr.Button(value="Update Color 5", min_width=64) with gr.Row(): reset_button = gr.Button(value="Reset the palette", min_width=64) zero = gr.Number(value=0, visible=False) one = gr.Number(value=1, visible=False) two = gr.Number(value=2, visible=False) three = gr.Number(value=3, visible=False) four = gr.Number(value=4, visible=False) color1_button.click(fn=demo.update_color, inputs=[color1_update, zero], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) color2_button.click(fn=demo.update_color, inputs=[color2_update, one], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) color3_button.click(fn=demo.update_color, inputs=[color3_update, two], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) color4_button.click(fn=demo.update_color, inputs=[color4_update, three], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) color5_button.click(fn=demo.update_color, inputs=[color5_update, four], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) reset_button.click(fn=perform_reset, inputs=[reset_button], outputs=[design, image1_gr, color1_update, image2_gr, color2_update, image3_gr, color3_update, image4_gr, color4_update, image5_gr, color5_update]) project_demo.launch()