File size: 2,404 Bytes
cb70b4c
249653c
3eeeced
48b3a1c
7a5b6f5
 
a254dcb
3eeeced
 
a254dcb
 
7c4c927
8503f33
3eeeced
115b639
9e8b807
115b639
 
 
 
 
 
d52eb29
a254dcb
6882504
dd6daed
115b639
 
7c4c927
9e8b807
7c4c927
562ce81
079f211
7c4c927
 
 
 
 
 
a254dcb
7b7b81a
7c4c927
 
a254dcb
7c4c927
 
 
fffc3a2
115b639
d1879c0
b872349
d1879c0
6882504
 
115b639
 
 
 
 
f17ecc5
11dac02
67adcd0
a254dcb
5264799
cb70b4c
1f780c8
 
2b6e66a
 
 
cb70b4c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
import requests
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from transformers import AutoFeatureExtractor, ResNetForImageClassification
import timm

feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-101")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-101")
model.eval()
import os 

def print_bn():
    
    bn_data = []
    for m in model.modules():
        if(type(m) is nn.BatchNorm2d):
            # print(m.momentum)
            bn_data.extend(m.running_mean.data.numpy().tolist())
            bn_data.extend(m.running_var.data.numpy().tolist())
            bn_data.append(m.momentum)
    print(len(bn_data))
    # bn_data.extend(model.resnet.embedder.embedder.convolution.weight.data.reshape(-1).numpy().tolist())
    # print(model.resnet.embedder.embedder.convolution.weight.data.numpy().tolist())
    return bn_data

def update_bn(image):

    cursor_im = 0
    image = T.Resize((90,90))(image)
    image = image.reshape(-1)
    for m in model.modules():
        if(type(m) is nn.BatchNorm2d):
            if(cursor_im < image.shape[0]):
                M = m.running_mean.data.shape[0]
                if(cursor_im+M < image.shape[0]):
                    m.running_mean.data = image[cursor_im:cursor_im+M]
                    cursor_im += M
                    print(cursor_im,':',cursor_im+M)
                else:
                    m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:]
                    break 
    return
    

def greet(image):
    if(image is None):
        bn_data = print_bn()
        return ','.join([f'{x:.2f}' for x in bn_data])
    else:  
        # conv_layer = model.resnet.embedder.embedder.convolution
        # conv_layer.weight.data = torch.ones_like(conv_layer.weight.data)
        print(type(image))
        image = torch.tensor(image).float()
        print(image.min(), image.max())
        image = image/255.0
        image = image.unsqueeze(0)
        image = torch.permute(image, [0,3,1,2])
        update_bn(image)
        print(image.shape)
        out = model(pixel_values=image)
    return "Hello world!"



image = gr.inputs.Image(label="Upload a photo for beauty", shape=(224,224))
out_image = gr.inputs.Image(label='Yes, it becomes better.')
iface = gr.Interface(fn=greet, inputs=image, outputs='text')
iface.launch()