import gradio as gr import pickle, os import pandas as pd import numpy as np import os from transformers import CLIPProcessor, CLIPModel from datasets import load_dataset from PIL import Image import requests from io import BytesIO model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip") processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip") # hf_token = os.environ.get("HF_API_TOKEN") # dataset = load_dataset('pratyush19/cyborg', use_auth_token=hf_token, split='train') # dir_path = "train/" # print (dataset) # print (dataset[0].keys()) with open('valid_images_sample.pkl','rb') as f: valid_images = pickle.load(f) with open('image_encodings_sample.pkl','rb') as f: image_encodings = pickle.load(f) valid_images = np.array(valid_images) with open('PIL_images.pkl','rb') as f: PIL_images = pickle.load(f) def softmax(x): e_x = np.exp(x - np.max(x, axis=1, keepdims=True)) return e_x / e_x.sum(axis=1, keepdims=True) def find_similar_images(caption, image_encodings): inputs = processor(text=[caption], return_tensors="pt") text_features = model.get_text_features(**inputs) text_features = text_features.detach().numpy() logits_per_image = softmax(np.dot(text_features, image_encodings.T)) return logits_per_image def find_relevant_images(caption): similarity_scores = find_similar_images(caption, image_encodings)[0] top_indices = np.argsort(similarity_scores)[::-1][:16] # top_path = valid_images[top_indices] images = [] for idx in top_indices: images.append(PIL_images[idx]) return images def gradio_interface(input_text): # with open("user_inputs.txt", "a") as file: # file.write(input_text + "\n") images = find_relevant_images(input_text) return images def clear_inputs(): return [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None] outputs = [None]*16 with gr.Blocks(title="MirrAI") as demo: gr.Markdown("

MirrAI: GenAI-based Fashion Search

") gr.Markdown("Enter a text to find the most relevant images from our dataset.") text_input = gr.Textbox(lines=1, label="Input Text", placeholder="Enter your text here...") with gr.Row(): cancel_button = gr.Button("Cancel") submit_button = gr.Button("Submit") examples = gr.Examples(["high-rise flare jean", "a-line dress with floral", "men colorful blazers", "jumpsuit with puffed sleeve", "sleeveless sweater", "floral shirt", "blue asymmetrical wedding dress with one sleeve", "women long coat", "cardigan sweater"], inputs=[text_input]) with gr.Row(): outputs[0] = gr.Image() outputs[1] = gr.Image() outputs[2] = gr.Image() outputs[3] = gr.Image() with gr.Row(): outputs[4] = gr.Image() outputs[5] = gr.Image() outputs[6] = gr.Image() outputs[7] = gr.Image() with gr.Row(): outputs[8] = gr.Image() outputs[9] = gr.Image() outputs[10] = gr.Image() outputs[11] = gr.Image() with gr.Row(): outputs[12] = gr.Image() outputs[13] = gr.Image() outputs[14] = gr.Image() outputs[15] = gr.Image() submit_button.click( fn=gradio_interface, inputs=text_input, outputs=outputs ) cancel_button.click( fn=clear_inputs, inputs=None, outputs=[text_input] + outputs ) demo.launch(share=True)