Ininitial commit
Browse files
app.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from depth import MidasDepth
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
depth_estimator = MidasDepth()
|
8 |
+
|
9 |
+
|
10 |
+
def get_depth(rgb):
|
11 |
+
depth = depth_estimator.get_depth(rgb)
|
12 |
+
|
13 |
+
return rgb, (depth.clip(0, 64) * 1024).astype("uint16")
|
14 |
+
|
15 |
+
|
16 |
+
starter = gr.Interface(fn=get_depth, inputs=[
|
17 |
+
gr.components.Image(label="rgb", type="pil"),
|
18 |
+
], outputs=[
|
19 |
+
gr.components.Image(type="pil", label="image"),
|
20 |
+
gr.components.Image(type="numpy", label="depth"),
|
21 |
+
|
22 |
+
])
|
23 |
+
|
24 |
+
gr.Interface(get_depth).launch(share=True)
|
depth.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
from tqdm.auto import trange
|
3 |
+
from PIL import ImageOps
|
4 |
+
from PIL import Image
|
5 |
+
from torch import nn
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
class MidasDepth(nn.Module):
|
12 |
+
def __init__(self, model_type="DPT_Large",
|
13 |
+
device=torch.device(
|
14 |
+
"cuda" if torch.cuda.is_available() else "cpu"),
|
15 |
+
is_inpainting=False):
|
16 |
+
super().__init__()
|
17 |
+
self.device = device
|
18 |
+
if self.device.type == "mps":
|
19 |
+
self.device = torch.device("cpu")
|
20 |
+
self.model = torch.hub.load(
|
21 |
+
"intel-isl/MiDaS", model_type).to(self.device).eval().requires_grad_(False)
|
22 |
+
self.transform = torch.hub.load(
|
23 |
+
"intel-isl/MiDaS", "transforms").dpt_transform
|
24 |
+
|
25 |
+
@torch.no_grad()
|
26 |
+
def forward(self, image):
|
27 |
+
if torch.is_tensor(image):
|
28 |
+
image = image.cpu().detach()
|
29 |
+
if not isinstance(image, np.ndarray):
|
30 |
+
image = np.asarray(image)
|
31 |
+
image = image.squeeze()
|
32 |
+
batch = self.transform(image).to(self.device)
|
33 |
+
prediction = self.model(batch)
|
34 |
+
prediction = torch.nn.functional.interpolate(
|
35 |
+
prediction.unsqueeze(1),
|
36 |
+
size=image.shape[-3:-1],
|
37 |
+
mode="bicubic",
|
38 |
+
align_corners=False,
|
39 |
+
)[:, 0]
|
40 |
+
# prediction = prediction - prediction.min() + 1.5
|
41 |
+
# prediction = 20 / prediction
|
42 |
+
return prediction # .squeeze()
|
43 |
+
|
44 |
+
@torch.no_grad()
|
45 |
+
def get_depth(self, img):
|
46 |
+
im = torch.from_numpy(np.asarray(img)).float().to(self.device) / 255.
|
47 |
+
og_depth = self(im.unsqueeze(0) * 255.)[0]
|
48 |
+
d = og_depth
|
49 |
+
d = (d - d.min()) / (d.max() - d.min()) * (10 - 3) + 3
|
50 |
+
d = 30 / d
|
51 |
+
# d = d.max() - d
|
52 |
+
# d = d / d.max() * 15
|
53 |
+
# d = d + 1.5
|
54 |
+
return d.detach().cpu().numpy()
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
from matplotlib import pyplot as plt
|
59 |
+
plt.imshow(MidasDepth().get_depth(Image.open("horse.jpg")))
|
60 |
+
plt.show()
|
horse.jpg
ADDED
serve_modal.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
|
3 |
+
|
4 |
+
stub = modal.Stub()
|
5 |
+
|
6 |
+
|
7 |
+
@stub.function()
|
8 |
+
def estimate_depth(image):
|
9 |
+
pass
|