Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -1122,25 +1122,12 @@ class StreamMultiDiffusion(nn.Module):
|
|
1122 |
else:
|
1123 |
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
1124 |
|
1125 |
-
ns = []
|
1126 |
-
c1, c2, c3 = 0, 0, 0
|
1127 |
-
for n, p in self.unet.named_parameters():
|
1128 |
-
if p.data.dtype == torch.float:
|
1129 |
-
c1 += 1
|
1130 |
-
ns.append(n)
|
1131 |
-
elif p.data.dtype == torch.half:
|
1132 |
-
c2 += 1
|
1133 |
-
else:
|
1134 |
-
c3 += 1
|
1135 |
-
print(c1, c2, c3)
|
1136 |
-
print(ns)
|
1137 |
model_pred = self.unet(
|
1138 |
x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
|
1139 |
t_list, # (B,)
|
1140 |
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
1141 |
return_dict=False,
|
1142 |
)[0] # (B, 4, h, w)
|
1143 |
-
print('222222222222222', model_pred.dtype)
|
1144 |
|
1145 |
if self.bootstrap_steps[0] > 0:
|
1146 |
# Uncentering.
|
@@ -1151,6 +1138,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
1151 |
bootstrap_mask_ = torch.concat([bootstrap_mask, bootstrap_mask], dim=0)
|
1152 |
else:
|
1153 |
bootstrap_mask_ = bootstrap_mask
|
|
|
1154 |
model_pred = shift_to_mask_bbox_center(model_pred, bootstrap_mask_)
|
1155 |
x_t_latent = shift_to_mask_bbox_center(x_t_latent, bootstrap_mask)
|
1156 |
|
@@ -1235,7 +1223,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
1235 |
self.stock_noise_ = self.stock_noise.repeat_interleave(self.num_layers, dim=0) # (T * p, 77, 768)
|
1236 |
|
1237 |
x_0_pred_batch = self.unet_step(latent)
|
1238 |
-
|
1239 |
latent = x_0_pred_batch[-1:]
|
1240 |
self.x_t_latent_buffer = (
|
1241 |
self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
|
|
|
1122 |
else:
|
1123 |
x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
|
1124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1125 |
model_pred = self.unet(
|
1126 |
x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
|
1127 |
t_list, # (B,)
|
1128 |
encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
|
1129 |
return_dict=False,
|
1130 |
)[0] # (B, 4, h, w)
|
|
|
1131 |
|
1132 |
if self.bootstrap_steps[0] > 0:
|
1133 |
# Uncentering.
|
|
|
1138 |
bootstrap_mask_ = torch.concat([bootstrap_mask, bootstrap_mask], dim=0)
|
1139 |
else:
|
1140 |
bootstrap_mask_ = bootstrap_mask
|
1141 |
+
print('2222222222222222222222222222222222222', model_pred.shape, bootstrap_mask_)
|
1142 |
model_pred = shift_to_mask_bbox_center(model_pred, bootstrap_mask_)
|
1143 |
x_t_latent = shift_to_mask_bbox_center(x_t_latent, bootstrap_mask)
|
1144 |
|
|
|
1223 |
self.stock_noise_ = self.stock_noise.repeat_interleave(self.num_layers, dim=0) # (T * p, 77, 768)
|
1224 |
|
1225 |
x_0_pred_batch = self.unet_step(latent)
|
1226 |
+
print('111111111111111111111111111111111')
|
1227 |
latent = x_0_pred_batch[-1:]
|
1228 |
self.x_t_latent_buffer = (
|
1229 |
self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
|