Edit attention map selection upon num_edit_tokens

#15
pipeline_semantic_stable_diffusion_img2img_solver.py CHANGED
@@ -928,11 +928,11 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
928
  from_where=["up", "down"],
929
  is_cross=True,
930
  select=text_cross_attention_maps.index(editing_prompt[c]),
931
- )
932
- attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
933
 
934
  # average over all tokens
935
- assert (attn_map.shape[3] == num_edit_tokens[c])
936
  attn_map = torch.sum(attn_map, dim=3)
937
 
938
  # gaussian_smoothing
 
928
  from_where=["up", "down"],
929
  is_cross=True,
930
  select=text_cross_attention_maps.index(editing_prompt[c]),
931
+ )
932
+ attn_map = out[:, :, :, 1:1 + num_edit_tokens[self.batch_size*c]] # 0 -> startoftext
933
 
934
  # average over all tokens
935
+ assert (attn_map.shape[3] == num_edit_tokens[self.batch_size*c])
936
  attn_map = torch.sum(attn_map, dim=3)
937
 
938
  # gaussian_smoothing