File size: 2,390 Bytes
4b1b927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

# import the essentials
from demos.foodvision_mini.model import create_vit_b_16_model
import torch
import torchvision
import time
import gradio as gr
import numpy as np
from pathlib import Path

class_names = ['pizza','steak','sushi']
device = 'cuda' if torch.cuda.is_available else 'cpu'

# creating the vit_b_16_model and loading it with state_dict of our trained model
vit_b_16_model,vit_b_16_transform = create_vit_b_16_model(num_classes=3)
vit_b_16_model.load_state_dict(torch.load(f='vit_b_16_20_percent_data.pth'))

# create the predict function
def predict(img):

  """
  args:
  img: is an image

  returns: prediction class, prediction probability, and time taken to make the prediction

  """

  # transforming the image
  tr_img = vit_b_16_transform(img).unsqueeze(dim=0).to(device)

  # make prediction with vit_b_16
  model = vit_b_16_model.to(device)

  # starting the time
  start_time = time.perf_counter()

  model.eval()
  with torch.inference_mode():
    pred_logit = model(tr_img)
    pred_label = torch.argmax(pred_logit,dim=1).cpu()
    pred_prob = torch.max(torch.softmax(pred_logit,dim=1)).cpu().item()

  # ending the time
  end_time = time.perf_counter()
  # pred_dict = {str(class_names[i]):float(pred_prob[0][i].item()) for i in range(len(class_names))}
  pred_prob  = float(np.round(pred_prob,3))
  pred_class = class_names[pred_label]
  time_taken = float(np.round(end_time-start_time,3))



  return pred_class,pred_prob,time_taken


# create example list
example_dir = Path('demos/foodvision_mini/examples')
example_list = [['examples/' + str(filepath)] for filepath in os.listdir(example_dir)]

# create Gradio interface
description = 'A machine learning model to classify images into pizza,steak and sushi appropriately'
title = 'Image Classifier'


demo = gr.Interface(fn=predict, # this function maps the inputs to the output
                    inputs=gr.Image(type='pil'), # pillow image
                    outputs=[gr.Label(num_top_classes=1,label='Prediction'),
                             gr.Number(label='prediction probability'),
                             gr.Number(label='prediction time(s)')],
                    examples=example_list,
                    description=description,
                    title=title
                    )

demo.launch(debug=False, # print errors locally?
            share=True) # share to the public?