Spaces:
Build error
Build error
import os | |
import numpy as np | |
import torch | |
from einops import rearrange | |
from imageio import imwrite | |
from pydantic import validator | |
import imageio | |
import tempfile | |
import gradio as gr | |
from PIL import Image | |
from my.utils import ( | |
tqdm, EventStorage, HeartBeat, EarlyLoopBreak, | |
get_event_storage, get_heartbeat, read_stats | |
) | |
from my.config import BaseConf, dispatch, optional_load_config | |
from my.utils.seed import seed_everything | |
from adapt import ScoreAdapter | |
from run_img_sampling import SD | |
from misc import torch_samps_to_imgs | |
from pose import PoseConfig | |
from run_nerf import VoxConfig | |
from voxnerf.utils import every | |
from voxnerf.render import ( | |
as_torch_tsrs, rays_from_img, ray_box_intersect, render_ray_bundle | |
) | |
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis | |
from pytorch3d.renderer import PointsRasterizationSettings | |
from semantic_coding import semantic_coding, semantic_karlo, semantic_sd | |
from pc_project import point_e, render_depth_from_cloud | |
device_glb = torch.device("cuda") | |
def tsr_stats(tsr): | |
return { | |
"mean": tsr.mean().item(), | |
"std": tsr.std().item(), | |
"max": tsr.max().item(), | |
} | |
class SJC_3DFuse(BaseConf): | |
family: str = "sd" | |
sd: SD = SD( | |
variant="v1", | |
prompt="a comfortable bed", | |
scale=100.0, | |
dir="./results", | |
alpha=0.3 | |
) | |
lr: float = 0.05 | |
n_steps: int = 10000 | |
vox: VoxConfig = VoxConfig( | |
model_type="V_SD", grid_size=100, density_shift=-1.0, c=3, | |
blend_bg_texture=False , bg_texture_hw=4, | |
bbox_len=1.0 | |
) | |
pose: PoseConfig = PoseConfig(rend_hw=64, FoV=60.0, R=1.5) | |
emptiness_scale: int = 10 | |
emptiness_weight: int = 1e4 | |
emptiness_step: float = 0.5 | |
emptiness_multiplier: float = 20.0 | |
depth_weight: int = 0 | |
var_red: bool = True | |
exp_dir: str = "./results" | |
ti_step: int = 800 | |
pt_step: int = 800 | |
initial: str = "" | |
random_seed: int = 0 | |
semantic_model: str = "Karlo" | |
bg_preprocess: bool = True | |
num_initial_image: int = 4 | |
def check_vox(cls, vox_cfg, values): | |
family = values['family'] | |
if family == "sd": | |
vox_cfg.c = 4 | |
return vox_cfg | |
def run(self): | |
raise Exception("This version is for huggingface demo, which doesn't support CLI. Please visit https://github.com/KU-CVLAB/3DFuse") | |
def run_gradio(self, points, images): | |
cfgs = self.dict() | |
initial = cfgs.pop('initial') | |
exp_dir=os.path.join(cfgs.pop('exp_dir'),initial) | |
# Optimization and pivotal tuning for LoRA | |
yield gr.update(value=None), "Tuning for the LoRA layer is starting now. It will take approximately ~10 mins.", gr.update(value=None) | |
state=semantic_coding(images, cfgs,self.sd,initial) | |
self.sd.dir=state | |
# Load SD with Consistency Injection Module | |
family = cfgs.pop("family") | |
model = getattr(self, family).make() | |
print(model.prompt) | |
cfgs.pop("vox") | |
vox = self.vox.make() | |
cfgs.pop("pose") | |
poser = self.pose.make() | |
# Score distillation | |
yield from fuse_3d(**cfgs, poser=poser,model=model,vox=vox,exp_dir=exp_dir, points=points, is_gradio=True) | |
def fuse_3d( | |
poser, vox, model: ScoreAdapter, | |
lr, n_steps, emptiness_scale, emptiness_weight, emptiness_step, emptiness_multiplier, | |
depth_weight, var_red, exp_dir, points, is_gradio, **kwargs | |
): | |
del kwargs | |
if is_gradio: | |
yield gr.update(visible=True), "LoRA layers tuning has just finished. \nScore distillation has started.", gr.update(visible=True) | |
assert model.samps_centered() | |
_, target_H, target_W = model.data_shape() | |
bs = 1 | |
aabb = vox.aabb.T.cpu().numpy() | |
vox = vox.to(device_glb) | |
opt = torch.optim.Adamax(vox.opt_params(), lr=lr) | |
H, W = poser.H, poser.W | |
Ks_, poses_, prompt_prefixes_, angles_list = poser.sample_train(n_steps,device_glb) | |
ts = model.us[30:-10] | |
fuse = EarlyLoopBreak(5) | |
raster_settings = PointsRasterizationSettings( | |
image_size= 800, | |
radius = 0.02, | |
points_per_pixel = 10 | |
) | |
ts = model.us[30:-10] | |
calibration_value=0.0 | |
with tqdm(total=n_steps) as pbar: | |
# HeartBeat(pbar) as hbeat, \ | |
# EventStorage(output_dir=os.path.join(exp_dir,'3d')) as metric: | |
for i in range(len(poses_)): | |
if fuse.on_break(): | |
break | |
depth_map = render_depth_from_cloud(points, angles_list[i], raster_settings, device_glb,calibration_value) | |
y, depth, ws = render_one_view(vox, aabb, H, W, Ks_[i], poses_[i], return_w=True) | |
p = f"{prompt_prefixes_[i]} {model.prompt}" | |
score_conds = model.prompts_emb([p]) | |
score_conds['c']=score_conds['c'].repeat(bs,1,1) | |
score_conds['uc']=score_conds['uc'].repeat(bs,1,1) | |
opt.zero_grad() | |
with torch.no_grad(): | |
chosen_σs = np.random.choice(ts, bs, replace=False) | |
chosen_σs = chosen_σs.reshape(-1, 1, 1, 1) | |
chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32) | |
noise = torch.randn(bs, *y.shape[1:], device=model.device) | |
zs = y + chosen_σs * noise | |
Ds = model.denoise(zs, chosen_σs,depth_map.unsqueeze(dim=0),**score_conds) | |
if var_red: | |
grad = (Ds - y) / chosen_σs | |
else: | |
grad = (Ds - zs) / chosen_σs | |
grad = grad.mean(0, keepdim=True) | |
y.backward(-grad, retain_graph=True) | |
if depth_weight > 0: | |
center_depth = depth[7:-7, 7:-7] | |
border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50) | |
center_depth_mean = center_depth.mean() | |
depth_diff = center_depth_mean - border_depth_mean | |
depth_loss = - torch.log(depth_diff + 1e-12) | |
depth_loss = depth_weight * depth_loss | |
depth_loss.backward(retain_graph=True) | |
emptiness_loss = torch.log(1 + emptiness_scale * ws).mean() | |
emptiness_loss = emptiness_weight * emptiness_loss | |
if emptiness_step * n_steps <= i: | |
emptiness_loss *= emptiness_multiplier | |
emptiness_loss.backward() | |
opt.step() | |
# metric.put_scalars(**tsr_stats(y)) | |
if every(pbar, percent=2): | |
with torch.no_grad(): | |
y = model.decode(y) | |
# vis_routine(metric, y, depth,p,depth_map[0]) | |
if is_gradio : | |
yield torch_samps_to_imgs(y)[0], f"Progress: {pbar.n}/{pbar.total} \nAfter the generation is complete, the video results will be displayed below.", gr.update(value=None) | |
# metric.step() | |
pbar.update() | |
pbar.set_description(p) | |
# hbeat.beat() | |
# metric.put_artifact( | |
# "ckpt", ".pt","", lambda fn: torch.save(vox.state_dict(), fn) | |
# ) | |
# with EventStorage("result"): | |
out=evaluate(model, vox, poser) | |
if is_gradio: | |
yield gr.update(visible=True), f"Generation complete. Please check the video below. \nThe result files and logs are located at {exp_dir}", gr.update(value=out) | |
else : | |
yield None | |
# metric.step() | |
# hbeat.done() | |
def evaluate(score_model, vox, poser): | |
H, W = poser.H, poser.W | |
vox.eval() | |
K, poses = poser.sample_test(100) | |
fuse = EarlyLoopBreak(5) | |
# metric = get_event_storage() | |
# hbeat = get_heartbeat() | |
aabb = vox.aabb.T.cpu().numpy() | |
vox = vox.to(device_glb) | |
num_imgs = len(poses) | |
frames=[] | |
for i in (pbar := tqdm(range(num_imgs))): | |
if fuse.on_break(): | |
break | |
pose = poses[i] | |
y, depth = render_one_view(vox, aabb, H, W, K, pose) | |
y = score_model.decode(y) | |
# vis_routine(metric, y, depth,"",None) | |
y=torch_samps_to_imgs(y)[0] | |
frames.append(y) | |
# metric.step() | |
# hbeat.beat() | |
# metric.flush_history() | |
# metric.put_artifact( | |
# "video", ".mp4","", | |
# lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "img")[1]) | |
# ) | |
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
writer = imageio.get_writer(out_file.name, fps=10) | |
for img in frames: | |
writer.append_data(img) | |
writer.close() | |
# metric.step() | |
return out_file.name | |
def render_one_view(vox, aabb, H, W, K, pose, return_w=False): | |
N = H * W | |
ro, rd = rays_from_img(H, W, K, pose) | |
ro, rd, t_min, t_max = scene_box_filter_(ro, rd, aabb) | |
assert len(ro) == N, "for now all pixels must be in" | |
ro, rd, t_min, t_max = as_torch_tsrs(vox.device, ro, rd, t_min, t_max) | |
rgbs, depth, weights = render_ray_bundle(vox, ro, rd, t_min, t_max) | |
rgbs = rearrange(rgbs, "(h w) c -> 1 c h w", h=H, w=W) | |
depth = rearrange(depth, "(h w) 1 -> h w", h=H, w=W) | |
if return_w: | |
return rgbs, depth, weights | |
else: | |
return rgbs, depth | |
def scene_box_filter_(ro, rd, aabb): | |
_, t_min, t_max = ray_box_intersect(ro, rd, aabb) | |
# do not render what's behind the ray origin | |
t_min, t_max = np.maximum(t_min, 0), np.maximum(t_max, 0) | |
return ro, rd, t_min, t_max | |
def vis_routine(metric, y, depth,prompt,depth_map): | |
pane = nerf_vis(y, depth, final_H=256) | |
im = torch_samps_to_imgs(y)[0] | |
depth = depth.cpu().numpy() | |
metric.put_artifact("view", ".png","",lambda fn: imwrite(fn, pane)) | |
metric.put_artifact("img", ".png",prompt, lambda fn: imwrite(fn, im)) | |
if depth_map != None: | |
metric.put_artifact("PC_depth", ".png",prompt, lambda fn: imwrite(fn, depth_map.cpu().squeeze())) | |
metric.put_artifact("depth", ".npy","",lambda fn: np.save(fn, depth)) | |
if __name__ == "__main__": | |
dispatch(SJC_3DFuse) |