{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "dd03eb44", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "token\n", "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\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/david/Documents/python-env-test/venv/lib/python3.10/site-packages/huggingface_hub/hf_api.py:101: FutureWarning: `name` and `organization` input arguments are deprecated and will be removed in v0.10. Pass `repo_id` instead.\n", " warnings.warn(\n", "Cloning https://huggingface.co/datasets/HuggingDavid/simple-mnist-flagging into local empty directory.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "38d85f20bb7d48f8934048f520b5125f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Download file img/tmp7qxdqjtl.png: 46%|####5 | 8.28k/18.1k [00:00" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "(, 'http://127.0.0.1:7880/', None)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2ecf20840bb14b4f96671ee323d83734", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Upload file img/tmpjuysmmri.png: 100%|##########| 17.6k/17.6k [00:00 main\n", "\n" ] } ], "source": [ "import torch\n", "import gradio as gr\n", "from torchvision import transforms\n", "from PIL import ImageOps\n", "import os\n", "from dotenv import load_dotenv\n", "\n", "load_dotenv()\n", "\n", "hf_writer = gr.HuggingFaceDatasetSaver(os.getenv('HF_TOKEN'), \"simple-mnist-flagging\")\n", "\n", "def load_model():\n", " model_dict = torch.load('linear_model.pt')\n", " return model_dict\n", "\n", "model = load_model()\n", "convert_tensor = transforms.ToTensor()\n", "\n", "def predict(img):\n", " img = ImageOps.grayscale(img).resize((28,28))\n", " image_tensor = convert_tensor(img).view(28*28)\n", " res = image_tensor @ model['weights'] + model['bias']\n", " res = res.sigmoid()\n", " return {\"It's 3\": float(res), \"It's 7\": float(1-res)}\n", "\n", "title = \"Is it 7 or 3\"\n", "description = '

Write a number, 7 or 3, in the middle.

'\n", "\n", "gr.Interface(fn=predict, \n", " inputs=gr.Paint(type=\"pil\", invert_colors=True),\n", " outputs=gr.Label(num_top_classes=2),\n", " title=title,\n", " flagging_options=[\"incorrect\",\"ambiguous\"],\n", " flagging_callback=hf_writer,\n", " description=description,\n", " allow_flagging='manual').launch()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }