Spaces:
Running
on
Zero
Running
on
Zero
bugfix
Browse files
app.py
CHANGED
@@ -27,11 +27,34 @@ if torch.cuda.get_device_properties(0).major >= 8:
|
|
27 |
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
|
28 |
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
@spaces.GPU(duration=20)
|
32 |
@torch.inference_mode()
|
33 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
34 |
-
def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False) -> Optional[Image.Image]:
|
35 |
if not image_input:
|
36 |
gr.Info("Please upload an image.")
|
37 |
return None
|
@@ -40,62 +63,69 @@ def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=
|
|
40 |
return None
|
41 |
|
42 |
if image_url:
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
48 |
# start to parse prompt
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
64 |
images = []
|
65 |
if return_rectangles:
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
78 |
else:
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
99 |
|
100 |
return [images, json_result]
|
101 |
|
@@ -115,13 +145,12 @@ with gr.Blocks() as demo:
|
|
115 |
submit_button = gr.Button(value='Submit', variant='primary')
|
116 |
with gr.Column():
|
117 |
image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
|
118 |
-
json_result = gr.Code(label="JSON Result", language="json")
|
119 |
|
120 |
-
print(image, image_url, task_prompt, text_prompt, image_gallery)
|
121 |
submit_button.click(
|
122 |
fn=process_image,
|
123 |
inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles],
|
124 |
-
outputs=[image_gallery
|
125 |
show_api=False
|
126 |
)
|
127 |
|
|
|
27 |
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
|
28 |
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
|
29 |
|
30 |
+
class calculateDuration:
|
31 |
+
def __init__(self, activity_name=""):
|
32 |
+
self.activity_name = activity_name
|
33 |
+
|
34 |
+
def __enter__(self):
|
35 |
+
self.start_time = time.time()
|
36 |
+
self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
|
37 |
+
print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
|
38 |
+
return self
|
39 |
+
|
40 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
41 |
+
self.end_time = time.time()
|
42 |
+
self.elapsed_time = self.end_time - self.start_time
|
43 |
+
self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
|
44 |
+
|
45 |
+
if self.activity_name:
|
46 |
+
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
|
47 |
+
else:
|
48 |
+
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
49 |
+
|
50 |
+
print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")
|
51 |
+
|
52 |
+
|
53 |
|
54 |
@spaces.GPU(duration=20)
|
55 |
@torch.inference_mode()
|
56 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
57 |
+
def process_image(image_input, image_url, task_prompt, text_prompt=None, dilate=0, merge_masks=False, return_rectangles=False, progress=gr.Progress(track_tqdm=True)) -> Optional[Image.Image]:
|
58 |
if not image_input:
|
59 |
gr.Info("Please upload an image.")
|
60 |
return None
|
|
|
63 |
return None
|
64 |
|
65 |
if image_url:
|
66 |
+
with calculateDuration("Download Image"):
|
67 |
+
print("start to fetch image from url", image_url)
|
68 |
+
response = requests.get(image_url)
|
69 |
+
response.raise_for_status()
|
70 |
+
image_input = PIL.Image.open(BytesIO(response.content))
|
71 |
+
print("fetch image success")
|
72 |
# start to parse prompt
|
73 |
+
with calculateDuration("run_florence_inference"):
|
74 |
+
_, result = run_florence_inference(
|
75 |
+
model=FLORENCE_MODEL,
|
76 |
+
processor=FLORENCE_PROCESSOR,
|
77 |
+
device=DEVICE,
|
78 |
+
image=image_input,
|
79 |
+
task=task_prompt,
|
80 |
+
text=text_prompt
|
81 |
+
)
|
82 |
+
with calculateDuration("sv.Detections"):
|
83 |
+
# start to dectect
|
84 |
+
detections = sv.Detections.from_lmm(
|
85 |
+
lmm=sv.LMM.FLORENCE_2,
|
86 |
+
result=result,
|
87 |
+
resolution_wh=image_input.size
|
88 |
+
)
|
89 |
+
# json_result = json.dumps([])
|
90 |
+
# print(detections)
|
91 |
images = []
|
92 |
if return_rectangles:
|
93 |
+
with calculateDuration("generate rectangle mask"):
|
94 |
+
# create mask in rectangle
|
95 |
+
(image_width, image_height) = image_input.size
|
96 |
+
bboxes = detections.xyxy
|
97 |
+
merge_mask_image = np.zeros((image_height, image_width), dtype=np.uint8)
|
98 |
+
|
99 |
+
for bbox in bboxes:
|
100 |
+
x1, y1, x2, y2 = map(int, bbox)
|
101 |
+
cv2.rectangle(merge_mask_image, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
|
102 |
+
clip_mask = np.zeros((image_height, image_width), dtype=np.uint8)
|
103 |
+
cv2.rectangle(clip_mask, (x1, y1), (x2, y2), 255, thickness=cv2.FILLED)
|
104 |
+
images.append(clip_mask)
|
105 |
+
if merge_masks:
|
106 |
+
images = [merge_mask_image] + images
|
107 |
else:
|
108 |
+
with calculateDuration("generate segmenet mask"):
|
109 |
+
# using sam generate segments images
|
110 |
+
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
|
111 |
+
if len(detections) == 0:
|
112 |
+
gr.Info("No objects detected.")
|
113 |
+
return None
|
114 |
+
print("mask generated:", len(detections.mask))
|
115 |
+
kernel_size = dilate
|
116 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
117 |
+
|
118 |
+
for i in range(len(detections.mask)):
|
119 |
+
mask = detections.mask[i].astype(np.uint8) * 255
|
120 |
+
if dilate > 0:
|
121 |
+
mask = cv2.dilate(mask, kernel, iterations=1)
|
122 |
+
images.append(mask)
|
123 |
+
|
124 |
+
if merge_masks:
|
125 |
+
merged_mask = np.zeros_like(images[0], dtype=np.uint8)
|
126 |
+
for mask in images:
|
127 |
+
merged_mask = cv2.bitwise_or(merged_mask, mask)
|
128 |
+
images = [merged_mask]
|
129 |
|
130 |
return [images, json_result]
|
131 |
|
|
|
145 |
submit_button = gr.Button(value='Submit', variant='primary')
|
146 |
with gr.Column():
|
147 |
image_gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
|
148 |
+
# json_result = gr.Code(label="JSON Result", language="json")
|
149 |
|
|
|
150 |
submit_button.click(
|
151 |
fn=process_image,
|
152 |
inputs=[image, image_url, task_prompt, text_prompt, dilate, merge_masks, return_rectangles],
|
153 |
+
outputs=[image_gallery],
|
154 |
show_api=False
|
155 |
)
|
156 |
|