KabeerAmjad commited on
Commit
53b38ec
1 Parent(s): 9f1c117

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -16
app.py CHANGED
@@ -1,34 +1,52 @@
1
  import gradio as gr
2
- from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
- from PIL import Image
4
  import torch
 
 
 
5
 
6
- # Load your Hugging Face model
7
  model_id = "KabeerAmjad/food_classification_model" # Replace with your actual model ID
8
- model = AutoModelForImageClassification.from_pretrained(model_id)
 
 
 
 
9
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
10
 
11
  # Define the prediction function
12
  def classify_image(img):
13
- # Preprocess the image and extract features
14
- inputs = feature_extractor(images=img, return_tensors="pt")
15
-
16
- # Run the model in evaluation mode
 
 
 
 
 
 
 
 
17
  with torch.no_grad():
18
- outputs = model(**inputs)
19
- probs = torch.softmax(outputs.logits, dim=-1)
20
-
21
  # Get the label with the highest probability
22
- top_label = model.config.id2label[probs.argmax().item()]
23
- return top_label
 
 
 
 
 
24
 
25
  # Create the Gradio interface
26
  iface = gr.Interface(
27
  fn=classify_image,
28
- inputs=gr.Image(type="pil"), # Image input type as PIL (Pillow Image)
29
- outputs="text", # Text output will display the predicted label
30
  title="Food Image Classification",
31
- description="Upload an image to classify if it’s an apple pie, etc."
32
  )
33
 
34
  # Launch the app
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from transformers import AutoFeatureExtractor
4
+ from torchvision import models, transforms
5
+ from PIL import Image
6
 
7
+ # Load your trained model from Hugging Face (if available) or load locally
8
  model_id = "KabeerAmjad/food_classification_model" # Replace with your actual model ID
9
+ model = models.resnet50() # Load ResNet50 architecture
10
+ model.load_state_dict(torch.load("path_to_trained_model_weights.pth")) # Load the trained weights
11
+ model.eval() # Set to evaluation mode
12
+
13
+ # Load the feature extractor (can be used if any custom preprocessing was applied)
14
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
15
 
16
  # Define the prediction function
17
  def classify_image(img):
18
+ # Preprocess the image to match ResNet50's expected input format
19
+ preprocess = transforms.Compose([
20
+ transforms.Resize((224, 224)),
21
+ transforms.RandomHorizontalFlip(),
22
+ transforms.RandomRotation(10),
23
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
26
+ ])
27
+ img_tensor = preprocess(img).unsqueeze(0) # Add batch dimension
28
+
29
+ # Make prediction with the model
30
  with torch.no_grad():
31
+ outputs = model(img_tensor)
32
+ probs = torch.softmax(outputs, dim=1)
33
+
34
  # Get the label with the highest probability
35
+ _, predicted_class = torch.max(probs, 1)
36
+
37
+ # If you have a list of class labels, use it
38
+ class_labels = ["Apple Pie", "Burger", "Pizza", "Tacos"] # Replace with your actual class labels
39
+ predicted_label = class_labels[predicted_class.item()]
40
+
41
+ return predicted_label
42
 
43
  # Create the Gradio interface
44
  iface = gr.Interface(
45
  fn=classify_image,
46
+ inputs=gr.Image(type="pil"),
47
+ outputs="text",
48
  title="Food Image Classification",
49
+ description="Upload an image to classify if it’s an apple pie, burger, pizza, etc."
50
  )
51
 
52
  # Launch the app