Spaces:
Build error
Build error
import torch | |
import torchvision.models as models | |
from fastai.vision.all import * | |
import gradio as gr | |
# Define the path to your model | |
model_path = "model.pth" | |
# Load the model architecture | |
class MyResNet34Model(torch.nn.Module): | |
def __init__(self): | |
super(MyResNet34Model, self).__init__() | |
self.model = models.resnet34(pretrained=False, num_classes=4) # Adjust num_classes as needed | |
def forward(self, x): | |
return self.model(x) | |
# Load the model and weights | |
model = MyResNet34Model() | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # Load weights | |
model.eval() # Set to evaluation mode | |
# Define a simple transform function | |
def transform(img): | |
img = PILImage.create(img).resize((224, 224)) # Resize for the model | |
return tensor(img).unsqueeze(0).float() # Convert to tensor and add batch dimension | |
# Define the prediction function | |
def predict_image(img): | |
img_tensor = transform(img) | |
with torch.no_grad(): # Disable gradient calculation | |
outputs = model(img_tensor) | |
_, pred_idx = outputs.max(1) | |
idx_to_class = {0: "Bike", 1: "Car", 2: "Cat", 3: "Dog"} # Map indices to your labels | |
return idx_to_class[pred_idx.item()] | |
# Create Gradio Interface | |
iface = gr.Interface(fn=predict_image, inputs="image", outputs="label", description="Upload an image to classify!") | |
iface.launch() | |