dalle-6D / app.py
neverix
A new hope
cd5957e
import os
os.environ["PYOPENGL_PLATFORM"] = "egl"
from tqdm.auto import trange
from PIL import Image
import gradio as gr
import numpy as np
import pyrender
import trimesh
import scipy
import torch
import cv2
class MidasDepth(object):
def __init__(self, model_type="DPT_Large", device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
self.device = device
self.midas = torch.hub.load("intel-isl/MiDaS", model_type).to(self.device).eval().requires_grad_(False)
self.transform = torch.hub.load("intel-isl/MiDaS", "transforms").dpt_transform
def get_depth(self, image):
if not isinstance(image, np.ndarray):
image = np.asarray(image)
if (image > 1).any():
image = image.astype("float64") / 255.
with torch.inference_mode():
batch = self.transform(image[..., :3]).to(self.device)
prediction = self.midas(batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
return prediction.detach().cpu().numpy()
def process_depth(dep):
depth = dep.copy()
depth -= depth.min()
depth /= depth.max()
depth = 1 / np.clip(depth, 0.2, 1)
blurred = cv2.medianBlur(depth, 5) # 9 not available because it requires 8-bit
maxd = cv2.dilate(blurred, np.ones((3, 3)))
mind = cv2.erode(blurred, np.ones((3, 3)))
edges = maxd - mind
threshold = .05 # Better to have false positives
pick_edges = edges > threshold
return depth, pick_edges
def make_mesh(pic, depth, pick_edges):
faces = []
im = np.asarray(pic)
grid = np.mgrid[0:im.shape[0], 0:im.shape[1]].transpose(1, 2, 0
).reshape(-1, 2)[..., ::-1]
flat_grid = grid[:, 1] * im.shape[1] + grid[:, 0]
positions = np.concatenate(((grid - np.array(im.shape[:-1])[np.newaxis, :]
/ 2) / im.shape[1] * 2,
depth.flatten()[flat_grid][..., np.newaxis]),
axis=-1)
positions[:, :-1] *= positions[:, -1:]
positions[:, 1] *= -1
colors = im.reshape(-1, 3)[flat_grid]
c = lambda x, y: y * im.shape[1] + x
for y in trange(im.shape[0]):
for x in range(im.shape[1]):
if pick_edges[y, x]:
continue
if x > 0 and y > 0:
faces.append([c(x, y), c(x, y - 1), c(x - 1, y)])
if x < im.shape[1] - 1 and y < im.shape[0] - 1:
faces.append([c(x, y), c(x, y + 1), c(x + 1, y)])
face_colors = np.asarray([colors[i[0]] for i in faces])
tri_mesh = trimesh.Trimesh(vertices=positions * np.array([1.0, 1.0, -1.0]),
faces=faces,
face_colors=np.concatenate((face_colors,
face_colors[..., -1:]
* 0 + 255),
axis=-1).reshape(-1, 4),
smooth=False,
)
return tri_mesh
def args_to_mat(tx, ty, tz, rx, ry, rz):
mat = np.eye(4)
mat[:3, :3] = scipy.spatial.transform.Rotation.from_euler("XYZ", (rx, ry, rz)).as_matrix()
mat[:3, 3] = tx, ty, tz
return mat
def render(mesh, mat):
mesh = pyrender.mesh.Mesh.from_trimesh(mesh, smooth=False)
scene = pyrender.Scene(ambient_light=np.array([1.0, 1.0, 1.0]))
camera = pyrender.PerspectiveCamera(yfov=np.pi / 2, aspectRatio=1.0)
scene.add(camera, pose=mat)
scene.add(mesh)
r = pyrender.OffscreenRenderer(1024, 1024)
rgb, d = r.render(scene, pyrender.constants.RenderFlags.FLAT)
mask = d == 0
rgb = rgb.copy()
rgb[mask] = 0
res = Image.fromarray(np.concatenate((rgb, ((mask[..., np.newaxis]) == 0).astype(np.uint8) * 255), axis=-1))
return res
def main():
from pyvirtualdisplay import Display
disp = Display()
disp.start()
midas = MidasDepth()
def fn(pic, *args):
depth, pick_edges = process_depth(midas.get_depth(pic))
mesh = make_mesh(pic, depth, pick_edges)
frame = render(mesh, args_to_mat(*args))
return np.asarray(frame), (255 / np.asarray(depth)).astype(np.uint8), None
interface = gr.Interface(fn=fn, inputs=[
gr.inputs.Image(label="src", type="numpy"),
gr.inputs.Number(label="tx", default=0.0),
gr.inputs.Number(label="ty", default=0.0),
gr.inputs.Number(label="tz", default=0.0),
gr.inputs.Number(label="rx", default=0.0),
gr.inputs.Number(label="ry", default=0.0),
gr.inputs.Number(label="rz", default=0.0)
], outputs=[
gr.outputs.Image(type="numpy", label="result"),
gr.outputs.Image(type="numpy", label="depth"),
gr.outputs.Video(label="interpolated")
], title="DALL·E 6D", description="Lift DALL·E 2 (or any other model) into 3D!")
gr.TabbedInterface([interface], ["Warp 3D images"]).launch()
disp.stop()
if __name__ == '__main__':
main()