panik's picture
Update app.py
b96eb2c
raw
history blame contribute delete
903 Bytes
import gradio as gr
import tensorflow as tf
import keras
from keras.datasets import mnist
import matplotlib.pyplot as plt
import random
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
def sample_digit(digit):
rn = 0
# pick a random digit from 60,000 in the training set until a desired match is found
while(train_labels[rn] != digit):
rn = int(random.random() * 60000)
digit_img = train_images[rn]
fig = plt.figure()
plt.imshow(digit_img, cmap=plt.cm.binary)
out_txt = "train_images[%d]" % rn
return fig, out_txt
iface = gr.Interface(
fn = sample_digit,
inputs = [
#gr.inputs.Dropdown([0, 1, 2, 3])
#gr.inputs.Number()
gr.inputs.Slider(minimum=0, maximum=9, step=1)
],
outputs=[gr.outputs.Image(type='plot'), 'text'],
title='MNIST Digit Sampler',
description='Pick a random digit from the MNIST dataset'
)
iface.launch()