HuggingDavid commited on
Commit
ea81160
1 Parent(s): c3ada85

Upload with huggingface_hub

Browse files
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ HF_TOKEN="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"
.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
Untitled.ipynb ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "dd03eb44",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "token\n",
14
+ "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\n"
15
+ ]
16
+ },
17
+ {
18
+ "name": "stderr",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "/Users/david/Documents/python-env-test/venv/lib/python3.10/site-packages/huggingface_hub/hf_api.py:101: FutureWarning: `name` and `organization` input arguments are deprecated and will be removed in v0.10. Pass `repo_id` instead.\n",
22
+ " warnings.warn(\n",
23
+ "Cloning https://huggingface.co/datasets/HuggingDavid/simple-mnist-flagging into local empty directory.\n"
24
+ ]
25
+ },
26
+ {
27
+ "data": {
28
+ "application/vnd.jupyter.widget-view+json": {
29
+ "model_id": "38d85f20bb7d48f8934048f520b5125f",
30
+ "version_major": 2,
31
+ "version_minor": 0
32
+ },
33
+ "text/plain": [
34
+ "Download file img/tmp7qxdqjtl.png: 46%|####5 | 8.28k/18.1k [00:00<?, ?B/s]"
35
+ ]
36
+ },
37
+ "metadata": {},
38
+ "output_type": "display_data"
39
+ },
40
+ {
41
+ "data": {
42
+ "application/vnd.jupyter.widget-view+json": {
43
+ "model_id": "ea599ef7307c42c7b1f3db8e453aadc9",
44
+ "version_major": 2,
45
+ "version_minor": 0
46
+ },
47
+ "text/plain": [
48
+ "Clean file img/tmp7qxdqjtl.png: 6%|5 | 1.00k/18.1k [00:00<?, ?B/s]"
49
+ ]
50
+ },
51
+ "metadata": {},
52
+ "output_type": "display_data"
53
+ },
54
+ {
55
+ "data": {
56
+ "application/vnd.jupyter.widget-view+json": {
57
+ "model_id": "4badbc924d4a4d04b5469682a6837c9a",
58
+ "version_major": 2,
59
+ "version_minor": 0
60
+ },
61
+ "text/plain": [
62
+ "Download file img/tmpb9pmlzsj.png: 100%|##########| 15.4k/15.4k [00:00<?, ?B/s]"
63
+ ]
64
+ },
65
+ "metadata": {},
66
+ "output_type": "display_data"
67
+ },
68
+ {
69
+ "data": {
70
+ "application/vnd.jupyter.widget-view+json": {
71
+ "model_id": "665fb3a23cc843cba87cdaad930b645a",
72
+ "version_major": 2,
73
+ "version_minor": 0
74
+ },
75
+ "text/plain": [
76
+ "Clean file img/tmpb9pmlzsj.png: 7%|6 | 1.00k/15.4k [00:00<?, ?B/s]"
77
+ ]
78
+ },
79
+ "metadata": {},
80
+ "output_type": "display_data"
81
+ },
82
+ {
83
+ "name": "stdout",
84
+ "output_type": "stream",
85
+ "text": [
86
+ "Running on local URL: http://127.0.0.1:7880\n",
87
+ "\n",
88
+ "To create a public link, set `share=True` in `launch()`.\n"
89
+ ]
90
+ },
91
+ {
92
+ "data": {
93
+ "text/html": [
94
+ "<div><iframe src=\"http://127.0.0.1:7880/\" width=\"900\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
95
+ ],
96
+ "text/plain": [
97
+ "<IPython.core.display.HTML object>"
98
+ ]
99
+ },
100
+ "metadata": {},
101
+ "output_type": "display_data"
102
+ },
103
+ {
104
+ "data": {
105
+ "text/plain": [
106
+ "(<gradio.routes.App at 0x162231e40>, 'http://127.0.0.1:7880/', None)"
107
+ ]
108
+ },
109
+ "execution_count": 2,
110
+ "metadata": {},
111
+ "output_type": "execute_result"
112
+ },
113
+ {
114
+ "data": {
115
+ "application/vnd.jupyter.widget-view+json": {
116
+ "model_id": "2ecf20840bb14b4f96671ee323d83734",
117
+ "version_major": 2,
118
+ "version_minor": 0
119
+ },
120
+ "text/plain": [
121
+ "Upload file img/tmpjuysmmri.png: 100%|##########| 17.6k/17.6k [00:00<?, ?B/s]"
122
+ ]
123
+ },
124
+ "metadata": {},
125
+ "output_type": "display_data"
126
+ },
127
+ {
128
+ "name": "stderr",
129
+ "output_type": "stream",
130
+ "text": [
131
+ "remote: Scanning LFS files for validity, may be slow... \n",
132
+ "remote: LFS file scan complete. \n",
133
+ "To https://huggingface.co/datasets/HuggingDavid/simple-mnist-flagging\n",
134
+ " 4b19b7d..458cf22 main -> main\n",
135
+ "\n"
136
+ ]
137
+ }
138
+ ],
139
+ "source": [
140
+ "import torch\n",
141
+ "import gradio as gr\n",
142
+ "from torchvision import transforms\n",
143
+ "from PIL import ImageOps\n",
144
+ "import os\n",
145
+ "from dotenv import load_dotenv\n",
146
+ "\n",
147
+ "load_dotenv()\n",
148
+ "\n",
149
+ "hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), \"simple-mnist-flagging\")\n",
150
+ "\n",
151
+ "def load_model():\n",
152
+ " model_dict = torch.load('linear_model.pt')\n",
153
+ " return model_dict\n",
154
+ "\n",
155
+ "model = load_model()\n",
156
+ "convert_tensor = transforms.ToTensor()\n",
157
+ "\n",
158
+ "def predict(img):\n",
159
+ " img = ImageOps.grayscale(img).resize((28,28))\n",
160
+ " image_tensor = convert_tensor(img).view(28*28)\n",
161
+ " res = image_tensor @ model['weights'] + model['bias']\n",
162
+ " res = res.sigmoid()\n",
163
+ " return {\"It's 3\": float(res), \"It's 7\": float(1-res)}\n",
164
+ "\n",
165
+ "title = \"Is it 7 or 3\"\n",
166
+ "description = '<p><center>Write a number, 7 or 3, in the middle.</center></p>'\n",
167
+ "\n",
168
+ "gr.Interface(fn=predict, \n",
169
+ " inputs=gr.Paint(type=\"pil\", invert_colors=True),\n",
170
+ " outputs=gr.Label(num_top_classes=2),\n",
171
+ " title=title,\n",
172
+ " flagging_options=[\"incorrect\",\"ambiguous\"],\n",
173
+ " flagging_callback=hf_writer,\n",
174
+ " description=description,\n",
175
+ " allow_flagging='manual').launch()"
176
+ ]
177
+ }
178
+ ],
179
+ "metadata": {
180
+ "kernelspec": {
181
+ "display_name": "Python 3 (ipykernel)",
182
+ "language": "python",
183
+ "name": "python3"
184
+ },
185
+ "language_info": {
186
+ "codemirror_mode": {
187
+ "name": "ipython",
188
+ "version": 3
189
+ },
190
+ "file_extension": ".py",
191
+ "mimetype": "text/x-python",
192
+ "name": "python",
193
+ "nbconvert_exporter": "python",
194
+ "pygments_lexer": "ipython3",
195
+ "version": "3.10.6"
196
+ }
197
+ },
198
+ "nbformat": 4,
199
+ "nbformat_minor": 5
200
+ }
app.py CHANGED
@@ -2,6 +2,12 @@ import torch
2
  import gradio as gr
