Sifal commited on
Commit
37ad7bc
·
verified ·
1 Parent(s): e7f48d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from utils.inference_utils import preprocess_image, predict
4
+ from utils.train_utils import initialize_model
5
+ from utils.data import CLASS_NAMES
6
+
7
+ # Load the model once during app initialization
8
+ model_name = "resnet"
9
+ model_weights = "./pokemon_resnet.pth"
10
+ num_classes = 150
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # Initialize and load the model
14
+ model = initialize_model(model_name, num_classes).to(device)
15
+ model.load_state_dict(torch.load(model_weights, map_location=device))
16
+ model.eval() # Set the model to evaluation mode
17
+
18
+
19
+ def classify_image(image):
20
+ """Function to preprocess the image and classify it."""
21
+ try:
22
+ # Preprocess the uploaded image
23
+ image_tensor = preprocess_image(image, (224, 224)).to(device)
24
+
25
+ # Perform inference
26
+ preds = torch.max(predict(model, image_tensor), 1)[1]
27
+ predicted_class = CLASS_NAMES[preds.item()]
28
+
29
+ return f"Predicted class: {predicted_class}"
30
+
31
+ except Exception as e:
32
+ return f"Error: {str(e)}"
33
+
34
+
35
+ # Create a Gradio interface
36
+ demo = gr.Interface(
37
+ fn=classify_image,
38
+ inputs=gr.inputs.Image(type="pil", label="Upload Image"),
39
+ outputs="text",
40
+ title="Pokemon Classifier",
41
+ description="Upload an image of a Pokemon, and the model will predict its class.",
42
+ )
43
+
44
+ if __name__ == "__main__":
45
+ # Launch the Gradio app
46
+ demo.launch()