AyushKumar5771 commited on
Commit
3901092
·
verified ·
1 Parent(s): edfc3b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # app.py
2
-
3
  import torch
4
  import torchvision.models as models
5
  from fastai.vision.all import *
@@ -9,20 +7,30 @@ import gradio as gr
9
  model_path = "model.pth"
10
 
11
  # Load the model architecture
12
- model = models.resnet34(pretrained=False) # Match the architecture you used
 
 
 
 
 
 
 
 
 
13
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Load weights
14
  model.eval() # Set to evaluation mode
15
 
16
- # Define a simple transform function, adjust if necessary
17
  def transform(img):
18
- img = PILImage.create(img).resize((224, 224)) # Resize if needed for your model
19
- return tensor(img).unsqueeze(0) # Convert to tensor and add batch dimension
20
 
21
  # Define the prediction function
22
  def predict_image(img):
23
  img_tensor = transform(img)
24
- outputs = model(img_tensor)
25
- _, pred_idx = outputs.max(1)
 
26
  idx_to_class = {0: "Bike", 1: "Car", 2: "Cat", 3: "Dog"} # Map indices to your labels
27
  return idx_to_class[pred_idx.item()]
28
 
 
 
 
1
  import torch
2
  import torchvision.models as models
3
  from fastai.vision.all import *
 
7
  model_path = "model.pth"
8
 
9
  # Load the model architecture
10
+ class MyResNet34Model(torch.nn.Module):
11
+ def __init__(self):
12
+ super(MyResNet34Model, self).__init__()
13
+ self.model = models.resnet34(pretrained=False, num_classes=4) # Adjust num_classes as needed
14
+
15
+ def forward(self, x):
16
+ return self.model(x)
17
+
18
+ # Load the model and weights
19
+ model = MyResNet34Model()
20
  model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Load weights
21
  model.eval() # Set to evaluation mode
22
 
23
+ # Define a simple transform function
24
  def transform(img):
25
+ img = PILImage.create(img).resize((224, 224)) # Resize for the model
26
+ return tensor(img).unsqueeze(0).float() # Convert to tensor and add batch dimension
27
 
28
  # Define the prediction function
29
  def predict_image(img):
30
  img_tensor = transform(img)
31
+ with torch.no_grad(): # Disable gradient calculation
32
+ outputs = model(img_tensor)
33
+ _, pred_idx = outputs.max(1)
34
  idx_to_class = {0: "Bike", 1: "Car", 2: "Cat", 3: "Dog"} # Map indices to your labels
35
  return idx_to_class[pred_idx.item()]
36