import torch from diffusers.models.attention import CrossAttention class MyCrossAttnProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) # new bookkeeping to save the attn probs attn.attn_probs = attention_probs hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states """ A function that prepares a U-Net model for training by enabling gradient computation for a specified set of parameters and setting the forward pass to be performed by a custom cross attention processor. Parameters: unet: A U-Net model. Returns: unet: The prepared U-Net model. """ def prep_unet(unet): # set the gradients for XA maps to be true for name, params in unet.named_parameters(): if 'attn2' in name: params.requires_grad = True else: params.requires_grad = False # replace the fwd function for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention": module.set_processor(MyCrossAttnProcessor()) return unet