taroii's picture
Update app.py
46e65a1
raw
history blame
1.36 kB
import gradio as gr
import torch
from transformers import DetrForObjectDetection, DetrImageProcessor, AutoModel
import supervision as sv
from supervision.detection.annotate import BoxAnnotator
from supervision.utils.notebook import plot_image
og_model = 'facebook/detr-resnet-50'
image_processor = DetrImageProcessor.from_pretrained(og_model)
model = AutoModel.from_pretrained("taroii/notfinetuned-detr-50")
def query(image):
with torch.no_grad():
# load image and predict
inputs = image_processor(images=image, return_tensors='pt')
outputs = model(**inputs)
# post-process
target_sizes = torch.tensor([image.shape[:2]])
results = image_processor.post_process_object_detection(
outputs=outputs,
threshold=CONFIDENCE_TRESHOLD,
target_sizes=target_sizes
)[0]
# annotate
detections = sv.Detections.from_transformers(transformers_results=results).with_nms(threshold=0.5)
labels = [f"{id2label[class_id]} {confidence:.2f}" for _, confidence, class_id, _ in detections]
frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
print('detections')
%matplotlib inline
plot_image(frame, (16, 16))
return labels, frame
gr.Interface.load("models/taroii/notfinetuned-detr-50").launch()