AyushKumar5771's picture
Update app.py
3901092 verified
raw
history blame
1.38 kB
import torch
import torchvision.models as models
from fastai.vision.all import *
import gradio as gr
# Define the path to your model
model_path = "model.pth"
# Load the model architecture
class MyResNet34Model(torch.nn.Module):
def __init__(self):
super(MyResNet34Model, self).__init__()
self.model = models.resnet34(pretrained=False, num_classes=4) # Adjust num_classes as needed
def forward(self, x):
return self.model(x)
# Load the model and weights
model = MyResNet34Model()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Load weights
model.eval() # Set to evaluation mode
# Define a simple transform function
def transform(img):
img = PILImage.create(img).resize((224, 224)) # Resize for the model
return tensor(img).unsqueeze(0).float() # Convert to tensor and add batch dimension
# Define the prediction function
def predict_image(img):
img_tensor = transform(img)
with torch.no_grad(): # Disable gradient calculation
outputs = model(img_tensor)
_, pred_idx = outputs.max(1)
idx_to_class = {0: "Bike", 1: "Car", 2: "Cat", 3: "Dog"} # Map indices to your labels
return idx_to_class[pred_idx.item()]
# Create Gradio Interface
iface = gr.Interface(fn=predict_image, inputs="image", outputs="label", description="Upload an image to classify!")
iface.launch()