3
  from torchvision import transforms
4
  from PIL import ImageOps
 
 
 
 
 
 
5
 
6
  def load_model():
7
  model_dict = torch.load('linear_model.pt')
@@ -24,5 +30,7 @@ gr.Interface(fn=predict,
24
  inputs=gr.Paint(type="pil", invert_colors=True),
25
  outputs=gr.Label(num_top_classes=2),
26
  title=title,
 
 
27
  description=description,
28
- allow_flagging='never').launch()
 
2
  import gradio as gr
3
  from torchvision import transforms
4
  from PIL import ImageOps
5
+ import os
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+
10
+ hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), "simple-mnist-flagging")
11
 
12
  def load_model():
13
  model_dict = torch.load('linear_model.pt')
 
30
  inputs=gr.Paint(type="pil", invert_colors=True),
31
  outputs=gr.Label(num_top_classes=2),
32
  title=title,
33
+ flagging_options=["incorrect","ambiguous"],
34
+ flagging_callback=hf_writer,
35
  description=description,
36
+ allow_flagging='manual').launch()
requirements.in CHANGED
@@ -1,3 +1,4 @@
1
  torch
2
  torchvision
3
- Pillow
 
 
1
  torch
2
  torchvision
3
+ Pillow
4
+ python-dotenv
requirements.txt CHANGED
@@ -2,7 +2,7 @@
2
  # This file is autogenerated by pip-compile with python 3.10
