Blazer007 commited on
Commit
8cfd16a
·
1 Parent(s): e3902fa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import gradio as gr
4
+ from huggingface_hub import from_pretrained_keras
5
+
6
+ teacher_model = from_pretrained_keras("Blazer007/consistency_training_with_supervision_teacher_model")
7
+
8
+ student_model = from_pretrained_keras("Blazer007/consistency_training_with_supervision_student_model")
9
+
10
+ class_names = [
11
+ "Airplane",
12
+ "Automobile",
13
+ "Bird",
14
+ "Cat",
15
+ "Deer",
16
+ "Dog",
17
+ "Frog",
18
+ "Horse",
19
+ "Ship",
20
+ "Truck",
21
+ ]
22
+
23
+ IMG_SIZE = 72
24
+
25
+ def infer(input_image):
26
+ print('#$$$$$$$$$$$$$$$$$$$$$$$$$ IN INFER $$$$$$$$$$$$$$$$$$$$$$$')
27
+ # image_tensor = read_image(input_image)
28
+ image_tensor = tf.convert_to_tensor(input_image)
29
+ image_tensor.set_shape([None, None, 3])
30
+ image_tensor = tf.image.resize(image_tensor, (IMG_SIZE, IMG_SIZE))
31
+ print(image_tensor.shape)
32
+ predictions = teacher_model.predict(np.expand_dims((image_tensor), axis=0))
33
+ print(predictions)
34
+ predictions = np.squeeze(predictions)
35
+ print(predictions)
36
+ predictions = np.argmax(predictions) # , axis=2
37
+ print(predictions)
38
+ predicted_label = class_names[predictions.item()]
39
+ print(predictions.item())
40
+ print(predicted_label)
41
+ return str(predicted_label)
42
+
43
+ input = gr.inputs.Image(shape=(IMG_SIZE, IMG_SIZE))
44
+ output = [gr.outputs.Label()]
45
+ examples = [[], []]
46
+ title = "Image Classification using "
47
+ description = "Upload an image or select from examples to classify it.<br>The allowed classes are - Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck.<br><p><b>Space author: Vivek Rai</b> <br><b> Keras example author: Sayak Paul </b></p>"
48
+
49
+ gr_interface = gr.Interface(
50
+ infer,
51
+ input,
52
+ output,
53
+ examples=examples,
54
+ allow_flagging=False,
55
+ analytics_enabled=False,
56
+ title=title,
57
+ description=description).launch(enable_queue=True, debug=True)