HuggingDavid commited on
Commit
0f7c18e
1 Parent(s): ea81160

Upload with huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +33 -10
  2. model.pt +3 -0
app.py CHANGED
@@ -5,32 +5,55 @@ from PIL import ImageOps
5
  import os
6
  from dotenv import load_dotenv
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  load_dotenv()
9
 
10
  hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), "simple-mnist-flagging")
11
 
12
  def load_model():
13
- model_dict = torch.load('linear_model.pt')
14
- return model_dict
 
 
15
 
16
  model = load_model()
17
  convert_tensor = transforms.ToTensor()
18
 
19
  def predict(img):
20
  img = ImageOps.grayscale(img).resize((28,28))
21
- image_tensor = convert_tensor(img).view(28*28)
22
- res = image_tensor @ model['weights'] + model['bias']
23
- res = res.sigmoid()
24
- return {"It's 3": float(res), "It's 7": float(1-res)}
25
 
26
- title = "Is it 7 or 3"
27
- description = '<p><center>Write a number, 7 or 3, in the middle.</center></p>'
28
 
29
  gr.Interface(fn=predict,
30
  inputs=gr.Paint(type="pil", invert_colors=True),
31
- outputs=gr.Label(num_top_classes=2),
32
  title=title,
33
  flagging_options=["incorrect","ambiguous"],
34
  flagging_callback=hf_writer,
35
  description=description,
36
- allow_flagging='manual').launch()
 
 
5
  import os
6
  from dotenv import load_dotenv
7
 
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+ class SimpleLenet(nn.Module):
12
+ def __init__(self, args=None):
13
+ super().__init__()
14
+ self.conv1 = nn.Conv2d(1, 6, 5, padding=2) # -> 6 channels, 28x28
15
+ self.pool = nn.MaxPool2d(2) # -> 6 channels, 14x14
16
+ self.conv2 = nn.Conv2d(6, 120, 14) #-> 120 channels, 1x1
17
+ self.fc1 = nn.Linear(120, 10)
18
+ self.fc2 = nn.Linear(10, 10)
19
+
20
+ def __call__(self, x):
21
+ xx = F.relu(self.conv1(x))
22
+ xx = F.relu(self.pool(xx))
23
+ xx = F.relu(self.conv2(xx))
24
+ xx = xx.flatten(1)
25
+ xx = F.relu(self.fc1(xx))
26
+ return self.fc2(xx)
27
+
28
  load_dotenv()
29
 
30
  hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), "simple-mnist-flagging")
31
 
32
  def load_model():
33
+ model = SimpleLenet()
34
+ model.load_state_dict(torch.load('model.pt'))
35
+ model.eval()
36
+ return model
37
 
38
  model = load_model()
39
  convert_tensor = transforms.ToTensor()
40
 
41
  def predict(img):
42
  img = ImageOps.grayscale(img).resize((28,28))
43
+ image_tensor = convert_tensor(img).view(1, 1, 28, 28)
44
+ logits = model(image_tensor)
45
+ pred = torch.argmax(logits, dim=1)
46
+ return pred.tolist()[0]
47
 
48
+ title = "Handwritten digit recognition"
49
+ description = '<p><center>Write a single digit in the middle of the canvas</center></p>'
50
 
51
  gr.Interface(fn=predict,
52
  inputs=gr.Paint(type="pil", invert_colors=True),
53
+ outputs="text",
54
  title=title,
55
  flagging_options=["incorrect","ambiguous"],
56
  flagging_callback=hf_writer,
57
  description=description,
58
+ allow_flagging='manual').launch()
59
+
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:922abc05756421cbdffc98dc27a3eef6abf476761170e77656e7bb5477a45b10
3
+ size 573263