3
  # To update, run:
4
  #
5
- # pip-compile --generate-hashes requirements.in
6
  #
7
  certifi==2022.9.24 \
8
  --hash=sha256:0d9c601124e5a6ba9712dbc60d9c53c21e34f5f641fe83002317394311bdce14 \
@@ -106,8 +106,12 @@ pillow==9.2.0 \
106
  --hash=sha256:fa768eff5f9f958270b081bb33581b4b569faabf8774726b283edb06617101dc \
107
  --hash=sha256:fac2d65901fb0fdf20363fbd345c01958a742f2dc62a8dd4495af66e3ff502a4
108
  # via
109
- # -r requirements.in
110
  # torchvision
 
 
 
 
111
  requests==2.28.1 \
112
  --hash=sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983 \
113
  --hash=sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349
@@ -134,7 +138,7 @@ torch==1.12.1 \
134
  --hash=sha256:f00c721f489089dc6364a01fd84906348fe02243d0af737f944fddb36003400d \
135
  --hash=sha256:f3b52a634e62821e747e872084ab32fbcb01b7fa7dbb7471b6218279f02a178a
136
  # via
137
- # -r requirements.in
138
  # torchvision
139
  torchvision==0.13.1 \
140
  --hash=sha256:0298bae3b09ac361866088434008d82b99d6458fe8888c8df90720ef4b347d44 \
@@ -156,7 +160,7 @@ torchvision==0.13.1 \
156
  --hash=sha256:e9a563894f9fa40692e24d1aa58c3ef040450017cfed3598ff9637f404f3fe3b \
157
  --hash=sha256:ef5fe3ec1848123cd0ec74c07658192b3147dcd38e507308c790d5943e87b88c \
158
  --hash=sha256:f230a1a40ed70d51e463ce43df243ec520902f8725de2502e485efc5eea9d864
159
- # via -r requirements.in
160
  typing-extensions==4.4.0 \
161
  --hash=sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa \
162
  --hash=sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e
 
2
  # This file is autogenerated by pip-compile with python 3.10
3
  # To update, run:
4
  #
5
+ # pip-compile --generate-hashes --output-file=gradio-app/requirements.txt gradio-app/requirements.in
6
  #
7
  certifi==2022.9.24 \
8
  --hash=sha256:0d9c601124e5a6ba9712dbc60d9c53c21e34f5f641fe83002317394311bdce14 \
 
106
  --hash=sha256:fa768eff5f9f958270b081bb33581b4b569faabf8774726b283edb06617101dc \
107
  --hash=sha256:fac2d65901fb0fdf20363fbd345c01958a742f2dc62a8dd4495af66e3ff502a4
108
  # via
109
+ # -r gradio-app/requirements.in
110
  # torchvision
111
+ python-dotenv==0.21.0 \
112
+ --hash=sha256:1684eb44636dd462b66c3ee016599815514527ad99965de77f43e0944634a7e5 \
113
+ --hash=sha256:b77d08274639e3d34145dfa6c7008e66df0f04b7be7a75fd0d5292c191d79045
114
+ # via -r gradio-app/requirements.in
115
  requests==2.28.1 \
116
  --hash=sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983 \
117
  --hash=sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349
 
138
  --hash=sha256:f00c721f489089dc6364a01fd84906348fe02243d0af737f944fddb36003400d \
139
  --hash=sha256:f3b52a634e62821e747e872084ab32fbcb01b7fa7dbb7471b6218279f02a178a
140
  # via
141
+ # -r gradio-app/requirements.in
142
  # torchvision
143
  torchvision==0.13.1 \
144
  --hash=sha256:0298bae3b09ac361866088434008d82b99d6458fe8888c8df90720ef4b347d44 \
 
160
  --hash=sha256:e9a563894f9fa40692e24d1aa58c3ef040450017cfed3598ff9637f404f3fe3b \
161
  --hash=sha256:ef5fe3ec1848123cd0ec74c07658192b3147dcd38e507308c790d5943e87b88c \
162
  --hash=sha256:f230a1a40ed70d51e463ce43df243ec520902f8725de2502e485efc5eea9d864
163
+ # via -r gradio-app/requirements.in
164
  typing-extensions==4.4.0 \
165
  --hash=sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa \
166
  --hash=sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e