File size: 5,136 Bytes
617065a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from diffusers.models.attention_processor import LoRAAttnProcessor


def add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None):
    """
    Add tokens to the tokenizer and set the initial value of token embeddings
    """
    tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token)
    text_encoder.resize_token_embeddings(len(tokenizer))
    token_embeds = text_encoder.get_input_embeddings().weight.data
    placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
    if initializer_token:
        token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
        for i, placeholder_token_id in enumerate(placeholder_token_ids):
            token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]]
    else:
        for i, placeholder_token_id in enumerate(placeholder_token_ids):
            token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id])
    return placeholder_token_ids


def tokenize_prompt(tokenizer, prompt, replace_token=False):
    text_inputs = tokenizer(
        prompt,
        replace_token=replace_token,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    return text_input_ids


def get_processor(self, return_deprecated_lora: bool = False):
    r"""
    Get the attention processor in use.

    Args:
        return_deprecated_lora (`bool`, *optional*, defaults to `False`):
            Set to `True` to return the deprecated LoRA attention processor.

    Returns:
        "AttentionProcessor": The attention processor in use.
    """
    if not return_deprecated_lora:
        return self.processor

    # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
    # serialization format for LoRA Attention Processors. It should be deleted once the integration
    # with PEFT is completed.
    is_lora_activated = {
        name: module.lora_layer is not None
        for name, module in self.named_modules()
        if hasattr(module, "lora_layer")
    }

    # 1. if no layer has a LoRA activated we can return the processor as usual
    if not any(is_lora_activated.values()):
        return self.processor

    # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
    is_lora_activated.pop("add_k_proj", None)
    is_lora_activated.pop("add_v_proj", None)
    # 2. else it is not posssible that only some layers have LoRA activated
    if not all(is_lora_activated.values()):
        raise ValueError(
            f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
        )

    # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
    # non_lora_processor_cls_name = self.processor.__class__.__name__
    # lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)

    hidden_size = self.inner_dim

    # now create a LoRA attention processor from the LoRA layers
    kwargs = {
        "cross_attention_dim": self.cross_attention_dim,
        "rank": self.to_q.lora_layer.rank,
        "network_alpha": self.to_q.lora_layer.network_alpha,
        "q_rank": self.to_q.lora_layer.rank,
        "q_hidden_size": self.to_q.lora_layer.out_features,
        "k_rank": self.to_k.lora_layer.rank,
        "k_hidden_size": self.to_k.lora_layer.out_features,
        "v_rank": self.to_v.lora_layer.rank,
        "v_hidden_size": self.to_v.lora_layer.out_features,
        "out_rank": self.to_out[0].lora_layer.rank,
        "out_hidden_size": self.to_out[0].lora_layer.out_features,
    }

    if hasattr(self.processor, "attention_op"):
        kwargs["attention_op"] = self.processor.attention_op

    lora_processor = LoRAAttnProcessor(hidden_size, **kwargs)
    lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
    lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
    lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
    lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())

    return lora_processor


def get_attn_processors(self):
    r"""
    Returns:
        `dict` of attention processors: A dictionary containing all attention processors used in the model with
        indexed by its weight name.
    """
    # set recursively
    processors = {}

    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
        if hasattr(module, "get_processor"):
            processors[f"{name}.processor"] = get_processor(module, return_deprecated_lora=True)

        for sub_name, child in module.named_children():
            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

        return processors

    for name, module in self.named_children():
        fn_recursive_add_processors(name, module, processors)

    return processors