File size: 1,979 Bytes
cb70b4c
249653c
3eeeced
48b3a1c
7a5b6f5
 
249653c
3eeeced
 
 
7c4c927
fffc3a2
8503f33
3eeeced
115b639
 
 
 
 
 
 
d52eb29
115b639
 
7c4c927
 
a800e1f
7a5b6f5
7c4c927
 
 
 
 
 
 
 
 
 
 
 
 
fffc3a2
3eeeced
 
 
fffc3a2
115b639
d1879c0
4aa43f9
d1879c0
115b639
 
 
 
 
67adcd0
4de4ec7
115b639
7c4c927
fffc3a2
5264799
cb70b4c
1f780c8
 
fffc3a2
 
cb70b4c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import requests
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T

import timm

model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
model.eval()

import os 

def print_bn():
    bn_data = []
    for m in model.modules():
        if(type(m) is nn.BatchNorm2d):
            # print(m.momentum)
            bn_data.extend(m.running_mean.data.numpy().tolist())
            bn_data.extend(m.running_var.data.numpy().tolist())
            bn_data.append(m.momentum)
    return bn_data

def update_bn(image):
    cursor_im = 0
    image = image.reshape(-1)
    image = T.Resize((40,40))(image)
    for m in model.modules():
        if(type(m) is nn.BatchNorm2d):
            if(cursor_im < image.shape[0]):
                M = m.running_mean.data.shape[0]
                if(cursor_im+M < image.shape[0]):
                    m.running_mean.data = image[cursor_im:cursor_im+M]
                    cursor_im += M # next
                else:
                    m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:]
                    break # finish 
    return
    

def greet(image):
    # url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT'
    # html = request_url(url)
    # key = os.getenv("OPENAI_API_KEY")
#     x = torch.ones([1,3,224,224])
    if(image is None):
        bn_data = print_bn()
        return ','.join([f'{x:.10f}' for x in bn_data])
    else:  
        print(type(image))
        image = torch.tensor(image).float()
        print(image.min(), image.max())
        image = image/255.0
        image = image.unsqueeze(0)
        print(image.shape)
        image = torch.permute(image, [0,3,1,2])
        out = model(image)
        update_bn(image)
    # model.train()
    return "Hello world!"



image = gr.inputs.Image(label="Upload a photo for beautify", shape=(224,224))
iface = gr.Interface(fn=greet, inputs=image, outputs="text")
iface.launch()