Spaces:
Runtime error
Runtime error
liuyuan-pal
commited on
Commit
•
36a325d
1
Parent(s):
ab287b7
update
Browse files- app.py +7 -9
- ckpt/sam_vit_h_4b8939.pth +3 -0
- requirements.txt +1 -0
- sam_utils.py +50 -0
app.py
CHANGED
@@ -9,6 +9,7 @@ import fire
|
|
9 |
from omegaconf import OmegaConf
|
10 |
|
11 |
from ldm.util import add_margin, instantiate_from_config
|
|
|
12 |
|
13 |
_TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
|
14 |
_DESCRIPTION = '''
|
@@ -31,12 +32,6 @@ _USER_GUIDE3 = "Generated multiview images are shown below!"
|
|
31 |
|
32 |
deployed = True
|
33 |
|
34 |
-
def mask_prediction(mask_predictor, image_in: Image.Image):
|
35 |
-
if image_in.mode=='RGBA':
|
36 |
-
return image_in
|
37 |
-
else:
|
38 |
-
raise NotImplementedError
|
39 |
-
|
40 |
def resize_inputs(image_input, crop_size):
|
41 |
alpha_np = np.asarray(image_input)[:, :, 3]
|
42 |
coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
|
@@ -58,6 +53,8 @@ def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, el
|
|
58 |
# prepare data
|
59 |
image_input = np.asarray(image_input)
|
60 |
image_input = image_input.astype(np.float32) / 255.0
|
|
|
|
|
61 |
image_input = image_input[:, :, :3] * 2.0 - 1.0
|
62 |
image_input = torch.from_numpy(image_input.astype(np.float32))
|
63 |
elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
|
@@ -103,7 +100,8 @@ def run_demo():
|
|
103 |
model = None
|
104 |
|
105 |
# init sam model
|
106 |
-
mask_predictor =
|
|
|
107 |
|
108 |
# with open('instructions_12345.md', 'r') as f:
|
109 |
# article = f.read()
|
@@ -144,7 +142,7 @@ def run_demo():
|
|
144 |
fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
|
145 |
|
146 |
with gr.Column(scale=1):
|
147 |
-
input_block = gr.Image(type='pil', image_mode='
|
148 |
elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
|
149 |
cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
|
150 |
# sample_num = gr.Slider(1, 2, 2, step=1, label='Sample Num', interactive=True, info='How many instance (16 images per instance)')
|
@@ -156,7 +154,7 @@ def run_demo():
|
|
156 |
output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
|
157 |
|
158 |
update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
|
159 |
-
image_block.change(fn=
|
160 |
.success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
|
161 |
|
162 |
crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
|
|
|
9 |
from omegaconf import OmegaConf
|
10 |
|
11 |
from ldm.util import add_margin, instantiate_from_config
|
12 |
+
from sam_utils import sam_init, sam_out_nosave
|
13 |
|
14 |
_TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
|
15 |
_DESCRIPTION = '''
|
|
|
32 |
|
33 |
deployed = True
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def resize_inputs(image_input, crop_size):
|
36 |
alpha_np = np.asarray(image_input)[:, :, 3]
|
37 |
coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
|
|
|
53 |
# prepare data
|
54 |
image_input = np.asarray(image_input)
|
55 |
image_input = image_input.astype(np.float32) / 255.0
|
56 |
+
alpha_values = image_input[:,:, 3:]
|
57 |
+
image_input[:, :, :3] = alpha_values * image_input[:,:, :3] + 1 - alpha_values # white background
|
58 |
image_input = image_input[:, :, :3] * 2.0 - 1.0
|
59 |
image_input = torch.from_numpy(image_input.astype(np.float32))
|
60 |
elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
|
|
|
100 |
model = None
|
101 |
|
102 |
# init sam model
|
103 |
+
mask_predictor = sam_init()
|
104 |
+
mask_predict_fn = lambda x: sam_out_nosave(mask_predictor, x)
|
105 |
|
106 |
# with open('instructions_12345.md', 'r') as f:
|
107 |
# article = f.read()
|
|
|
142 |
fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
|
143 |
|
144 |
with gr.Column(scale=1):
|
145 |
+
input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
|
146 |
elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
|
147 |
cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
|
148 |
# sample_num = gr.Slider(1, 2, 2, step=1, label='Sample Num', interactive=True, info='How many instance (16 images per instance)')
|
|
|
154 |
output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
|
155 |
|
156 |
update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
|
157 |
+
image_block.change(fn=mask_predict_fn, inputs=[image_block], outputs=[sam_block], queue=False)\
|
158 |
.success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
|
159 |
|
160 |
crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
|
ckpt/sam_vit_h_4b8939.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
+
size 2564550879
|
requirements.txt
CHANGED
@@ -20,4 +20,5 @@ easydict
|
|
20 |
nerfacc
|
21 |
imageio-ffmpeg==0.4.7
|
22 |
fire
|
|
|
23 |
git+https://github.com/openai/CLIP.git
|
|
|
20 |
nerfacc
|
21 |
imageio-ffmpeg==0.4.7
|
22 |
fire
|
23 |
+
segment_anything
|
24 |
git+https://github.com/openai/CLIP.git
|
sam_utils.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import time
|
6 |
+
|
7 |
+
from segment_anything import sam_model_registry, SamPredictor
|
8 |
+
|
9 |
+
def sam_init(device_id=0):
|
10 |
+
sam_checkpoint = os.path.join(os.path.dirname(__file__), "ckpt/sam_vit_h_4b8939.pth")
|
11 |
+
model_type = "vit_h"
|
12 |
+
|
13 |
+
device = "cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu"
|
14 |
+
|
15 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
|
16 |
+
predictor = SamPredictor(sam)
|
17 |
+
return predictor
|
18 |
+
|
19 |
+
def sam_out_nosave(predictor, input_image, bbox_sliders=(0,0,255,255)):
|
20 |
+
bbox = np.array(bbox_sliders)
|
21 |
+
image = np.asarray(input_image)
|
22 |
+
|
23 |
+
start_time = time.time()
|
24 |
+
predictor.set_image(image)
|
25 |
+
|
26 |
+
h, w, _ = image.shape
|
27 |
+
input_point = np.array([[h//2, w//2]])
|
28 |
+
input_label = np.array([1])
|
29 |
+
|
30 |
+
masks, scores, logits = predictor.predict(
|
31 |
+
point_coords=input_point,
|
32 |
+
point_labels=input_label,
|
33 |
+
multimask_output=True,
|
34 |
+
)
|
35 |
+
|
36 |
+
masks_bbox, scores_bbox, logits_bbox = predictor.predict(
|
37 |
+
box=bbox,
|
38 |
+
multimask_output=True
|
39 |
+
)
|
40 |
+
|
41 |
+
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
42 |
+
opt_idx = np.argmax(scores)
|
43 |
+
mask = masks[opt_idx]
|
44 |
+
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
45 |
+
out_image[:, :, :3] = image
|
46 |
+
out_image_bbox = out_image.copy()
|
47 |
+
out_image[:, :, 3] = mask.astype(np.uint8) * 255
|
48 |
+
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
|
49 |
+
torch.cuda.empty_cache()
|
50 |
+
return Image.fromarray(out_image_bbox, mode='RGBA')
|