Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
import torch | |
from torchvision.transforms import InterpolationMode | |
BICUBIC = InterpolationMode.BICUBIC | |
from utils import setup, get_similarity_map,get_noun_phrase, rgb_to_hsv, hsv_to_rgb | |
from vpt.launch import default_argument_parser | |
from collections import OrderedDict | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import models | |
import string | |
import nltk | |
nltk.download('punkt') | |
nltk.download('averaged_perceptron_tagger') | |
from nltk.tokenize import word_tokenize | |
import torchvision | |
import spacy | |
# download the model | |
spacy.cli.download("en_core_web_sm") | |
# Load spaCy model | |
nlp = spacy.load("en_core_web_sm") | |
def extract_objects(prompt): | |
doc = nlp(prompt) | |
# Extract object nouns (including proper nouns and compound nouns) | |
objects = set() | |
for token in doc: | |
# Check if the token is a noun or part of a named entity | |
if token.pos_ in {"NOUN", "PROPN"} or token.ent_type_: | |
objects.add(token.text) | |
# Check if the token is part of a compound noun | |
if token.dep_ in {"compound"}: | |
objects.add(token.head.text) | |
return list(objects) | |
args = default_argument_parser().parse_args() | |
cfg = setup(args) | |
multi_classes = True | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
Ours, preprocess = models.load("CS-ViT-B/16", device=device, cfg=cfg, train_bool=False) | |
state_dict = torch.load("sketch_seg_best_miou.pth", map_location=device) | |
# Trained on 2 gpus so we need to remove the prefix "module." to test it on a single GPU | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k[7:] # remove `module.` | |
new_state_dict[name] = v | |
Ours.load_state_dict(new_state_dict) | |
Ours.eval() | |
print("Model loaded successfully") | |
def run(sketch, caption, threshold, seed): | |
# select a random seed between 1 and 10 for the color | |
color_seed = np.random.randint(0, 4) | |
# set the condidate classes here | |
caption = caption.replace('\n',' ') | |
classes = extract_objects(caption) | |
# translator = str.maketrans('', '', string.punctuation) | |
# caption = caption.translate(translator).lower() | |
# words = word_tokenize(caption) | |
# classes = get_noun_phrase(words) | |
# print(classes) | |
if len(classes) ==0 or multi_classes == False: | |
classes = [caption] | |
# print(classes) | |
colors = plt.get_cmap("Set1").colors | |
classes_colors = colors[color_seed:len(classes)+color_seed] | |
sketch2 = sketch['composite'] | |
# when the drawing tool is used | |
if sketch2[:,:,0:3].sum() == 0: | |
temp = sketch2[:,:,3] | |
# invert it | |
temp = 255 - temp | |
sketch2 = np.repeat(temp[:, :, np.newaxis], 3, axis=2) | |
temp2= np.full_like(temp, 255) | |
sketch2 = np.dstack((sketch2, temp2)) | |
sketch2 = np.array(sketch2) | |
pil_img = Image.fromarray(sketch2).convert('RGB') | |
sketch_tensor = preprocess(pil_img).unsqueeze(0).to(device) | |
# torchvision.utils.save_image(sketch_tensor, 'sketch_tensor.png') | |
with torch.no_grad(): | |
text_features = models.encode_text_with_prompt_ensemble(Ours, classes, device, no_module=True) | |
redundant_features = models.encode_text_with_prompt_ensemble(Ours, [""], device, no_module=True) | |
num_of_tokens = 3 | |
with torch.no_grad(): | |
sketch_features = Ours.encode_image(sketch_tensor, layers=[12], | |
text_features=text_features - redundant_features, mode="test").squeeze(0) | |
sketch_features = sketch_features / sketch_features.norm(dim=1, keepdim=True) | |
similarity = sketch_features @ (text_features - redundant_features).t() | |
patches_similarity = similarity[0, num_of_tokens + 1:, :] | |
pixel_similarity = get_similarity_map(patches_similarity.unsqueeze(0), pil_img.size).cpu() | |
# visualize_attention_maps_with_tokens(pixel_similarity, classes) | |
pixel_similarity[pixel_similarity < threshold] = 0 | |
pixel_similarity_array = pixel_similarity.cpu().numpy().transpose(2, 0, 1) | |
# display_segmented_sketch(pixel_similarity_array, sketch2, classes, classes_colors, live=True) | |
# Find the class index with the highest similarity for each pixel | |
class_indices = np.argmax(pixel_similarity_array, axis=0) | |
# Create an HSV image placeholder | |
hsv_image = np.zeros(class_indices.shape + (3,)) # Shape (512, 512, 3) | |
hsv_image[..., 2] = 1 # Set Value to 1 for a white base | |
# Set the hue and value channels | |
for i, color in enumerate(classes_colors): | |
rgb_color = np.array(color).reshape(1, 1, 3) | |
hsv_color = rgb_to_hsv(rgb_color) | |
mask = class_indices == i | |
if i < len(classes): # For the first N-2 classes, set color based on similarity | |
hsv_image[..., 0][mask] = hsv_color[0, 0, 0] # Hue | |
hsv_image[..., 1][mask] = pixel_similarity_array[i][mask] > 0 # Saturation | |
hsv_image[..., 2][mask] = pixel_similarity_array[i][mask] # Value | |
else: # For the last two classes, set pixels to black | |
hsv_image[..., 0][mask] = 0 # Hue doesn't matter for black | |
hsv_image[..., 1][mask] = 0 # Saturation set to 0 | |
hsv_image[..., 2][mask] = 0 # Value set to 0, making it black | |
mask_tensor_org = sketch2[:,:,0]/255 | |
hsv_image[mask_tensor_org>=0.5] = [0,0,1] | |
# Convert the HSV image back to RGB to display and save | |
rgb_image = hsv_to_rgb(hsv_image) | |
if len(classes) > 1: | |
# Calculate centroids and render class names | |
for i, class_name in enumerate(classes): | |
mask = class_indices == i | |
if np.any(mask): | |
y, x = np.nonzero(mask) | |
centroid_x, centroid_y = np.mean(x), np.mean(y) | |
plt.text(centroid_x, centroid_y, class_name, color=classes_colors[i], ha='center', va='center',fontsize=10, # color=classes_colors[i] | |
bbox=dict(facecolor='lightgrey', edgecolor='none', boxstyle='round,pad=0.2', alpha=0.8)) | |
# Display the image with class names | |
plt.imshow(rgb_image) | |
plt.axis('off') | |
plt.tight_layout() | |
# plt.savefig(f'poster_vis/{classes[0]}.png', bbox_inches='tight', pad_inches=0) | |
plt.savefig('output.png', bbox_inches='tight', pad_inches=0) | |
plt.close() | |
# rgb_image = Image.open(f'poster_vis/{classes[0]}.png') | |
rgb_image = Image.open('output.png') | |
return rgb_image | |
scripts = """ | |
async () => { | |
// START gallery format | |
// Get all image elements with the class "image" | |
var images = document.querySelectorAll('.image_gallery'); | |
var originalParent = document.querySelector('#component-0'); | |
// Create a new parent div element | |
var parentDiv = document.createElement('div'); | |
var beforeDiv= document.querySelector('.table-wrap').parentElement; | |
parentDiv.id = "gallery_container"; | |
// Loop through each image, append it to the parent div, and remove it from its original parent | |
images.forEach(function(image , index ) { | |
// Append the image to the parent div | |
parentDiv.appendChild(image); | |
// Add click event listener to each image | |
image.addEventListener('click', function() { | |
let nth_ch = index+1 | |
document.querySelector('.tr-body:nth-child(' + nth_ch + ')').click() | |
console.log('.tr-body:nth-child(' + nth_ch + ')'); | |
}); | |
// Remove the image from its original parent | |
}); | |
// Get a reference to the original parent of the images | |
var originalParent = document.querySelector('#component-0'); | |
// Append the new parent div to the original parent | |
originalParent.insertBefore(parentDiv, beforeDiv); | |
// END gallery format | |
// START confidence span | |
// Get the selected div (replace 'selectedDivId' with the actual ID of your div) | |
var selectedDiv = document.querySelector("label[for='range_id_0'] > span") | |
// Get the text content of the div | |
var textContent = selectedDiv.textContent; | |
// Find the text before the first colon ':' | |
var colonIndex = textContent.indexOf(':'); | |
var textBeforeColon = textContent.substring(0, colonIndex); | |
// Wrap the text before colon with a span element | |
var spanElement = document.createElement('span'); | |
spanElement.textContent = textBeforeColon; | |
// Replace the original text with the modified text containing the span | |
selectedDiv.innerHTML = textContent.replace(textBeforeColon, spanElement.outerHTML); | |
// START format the column names : | |
// Get all elements with the class "test_class" | |
var elements = document.querySelectorAll('.tr-head > th'); | |
// Iterate over each element | |
elements.forEach(function(element) { | |
// Get the text content of the element | |
var text = element.textContent.trim(); | |
// Remove ":" from the text | |
var wordWithoutColon = text.replace(':', ''); | |
// Split the text into words | |
var words = wordWithoutColon.split(' '); | |
// Keep only the first word | |
var firstWord = words[0]; | |
// Set the text content of the element to the first word | |
element.textContent = firstWord; | |
}); | |
document.querySelector('input[type=number]').disabled = true; | |
} | |
""" | |
css=""" | |
gradio-app { | |
background-color: white !important; | |
} | |
.white-bg { | |
background-color: white !important; | |
} | |
.gray-border { | |
border: 1px solid dimgrey !important; | |
} | |
.border-radius { | |
border-radius: 8px !important; | |
} | |
.black-text { | |
color : black !important; | |
} | |
th { | |
color : black !important; | |
} | |
tr { | |
background-color: white !important; | |
color: black !important; | |
} | |
td { | |
border-bottom : 1px solid black !important; | |
} | |
label[data-testid="block-label"] { | |
background: white; | |
color: black; | |
font-weight: bold; | |
} | |
.controls-wrap button:disabled { | |
color: gray !important; | |
background-color: white !important; | |
} | |
.controls-wrap button:not(:disabled) { | |
color: black !important; | |
background-color: white !important; | |
} | |
.source-wrap button { | |
color: black !important; | |
} | |
.toolbar-wrap button { | |
color: black !important; | |
} | |
.empty.wrap { | |
color: black !important; | |
} | |
textarea { | |
background-color : #f7f9f8 !important; | |
color : #afb0b1 !important | |
} | |
input[data-testid="number-input"] { | |
background-color : #f7f9f8 !important; | |
color : black !important | |
} | |
tr > th { | |
border-bottom : 1px solid black !important; | |
} | |
tr:hover { | |
background: #f7f9f8 !important; | |
} | |
#component-19{ | |
justify-content: center !important; | |
} | |
#component-19 > button { | |
flex: none !important; | |
background-color : black !important; | |
font-weight: bold !important; | |
} | |
.bold { | |
font-weight: bold !important; | |
} | |
span[data-testid="block-info"]{ | |
color: black !important; | |
font-weight: bold !important; | |
} | |
#component-14 > div { | |
background-color : white !important; | |
} | |
button[aria-label="Clear"] { | |
background-color : white !important; | |
color: black !important; | |
} | |
#gallery_container { | |
display: flex; | |
flex-wrap: wrap; | |
justify-content: start; | |
} | |
.image_gallery { | |
margin-bottom: 1rem; | |
margin-right: 1rem; | |
} | |
label[for='range_id_0'] > span > span { | |
text-decoration: underline; | |
} | |
label[for='range_id_0'] > span > span { | |
font-size: normal !important; | |
} | |
.underline { | |
text-decoration: underline; | |
} | |
.mt-mb-1{ | |
margin-top: 1rem; | |
margin-bottom: 1rem; | |
} | |
#gallery_container + div { | |
visibility: hidden; | |
height: 10px; | |
} | |
input[type=number][disabled] { | |
background-color: rgb(247, 249, 248) !important; | |
color: black !important; | |
-webkit-text-fill-color: black !important; | |
} | |
#component-13 { | |
display: flex; | |
flex-direction: column; | |
align-items: center; | |
} | |
""" | |
with gr.Blocks(js=scripts, css=css, theme='gstaff/xkcd') as demo: | |
gr.HTML("<h1 class='black-text' style='text-align: center;'>Open Vocabulary Scene Sketch Semantic Understanding</div>") | |
gr.HTML("<div class='black-text'></div>") | |
# gr.HTML("<div class='black-text' style='text-align: center;'><a href='https://ahmedbourouis.github.io/ahmed-bourouis/'>Ahmed Bourouis</a>,<a href='https://profiles.stanford.edu/judith-fan'>Judith Ellen Fan</a>, <a href='https://yulia.gryaditskaya.com/'>Yulia Gryaditskaya</a></div>") | |
gr.HTML("<div class='black-text' style='text-align: center;'>Ahmed Bourouis, Judith Ellen Fan, Yulia Gryaditskaya</div>") | |
gr.HTML("<div class='black-text' style='text-align: center;' >CVPR, 2024</p>") | |
gr.HTML("<div style='text-align: center;'><p><a href='https://ahmedbourouis.github.io/Scene_Sketch_Segmentation/'>Project page</a></p></div>") | |
# gr.Markdown( "Scene Sketch Semantic Segmentation.", elem_classes=["black-txt" , "h1"] ) | |
# gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] ) | |
# gr.Markdown( "Open Vocabulary Scene Sketch Semantic Understanding", elem_classes=["black-txt" , "p"] ) | |
# gr.Markdown( "") | |
with gr.Row(): | |
with gr.Column(): | |
# in_image = gr.Image( label="Sketch", type="pil", sources="upload" , height=512 ) | |
in_canvas_image = gr.Sketchpad( | |
# value=Image.new('RGB', (512, 512), color=(255, 255, 255)), | |
brush=gr.Brush(colors=["#000000"], color_mode="fixed" , default_size=2), | |
image_mode="RGBA",elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] , | |
label="Sketch" , canvas_size=(512,512) ,sources=['upload'], | |
interactive=True , layers= False, transforms=[] | |
) | |
query_selector = 'button[aria-label="Upload button"]' | |
# with gr.Row(): | |
# segment_btn.click(fn=run, inputs=[in_image, in_textbox, in_slider], outputs=[out_image]) | |
upload_draw_btn = gr.HTML(f""" | |
<div id="upload_draw_group" class="svelte-15lo0d8 stretch"> | |
<button class="sm black-text white-bg gray-border border-radius own-shadow svelte-cmf5ev bold" id="upload_btn" onclick="return document.querySelector('.source-wrap button').click()"> Upload a new sketch</button> | |
<button class="sm black-text white-bg gray-border border-radius own-shadow svelte-cmf5ev bold" id="draw_btn" onclick="return document.querySelector('.controls-wrap button:nth-child(3)').click()"> Draw a new sketch</button> | |
</div> | |
""") | |
# in_textbox = gr.Textbox( lines=2, elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,label="Caption your Sketch!", placeholder="Include the categories that you want the AI to segment. \n e.g. 'giraffe, clouds' or 'a boy flying a kite' ") | |
with gr.Column(): | |
out_image = gr.Image( value=Image.new('RGB', (512, 512), color=(255, 255, 255)), | |
elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] , | |
type="pil", label="Segmented Sketch" ) #, height=512, width=512) | |
# # gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Confidence:</span> Adjust AI agent confidence in guessing categories </div>") | |
# in_slider = gr.Slider(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] , | |
# info="Adjust AI agent confidence in guessing categories", | |
# label="Confidence:", | |
# value=0.5 , interactive=True, step=0.05, minimum=0, maximum=1) | |
with gr.Row(): | |
with gr.Column(): | |
in_textbox = gr.Textbox( lines=2, elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] ,label="Caption your Sketch!", placeholder="Include the categories that you want the AI to segment. \n e.g. 'giraffe, clouds' or 'a boy flying a kite' ") | |
with gr.Column(): | |
# gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Confidence:</span> Adjust AI agent confidence in guessing categories </div>") | |
in_slider = gr.Slider(elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" ] , | |
info="Adjust AI agent confidence in guessing categories", | |
label="Confidence:", | |
value=0.5 , interactive=True, step=0.05, minimum=0, maximum=1) | |
with gr.Row(): | |
segment_btn = gr.Button( 'Segment it¹ !' , elem_classes=["white-bg", "gray-border" , "border-radius" ,"own-shadow" , 'bold' , 'mt-mb-1' ] , size="sm") | |
segment_btn.click(fn=run, inputs=[in_canvas_image , in_textbox , in_slider ], outputs=[out_image]) | |
gallery_label = gr.HTML("<h3 class='black-text'> <span class='black-text underline'>Gallery:</span> <span style='color: grey;'>you can click on any of the example sketches below to start segmenting them (or even drawing over them)</span> </div>") | |
gallery= gr.HTML(f""" | |
<div> | |
{gr.Image( elem_classes=["image_gallery"] , label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_1.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_2.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/sketch_3.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004068.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000004546.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000005076.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000006336.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000011766.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024458.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000024931.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000034214.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000260974.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000268340.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000305414.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000484246.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000549338.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000038116.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000221509.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000246066.png', height=200, width=200)} | |
{gr.Image( elem_classes=["image_gallery"] ,label="Sketch", show_download_button=False, show_label=False, type="pil", value='demo/000000001611.png', height=200, width=200)} | |
</div> | |
""") | |
examples = gr.Examples( | |
examples_per_page=30, | |
examples=[ | |
['demo/sketch_1.png', 'giraffe looking at you', 0.6], | |
['demo/sketch_2.png', 'a kite flying in the sky', 0.6], | |
['demo/sketch_3.png', 'a girl playing', 0.6], | |
['demo/000000004068.png', 'car going so fast', 0.6], | |
['demo/000000004546.png', 'mountains in the background', 0.6], | |
['demo/000000005076.png', 'huge tree', 0.6], | |
['demo/000000006336.png', 'nice three sheeps', 0.6], | |
['demo/000000011766.png', 'bird minding its own business', 0.6], | |
['demo/000000024458.png', 'horse with a mask on', 0.6], | |
['demo/000000024931.png', 'some random person', 0.6], | |
['demo/000000034214.png', 'a cool kid on a skateboard', 0.6], | |
['demo/000000260974.png', 'the chair on the left', 0.6], | |
['demo/000000268340.png', 'stop sign', 0.6], | |
['demo/000000305414.png', 'a lonely elephant roaming around', 0.6], | |
['demo/000000484246.png', 'giraffe with a loong neck', 0.6], | |
['demo/000000549338.png', 'two donkeys trying to be smart', 0.6], | |
['demo/000000038116.png', 'a bat next to a kid', 0.6], | |
['demo/000000221509.png', 'funny looking cow', 0.6], | |
['demo/000000246066.png', 'bench in the park', 0.6], | |
['demo/000000001611.png', 'trees in the background', 0.6] | |
], | |
inputs=[in_canvas_image, in_textbox , in_slider], | |
fn=run, | |
# cache_examples=True, | |
) | |
gr.HTML("<h5 class='black-text' style='text-align: left;'>¹This demo runs on a basic 2 vCPU. For instant segmentation, use a commercial Nvidia RTX 3090 GPU.</h5>") | |
gr.HTML("<h5 class='black-text' style='text-align: left;'>¹We compare the entire caption to the scene sketch and threshold most similar pixels, without extracting individual classes.</h5>") | |
demo.launch(share=False) | |