Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.transforms as T | |
from transformers import AutoFeatureExtractor, ResNetForImageClassification | |
import timm | |
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-101") | |
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-101") | |
model.eval() | |
import os | |
def print_bn(): | |
bn_data = [] | |
for m in model.modules(): | |
if(type(m) is nn.BatchNorm2d): | |
# print(m.momentum) | |
bn_data.extend(m.running_mean.data.numpy().tolist()) | |
bn_data.extend(m.running_var.data.numpy().tolist()) | |
bn_data.append(m.momentum) | |
print(len(bn_data)) | |
# bn_data.extend(model.resnet.embedder.embedder.convolution.weight.data.reshape(-1).numpy().tolist()) | |
# print(model.resnet.embedder.embedder.convolution.weight.data.numpy().tolist()) | |
return bn_data | |
def update_bn(image): | |
cursor_im = 0 | |
image = T.Resize((90,90))(image) | |
image = image.reshape(-1) | |
for m in model.modules(): | |
if(type(m) is nn.BatchNorm2d): | |
if(cursor_im < image.shape[0]): | |
M = m.running_mean.data.shape[0] | |
if(cursor_im+M < image.shape[0]): | |
m.running_mean.data = image[cursor_im:cursor_im+M] | |
cursor_im += M | |
print(cursor_im,':',cursor_im+M) | |
else: | |
m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:] | |
break | |
return | |
def greet(image): | |
if(image is None): | |
bn_data = print_bn() | |
return ','.join([f'{x:.2f}' for x in bn_data]) | |
else: | |
# conv_layer = model.resnet.embedder.embedder.convolution | |
# conv_layer.weight.data = torch.ones_like(conv_layer.weight.data) | |
print(type(image)) | |
image = torch.tensor(image).float() | |
print(image.min(), image.max()) | |
image = image/255.0 | |
image = image.unsqueeze(0) | |
image = torch.permute(image, [0,3,1,2]) | |
update_bn(image) | |
print(image.shape) | |
out = model(pixel_values=image) | |
return "Hello world!" | |
image = gr.inputs.Image(label="Upload a photo for beauty", shape=(224,224)) | |
out_image = gr.inputs.Image(label='Yes, it becomes better.') | |
iface = gr.Interface(fn=greet, inputs=image, outputs='text') | |
iface.launch() |