Spaces:
Running
Running
Realcat
commited on
Commit
·
c6c23ae
1
Parent(s):
6534305
fix: cpu roma
Browse files
third_party/Roma/roma/models/encoders.py
CHANGED
@@ -24,7 +24,10 @@ class ResNet50(nn.Module):
|
|
24 |
self.freeze_bn = freeze_bn
|
25 |
self.early_exit = early_exit
|
26 |
self.amp = amp
|
27 |
-
|
|
|
|
|
|
|
28 |
|
29 |
def forward(self, x, **kwargs):
|
30 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
@@ -60,7 +63,10 @@ class VGG19(nn.Module):
|
|
60 |
super().__init__()
|
61 |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
62 |
self.amp = amp
|
63 |
-
|
|
|
|
|
|
|
64 |
|
65 |
def forward(self, x, **kwargs):
|
66 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
@@ -94,7 +100,10 @@ class CNNandDinov2(nn.Module):
|
|
94 |
else:
|
95 |
self.cnn = VGG19(**cnn_kwargs)
|
96 |
self.amp = amp
|
97 |
-
|
|
|
|
|
|
|
98 |
if self.amp:
|
99 |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
100 |
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
|
|
24 |
self.freeze_bn = freeze_bn
|
25 |
self.early_exit = early_exit
|
26 |
self.amp = amp
|
27 |
+
if not torch.cuda.is_available():
|
28 |
+
self.amp_dtype = torch.float32
|
29 |
+
else:
|
30 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
31 |
|
32 |
def forward(self, x, **kwargs):
|
33 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
|
63 |
super().__init__()
|
64 |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
65 |
self.amp = amp
|
66 |
+
if not torch.cuda.is_available():
|
67 |
+
self.amp_dtype = torch.float32
|
68 |
+
else:
|
69 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
70 |
|
71 |
def forward(self, x, **kwargs):
|
72 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
|
100 |
else:
|
101 |
self.cnn = VGG19(**cnn_kwargs)
|
102 |
self.amp = amp
|
103 |
+
if not torch.cuda.is_available():
|
104 |
+
self.amp_dtype = torch.float32
|
105 |
+
else:
|
106 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
107 |
if self.amp:
|
108 |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
109 |
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
third_party/Roma/roma/models/matcher.py
CHANGED
@@ -71,8 +71,12 @@ class ConvRefiner(nn.Module):
|
|
71 |
self.disable_local_corr_grad = disable_local_corr_grad
|
72 |
self.is_classifier = is_classifier
|
73 |
self.sample_mode = sample_mode
|
74 |
-
self.
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
def create_block(
|
77 |
self,
|
78 |
in_dim,
|
@@ -109,8 +113,8 @@ class ConvRefiner(nn.Module):
|
|
109 |
if self.has_displacement_emb:
|
110 |
im_A_coords = torch.meshgrid(
|
111 |
(
|
112 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
113 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
114 |
)
|
115 |
)
|
116 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
@@ -296,8 +300,11 @@ class Decoder(nn.Module):
|
|
296 |
self.displacement_dropout_p = displacement_dropout_p
|
297 |
self.gm_warp_dropout_p = gm_warp_dropout_p
|
298 |
self.flow_upsample_mode = flow_upsample_mode
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
301 |
def get_placeholder_flow(self, b, h, w, device):
|
302 |
coarse_coords = torch.meshgrid(
|
303 |
(
|
@@ -615,8 +622,8 @@ class RegressionMatcher(nn.Module):
|
|
615 |
# Create im_A meshgrid
|
616 |
im_A_coords = torch.meshgrid(
|
617 |
(
|
618 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
619 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
620 |
)
|
621 |
)
|
622 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
71 |
self.disable_local_corr_grad = disable_local_corr_grad
|
72 |
self.is_classifier = is_classifier
|
73 |
self.sample_mode = sample_mode
|
74 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
75 |
+
if not torch.cuda.is_available():
|
76 |
+
self.amp_dtype = torch.float32
|
77 |
+
else:
|
78 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
79 |
+
|
80 |
def create_block(
|
81 |
self,
|
82 |
in_dim,
|
|
|
113 |
if self.has_displacement_emb:
|
114 |
im_A_coords = torch.meshgrid(
|
115 |
(
|
116 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=self.device),
|
117 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=self.device),
|
118 |
)
|
119 |
)
|
120 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
300 |
self.displacement_dropout_p = displacement_dropout_p
|
301 |
self.gm_warp_dropout_p = gm_warp_dropout_p
|
302 |
self.flow_upsample_mode = flow_upsample_mode
|
303 |
+
if not torch.cuda.is_available():
|
304 |
+
self.amp_dtype = torch.float32
|
305 |
+
else:
|
306 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
307 |
+
|
308 |
def get_placeholder_flow(self, b, h, w, device):
|
309 |
coarse_coords = torch.meshgrid(
|
310 |
(
|
|
|
622 |
# Create im_A meshgrid
|
623 |
im_A_coords = torch.meshgrid(
|
624 |
(
|
625 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
626 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
627 |
)
|
628 |
)
|
629 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|