File size: 7,048 Bytes
4450790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188

import torch
import torch.nn.functional as F
from torchvision import transforms
import os
from contextlib import nullcontext

import comfy.model_management as mm
from comfy.utils import ProgressBar, load_torch_file
import folder_paths

from .depth_anything_v2.dpt import DepthAnythingV2

from contextlib import nullcontext
try:
    from accelerate import init_empty_weights
    from accelerate.utils import set_module_tensor_to_device
    is_accelerate_available = True
except:
    is_accelerate_available = False
    pass

class DownloadAndLoadDepthAnythingV2Model:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "model": (
                    [ 
                        'depth_anything_v2_vits_fp16.safetensors',
                        'depth_anything_v2_vits_fp32.safetensors',
                        'depth_anything_v2_vitb_fp16.safetensors',
                        'depth_anything_v2_vitb_fp32.safetensors',
                        'depth_anything_v2_vitl_fp16.safetensors',
                        'depth_anything_v2_vitl_fp32.safetensors',
                        'depth_anything_v2_metric_hypersim_vitl_fp32.safetensors',
                        'depth_anything_v2_metric_vkitti_vitl_fp32.safetensors'
                    ],
                    {
                    "default": 'depth_anything_v2_vitl_fp32.safetensors'
                    }),
            },
        }

    RETURN_TYPES = ("DAMODEL",)
    RETURN_NAMES = ("da_v2_model",)
    FUNCTION = "loadmodel"
    CATEGORY = "DepthAnythingV2"
    DESCRIPTION = """
Models autodownload to `ComfyUI\models\depthanything` from   
https://huggingface.co/Kijai/DepthAnythingV2-safetensors/tree/main   
   
fp16 reduces quality by a LOT, not recommended.
"""

    def loadmodel(self, model):
        device = mm.get_torch_device()
        dtype = torch.float16 if "fp16" in model else torch.float32
        model_configs = {
            'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
            'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
            'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
            #'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
        }
        custom_config = {
            'model_name': model,
        }
        if not hasattr(self, 'model') or self.model == None or custom_config != self.current_config:
            self.current_config = custom_config
            download_path = os.path.join(folder_paths.models_dir, "depthanything")
            model_path = os.path.join(download_path, model)

            if not os.path.exists(model_path):
                print(f"Downloading model to: {model_path}")
                from huggingface_hub import snapshot_download
                snapshot_download(repo_id="Kijai/DepthAnythingV2-safetensors", 
                                  allow_patterns=[f"*{model}*"],
                                  local_dir=download_path, 
                                  local_dir_use_symlinks=False)

            print(f"Loading model from: {model_path}")

            if "vitl" in model:
                encoder = "vitl"
            elif "vitb" in model:
                encoder = "vitb"
            elif "vits" in model:
                encoder = "vits"

            if "hypersim" in model:
                max_depth = 20.0
            else:
                max_depth = 80.0

            with (init_empty_weights() if is_accelerate_available else nullcontext()):
                if 'metric' in model:
                    self.model = DepthAnythingV2(**{**model_configs[encoder], 'is_metric': True, 'max_depth': max_depth})
                else:
                    self.model = DepthAnythingV2(**model_configs[encoder])
            
            state_dict = load_torch_file(model_path)
            if is_accelerate_available:
                for key in state_dict:
                    set_module_tensor_to_device(self.model, key, device=device, dtype=dtype, value=state_dict[key])
            else:
                self.model.load_state_dict(state_dict)

            self.model.eval()
            da_model = {
                "model": self.model,
                "dtype": dtype,
                "is_metric": self.model.is_metric
            }
           
        return (da_model,)
    
class DepthAnything_V2:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "da_model": ("DAMODEL", ),
            "images": ("IMAGE", ),
            },
        }
    
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES =("image",)
    FUNCTION = "process"
    CATEGORY = "DepthAnythingV2"
    DESCRIPTION = """
https://depth-anything-v2.github.io
"""

    def process(self, da_model, images):
        device = mm.get_torch_device()
        offload_device = mm.unet_offload_device()
        model = da_model['model']
        dtype=da_model['dtype']
        
        B, H, W, C = images.shape

        #images = images.to(device)
        images = images.permute(0, 3, 1, 2)

        orig_H, orig_W = H, W
        if W % 14 != 0:
            W = W - (W % 14)
        if H % 14 != 0:
            H = H - (H % 14)
        if orig_H % 14 != 0 or orig_W % 14 != 0:
            images = F.interpolate(images, size=(H, W), mode="bilinear")
        
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        normalized_images = normalize(images)
        pbar = ProgressBar(B)
        out = []
        model.to(device)
        autocast_condition = (dtype != torch.float32) and not mm.is_device_mps(device)
        with torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
            for img in normalized_images:
                depth = model(img.unsqueeze(0).to(device))
                depth = (depth - depth.min()) / (depth.max() - depth.min())
                out.append(depth.cpu())
                pbar.update(1)
            model.to(offload_device)
            depth_out = torch.cat(out, dim=0)
            depth_out = depth_out.unsqueeze(-1).repeat(1, 1, 1, 3).cpu().float()
        
        final_H = (orig_H // 2) * 2
        final_W = (orig_W // 2) * 2

        

        if depth_out.shape[1] != final_H or depth_out.shape[2] != final_W:
            depth_out = F.interpolate(depth_out.permute(0, 3, 1, 2), size=(final_H, final_W), mode="bilinear").permute(0, 2, 3, 1)
        depth_out = (depth_out - depth_out.min()) / (depth_out.max() - depth_out.min())
        depth_out = torch.clamp(depth_out, 0, 1)
        if da_model['is_metric']:
            depth_out = 1 - depth_out
        return (depth_out,)
    
NODE_CLASS_MAPPINGS = {
    "DepthAnything_V2": DepthAnything_V2,
    "DownloadAndLoadDepthAnythingV2Model": DownloadAndLoadDepthAnythingV2Model
}
NODE_DISPLAY_NAME_MAPPINGS = {
    "DepthAnything_V2": "Depth Anything V2",
    "DownloadAndLoadDepthAnythingV2Model": "DownloadAndLoadDepthAnythingV2Model"
}