mnist / app.py
karynaur's picture
Update app.py
59ba40b
raw
history blame contribute delete
683 Bytes
import gradio as gr
import pickle
filename = 'model.sav'
model = pickle.load(open(filename, 'rb'))
def recognize_digit(image):
image = image.reshape(1, -1) # add a batch dimension
prediction = model.predict(image)[0]
ret = {str(i): 0 for i in range(10)}
ret[str(prediction)] = 1
return ret
output_component = gr.outputs.Label(num_top_classes=3)
gr.Interface(fn=recognize_digit,
inputs="sketchpad",
outputs=output_component,
title="MNIST Sketchpad",
description="Draw a number 0 through 9 on the sketchpad, and click submit to see the model's predictions. Model trained on the MNIST dataset.",
).launch();