Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
import numpy as np | |
import random | |
from einops import rearrange | |
import matplotlib.pyplot as plt | |
from torchvision.transforms import v2 | |
from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtractor | |
path = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']] | |
model_name = "vit-t-mae-pretrain.pt" | |
model = torch.load(model_name, map_location='cpu') | |
model.eval() | |
device = torch.device("cpu") | |
model.to(device) | |
transform = v2.Compose([ | |
v2.Resize((96, 96)), | |
v2.ToTensor(), | |
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
# Load and Preprocess the Image | |
def load_image(image_path, transform): | |
img = Image.open(image_path).convert('RGB') | |
img = transform(img).unsqueeze(0) # Add batch dimension | |
return img | |
def show_image(img, title): | |
img = rearrange(img, "c h w -> h w c") | |
img = (img.cpu().detach().numpy() + 1) / 2 # Normalize to [0, 1] | |
plt.imshow(img) | |
plt.axis('off') | |
plt.title(title) | |
# Visualize a Single Image | |
def visualize_single_image(image_path): | |
img = load_image(image_path, transform).to(device) | |
# Run inference | |
model.eval() | |
with torch.no_grad(): | |
predicted_img, mask = model(img) | |
# Convert the tensor back to a displayable image | |
# masked image | |
im_masked = img * (1 - mask) | |
# MAE reconstruction pasted with visible patches | |
im_paste = img * (1 - mask) + predicted_img * mask | |
# make the plt figure larger | |
plt.figure(figsize=(18, 8)) | |
plt.subplot(1, 4, 1) | |
show_image(img[0], "original") | |
plt.subplot(1, 4, 2) | |
show_image(im_masked[0], "masked") | |
plt.subplot(1, 4, 3) | |
show_image(predicted_img[0], "reconstruction") | |
plt.subplot(1, 4, 4) | |
show_image(im_paste[0], "reconstruction + visible") | |
plt.tight_layout() | |
# convert the plt figure to a numpy array | |
plt.savefig("output.png") | |
return np.array(plt.imread("output.png")) | |
inputs_image = [ | |
gr.components.Image(type="filepath", label="Input Image"), | |
] | |
outputs_image = [ | |
gr.components.Image(type="numpy", label="Output Image"), | |
] | |
gr.Interface( | |
fn=visualize_single_image, | |
inputs=inputs_image, | |
outputs=outputs_image, | |
examples=path, | |
title="MAE-ViT Image Reconstruction", | |
description="This is a demo of the MAE-ViT model for image reconstruction.", | |
).launch() |