Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
import torch.nn.functional as F | |
import numpy as np | |
import pickle | |
import json | |
import requests | |
from transformers import CLIPProcessor, AutoModelForSemanticSegmentation, AutoFeatureExtractor, CLIPModel | |
from torch import nn | |
import io | |
# Initialize the models using huggingface | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the CLIP model from hugging face | |
clip_hg = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device).eval() | |
processor_hg = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
seg_hg = AutoModelForSemanticSegmentation.from_pretrained('mattmdjaga/segformer_b2_clothes').to(device).eval() | |
extractor_hg = AutoFeatureExtractor.from_pretrained('mattmdjaga/segformer_b2_clothes', reduce_labels=False) | |
# Load the data and normalize the embeddings just in case. | |
features = torch.load('features.pt').to(device) | |
features_main = F.normalize(features) | |
item_embeddings = torch.load('item_embeds.pt').to(device) | |
item_embeddings = F.normalize(item_embeddings) | |
url_list_main = pickle.load(open('new_url_list.pt','rb')) | |
clothes_tree = json.load(open('clothes_tree_new_data.json')) | |
rec_dic = json.load(open('top5_mini_new.json')) | |
# URL for an image if no image is selected | |
url = 'https://bitsofco.de/content/images/2018/12/Screenshot-2018-12-16-at-21.06.29.png' | |
# Set up all the variables | |
label = ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', | |
'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'] | |
clothing_type = ['top', 'bottom', 'dress'] | |
top_type = ['t-shirt', 'tank top', 'blouse', 'sweater', 'hoodie', 'cardigan','turtleneck','blazer','polo','collar shirt','knitwear', | |
'tuxedo', 'Compression top','duffle coat', 'peacoat', 'long coat', 'trench coat', | |
'biker jacket', 'blazer', 'bomber jacket', 'hooded jacket', 'leather jacket', 'military jacket', 'down jacket', 'shirt jacket', | |
'suit jacket', 'dinner jacket', 'gillet', 'track jacket' | |
] | |
bottom_type = ['skirt', 'leggings', 'sweatpants', 'skinny pants', 'tailored pants', 'track pants', 'wide-leg pants' | |
, 'cargo shorts', 'denim shorts', 'track shorts', 'compression shorts', 'cycling shorts','denim pants', | |
'cargo pants', 'chino pants', 'chino shorts' | |
] | |
dress_type = ['casual dress', 'cocktail dress', 'evening dress', 'maxi dress', 'mini dress', 'party dress', 'sundress'] | |
styles = ['plain','polka dot','striped','floral','checkered','zebra print','leopard print','plaid','paisley'] | |
colors = ['blue','red','pink','orange','yellow','purple','gold','white','off white','black','grey','green','brown','beige','cream','navy','maroon'] | |
top_list = [f"{t}, {color}, {style}" for t in top_type for style in styles for color in colors] | |
bottom_list = [f"{t}, {color}, {style}" for t in bottom_type for style in styles for color in colors] | |
dress_list = [f"{t}, {color}, {style}" for t in dress_type for style in styles for color in colors] | |
all_items = top_list + bottom_list + dress_list | |
clothing_type = ['top', 'bottom', 'dress'] | |
all_types = {'top' :top_type, | |
'bottom' : bottom_type, | |
'dress':dress_type} | |
patterns_list = styles.copy() | |
colors_list = colors.copy() | |
clicks = 0 | |
c_types = [] | |
types = [] | |
colors = [] | |
patterns = [] | |
new_files = [] | |
out = [] | |
clothes_click = 0 | |
global_mask = None | |
mask_choice = 'Clothes' | |
# Define all needed functions | |
def find_closest(target_feature, features): | |
''' | |
Purpose: Find the closest embedding to the given image embedding | |
Inputs: | |
target_feature (tenosr): embedding of our search item | |
features (tensor): embedding of all the items in the dataset | |
Outputs: | |
group_sorted_indices (list): indicies of the closest items in a sorted order | |
''' | |
cos_similarity = features.to(torch.float32) @ target_feature.to(torch.float32).T | |
group_sorted_indices = torch.argsort(cos_similarity, descending=True,dim=0).squeeze(1).cpu().tolist() | |
return group_sorted_indices | |
def filter_function(choices): | |
''' | |
Purpose: Find a list of items that match the given filters | |
Inputs: | |
choices (list): list of filters | |
Outputs: | |
Upating the choices of filters | |
''' | |
# Import the global variables | |
global clicks | |
global c_types | |
global types | |
global colors | |
global patterns | |
global new_files | |
new_choices = [] | |
# Clicks is just a reference to how far we are in the filter tree | |
# We keep going down and saving the selected options until we reach the end | |
# Then we add items which had the desired filters to the new_choices list | |
# This is then used to filter out the items that don't match the filters in search | |
if clicks == 0: | |
temp_choices = [choice for choice in choices if choice in clothing_type] | |
if len(temp_choices) == 0: | |
temp_choices = clothing_type | |
for choice in temp_choices: | |
c_types.append(choice) | |
new_choices.extend(list(clothes_tree[choice].keys())) | |
if clicks == 1: | |
temp_choices = [choice for c_type in c_types for choice in choices if choice in all_types[c_type]] | |
if len(temp_choices) == 0: | |
types = [] | |
for c_type in c_types: | |
types.extend([(t,c_type) for t in clothes_tree[c_type].keys()]) | |
for choice in temp_choices: | |
if choice in clothes_tree['top']: | |
types.append((choice,'top')) | |
elif choice in clothes_tree['bottom']: | |
types.append((choice,'bottom')) | |
else : | |
types.append((choice,'dress')) | |
new_choices = list(clothes_tree['top']['t-shirt'].keys()) | |
if clicks == 2: | |
temp_choices = [choice for choice in choices if choice in colors_list] | |
if len(temp_choices) == 0: | |
colors = colors_list.copy() | |
for choice in temp_choices: | |
colors.append(choice) | |
new_choices = list(clothes_tree['top']['t-shirt']['red'].keys()) | |
if clicks == 3: | |
temp_choices = [choice for choice in choices if choice in patterns_list] | |
if len(temp_choices) == 0: | |
patterns = patterns_list.copy() | |
for choice in temp_choices: | |
patterns.append(choice) | |
for type_,c_type in types: | |
for color in colors: | |
for pattern in patterns: | |
new_files.extend(clothes_tree[c_type][type_][color][pattern]) | |
clicks += 1 | |
new_choices = ['Press Search to use the set filter. Dont press this button'] | |
return gr.update(choices=new_choices, label='Press Search to use the filter or press filter to reset the filter') | |
if clicks == 4: | |
c_types.clear() | |
types.clear() | |
colors.clear() | |
patterns.clear() | |
new_files.clear() | |
clicks = 0 | |
new_choices = ['top','bottom','dress'] | |
return gr.update(choices=new_choices,label='Select the type of clothing you want to search for') | |
clicks += 1 | |
return gr.update(choices=new_choices) | |
def set_theme(theme): | |
''' | |
Purpose: Set the theme using filters | |
Inputs: | |
theme (string): theme to be set | |
Outputs: | |
Upadting to show the chosen theme | |
''' | |
global new_files | |
new_files.clear() | |
# Here we just manually set the filters to the desired theme | |
# Then we just find images with the desired filters | |
if theme == 'Red carpet': | |
types = [('evening dress','dress'), ('tuxedo','top'), ('suit jacket','top'), ('dinner jacket','top'),('maxi dress','dress')] | |
colors = ['red','purple','gold','white','off white','black','beige','cream','navy','maroon'] | |
patterns = ['plain'] | |
elif theme == 'Sports': | |
types = [ ('track shorts','bottom'), ('track pants','bottom'), ('track jacket','top'), | |
('Compression top','top'),('cycling shorts','bottom'),('compression shorts','bottom'),('tank top','top')] | |
colors = colors_list.copy() | |
patterns = patterns_list.copy()# | |
elif theme =='My preference': | |
types = [('evening dress','dress'), ('tuxedo','top'), ('suit jacket','top'), ('dinner jacket','top'),('maxi dress','dress')] | |
colors = ['red','purple','gold'] | |
patterns = ['plain','zebra print'] | |
else: | |
return gr.update(label='Chosen theme: None') | |
for type_,c_type in types: | |
for color in colors: | |
for pattern in patterns: | |
new_files.extend(clothes_tree[c_type][type_][color][pattern]) | |
return gr.update(label='Chosen theme: '+theme) | |
def segment(img): | |
''' | |
Purpose: Segment the image to get the mask | |
Inputs: | |
img(pil image): image to be segmented | |
Outputs: | |
img(pil image): original image | |
arr(numpy array): array of image | |
pred_seg(tensor): mask | |
''' | |
# Get the segmentation mask then umsample it to the original size | |
encoding = extractor_hg(img.convert('RGB'), return_tensors="pt") | |
pixel_values = encoding.pixel_values.to(device) | |
outputs = seg_hg(pixel_values=pixel_values) | |
logits = outputs.logits.cpu() | |
upsampled_logits = nn.functional.interpolate( | |
logits, | |
size=img.size[::-1], | |
mode="bilinear", | |
align_corners=False, | |
) | |
pred_seg = upsampled_logits.argmax(dim=1)[0] | |
arr_img = np.array(img) | |
return img, arr_img, pred_seg | |
def clean_img(img): | |
''' | |
Purpose: Clean the image to remove the chosen items | |
Inputs: | |
img(numpy array): image to be cleaned | |
Outputs: | |
img(numpy array): cleaned image | |
''' | |
# Here we remove pixels whihc are not in our desired class | |
global global_mask | |
global mask_choice | |
bad = [] | |
mask_size = global_mask.shape | |
img_size = img.shape[:2] | |
if img_size != mask_size: | |
return img | |
if mask_choice=='Person': | |
bad.append(0) | |
elif mask_choice=='Clothes': | |
bad.extend([0,2,15,14,13,12,11]) | |
elif mask_choice=='Upper Body/Dress': | |
bad.extend([0,5,6,9,10,12,13,16]) | |
elif mask_choice=='Lower Body': | |
bad.extend([0,1,2,3,4,7,8,11,14,15,16]) | |
elif mask_choice=='Upper Body/Dress, no person': | |
bad.extend([0,1,2,15,11,14,5,6,9,10,12,13,16,3]) | |
for i in bad: | |
global_mask[global_mask==i] = 50 | |
img[global_mask==50] = 255 | |
return img | |
def label_to_rec_lables (label): | |
''' | |
Purpose: Use the label to get the corresponding reccomendation labels | |
Inputs: | |
label(string): label of the image | |
Outputs: | |
rec_labels(list): list of reccomendation labels | |
''' | |
# This function is used to get the reccomendation labels to then | |
# filter the reccomendation search to them | |
labels = label.split(',') | |
new_label = rec_dic[','.join(labels[:2])] | |
print('Reccomendation label: ',new_label) | |
n = 5 if len(new_label) >= 5 else len(new_label) | |
labels = [] | |
labels = [new_label[i][0].split(',') for i in range(n)] | |
chosen = [] | |
c_types = ['top','bottom','dress'] | |
for item in labels: | |
label_type = item[0] | |
label_color = item[1].strip() | |
for c_type in c_types: | |
if label_type in all_types[c_type]: | |
item_type = c_type | |
chosen.append([item_type,label_type,label_color]) | |
print('Chosen: ',chosen) | |
return chosen | |
def filter_features(labels, rec=False, rec_items=None): | |
''' | |
Purpose: Filter the features to only contain the chosen label | |
Inputs: | |
labels(str): label string | |
rec(bool): if the function is called from the recommendation function | |
rec_items(list): list containing the label info | |
Outputs: | |
url_list(list): list of urls after filtering | |
features(tensor): features after filtering | |
''' | |
global url_list_main | |
global features_main | |
# Here we filter the features to only contain the desired labels and | |
# also provide the new url list | |
labels = labels.split(',') | |
label_type = labels[0] | |
label_color = labels[1].strip() | |
c_types = ['top', 'bottom', 'dress'] | |
for c_type in c_types: | |
if label_type in all_types[c_type]: | |
item_type = c_type | |
new_list = set() | |
if rec: | |
item_type = rec_items[0] | |
label_type = rec_items[1] | |
label_color = rec_items[2] | |
for pattern in patterns_list: | |
new_list.update(clothes_tree[item_type][label_type][label_color][pattern]) | |
else: | |
#for color in colors_list: | |
color = label_color | |
for pattern in patterns_list: | |
new_list.update(clothes_tree[item_type][label_type][color][pattern]) | |
new_files = list(new_list) | |
temp_url = [] | |
temp_features = torch.zeros(len(new_files), 512).to(device) | |
for c,i in enumerate(new_files): | |
temp_url.append(url_list_main[i]) | |
temp_features[c] = features_main[i] | |
url_list = temp_url | |
features = temp_features.to(torch.float32) | |
return url_list, features | |
def get_image_from_url(idx,url_list,items=5): | |
''' | |
Purpose: Get a list of images from the url list using the indecies | |
Inputs: | |
idx(list): list of indecies | |
url_list(list): list of urls | |
items(int): number of images to return | |
Outputs: | |
images(list): list of images | |
''' | |
# Looping until we have the desired number of images | |
res = [] | |
i = 0 | |
n = 15 if len(idx) > 15 else len(idx) | |
while len(res) != items and i != n: | |
try: | |
req = requests.get(url_list[idx[i]],stream=True,timeout=5) | |
img = Image.open(req.raw).convert('RGB') | |
img = np.array(img) | |
res.append(img) | |
i += 1 | |
except: | |
print('Error with: ' + url_list[i]) | |
i += 1 | |
continue | |
return res | |
def get_label(img): | |
''' | |
Purpose: Get the label of the image | |
Inputs: | |
img(numpy array or pil image): image to get label of | |
Outputs: | |
label(string): label of the image | |
''' | |
img_features = processor_hg(images=img, return_tensors="pt", padding=True).to(device) | |
with torch.no_grad(): | |
img_features = clip_hg.get_image_features(**img_features) | |
idx = find_closest(img_features,item_embeddings)[0] | |
label = all_items[idx] | |
return label | |
def resize_img(img,thresh=384): | |
''' | |
Purpose: Resize the image to have the largest dimension be thresh | |
Inputs: | |
img(pil image): image to resize | |
thresh(int): threshold for the largest dimension | |
Outputs: | |
img(pil image): resized image | |
''' | |
size = img.size | |
larger_dim = 0 if size[0] > size[1] else 1 | |
if size[larger_dim] > thresh: | |
size = (int(size[0] * thresh / size[larger_dim]), int(size[1] * thresh / size[larger_dim])) | |
img = img.resize(size) | |
return img | |
def segment_function(choice): | |
''' | |
Purpose: Set the mask choice so that it can be called during search | |
Inputs: | |
choice(string): mask choice | |
Outputs: | |
None | |
''' | |
global mask_choice | |
mask_choice = choice | |
return gr.update(label =f'Selection: {choice}') | |
def rec_function(option): | |
''' | |
Purpose: using an image to get a reccomendation return that image and the reccomendations | |
Inputs: | |
option(int): option to use | |
Outputs: | |
rec_out(list): list of images | |
temp_out(numpy array): choice image | |
''' | |
global out | |
global url_list_main | |
global features_main | |
# Here we get the items which should be reccomended based on the | |
# chosen image. Then we find the closest items to the chosen image | |
# out of the reccomended items. Finally we crop the images so that | |
# we only see the reccomended items in the output | |
if not out: | |
req = requests.get(url,stream=True) | |
img = np.array(Image.open(req.raw).convert('RGB')) | |
rec_out = [img]*5 | |
return rec_out | |
img = Image.fromarray(out[option]) | |
choice_img = resize_img(img) | |
label = get_label(choice_img) | |
target_labels = label_to_rec_lables(label) | |
temp_out = [] | |
img_features = processor_hg(images=choice_img, return_tensors="pt", padding=True).to(device) | |
with torch.no_grad(): | |
img_features = clip_hg.get_image_features(**img_features) | |
n = len(target_labels) | |
if n == 1: | |
return_items = 5 | |
elif n == 2: | |
return_items = 3 | |
elif n == 3: | |
return_items = 2 | |
else: | |
return_items = 1 | |
for item in target_labels: | |
url_list, features = filter_features(label, rec=True, rec_items=item) | |
idx = find_closest(img_features, features)[:5] | |
temp_out.extend(get_image_from_url(idx,url_list,items=return_items)) | |
rec_out = [] | |
for temp_img in temp_out: | |
temp_img = resize_img(Image.fromarray(temp_img)) | |
img, seg_img, out_mask = segment(temp_img) | |
label_type = label.split(',')[0].strip() | |
bad = [] | |
if label_type in top_type or label_type in dress_type: | |
bad.extend([0,1,2,3,4,7,8,11,14,15,16]) | |
elif label_type in bottom_type: | |
bad.extend([0,5,6,9,10,12,13,16]) | |
for i in bad: | |
out_mask[out_mask==i] = 50 | |
img = np.array(img) | |
img[out_mask==50] = 255 | |
h, w = img.shape[:2] | |
# find the highest and lowest y-coordinates where the pixel is not white | |
top = 0 | |
bottom = h | |
for i in range(h): | |
if np.all(img[i] == 255): | |
top = i | |
else: | |
break | |
for i in range(h-1, 0, -1): | |
if np.all(img[i] == 255): | |
bottom = i | |
else: | |
break | |
# find the highest and lowest x-coordinates where the pixel is not white | |
left = 0 | |
right = w | |
for i in range(w): | |
if np.all(img[:, i] == 255): | |
left = i | |
else: | |
break | |
for i in range(w-1, 0, -1): | |
if np.all(img[:, i] == 255): | |
right = i | |
else: | |
break | |
# crop the image | |
# add 10 pixels to the top and bottom if those are not the edges of the image | |
if top - 10 > 0: | |
top -= 10 | |
if bottom + 10 < h: | |
bottom += 10 | |
# add 10 pixels to the left and right if those are not the edges of the image | |
if left - 10 > 0: | |
left -= 10 | |
if right + 10 < w: | |
right += 10 | |
if top > bottom or right < left: | |
rec_out.append(temp_img) | |
else: | |
temp_img = np.array(temp_img) | |
img = temp_img[top:bottom, left:right] | |
rec_out.append(img) | |
temp_out = [choice_img] | |
return rec_out, temp_out | |
def reset_values(): | |
''' | |
Purpose: reset the values of the global variables | |
Inputs: | |
None | |
Outputs: | |
None | |
''' | |
global global_mask | |
global out | |
global mask_choice | |
global clicks | |
global c_types | |
global types | |
global colors | |
global patterns | |
global new_files | |
global_mask = None | |
out = None | |
mask_choice = None | |
clicks = 0 | |
c_types.clear() | |
types.clear() | |
colors.clear() | |
patterns.clear() | |
new_files.clear() | |
return [gr.update(choices=['top','bottom','dress'],value=[]),gr.update(choices=['Person','Clothes','Upper Body/Dress','Upper Body/Dress, no person','Lower Body'],value=None) | |
,gr.update(value=None), gr.update(value=[]),gr.update(value=[]),gr.update(value=0)] | |
def search_function(img, text, use_choice,use_label): | |
''' | |
Purpose: search for images based on the text input or image input | |
Inputs: | |
img(pil image): image input | |
text(string): text input | |
use_choice(boolean): Boolen to know if to use image or text | |
use_label(boolean): whether to use the label | |
Outputs: | |
out(list): list of images | |
''' | |
global new_files | |
global global_mask | |
global out | |
use_img = False | |
use_text = False | |
if use_choice == 'Use Image': | |
use_img = True | |
elif use_choice == 'Use Text': | |
use_text = True | |
if new_files: | |
global url_list_main | |
global features_main | |
temp_url = [] | |
new_files = list(set(new_files)) | |
temp_features = torch.zeros(len(new_files), 512).to(device) | |
for c,i in enumerate(new_files): | |
temp_url.append(url_list_main[i]) | |
temp_features[c] = features_main[i] | |
url_list = temp_url | |
features = temp_features.to(torch.float32) | |
else: | |
features = features_main.clone() | |
url_list = url_list_main.copy() | |
if use_text and not use_img: | |
text_features = processor_hg(text=text, return_tensors="pt", padding=True).to(device) | |
with torch.no_grad(): | |
text_features = clip_hg.get_text_features(**text_features) | |
idx = find_closest(text_features, features)[:15] | |
out = get_image_from_url(idx,url_list) | |
else : | |
if not isinstance(global_mask,type(None)): | |
seg_img = clean_img(img) | |
else: | |
seg_img = img | |
img = Image.fromarray(seg_img) | |
label = get_label(img) | |
print(label) | |
if not new_files and use_label: | |
url_list, features = filter_features(label) | |
img_features = processor_hg(images=img, return_tensors="pt", padding=True).to(device) | |
with torch.no_grad(): | |
img_features = clip_hg.get_image_features(**img_features) | |
idx = find_closest(img_features, features)[:15] | |
out = get_image_from_url(idx,url_list) | |
if use_img: | |
out.pop() | |
out.insert(0, seg_img) | |
return out | |
def search(img,text, choice,use_label,rotation): | |
global global_mask | |
try: | |
img = Image.fromarray(img).convert('RGB') | |
except: | |
img = Image.open(requests.get(url, stream=True).raw).convert('RGB') | |
img = img.rotate(rotation) | |
img = resize_img(img) | |
pil, img, out_mask = segment(img) | |
global_mask = out_mask | |
res = search_function(img, text, choice,use_label) | |
return res | |
# Define the app layout | |
with gr.Blocks() as demo: | |
gr.Markdown("Search using image segmentation") | |
with gr.Tab("Search"): | |
with gr.Row(): | |
search_image = gr.Image() | |
search_input = [search_image,gr.Textbox(lines=2, label="Search Text")] | |
with gr.Column(): | |
search_type = gr.Radio(choices=['Use Image','Use Text'],label='Select the type of search you want to perform',value='Use Image') | |
use_label = gr.Checkbox(label="Use Label",value=True) | |
image_output = [gr.Gallery(label='Outputs')] | |
rec_out = [gr.Gallery(label='Recommendations',interactive=True)] | |
with gr.Row(): | |
rec_selector = gr.Radio(label='Select which item you want a recommendation for',choices = [1,2,3,4],value=1) | |
rec_button = gr.Button("Get Recommendation") | |
with gr.Row(): | |
clothes_selector = gr.Radio(label='Choose a segmentation', | |
choices=['Person','Clothes','Upper Body/Dress','Upper Body/Dress, no person','Lower Body'],interactive=True) | |
theme_radio = gr.Radio(label='Choose a theme',choices=['None','Red carpet','Sports'],interactive=True) | |
rotation_radio = gr.Radio(label='Choose a rotation',choices=[0,90,180,270],interactive=True,value=0) | |
with gr.Row(): | |
filter_checkbox = gr.CheckboxGroup(label='Choose the clothing types', choices=['top','bottom','dress'],interactive=True,value=['top']) | |
filter_button = gr.Button("Filter Button") | |
search_button = gr.Button("Search Button") | |
clothes_selector.change(segment_function,inputs=[clothes_selector],outputs=clothes_selector) | |
search_image.change(reset_values, inputs=None, outputs=[filter_checkbox,clothes_selector,theme_radio,image_output[0],rec_out[0],rotation_radio]) | |
theme_radio.change(set_theme, inputs=theme_radio, outputs=theme_radio) | |
rec_button.click(rec_function, inputs=rec_selector, outputs=[rec_out[0],image_output[0]]) | |
filter_button.click(filter_function, inputs=filter_checkbox, outputs=filter_checkbox) | |
search_button.click(search, inputs=search_input+[search_type,use_label,rotation_radio], outputs=image_output) | |
demo.launch(share=False) |