Guy2's picture
Update app.py
a00dc8b
raw
history blame
1.52 kB
import gradio as gr
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
import supervision as sv
import json
import requests
from PIL import Image
import numpy as np
image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("Guy2/AirportSec-150epoch")
id2label = {0: 'dangerous-items', 1: 'Gun', 2: 'Knife', 3: 'Pliers', 4: 'Scissors', 5: 'Wrench'}
def anylize(url):
image = Image.open(requests.get(url, stream=True).raw)
# return image
with torch.no_grad():
inputs = image_processor(images=image, return_tensors='pt')
outputs = model(**inputs)
image = np.array(image)
target_sizes = torch.tensor([image.shape[:2]])
results = image_processor.post_process_object_detection(
outputs=outputs,
threshold=0.8,
target_sizes=target_sizes
)[0]
# annotate
detections = sv.Detections.from_transformers(transformers_results=results).with_nms(threshold=0.5)
# labels = [str([list(xyxy), confidence, id2label[class_id]]) for xyxy, _, confidence, class_id, _ in detections]
labels = [[list(xyxy), confidence, id2label[class_id]] for xyxy, _, confidence, class_id, _ in detections]
print(labels)
return str(labels)
# json_list = json.dumps(labels)
# return json_list
gr.Interface(fn = anylize, inputs="text", outputs="text").launch()
# gr.Interface(fn = anylize, inputs="text", outputs="image").launch()