morinop commited on
Commit
7c4c927
1 Parent(s): 4aa43f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
6
  import timm
7
 
8
  model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
9
- model.train()
10
 
11
  import os
12
 
@@ -20,6 +20,22 @@ def print_bn():
20
  bn_data.append(m.momentum)
21
  return bn_data
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def greet(image):
24
  # url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT'
25
  # html = request_url(url)
@@ -37,7 +53,7 @@ def greet(image):
37
  print(image.shape)
38
  image = torch.permute(image, [0,3,1,2])
39
  out = model(image)
40
-
41
  # model.train()
42
  return "Hello world!"
43
 
 
6
  import timm
7
 
8
  model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
9
+ model.eval()
10
 
11
  import os
12
 
 
20
  bn_data.append(m.momentum)
21
  return bn_data
22
 
23
+ def update_bn(image):
24
+ cursor_im = 0
25
+ image = image.view(-1)
26
+ for m in model.modules():
27
+ if(type(m) is nn.BatchNorm2d):
28
+ if(cursor_im < image.shape[0]):
29
+ M = m.running_mean.data.shape[0]
30
+ if(cursor_im+M < image.shape[0]):
31
+ m.running_mean.data = image[cursor_im:cursor_im+M]
32
+ cursor_im += M # next
33
+ else:
34
+ m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:]
35
+ break # finish
36
+ return
37
+
38
+
39
  def greet(image):
40
  # url = f'https://huggingface.co/spaces?p=1&sort=modified&search=GPT'
41
  # html = request_url(url)
 
53
  print(image.shape)
54
  image = torch.permute(image, [0,3,1,2])
55
  out = model(image)
56
+ update_bn(image)
57
  # model.train()
58
  return "Hello world!"
59