ironjr commited on
Commit
a98f79f
·
verified ·
1 Parent(s): 4d7f709

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -0
model.py CHANGED
@@ -1121,12 +1121,14 @@ class StreamMultiDiffusion(nn.Module):
1121
  else:
1122
  x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
1123
 
 
1124
  model_pred = self.unet(
1125
  x_t_latent_plus_uc.to(self.dtype), # (B, 4, h, w)
1126
  t_list, # (B,)
1127
  encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1128
  return_dict=False,
1129
  )[0] # (B, 4, h, w)
 
1130
 
1131
  if self.bootstrap_steps[0] > 0:
1132
  # Uncentering.
 
1121
  else:
1122
  x_t_latent_plus_uc = x_t_latent # (T * p, 4, h, w)
1123
 
1124
+ print('1111111111111111111111', x_t_latent_plus_uc.dtype, self.unet.dtype, self.prompt_embeds.dtype)
1125
  model_pred = self.unet(
1126
  x_t_latent_plus_uc.to(self.dtype), # (B, 4, h, w)
1127
  t_list, # (B,)
1128
  encoder_hidden_states=self.prompt_embeds, # (B, 77, 768)
1129
  return_dict=False,
1130
  )[0] # (B, 4, h, w)
1131
+ print('222222222222222', model_pred.dtype)
1132
 
1133
  if self.bootstrap_steps[0] > 0:
1134
  # Uncentering.