|
from transformers import AutoFeatureExtractor, AutoModel |
|
import torch |
|
from torchvision.transforms.functional import to_pil_image |
|
from einops import rearrange, reduce |
|
from skops import hub_utils |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import gradio as gr |
|
|
|
import os |
|
import pickle |
|
|
|
|
|
setups = ['ResNet-50', 'ViT', 'DINO-ResNet-50', 'DINO-ViT'] |
|
embedder_names = ['microsoft/resnet-50', 'google/vit-base-patch16-224', 'Ramos-Ramos/dino-resnet-50', 'facebook/dino-vitb16'] |
|
gam_names = ['emb-gam-resnet', 'emb-gam-vit', 'emb-gam-dino-resnet', 'emb-gam-dino'] |
|
|
|
embedder_to_setup = dict(zip(embedder_names, setups)) |
|
gam_to_setup = dict(zip(gam_names, setups)) |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
embedders = {} |
|
for name in embedder_names: |
|
embedder = {} |
|
embedder['feature_extractor'] = AutoFeatureExtractor.from_pretrained(name) |
|
embedder['model'] = AutoModel.from_pretrained(name).eval().to(device) |
|
|
|
if 'resnet-50' in name: |
|
embedder['num_patches_side'] = 7 |
|
embedder['embedding_postprocess'] = lambda x: rearrange(x.last_hidden_state, 'b d h w -> b (h w) d') |
|
else: |
|
embedder['num_patches_side'] = embedder['model'].config.image_size // embedder['model'].config.patch_size |
|
embedder['embedding_postprocess'] = lambda x: x.last_hidden_state[:, 1:] |
|
embedders[embedder_to_setup[name]] = embedder |
|
|
|
gams = {} |
|
for name in gam_names: |
|
if not os.path.exists(name): |
|
os.mkdir(name) |
|
hub_utils.download(repo_id=f'Ramos-Ramos/{name}', dst=name) |
|
|
|
with open(f'{name}/model.pkl', 'rb') as infile: |
|
gams[gam_to_setup[name]] = pickle.load(infile) |
|
|
|
labels = [ |
|
'tench', |
|
'English springer', |
|
'cassette player', |
|
'chain saw', |
|
'church', |
|
'French horn', |
|
'garbage truck', |
|
'gas pump', |
|
'golf ball', |
|
'parachute' |
|
] |
|
|
|
def visualize(input_img, visual_emb_gam_setups, show_scores, show_cbars): |
|
'''Visualizes the patch contributions to all labels of one or more visual |
|
Emb-GAMs''' |
|
|
|
if not visual_emb_gam_setups: |
|
fig = plt.Figure() |
|
return fig, fig |
|
|
|
patch_contributions = {} |
|
|
|
|
|
for setup in visual_emb_gam_setups: |
|
|
|
embedder_setup = embedders[setup] |
|
feature_extractor = embedder_setup['feature_extractor'] |
|
embedding_postprocess = embedder_setup['embedding_postprocess'] |
|
num_patches_side = embedder_setup['num_patches_side'] |
|
|
|
|
|
gam = gams[setup] |
|
|
|
|
|
inputs = { |
|
k: v.to(device) |
|
for k, v |
|
in feature_extractor(input_img, return_tensors='pt').items() |
|
} |
|
with torch.no_grad(): |
|
patch_embeddings = embedding_postprocess( |
|
embedder_setup['model'](**inputs) |
|
).cpu()[0] |
|
|
|
|
|
patch_contributions[setup] = ( |
|
gam.coef_ \ |
|
@ patch_embeddings.T.numpy() \ |
|
+ gam.intercept_.reshape(-1, 1) / (num_patches_side ** 2) |
|
).reshape(-1, num_patches_side, num_patches_side) |
|
|
|
|
|
|
|
multiple_setups = len(visual_emb_gam_setups) > 1 |
|
|
|
|
|
fig, axs = plt.subplots( |
|
len(visual_emb_gam_setups), |
|
11, |
|
figsize=(20, round(10/4 * len(visual_emb_gam_setups))) |
|
) |
|
gs_ax = axs[0, 0] if multiple_setups else axs[0] |
|
gs = gs_ax.get_gridspec() |
|
ax_rm = axs[:, 0] if multiple_setups else [axs[0]] |
|
for ax in ax_rm: |
|
ax.remove() |
|
ax_orig_img = fig.add_subplot(gs[:, 0] if multiple_setups else gs[0]) |
|
|
|
|
|
ax_orig_img.imshow(input_img) |
|
ax_orig_img.axis('off') |
|
|
|
|
|
axs_maps = axs[:, 1:] if multiple_setups else [axs[1:]] |
|
for i, setup in enumerate(visual_emb_gam_setups): |
|
vmin = patch_contributions[setup].min() |
|
vmax = patch_contributions[setup].max() |
|
for j in range(10): |
|
ax = axs_maps[i][j] |
|
sns.heatmap( |
|
patch_contributions[setup][j], |
|
ax=ax, |
|
square=True, |
|
vmin=vmin, |
|
vmax=vmax, |
|
cbar=show_cbars |
|
) |
|
if show_scores: |
|
ax.set_xlabel(f'{patch_contributions[setup][j].sum():.2f}') |
|
if j == 0: |
|
ax.set_ylabel(setup) |
|
if i == 0: |
|
ax.set_title(labels[j]) |
|
ax.set_xticks([]) |
|
ax.set_yticks([]) |
|
|
|
plt.tight_layout() |
|
|
|
return fig |
|
|
|
demo = gr.Interface( |
|
fn=visualize, |
|
inputs=[ |
|
gr.Image(shape=(224, 224), type='pil', label='Input image'), |
|
gr.CheckboxGroup(setups, value=setups, label='Visual Emb-GAM'), |
|
gr.Checkbox(label='Show scores'), |
|
gr.Checkbox(label='Show color bars') |
|
], |
|
outputs=[ |
|
gr.Plot(label='Patch contributions'), |
|
] |
|
) |
|
demo.launch(debug=True) |