ironjr commited on
Commit
2b068ec
·
verified ·
1 Parent(s): 2ffb19d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -14
model.py CHANGED
@@ -1122,25 +1122,12 @@ class StreamMultiDiffusion(nn.Module):
1122
  else:
1123
  x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
1124
 
1125
- ns = []
1126
- c1, c2, c3 = 0, 0, 0
1127
- for n, p in self.unet.named_parameters():
1128
- if p.data.dtype == torch.float:
1129
- c1 += 1
1130
- ns.append(n)
1131
- elif p.data.dtype == torch.half:
1132
- c2 += 1
1133
- else:
1134
- c3 += 1
1135
- print(c1, c2, c3)
1136
- print(ns)
1137
  model_pred = self.unet(
1138
  x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
1139
  t_list, # (B,)
1140
  encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1141
  return_dict=False,
1142
  )[0] # (B, 4, h, w)
1143
- print('222222222222222', model_pred.dtype)
1144
 
1145
  if self.bootstrap_steps[0] > 0:
1146
  # Uncentering.
@@ -1151,6 +1138,7 @@ class StreamMultiDiffusion(nn.Module):
1151
  bootstrap_mask_ = torch.concat([bootstrap_mask, bootstrap_mask], dim=0)
1152
  else:
1153
  bootstrap_mask_ = bootstrap_mask
 
1154
  model_pred = shift_to_mask_bbox_center(model_pred, bootstrap_mask_)
1155
  x_t_latent = shift_to_mask_bbox_center(x_t_latent, bootstrap_mask)
1156
 
@@ -1235,7 +1223,7 @@ class StreamMultiDiffusion(nn.Module):
1235
  self.stock_noise_ = self.stock_noise.repeat_interleave(self.num_layers, dim=0) # (T * p, 77, 768)
1236
 
1237
  x_0_pred_batch = self.unet_step(latent)
1238
-
1239
  latent = x_0_pred_batch[-1:]
1240
  self.x_t_latent_buffer = (
1241
  self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]
 
1122
  else:
1123
  x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
1124
 
 
 
 
 
 
 
 
 
 
 
 
 
1125
  model_pred = self.unet(
1126
  x_t_latent_plus_uc.to(self.unet.dtype), # (B, 4, h, w)
1127
  t_list, # (B,)
1128
  encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1129
  return_dict=False,
1130
  )[0] # (B, 4, h, w)
 
1131
 
1132
  if self.bootstrap_steps[0] > 0:
1133
  # Uncentering.
 
1138
  bootstrap_mask_ = torch.concat([bootstrap_mask, bootstrap_mask], dim=0)
1139
  else:
1140
  bootstrap_mask_ = bootstrap_mask
1141
+ print('2222222222222222222222222222222222222', model_pred.shape, bootstrap_mask_)
1142
  model_pred = shift_to_mask_bbox_center(model_pred, bootstrap_mask_)
1143
  x_t_latent = shift_to_mask_bbox_center(x_t_latent, bootstrap_mask)
1144
 
 
1223
  self.stock_noise_ = self.stock_noise.repeat_interleave(self.num_layers, dim=0) # (T * p, 77, 768)
1224
 
1225
  x_0_pred_batch = self.unet_step(latent)
1226
+ print('111111111111111111111111111111111')
1227
  latent = x_0_pred_batch[-1:]
1228
  self.x_t_latent_buffer = (
1229
  self.alpha_prod_t_sqrt[1:] * x_0_pred_batch[:-1]