Spaces:
Runtime error
Runtime error
import numpy as np | |
import time | |
from pathlib import Path | |
import torch | |
import imageio | |
from my.utils import tqdm | |
from my.utils.seed import seed_everything | |
from run_img_sampling import SD, StableDiffusion | |
from misc import torch_samps_to_imgs | |
from pose import PoseConfig | |
from run_nerf import VoxConfig | |
from voxnerf.utils import every | |
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis | |
from run_sjc import render_one_view, tsr_stats | |
from highres_final_vis import highres_render_one_view | |
import gradio as gr | |
import gc | |
import os | |
device_glb = torch.device("cuda") | |
def vis_routine(y, depth): | |
pane = nerf_vis(y, depth, final_H=256) | |
im = torch_samps_to_imgs(y)[0] | |
depth = depth.cpu().numpy() | |
return pane, im, depth | |
css = ''' | |
.instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important} | |
.arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important} | |
#component-4, #component-3, #component-10{min-height: 0} | |
.duplicate-button img{margin: 0} | |
''' | |
with gr.Blocks(css=css) as demo: | |
# title | |
gr.Markdown('# [Score Jacobian Chaining](https://github.com/pals-ttic/sjc): Lifting Pretrained 2D Diffusion Models for 3D Generation') | |
gr.HTML(f''' | |
<div class="gr-prose" style="max-width: 80%"> | |
<h2>Attention - This Space takes over 30min to run!</h2> | |
<p>If the Queue is too long you can run locally or duplicate the Space and run it on your own profile using a (paid) private T4 GPU for training. As each T4 costs US$0.60/h, it should cost < US$1 to train most models using default settings! <a style='display:inline-block' href='https://huggingface.co/spaces/MirageML/sjc?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a></p> | |
</div> | |
''') | |
# inputs | |
prompt = gr.Textbox(label="Prompt", max_lines=1, value="A high quality photo of a delicious burger") | |
iters = gr.Slider(label="Iters", minimum=100, maximum=20000, value=10000, step=100) | |
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True) | |
button = gr.Button('Generate') | |
# outputs | |
image = gr.Image(label="image", visible=True) | |
# depth = gr.Image(label="depth", visible=True) | |
video = gr.Video(label="video", visible=False) | |
logs = gr.Textbox(label="logging") | |
def submit(prompt, iters, seed): | |
start_t = time.time() | |
seed_everything(seed) | |
# cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True} | |
pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5) | |
poser = pose.make() | |
sd_model = SD(variant='v1', v2_highres=False, prompt=prompt, scale=100.0, precision='autocast') | |
model = sd_model.make() | |
vox = VoxConfig( | |
model_type="V_SD", grid_size=100, density_shift=-1.0, c=4, | |
blend_bg_texture=True, bg_texture_hw=4, | |
bbox_len=1.0) | |
vox = vox.make() | |
lr = 0.05 | |
n_steps = iters | |
emptiness_scale = 10 | |
emptiness_weight = 10000 | |
emptiness_step = 0.5 | |
emptiness_multiplier = 20.0 | |
depth_weight = 0 | |
var_red = True | |
assert model.samps_centered() | |
_, target_H, target_W = model.data_shape() | |
bs = 1 | |
aabb = vox.aabb.T.cpu().numpy() | |
vox = vox.to(device_glb) | |
opt = torch.optim.Adamax(vox.opt_params(), lr=lr) | |
H, W = poser.H, poser.W | |
Ks, poses, prompt_prefixes = poser.sample_train(n_steps) | |
ts = model.us[30:-10] | |
same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1) | |
with tqdm(total=n_steps) as pbar: | |
for i in range(n_steps): | |
p = f"{prompt_prefixes[i]} {model.prompt}" | |
score_conds = model.prompts_emb([p]) | |
y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True) | |
if isinstance(model, StableDiffusion): | |
pass | |
else: | |
y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear') | |
opt.zero_grad() | |
with torch.no_grad(): | |
chosen_σs = np.random.choice(ts, bs, replace=False) | |
chosen_σs = chosen_σs.reshape(-1, 1, 1, 1) | |
chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32) | |
# chosen_σs = us[i] | |
noise = torch.randn(bs, *y.shape[1:], device=model.device) | |
zs = y + chosen_σs * noise | |
Ds = model.denoise(zs, chosen_σs, **score_conds) | |
if var_red: | |
grad = (Ds - y) / chosen_σs | |
else: | |
grad = (Ds - zs) / chosen_σs | |
grad = grad.mean(0, keepdim=True) | |
y.backward(-grad, retain_graph=True) | |
if depth_weight > 0: | |
center_depth = depth[7:-7, 7:-7] | |
border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50) | |
center_depth_mean = center_depth.mean() | |
depth_diff = center_depth_mean - border_depth_mean | |
depth_loss = - torch.log(depth_diff + 1e-12) | |
depth_loss = depth_weight * depth_loss | |
depth_loss.backward(retain_graph=True) | |
emptiness_loss = torch.log(1 + emptiness_scale * ws).mean() | |
emptiness_loss = emptiness_weight * emptiness_loss | |
if emptiness_step * n_steps <= i: | |
emptiness_loss *= emptiness_multiplier | |
emptiness_loss.backward() | |
opt.step() | |
# metric.put_scalars() | |
with torch.no_grad(): | |
if isinstance(model, StableDiffusion): | |
y = model.decode(y) | |
pane, img, depth = vis_routine(y, depth) | |
yield { | |
image: gr.update(value=img, visible=True), | |
video: gr.update(visible=False), | |
logs: f"Steps: {i}/{n_steps}: \n" + str(tsr_stats(y)), | |
} | |
# TODO: Output pane, img and depth to Gradio | |
pbar.update() | |
pbar.set_description(p) | |
# TODO: Save Checkpoint | |
with torch.no_grad(): | |
n_frames=200 | |
factor=4 | |
ckpt = vox.state_dict() | |
H, W = poser.H, poser.W | |
vox.eval() | |
K, poses = poser.sample_test(n_frames) | |
del n_frames | |
poses = poses[60:] # skip the full overhead view; not interesting | |
aabb = vox.aabb.T.cpu().numpy() | |
vox = vox.to(device_glb) | |
num_imgs = len(poses) | |
all_images = [] | |
for i in (pbar := tqdm(range(num_imgs))): | |
pose = poses[i] | |
y, depth = highres_render_one_view(vox, aabb, H, W, K, pose, f=factor) | |
if isinstance(model, StableDiffusion): | |
y = model.decode(y) | |
pane, img, depth = vis_routine(y, depth) | |
# Save img to output | |
all_images.append(img) | |
yield { | |
image: gr.update(value=img, visible=True), | |
video: gr.update(visible=False), | |
logs: str(tsr_stats(y)), | |
} | |
output_video = "/tmp/tmp.mp4" | |
imageio.mimwrite(output_video, all_images, quality=8, fps=10) | |
end_t = time.time() | |
yield { | |
image: gr.update(value=img, visible=False), | |
video: gr.update(value=output_video, visible=True), | |
logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!", | |
} | |
button.click( | |
submit, | |
[prompt, iters, seed], | |
[image, video, logs] | |
) | |
# concurrency_count: only allow ONE running progress, else GPU will OOM. | |
demo.queue(concurrency_count=1) | |
demo.launch() |