zliang commited on
Commit
743990b
1 Parent(s): 6c215ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -9,7 +9,15 @@ model_path = 'best.pt' # Replace with the path to your trained .pt file
9
  model = YOLO(model_path)
10
 
11
  # Function to perform inference on an image
12
- #@spaces.GPU
 
 
 
 
 
 
 
 
13
  def infer_image(image):
14
  # Convert the image from BGR to RGB
15
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@@ -24,11 +32,14 @@ def infer_image(image):
24
  cls = int(box.cls[0])
25
  conf = float(box.conf[0])
26
 
 
 
 
27
  # Draw bounding box
28
- cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
29
  # Draw label
30
  label = f'{model.names[cls]} {conf:.2f}'
31
- cv2.putText(image, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
32
 
33
  return image
34
 
 
9
  model = YOLO(model_path)
10
 
11
  # Function to perform inference on an image
12
+ colors = {
13
+ 0: (255, 0, 0), # Red for category 0
14
+ 1: (0, 255, 0), # Green for category 1
15
+ 2: (0, 0, 255), # Blue for category 2
16
+ 3: (255, 255, 0), # Cyan for category 3
17
+ 4: (255, 0, 255) # Magenta for category 4
18
+ }
19
+
20
+ # Function to perform inference on an image
21
  def infer_image(image):
22
  # Convert the image from BGR to RGB
23
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
32
  cls = int(box.cls[0])
33
  conf = float(box.conf[0])
34
 
35
+ # Get the color for the current class
36
+ color = colors.get(cls, (0, 255, 0)) # Default to green if class not found
37
+
38
  # Draw bounding box
39
+ cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
40
  # Draw label
41
  label = f'{model.names[cls]} {conf:.2f}'
42
+ cv2.putText(image, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
43
 
44
  return image
45