File size: 6,220 Bytes
43b7e92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import TYPE_CHECKING, Dict, List, Union

from ..utils import logging


if TYPE_CHECKING:
    # import here to avoid circular imports
    from ..models import UNet2DConditionModel

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def _translate_into_actual_layer_name(name):
    """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
    if name == "mid":
        return "mid_block.attentions.0"

    updown, block, attn = name.split(".")

    updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
    block = block.replace("block_", "")
    attn = "attentions." + attn

    return ".".join((updown, block, attn))


def _maybe_expand_lora_scales(
    unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
):
    blocks_with_transformer = {
        "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
        "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
    }
    transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}

    expanded_weight_scales = [
        _maybe_expand_lora_scales_for_one_adapter(
            weight_for_adapter,
            blocks_with_transformer,
            transformer_per_block,
            unet.state_dict(),
            default_scale=default_scale,
        )
        for weight_for_adapter in weight_scales
    ]

    return expanded_weight_scales


def _maybe_expand_lora_scales_for_one_adapter(
    scales: Union[float, Dict],
    blocks_with_transformer: Dict[str, int],
    transformer_per_block: Dict[str, int],
    state_dict: None,
    default_scale: float = 1.0,
):
    """
    Expands the inputs into a more granular dictionary. See the example below for more details.

    Parameters:
        scales (`Union[float, Dict]`):
            Scales dict to expand.
        blocks_with_transformer (`Dict[str, int]`):
            Dict with keys 'up' and 'down', showing which blocks have transformer layers
        transformer_per_block (`Dict[str, int]`):
            Dict with keys 'up' and 'down', showing how many transformer layers each block has

    E.g. turns
    ```python
    scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
    blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
    transformer_per_block = {"down": 2, "up": 3}
    ```
    into
    ```python
    {
        "down.block_1.0": 2,
        "down.block_1.1": 2,
        "down.block_2.0": 2,
        "down.block_2.1": 2,
        "mid": 3,
        "up.block_0.0": 4,
        "up.block_0.1": 4,
        "up.block_0.2": 4,
        "up.block_1.0": 5,
        "up.block_1.1": 6,
        "up.block_1.2": 7,
    }
    ```
    """
    if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
        raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")

    if sorted(transformer_per_block.keys()) != ["down", "up"]:
        raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")

    if not isinstance(scales, dict):
        # don't expand if scales is a single number
        return scales

    scales = copy.deepcopy(scales)

    if "mid" not in scales:
        scales["mid"] = default_scale
    elif isinstance(scales["mid"], list):
        if len(scales["mid"]) == 1:
            scales["mid"] = scales["mid"][0]
        else:
            raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")

    for updown in ["up", "down"]:
        if updown not in scales:
            scales[updown] = default_scale

        # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
        if not isinstance(scales[updown], dict):
            scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}

        # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
        for i in blocks_with_transformer[updown]:
            block = f"block_{i}"
            # set not assigned blocks to default scale
            if block not in scales[updown]:
                scales[updown][block] = default_scale
            if not isinstance(scales[updown][block], list):
                scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
            elif len(scales[updown][block]) == 1:
                # a list specifying scale to each masked IP input
                scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
            elif len(scales[updown][block]) != transformer_per_block[updown]:
                raise ValueError(
                    f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
                )

        # eg {"down": "block_1": [1, 1]}}  to {"down.block_1.0": 1, "down.block_1.1": 1}
        for i in blocks_with_transformer[updown]:
            block = f"block_{i}"
            for tf_idx, value in enumerate(scales[updown][block]):
                scales[f"{updown}.{block}.{tf_idx}"] = value

        del scales[updown]

    for layer in scales.keys():
        if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
            raise ValueError(
                f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
            )

    return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}