Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -78,7 +78,7 @@ class StreamMultiDiffusion(nn.Module):
|
|
78 |
self.default_mask_strength = default_mask_strength
|
79 |
self.default_prompt_strength = default_prompt_strength
|
80 |
self.register_buffer('bootstrap_steps', (
|
81 |
-
bootstrap_steps > torch.arange(len(t_index_list))).to(dtype=self.dtype, device=self.device))
|
82 |
self.bootstrap_mix_steps = bootstrap_mix_steps
|
83 |
self.register_buffer('bootstrap_mix_ratios', (
|
84 |
bootstrap_mix_steps - torch.arange(len(t_index_list), device=self.device)).clip_(0, 1).to(self.dtype))
|
@@ -1091,8 +1091,6 @@ class StreamMultiDiffusion(nn.Module):
|
|
1091 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1092 |
p = self.num_layers
|
1093 |
x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
|
1094 |
-
print('111111111111111111111')
|
1095 |
-
|
1096 |
if self.bootstrap_steps[0] > 0:
|
1097 |
# Background bootstrapping.
|
1098 |
bootstrap_latent = self.scheduler.add_noise(
|
@@ -1100,7 +1098,6 @@ class StreamMultiDiffusion(nn.Module):
|
|
1100 |
self.stock_noise,
|
1101 |
torch.tensor(self.sub_timesteps_tensor, device=self.device),
|
1102 |
)
|
1103 |
-
print('111111111111111111111', self.bootstrap_steps)
|
1104 |
|
1105 |
x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
|
1106 |
bootstrap_mask = (
|
@@ -1109,11 +1106,9 @@ class StreamMultiDiffusion(nn.Module):
|
|
1109 |
) # (p, t, c, h, w)
|
1110 |
x_t_latent = (1.0 - bootstrap_mask) * bootstrap_latent[None] + bootstrap_mask * x_t_latent
|
1111 |
x_t_latent = rearrange(x_t_latent, 'p t c h w -> (t p) c h w')
|
1112 |
-
print('222222222222222222222')
|
1113 |
|
1114 |
# Centering.
|
1115 |
x_t_latent = shift_to_mask_bbox_center(x_t_latent, rearrange(self.masks, 'p t c h w -> (t p) c h w'), reverse=True)
|
1116 |
-
print('333333333333333333333')
|
1117 |
|
1118 |
t_list = self.sub_timesteps_tensor_ # (T * p,)
|
1119 |
if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
|
|
|
78 |
self.default_mask_strength = default_mask_strength
|
79 |
self.default_prompt_strength = default_prompt_strength
|
80 |
self.register_buffer('bootstrap_steps', (
|
81 |
+
bootstrap_steps > torch.arange(len(t_index_list))).float().to(dtype=self.dtype, device=self.device))
|
82 |
self.bootstrap_mix_steps = bootstrap_mix_steps
|
83 |
self.register_buffer('bootstrap_mix_ratios', (
|
84 |
bootstrap_mix_steps - torch.arange(len(t_index_list), device=self.device)).clip_(0, 1).to(self.dtype))
|
|
|
1091 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1092 |
p = self.num_layers
|
1093 |
x_t_latent = x_t_latent.repeat_interleave(p, dim=0) # (T * p, 4, h, w)
|
|
|
|
|
1094 |
if self.bootstrap_steps[0] > 0:
|
1095 |
# Background bootstrapping.
|
1096 |
bootstrap_latent = self.scheduler.add_noise(
|
|
|
1098 |
self.stock_noise,
|
1099 |
torch.tensor(self.sub_timesteps_tensor, device=self.device),
|
1100 |
)
|
|
|
1101 |
|
1102 |
x_t_latent = rearrange(x_t_latent, '(t p) c h w -> p t c h w', p=p)
|
1103 |
bootstrap_mask = (
|
|
|
1106 |
) # (p, t, c, h, w)
|
1107 |
x_t_latent = (1.0 - bootstrap_mask) * bootstrap_latent[None] + bootstrap_mask * x_t_latent
|
1108 |
x_t_latent = rearrange(x_t_latent, 'p t c h w -> (t p) c h w')
|
|
|
1109 |
|
1110 |
# Centering.
|
1111 |
x_t_latent = shift_to_mask_bbox_center(x_t_latent, rearrange(self.masks, 'p t c h w -> (t p) c h w'), reverse=True)
|
|
|
1112 |
|
1113 |
t_list = self.sub_timesteps_tensor_ # (T * p,)
|
1114 |
if self.guidance_scale > 1.0 and self.cfg_type == 'initialize':
|