Spaces:
Sleeping
Sleeping
File size: 2,307 Bytes
8e0b903 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import os
import torch
from pathlib import Path
from models.common import DetectMultiBackend
from utils.dataloaders import LoadImages
from utils.general import (non_max_suppression, scale_boxes, check_img_size)
from utils.plots import Annotator, colors
from utils.torch_utils import select_device
import cv2
def predict_image(image_path, weights=r"yolov9/yolov9_vinbigData.pt", conf_thres=0.25, iou_thres=0.45, output_dir='pages/output_yolov9', device='cpu'):
# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device)
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size((640, 640), s=stride) # Inference size
dataset = LoadImages(image_path, img_size=imgsz, stride=stride, auto=pt)
model.warmup(imgsz=(1, 3, *imgsz)) # Warmup model
for path, im, im0s, _, _ in dataset:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # Expand for batch dim
# Inference
pred = model(im)
# Nếu `pred` là một danh sách, lấy phần tử đầu tiên
if isinstance(pred, list):
pred = pred[0]
# Thực hiện NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, max_det=1000)
# Process predictions
for i, det in enumerate(pred): # Per image
im0 = im0s.copy()
annotator = Annotator(im0, line_width=3, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# Draw bounding boxes and labels on image
for *xyxy, conf, cls in reversed(det):
label = f'{names[int(cls)]} {conf:.2f}'
annotator.box_label(xyxy, label, color=colors(int(cls), True))
# Save or display results
output_path = os.path.join(output_dir, Path(path).name)
os.makedirs(output_dir, exist_ok=True)
im0 = annotator.result()
cv2.imwrite(output_path, im0) |