File size: 2,307 Bytes
8e0b903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import torch
from pathlib import Path
from models.common import DetectMultiBackend
from utils.dataloaders import LoadImages
from utils.general import (non_max_suppression, scale_boxes, check_img_size)
from utils.plots import Annotator, colors
from utils.torch_utils import select_device
import cv2

def predict_image(image_path, weights=r"yolov9/yolov9_vinbigData.pt", conf_thres=0.25, iou_thres=0.45, output_dir='pages/output_yolov9', device='cpu'):
    # Load model
    device = select_device(device)
    model = DetectMultiBackend(weights, device=device)
    stride, names, pt = model.stride, model.names, model.pt
    imgsz = check_img_size((640, 640), s=stride)  # Inference size
    dataset = LoadImages(image_path, img_size=imgsz, stride=stride, auto=pt)
    model.warmup(imgsz=(1, 3, *imgsz))  # Warmup model
    for path, im, im0s, _, _ in dataset:
        im = torch.from_numpy(im).to(model.device)
        im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32
        im /= 255  # 0 - 255 to 0.0 - 1.0
        if len(im.shape) == 3:
            im = im[None]  # Expand for batch dim

        # Inference
        pred = model(im)

        # Nếu `pred` là một danh sách, lấy phần tử đầu tiên
        if isinstance(pred, list):
            pred = pred[0]

        # Thực hiện NMS
        pred = non_max_suppression(pred, conf_thres, iou_thres, max_det=1000)

        # Process predictions
        for i, det in enumerate(pred):  # Per image
            im0 = im0s.copy()
            annotator = Annotator(im0, line_width=3, example=str(names))
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()

                # Draw bounding boxes and labels on image
                for *xyxy, conf, cls in reversed(det):
                    label = f'{names[int(cls)]} {conf:.2f}'
                    annotator.box_label(xyxy, label, color=colors(int(cls), True))

            # Save or display results
            output_path = os.path.join(output_dir, Path(path).name)
            os.makedirs(output_dir, exist_ok=True)
            im0 = annotator.result()
            cv2.imwrite(output_path, im0)