File size: 2,118 Bytes
cb70b4c
249653c
3eeeced
48b3a1c
7a5b6f5
 
249653c
3eeeced
 
 
7c4c927
fffc3a2
8503f33
3eeeced
115b639
 
 
 
 
 
 
d52eb29
115b639
 
7c4c927
 
2aaad1e
f17ecc5
7a5b6f5
079f211
2aaad1e
7c4c927
 
 
 
 
 
 
7b7b81a
7c4c927
 
 
 
 
 
fffc3a2
3eeeced
 
 
fffc3a2
115b639
d1879c0
4aa43f9
d1879c0
115b639
 
 
 
 
f17ecc5
11dac02
67adcd0
115b639
fffc3a2
5264799
cb70b4c
1f780c8
 
fffc3a2
 
cb70b4c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
import requests
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T

import timm

model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
model.eval()

import os 

def print_bn():
    bn_data = []
    for m in model.modules():
        if(type(m) is nn.BatchNorm2d):
            # print(m.momentum)
            bn_data.extend(m.running_mean.data.numpy().tolist())
            bn_data.extend(m.running_var.data.numpy().tolist())
            bn_data.append(m.momentum)
    return bn_data

def update_bn(image):
    cursor_im = 0
    print("Before Resize:", image.shape)
    
    image = T.Resize((40,40))(image)
    image = image.reshape(-1)
    print("After Resize:", image.shape)
    for m in model.modules():
        if(type(m) is nn.BatchNorm2d):
            if(cursor_im < image.shape[0]):
                M = m.running_mean.data.shape[0]
                if(cursor_im+M < image.shape[0]):
                    m.running_mean.data = image[cursor_im:cursor_im+M]
                    cursor_im += M # next
                    print(cursor_im,':',cursor_im+M)
                else:
                    m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:]
                    break # finish 
    return
    

def greet(image):
    # url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT'
    # html = request_url(url)
    # key = os.getenv("OPENAI_API_KEY")
#     x = torch.ones([1,3,224,224])
    if(image is None):
        bn_data = print_bn()
        return ','.join([f'{x:.10f}' for x in bn_data])
    else:  
        print(type(image))
        image = torch.tensor(image).float()
        print(image.min(), image.max())
        image = image/255.0
        image = image.unsqueeze(0)
        image = torch.permute(image, [0,3,1,2])
        update_bn(image)
        print(image.shape)
        out = model(image)
    # model.train()
    return "Hello world!"



image = gr.inputs.Image(label="Upload a photo for beautify", shape=(224,224))
iface = gr.Interface(fn=greet, inputs=image, outputs="text")
iface.launch()