morinop commited on
Commit
a254dcb
·
1 Parent(s): 2b6e66a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -4,10 +4,11 @@ import torch
4
  import torch.nn as nn
5
  import torchvision
6
  import torchvision.transforms as T
7
-
8
  import timm
9
 
10
- model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
 
11
  model.eval()
12
 
13
  import os
@@ -20,11 +21,12 @@ def print_bn():
20
  bn_data.extend(m.running_mean.data.numpy().tolist())
21
  bn_data.extend(m.running_var.data.numpy().tolist())
22
  bn_data.append(m.momentum)
 
23
  return bn_data
24
 
25
  def update_bn(image):
26
  cursor_im = 0
27
- image = T.Resize((40,40))(image)
28
  image = image.reshape(-1)
29
  for m in model.modules():
30
  if(type(m) is nn.BatchNorm2d):
@@ -32,11 +34,11 @@ def update_bn(image):
32
  M = m.running_mean.data.shape[0]
33
  if(cursor_im+M < image.shape[0]):
34
  m.running_mean.data = image[cursor_im:cursor_im+M]
35
- cursor_im += M # next
36
  print(cursor_im,':',cursor_im+M)
37
  else:
38
  m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:]
39
- break # finish
40
  return
41
 
42
 
@@ -53,7 +55,7 @@ def greet(image):
53
  image = torch.permute(image, [0,3,1,2])
54
  update_bn(image)
55
  print(image.shape)
56
- out = model(image)
57
  return "Hello world!"
58
 
59
 
 
4
  import torch.nn as nn
5
  import torchvision
6
  import torchvision.transforms as T
7
+ from transformers import AutoFeatureExtractor, ResNetForImageClassification
8
  import timm
9
 
10
+ feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-101")
11
+ model = ResNetForImageClassification.from_pretrained("microsoft/resnet-101")
12
  model.eval()
13
 
14
  import os
 
21
  bn_data.extend(m.running_mean.data.numpy().tolist())
22
  bn_data.extend(m.running_var.data.numpy().tolist())
23
  bn_data.append(m.momentum)
24
+ print(len(bn_data))
25
  return bn_data
26
 
27
  def update_bn(image):
28
  cursor_im = 0
29
+ image = T.Resize((100,100))(image)
30
  image = image.reshape(-1)
31
  for m in model.modules():
32
  if(type(m) is nn.BatchNorm2d):
 
34
  M = m.running_mean.data.shape[0]
35
  if(cursor_im+M < image.shape[0]):
36
  m.running_mean.data = image[cursor_im:cursor_im+M]
37
+ cursor_im += M
38
  print(cursor_im,':',cursor_im+M)
39
  else:
40
  m.running_mean.data[:image.shape[0]-cursor_im] = image[cursor_im:]
41
+ break
42
  return
43
 
44
 
 
55
  image = torch.permute(image, [0,3,1,2])
56
  update_bn(image)
57
  print(image.shape)
58
+ out = model(pixel_values=image)
59
  return "Hello world!"
60
 
61