Spaces:
Build error
Build error
AyushKumar5771
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
# app.py
|
2 |
-
|
3 |
import torch
|
4 |
import torchvision.models as models
|
5 |
from fastai.vision.all import *
|
@@ -9,20 +7,30 @@ import gradio as gr
|
|
9 |
model_path = "model.pth"
|
10 |
|
11 |
# Load the model architecture
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Load weights
|
14 |
model.eval() # Set to evaluation mode
|
15 |
|
16 |
-
# Define a simple transform function
|
17 |
def transform(img):
|
18 |
-
img = PILImage.create(img).resize((224, 224)) # Resize
|
19 |
-
return tensor(img).unsqueeze(0) # Convert to tensor and add batch dimension
|
20 |
|
21 |
# Define the prediction function
|
22 |
def predict_image(img):
|
23 |
img_tensor = transform(img)
|
24 |
-
|
25 |
-
|
|
|
26 |
idx_to_class = {0: "Bike", 1: "Car", 2: "Cat", 3: "Dog"} # Map indices to your labels
|
27 |
return idx_to_class[pred_idx.item()]
|
28 |
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torchvision.models as models
|
3 |
from fastai.vision.all import *
|
|
|
7 |
model_path = "model.pth"
|
8 |
|
9 |
# Load the model architecture
|
10 |
+
class MyResNet34Model(torch.nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super(MyResNet34Model, self).__init__()
|
13 |
+
self.model = models.resnet34(pretrained=False, num_classes=4) # Adjust num_classes as needed
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
return self.model(x)
|
17 |
+
|
18 |
+
# Load the model and weights
|
19 |
+
model = MyResNet34Model()
|
20 |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Load weights
|
21 |
model.eval() # Set to evaluation mode
|
22 |
|
23 |
+
# Define a simple transform function
|
24 |
def transform(img):
|
25 |
+
img = PILImage.create(img).resize((224, 224)) # Resize for the model
|
26 |
+
return tensor(img).unsqueeze(0).float() # Convert to tensor and add batch dimension
|
27 |
|
28 |
# Define the prediction function
|
29 |
def predict_image(img):
|
30 |
img_tensor = transform(img)
|
31 |
+
with torch.no_grad(): # Disable gradient calculation
|
32 |
+
outputs = model(img_tensor)
|
33 |
+
_, pred_idx = outputs.max(1)
|
34 |
idx_to_class = {0: "Bike", 1: "Car", 2: "Cat", 3: "Dog"} # Map indices to your labels
|
35 |
return idx_to_class[pred_idx.item()]
|
36 |
|