Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,63 @@
|
|
1 |
-
transformers
|
2 |
-
torch
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline, SegGptImageProcessor, SegGptForImageSegmentation
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
depth_anything = pipeline(task = "depth-estimation", model="nielsr/depth-anything-small", device=0)
|
8 |
+
checkpoint = "BAAI/seggpt-vit-large"
|
9 |
+
image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
|
10 |
+
model = SegGptForImageSegmentation.from_pretrained(checkpoint)
|
11 |
+
|
12 |
+
def infer_seggpt(image_input, image_prompt, mask_prompt):
|
13 |
+
num_labels = 100
|
14 |
+
inputs = image_processor(
|
15 |
+
images=image_input,
|
16 |
+
prompt_images=image_prompt,
|
17 |
+
prompt_masks=mask_prompt,
|
18 |
+
return_tensors="pt",
|
19 |
+
num_labels=num_labels
|
20 |
+
)
|
21 |
+
with torch.no_grad():
|
22 |
+
outputs = model(**inputs)
|
23 |
+
|
24 |
+
target_sizes = [image_input.shape[:2]]
|
25 |
+
|
26 |
+
mask = image_processor.post_process_semantic_segmentation(outputs, target_sizes, num_labels=num_labels)[0]
|
27 |
+
palette = image_processor.get_palette(num_labels)
|
28 |
+
fig, ax = plt.subplots()
|
29 |
+
plt.gca().get_xaxis().get_major_formatter().set_useOffset(False)
|
30 |
+
mask_rgb = image_processor.mask_to_rgb(mask.cpu().numpy(), palette, data_format="channels_last")
|
31 |
+
print(mask_rgb.shape, image_input.shape)
|
32 |
+
ax.imshow(Image.fromarray(image_input))
|
33 |
+
ax.imshow(mask_rgb, cmap='viridis', alpha=0.6)
|
34 |
+
|
35 |
+
ax.axis("off")
|
36 |
+
ax.margins(0)
|
37 |
+
plt.show()
|
38 |
+
plt.savefig("masks.png", bbox_inches='tight', pad_inches=0)
|
39 |
+
return "masks.png"
|
40 |
+
|
41 |
+
def infer(image_input, image_prompt, mask_prompt):
|
42 |
+
sg_masks = []
|
43 |
+
mask_prompt = depth_anything(image_prompt)["depth"].convert("RGB")
|
44 |
+
|
45 |
+
sg_mask = infer_seggpt(np.asarray(image_input), np.asarray(image_prompt),
|
46 |
+
np.asarray(mask_prompt))
|
47 |
+
|
48 |
+
return sg_mask
|
49 |
+
|
50 |
+
import gradio as gr
|
51 |
+
|
52 |
+
demo = gr.Interface(
|
53 |
+
infer,
|
54 |
+
inputs=[gr.Image(type="pil", label="Image Input"), gr.Image(type="pil", label="Image Prompt")],
|
55 |
+
outputs=[gr.Image(type="filepath", label="Mask Output")],
|
56 |
+
#gr.Image(type="numpy", label="Output Mask")],
|
57 |
+
title="SegGPT 🤝 Depth Anything: Speak to Segmentation in Image",
|
58 |
+
description="SegGPT is a one-shot image segmentation model where one could ask model what to segment through uploading an example image and an example mask, and ask to segment the same thing in another image. In this demo, we have combined SegGPT and Depth Anything to automatically generate the mask for most outstanding object and segment the same thing in another image for you. You can see how it works by trying the example.",
|
59 |
+
|
60 |
+
examples=[
|
61 |
+
["./cats.png", "./cat.png"],
|
62 |
+
])
|
63 |
+
demo.launch(debug=True)
|