ironjr commited on
Commit
aad81b4
·
verified ·
1 Parent(s): 5540e7d

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -2
model.py CHANGED
@@ -382,7 +382,6 @@ class StreamMultiDiffusion(nn.Module):
382
  inputs = self.i2t_processor(image, question, return_tensors='pt')
383
  out = self.i2t_model.generate(**{k: v.to(self.i2t_model.device) for k, v in inputs.items()}, max_new_tokens=77)
384
  prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
385
- print(prompt)
386
  return prompt
387
 
388
  @torch.no_grad()
@@ -474,7 +473,6 @@ class StreamMultiDiffusion(nn.Module):
474
  if self.white is None:
475
  self.white = self.encode_imgs(torch.ones(1, 3, self.height, self.width, dtype=self.dtype, device=self.device))
476
  mix_ratio = self.bootstrap_mix_ratios[:, None, None, None]
477
- print(mix_ratio, mix_ratio.dtype, self.white.dtype, self.white.device, self.state['background'].latent.dtype, self.state['background'].latent.device)
478
  self.bootstrap_latent = mix_ratio * self.white + (1.0 - mix_ratio) * self.state['background'].latent
479
 
480
  self.ready_checklist['background_registered'] = True
@@ -1093,6 +1091,7 @@ class StreamMultiDiffusion(nn.Module):
1093
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1094
  p = self.num_layers
1095
  x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
 
1096
 
1097
  if self.bootstrap_steps[0] > 0:
1098
  # Background bootstrapping.
@@ -1101,6 +1100,8 @@ class StreamMultiDiffusion(nn.Module):
1101
  self.stock_noise,
1102
  torch.tensor(self.sub_timesteps_tensor, device=self.device),
1103
  )
 
 
1104
  x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
1105
  bootstrap_mask = (
1106
  self.masks * self.bootstrap_steps[None, :, None, None, None]
 
382
  inputs = self.i2t_processor(image, question, return_tensors='pt')
383
  out = self.i2t_model.generate(**{k: v.to(self.i2t_model.device) for k, v in inputs.items()}, max_new_tokens=77)
384
  prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
 
385
  return prompt
386
 
387
  @torch.no_grad()
 
473
  if self.white is None:
474
  self.white = self.encode_imgs(torch.ones(1, 3, self.height, self.width, dtype=self.dtype, device=self.device))
475
  mix_ratio = self.bootstrap_mix_ratios[:, None, None, None]
 
476
  self.bootstrap_latent = mix_ratio * self.white + (1.0 - mix_ratio) * self.state['background'].latent
477
 
478
  self.ready_checklist['background_registered'] = True
 
1091
  ) -> Tuple[torch.Tensor, torch.Tensor]:
1092
  p = self.num_layers
1093
  x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
1094
+ print('111111111111111111111')
1095
 
1096
  if self.bootstrap_steps[0] > 0:
1097
  # Background bootstrapping.
 
1100
  self.stock_noise,
1101
  torch.tensor(self.sub_timesteps_tensor, device=self.device),
1102
  )
1103
+ print('111111111111111111111', bootstrap_steps)
1104
+
1105
  x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
1106
  bootstrap_mask = (
1107
  self.masks * self.bootstrap_steps[None, :, None, None, None]