File size: 6,220 Bytes
43b7e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import TYPE_CHECKING, Dict, List, Union
from ..utils import logging
if TYPE_CHECKING:
# import here to avoid circular imports
from ..models import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _translate_into_actual_layer_name(name):
"""Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
if name == "mid":
return "mid_block.attentions.0"
updown, block, attn = name.split(".")
updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
block = block.replace("block_", "")
attn = "attentions." + attn
return ".".join((updown, block, attn))
def _maybe_expand_lora_scales(
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
):
blocks_with_transformer = {
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
}
transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}
expanded_weight_scales = [
_maybe_expand_lora_scales_for_one_adapter(
weight_for_adapter,
blocks_with_transformer,
transformer_per_block,
unet.state_dict(),
default_scale=default_scale,
)
for weight_for_adapter in weight_scales
]
return expanded_weight_scales
def _maybe_expand_lora_scales_for_one_adapter(
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
state_dict: None,
default_scale: float = 1.0,
):
"""
Expands the inputs into a more granular dictionary. See the example below for more details.
Parameters:
scales (`Union[float, Dict]`):
Scales dict to expand.
blocks_with_transformer (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing which blocks have transformer layers
transformer_per_block (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing how many transformer layers each block has
E.g. turns
```python
scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
transformer_per_block = {"down": 2, "up": 3}
```
into
```python
{
"down.block_1.0": 2,
"down.block_1.1": 2,
"down.block_2.0": 2,
"down.block_2.1": 2,
"mid": 3,
"up.block_0.0": 4,
"up.block_0.1": 4,
"up.block_0.2": 4,
"up.block_1.0": 5,
"up.block_1.1": 6,
"up.block_1.2": 7,
}
```
"""
if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")
if sorted(transformer_per_block.keys()) != ["down", "up"]:
raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")
if not isinstance(scales, dict):
# don't expand if scales is a single number
return scales
scales = copy.deepcopy(scales)
if "mid" not in scales:
scales["mid"] = default_scale
elif isinstance(scales["mid"], list):
if len(scales["mid"]) == 1:
scales["mid"] = scales["mid"][0]
else:
raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
for updown in ["up", "down"]:
if updown not in scales:
scales[updown] = default_scale
# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
if not isinstance(scales[updown], dict):
scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
# eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
# set not assigned blocks to default scale
if block not in scales[updown]:
scales[updown][block] = default_scale
if not isinstance(scales[updown][block], list):
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
elif len(scales[updown][block]) == 1:
# a list specifying scale to each masked IP input
scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
elif len(scales[updown][block]) != transformer_per_block[updown]:
raise ValueError(
f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
)
# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
for tf_idx, value in enumerate(scales[updown][block]):
scales[f"{updown}.{block}.{tf_idx}"] = value
del scales[updown]
for layer in scales.keys():
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
)
return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}
|