Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -382,7 +382,6 @@ class StreamMultiDiffusion(nn.Module):
|
|
382 |
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
383 |
out = self.i2t_model.generate(**{k: v.to(self.i2t_model.device) for k, v in inputs.items()}, max_new_tokens=77)
|
384 |
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
385 |
-
print(prompt)
|
386 |
return prompt
|
387 |
|
388 |
@torch.no_grad()
|
@@ -474,7 +473,6 @@ class StreamMultiDiffusion(nn.Module):
|
|
474 |
if self.white is None:
|
475 |
self.white = self.encode_imgs(torch.ones(1, 3, self.height, self.width, dtype=self.dtype, device=self.device))
|
476 |
mix_ratio = self.bootstrap_mix_ratios[:, None, None, None]
|
477 |
-
print(mix_ratio, mix_ratio.dtype, self.white.dtype, self.white.device, self.state['background'].latent.dtype, self.state['background'].latent.device)
|
478 |
self.bootstrap_latent = mix_ratio * self.white + (1.0 - mix_ratio) * self.state['background'].latent
|
479 |
|
480 |
self.ready_checklist['background_registered'] = True
|
@@ -1093,6 +1091,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
1093 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1094 |
p = self.num_layers
|
1095 |
x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
|
|
|
1096 |
|
1097 |
if self.bootstrap_steps[0] > 0:
|
1098 |
# Background bootstrapping.
|
@@ -1101,6 +1100,8 @@ class StreamMultiDiffusion(nn.Module):
|
|
1101 |
self.stock_noise,
|
1102 |
torch.tensor(self.sub_timesteps_tensor, device=self.device),
|
1103 |
)
|
|
|
|
|
1104 |
x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
|
1105 |
bootstrap_mask = (
|
1106 |
self.masks * self.bootstrap_steps[None, :, None, None, None]
|
|
|
382 |
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
383 |
out = self.i2t_model.generate(**{k: v.to(self.i2t_model.device) for k, v in inputs.items()}, max_new_tokens=77)
|
384 |
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
|
|
385 |
return prompt
|
386 |
|
387 |
@torch.no_grad()
|
|
|
473 |
if self.white is None:
|
474 |
self.white = self.encode_imgs(torch.ones(1, 3, self.height, self.width, dtype=self.dtype, device=self.device))
|
475 |
mix_ratio = self.bootstrap_mix_ratios[:, None, None, None]
|
|
|
476 |
self.bootstrap_latent = mix_ratio * self.white + (1.0 - mix_ratio) * self.state['background'].latent
|
477 |
|
478 |
self.ready_checklist['background_registered'] = True
|
|
|
1091 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1092 |
p = self.num_layers
|
1093 |
x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
|
1094 |
+
print('111111111111111111111')
|
1095 |
|
1096 |
if self.bootstrap_steps[0] > 0:
|
1097 |
# Background bootstrapping.
|
|
|
1100 |
self.stock_noise,
|
1101 |
torch.tensor(self.sub_timesteps_tensor, device=self.device),
|
1102 |
)
|
1103 |
+
print('111111111111111111111', bootstrap_steps)
|
1104 |
+
|
1105 |
x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
|
1106 |
bootstrap_mask = (
|
1107 |
self.masks * self.bootstrap_steps[None, :, None, None, None]
|