Train parameters exclusively in specific ranges (#1390)
Browse files* Train parameters exclusively in specific ranges
* Fix the style and update docs
* Update yaml example
- examples/mistral/mixtral.yml +6 -6
- src/axolotl/train.py +2 -2
- src/axolotl/utils/freeze.py +199 -11
- tests/test_freeze.py +285 -0
examples/mistral/mixtral.yml
CHANGED
@@ -16,12 +16,12 @@ output_dir: ./qlora-out
|
|
16 |
|
17 |
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
18 |
unfrozen_parameters:
|
19 |
-
# - lm_head
|
20 |
-
# - model.embed_tokens
|
21 |
-
# - model.layers.2[0-9]+.block_sparse_moe.gate
|
22 |
-
# - model.layers.2[0-9]+.block_sparse_moe.experts
|
23 |
-
# - model.layers.3[0-9]+.block_sparse_moe.gate
|
24 |
-
# - model.layers.3[0-9]+.block_sparse_moe.experts
|
25 |
|
26 |
model_config:
|
27 |
output_router_logits: true
|
|
|
16 |
|
17 |
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
18 |
unfrozen_parameters:
|
19 |
+
# - ^lm_head.weight$
|
20 |
+
# - ^model.embed_tokens.weight$[:32000]
|
21 |
+
# - model.layers.2[0-9]+.block_sparse_moe.gate
|
22 |
+
# - model.layers.2[0-9]+.block_sparse_moe.experts
|
23 |
+
# - model.layers.3[0-9]+.block_sparse_moe.gate
|
24 |
+
# - model.layers.3[0-9]+.block_sparse_moe.experts
|
25 |
|
26 |
model_config:
|
27 |
output_router_logits: true
|
src/axolotl/train.py
CHANGED
@@ -19,7 +19,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
19 |
from axolotl.common.cli import TrainerCliArgs
|
20 |
from axolotl.logging_config import configure_logging
|
21 |
from axolotl.utils.dict import DictDefault
|
22 |
-
from axolotl.utils.freeze import
|
23 |
from axolotl.utils.models import load_model, load_tokenizer
|
24 |
from axolotl.utils.trainer import setup_trainer
|
25 |
|
@@ -99,7 +99,7 @@ def train(
|
|
99 |
safe_serialization = cfg.save_safetensors is True
|
100 |
|
101 |
if cfg.unfrozen_parameters:
|
102 |
-
|
103 |
|
104 |
trainer = setup_trainer(
|
105 |
cfg,
|
|
|
19 |
from axolotl.common.cli import TrainerCliArgs
|
20 |
from axolotl.logging_config import configure_logging
|
21 |
from axolotl.utils.dict import DictDefault
|
22 |
+
from axolotl.utils.freeze import freeze_layers_except
|
23 |
from axolotl.utils.models import load_model, load_tokenizer
|
24 |
from axolotl.utils.trainer import setup_trainer
|
25 |
|
|
|
99 |
safe_serialization = cfg.save_safetensors is True
|
100 |
|
101 |
if cfg.unfrozen_parameters:
|
102 |
+
freeze_layers_except(model, cfg.unfrozen_parameters)
|
103 |
|
104 |
trainer = setup_trainer(
|
105 |
cfg,
|
src/axolotl/utils/freeze.py
CHANGED
@@ -3,13 +3,14 @@ module to freeze/unfreeze parameters by name
|
|
3 |
"""
|
4 |
import logging
|
5 |
import re
|
|
|
6 |
|
7 |
from axolotl.utils.distributed import is_main_process
|
8 |
|
9 |
LOG = logging.getLogger("axolotl.utils.freeze")
|
10 |
|
11 |
|
12 |
-
def
|
13 |
"""
|
14 |
Freezes all layers of the given model except for the layers that match given regex patterns.
|
15 |
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
@@ -17,22 +18,209 @@ def freeze_parameters_except(model, regex_patterns):
|
|
17 |
Parameters:
|
18 |
- model (nn.Module): The PyTorch model to be modified.
|
19 |
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
|
|
|
|
|
|
|
|
20 |
|
21 |
Returns:
|
22 |
None; the model is modified in place.
|
23 |
"""
|
24 |
-
|
25 |
-
|
26 |
-
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
27 |
-
]
|
28 |
|
29 |
-
|
30 |
-
for param in model.parameters():
|
31 |
-
param.requires_grad = False
|
32 |
|
33 |
# Unfreeze layers that match the regex patterns
|
34 |
for name, param in model.named_parameters():
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
38 |
param.requires_grad = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
"""
|
4 |
import logging
|
5 |
import re
|
6 |
+
from typing import Callable, List, Tuple
|
7 |
|
8 |
from axolotl.utils.distributed import is_main_process
|
9 |
|
10 |
LOG = logging.getLogger("axolotl.utils.freeze")
|
11 |
|
12 |
|
13 |
+
def freeze_layers_except(model, regex_patterns):
|
14 |
"""
|
15 |
Freezes all layers of the given model except for the layers that match given regex patterns.
|
16 |
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
|
|
18 |
Parameters:
|
19 |
- model (nn.Module): The PyTorch model to be modified.
|
20 |
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
21 |
+
Note that you cannot use a dot as a wildcard character in the patterns since it is reserved for separating layer names.
|
22 |
+
Also, to match the entire layer name, the pattern should start with "^" and end with "$", otherwise it will match any part of the layer name.
|
23 |
+
The range pattern part is optional and it is not compiled as a regex pattern which means you must put "$" before the range pattern if you want to match the entire layer name.
|
24 |
+
E.g., ["^model.embed_tokens.weight$[:32000]", "layers.2[0-9]+.block_sparse_moe.gate.[a-z]+$"]
|
25 |
|
26 |
Returns:
|
27 |
None; the model is modified in place.
|
28 |
"""
|
29 |
+
if isinstance(regex_patterns, str):
|
30 |
+
regex_patterns = [regex_patterns]
|
|
|
|
|
31 |
|
32 |
+
patterns = [LayerNamePattern(pattern) for pattern in regex_patterns]
|
|
|
|
|
33 |
|
34 |
# Unfreeze layers that match the regex patterns
|
35 |
for name, param in model.named_parameters():
|
36 |
+
param.requires_grad = False
|
37 |
+
unfrozen_ranges = []
|
38 |
+
for pattern in patterns:
|
39 |
+
if not pattern.match(name):
|
40 |
+
continue
|
41 |
+
|
42 |
param.requires_grad = True
|
43 |
+
|
44 |
+
if pattern.range is not None:
|
45 |
+
unfrozen_ranges.append(pattern.range)
|
46 |
+
|
47 |
+
merged_unfrozen_ranges = _merge_ranges(unfrozen_ranges, len(param))
|
48 |
+
|
49 |
+
if param.requires_grad and is_main_process():
|
50 |
+
unfrozen_ranges = (
|
51 |
+
f" with ranges {merged_unfrozen_ranges}"
|
52 |
+
if merged_unfrozen_ranges
|
53 |
+
else ""
|
54 |
+
)
|
55 |
+
LOG.debug(f"Unfrozen {name}{unfrozen_ranges}")
|
56 |
+
|
57 |
+
if not merged_unfrozen_ranges:
|
58 |
+
continue
|
59 |
+
|
60 |
+
# The range list we need is actually the inverted of the merged ranges
|
61 |
+
ranges_to_freeze = _invert_ranges(merged_unfrozen_ranges, len(param))
|
62 |
+
|
63 |
+
param.register_hook(_create_freeze_parameters_hook(ranges_to_freeze))
|
64 |
+
|
65 |
+
if is_main_process() and all(
|
66 |
+
not param.requires_grad for param in model.parameters()
|
67 |
+
):
|
68 |
+
LOG.warning("All parameters are frozen. Model will not be trained.")
|
69 |
+
|
70 |
+
|
71 |
+
def _invert_ranges(
|
72 |
+
given_ranges: List[Tuple[int, int]], layer_size: int
|
73 |
+
) -> List[Tuple[int, int]]:
|
74 |
+
"""
|
75 |
+
Inverts a list of ranges to obtain the ranges not covered by the given ranges.
|
76 |
+
|
77 |
+
Parameters:
|
78 |
+
- given_ranges (List[Tuple[int, int]]): List of ranges to invert. Each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
79 |
+
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
80 |
+
Returns:
|
81 |
+
- List[Tuple[int, int]]: List of inverted ranges, where each range is represented as a tuple of start (inclusive) and end (exclusive) indices.
|
82 |
+
"""
|
83 |
+
if not given_ranges:
|
84 |
+
return [(0, layer_size)]
|
85 |
+
|
86 |
+
inverted_ranges = []
|
87 |
+
current_start = 0
|
88 |
+
|
89 |
+
for start, end in sorted(given_ranges):
|
90 |
+
if start > current_start:
|
91 |
+
inverted_ranges.append((current_start, start))
|
92 |
+
current_start = max(current_start, end)
|
93 |
+
|
94 |
+
# Handle the case where the last given range does not reach the end of the total_size
|
95 |
+
if current_start < layer_size:
|
96 |
+
inverted_ranges.append((current_start, layer_size))
|
97 |
+
|
98 |
+
return inverted_ranges
|
99 |
+
|
100 |
+
|
101 |
+
def _merge_ranges(
|
102 |
+
given_ranges: List[Tuple[int, int | None]], layer_size: int
|
103 |
+
) -> List[Tuple[int, int]]:
|
104 |
+
"""
|
105 |
+
Merges overlapping ranges and sorts the given ranges.
|
106 |
+
|
107 |
+
This function takes a list of ranges and merges any overlapping ranges. The ranges are represented
|
108 |
+
as tuples, where the first element is the start index (inclusive) and the second element is the end
|
109 |
+
index (exclusive). The end index can be None, indicating that the range extends to the end of the
|
110 |
+
sequence.
|
111 |
+
|
112 |
+
Parameters:
|
113 |
+
- given_ranges (List[Tuple[int, int | None]]): List of ranges to merge.
|
114 |
+
- layer_size (int): The length of the layer. E.g., len(model.layer.weight)
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
- List[Tuple[int, int]]: List of merged ranges, as start (inclusive) and end (exclusive) indices.
|
118 |
+
"""
|
119 |
+
# End of each range can be determined now since we have the total size
|
120 |
+
processed_ranges = [
|
121 |
+
(start, end if end is not None else layer_size) for start, end in given_ranges
|
122 |
+
]
|
123 |
+
|
124 |
+
# No need to merge if there's only one or no ranges
|
125 |
+
if len(processed_ranges) <= 1:
|
126 |
+
return processed_ranges
|
127 |
+
|
128 |
+
sorted_ranges = sorted(processed_ranges)
|
129 |
+
|
130 |
+
merged_ranges = [sorted_ranges[0]]
|
131 |
+
for start, end in sorted_ranges[1:]:
|
132 |
+
prev_start, prev_end = merged_ranges[-1]
|
133 |
+
if start <= prev_end:
|
134 |
+
merged_ranges[-1] = (prev_start, max(prev_end, end))
|
135 |
+
else:
|
136 |
+
merged_ranges.append((start, end))
|
137 |
+
|
138 |
+
return merged_ranges
|
139 |
+
|
140 |
+
|
141 |
+
def _create_freeze_parameters_hook(ranges_to_freeze: List[Tuple[int, int]]) -> Callable:
|
142 |
+
"""
|
143 |
+
Create a hook to freeze parameters in specified ranges by setting their gradients to zero.
|
144 |
+
|
145 |
+
This function takes a list of tuples representing the ranges of indices to freeze. Each tuple should contain
|
146 |
+
two integers representing the start and end indices of the range.
|
147 |
+
|
148 |
+
Parameters:
|
149 |
+
- ranges_to_freeze (List[Tuple[int, int]]): Ranges of indices to freeze.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
- Callable: A hook function to be used with `register_hook` on parameters.
|
153 |
+
|
154 |
+
Example usage:
|
155 |
+
```
|
156 |
+
ranges_to_freeze = [(0, 10), (20, 30)]
|
157 |
+
hook = _create_freeze_parameters_hook(ranges_to_freeze)
|
158 |
+
model.register_hook(hook)
|
159 |
+
```
|
160 |
+
"""
|
161 |
+
|
162 |
+
def freeze_parameters_hook(gradients):
|
163 |
+
for start, end in ranges_to_freeze:
|
164 |
+
gradients[start:end].zero_()
|
165 |
+
|
166 |
+
return freeze_parameters_hook
|
167 |
+
|
168 |
+
|
169 |
+
class LayerNamePattern:
|
170 |
+
"""
|
171 |
+
Represents a regex pattern for layer names, potentially including a parameter index range.
|
172 |
+
"""
|
173 |
+
|
174 |
+
def __init__(self, pattern: str):
|
175 |
+
"""
|
176 |
+
Initializes a new instance of the LayerNamePattern class.
|
177 |
+
|
178 |
+
Parameters:
|
179 |
+
- pattern (str): The regex pattern for layer names, potentially including a parameter index range.
|
180 |
+
"""
|
181 |
+
self.raw_pattern = pattern
|
182 |
+
name_pattern, self.range = self._parse_pattern(pattern)
|
183 |
+
self.name_regex = re.compile(name_pattern.replace(".", "\\."))
|
184 |
+
|
185 |
+
def match(self, name: str) -> bool:
|
186 |
+
"""
|
187 |
+
Checks if the given layer name matches the regex pattern.
|
188 |
+
|
189 |
+
Parameters:
|
190 |
+
- name (str): The layer name to check.
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
- bool: True if the layer name matches the pattern, False otherwise.
|
194 |
+
"""
|
195 |
+
return self.name_regex.match(name) is not None
|
196 |
+
|
197 |
+
def _parse_pattern(self, pattern: str) -> Tuple[str, Tuple[int, int | None] | None]:
|
198 |
+
"""
|
199 |
+
Extracts the range pattern from the given pattern.
|
200 |
+
|
201 |
+
Parameters:
|
202 |
+
- pattern (str): The pattern to extract the range from.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
- Tuple[str, Tuple[int, int | None] | None]: A tuple containing the regex pattern to match the layer name without the range pattern and the range of layer indices to match, if specified.
|
206 |
+
"""
|
207 |
+
match = re.match(r"^(.+)\[([0-9]*)(?::([0-9]*))?\]$", pattern)
|
208 |
+
if not match:
|
209 |
+
return pattern, None
|
210 |
+
|
211 |
+
base_pattern, start_part, end_part = match.groups()
|
212 |
+
|
213 |
+
if end_part is None and start_part.isdecimal():
|
214 |
+
index = int(start_part)
|
215 |
+
return base_pattern, (index, index + 1)
|
216 |
+
|
217 |
+
# [:end] or [start:] or [start:end]
|
218 |
+
start = int(start_part) if start_part else 0
|
219 |
+
end = int(end_part) if end_part else None
|
220 |
+
|
221 |
+
if end is not None and start >= end:
|
222 |
+
raise ValueError(
|
223 |
+
f"Invalid range in layer name pattern: {pattern}."
|
224 |
+
"End of range must be greater than start."
|
225 |
+
)
|
226 |
+
return base_pattern, (start, end)
|
tests/test_freeze.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module contains unit tests for the `freeze_layers_except` function.
|
3 |
+
|
4 |
+
The `freeze_layers_except` function is used to freeze layers in a model, except for the specified layers.
|
5 |
+
The unit tests in this module verify the behavior of the `freeze_layers_except` function in different scenarios.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import unittest
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from axolotl.utils.freeze import freeze_layers_except
|
14 |
+
|
15 |
+
ZERO = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
16 |
+
ONE_TO_TEN = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
17 |
+
|
18 |
+
|
19 |
+
class TestFreezeLayersExcept(unittest.TestCase):
|
20 |
+
"""
|
21 |
+
A test case class for the `freeze_layers_except` function.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def setUp(self):
|
25 |
+
self.model = _TestModel()
|
26 |
+
|
27 |
+
def test_freeze_layers_with_dots_in_name(self):
|
28 |
+
freeze_layers_except(self.model, ["features.layer"])
|
29 |
+
self.assertTrue(
|
30 |
+
self.model.features.layer.weight.requires_grad,
|
31 |
+
"model.features.layer should be trainable.",
|
32 |
+
)
|
33 |
+
self.assertFalse(
|
34 |
+
self.model.classifier.weight.requires_grad,
|
35 |
+
"model.classifier should be frozen.",
|
36 |
+
)
|
37 |
+
|
38 |
+
def test_freeze_layers_without_dots_in_name(self):
|
39 |
+
freeze_layers_except(self.model, ["classifier"])
|
40 |
+
self.assertFalse(
|
41 |
+
self.model.features.layer.weight.requires_grad,
|
42 |
+
"model.features.layer should be trainable.",
|
43 |
+
)
|
44 |
+
self.assertTrue(
|
45 |
+
self.model.classifier.weight.requires_grad,
|
46 |
+
"model.classifier should be frozen.",
|
47 |
+
)
|
48 |
+
|
49 |
+
def test_freeze_layers_regex_patterns(self):
|
50 |
+
# The second pattern cannot match because only characters 'a' to 'c' are allowed after the word 'class', whereas it should be matching the character 'i'.
|
51 |
+
freeze_layers_except(self.model, [r"^features.[a-z]+.weight$", r"class[a-c]+"])
|
52 |
+
self.assertTrue(
|
53 |
+
self.model.features.layer.weight.requires_grad,
|
54 |
+
"model.features.layer should be trainable.",
|
55 |
+
)
|
56 |
+
self.assertFalse(
|
57 |
+
self.model.classifier.weight.requires_grad,
|
58 |
+
"model.classifier should be frozen.",
|
59 |
+
)
|
60 |
+
|
61 |
+
def test_all_layers_frozen(self):
|
62 |
+
freeze_layers_except(self.model, [])
|
63 |
+
self.assertFalse(
|
64 |
+
self.model.features.layer.weight.requires_grad,
|
65 |
+
"model.features.layer should be frozen.",
|
66 |
+
)
|
67 |
+
self.assertFalse(
|
68 |
+
self.model.classifier.weight.requires_grad,
|
69 |
+
"model.classifier should be frozen.",
|
70 |
+
)
|
71 |
+
|
72 |
+
def test_all_layers_unfrozen(self):
|
73 |
+
freeze_layers_except(self.model, ["features.layer", "classifier"])
|
74 |
+
self.assertTrue(
|
75 |
+
self.model.features.layer.weight.requires_grad,
|
76 |
+
"model.features.layer should be trainable.",
|
77 |
+
)
|
78 |
+
self.assertTrue(
|
79 |
+
self.model.classifier.weight.requires_grad,
|
80 |
+
"model.classifier should be trainable.",
|
81 |
+
)
|
82 |
+
|
83 |
+
def test_freeze_layers_with_range_pattern_start_end(self):
|
84 |
+
freeze_layers_except(self.model, ["features.layer[1:5]"])
|
85 |
+
self.assertTrue(
|
86 |
+
self.model.features.layer.weight.requires_grad,
|
87 |
+
"model.features.layer should be trainable.",
|
88 |
+
)
|
89 |
+
self.assertFalse(
|
90 |
+
self.model.classifier.weight.requires_grad,
|
91 |
+
"model.classifier should be frozen.",
|
92 |
+
)
|
93 |
+
|
94 |
+
self._assert_gradient_output(
|
95 |
+
[
|
96 |
+
ZERO,
|
97 |
+
ONE_TO_TEN,
|
98 |
+
ONE_TO_TEN,
|
99 |
+
ONE_TO_TEN,
|
100 |
+
ONE_TO_TEN,
|
101 |
+
ZERO,
|
102 |
+
ZERO,
|
103 |
+
ZERO,
|
104 |
+
ZERO,
|
105 |
+
ZERO,
|
106 |
+
]
|
107 |
+
)
|
108 |
+
|
109 |
+
def test_freeze_layers_with_range_pattern_single_index(self):
|
110 |
+
freeze_layers_except(self.model, ["features.layer[5]"])
|
111 |
+
self.assertTrue(
|
112 |
+
self.model.features.layer.weight.requires_grad,
|
113 |
+
"model.features.layer should be trainable.",
|
114 |
+
)
|
115 |
+
self.assertFalse(
|
116 |
+
self.model.classifier.weight.requires_grad,
|
117 |
+
"model.classifier should be frozen.",
|
118 |
+
)
|
119 |
+
|
120 |
+
self._assert_gradient_output(
|
121 |
+
[ZERO, ZERO, ZERO, ZERO, ZERO, ONE_TO_TEN, ZERO, ZERO, ZERO, ZERO]
|
122 |
+
)
|
123 |
+
|
124 |
+
def test_freeze_layers_with_range_pattern_start_omitted(self):
|
125 |
+
freeze_layers_except(self.model, ["features.layer[:5]"])
|
126 |
+
self.assertTrue(
|
127 |
+
self.model.features.layer.weight.requires_grad,
|
128 |
+
"model.features.layer should be trainable.",
|
129 |
+
)
|
130 |
+
self.assertFalse(
|
131 |
+
self.model.classifier.weight.requires_grad,
|
132 |
+
"model.classifier should be frozen.",
|
133 |
+
)
|
134 |
+
|
135 |
+
self._assert_gradient_output(
|
136 |
+
[
|
137 |
+
ONE_TO_TEN,
|
138 |
+
ONE_TO_TEN,
|
139 |
+
ONE_TO_TEN,
|
140 |
+
ONE_TO_TEN,
|
141 |
+
ONE_TO_TEN,
|
142 |
+
ZERO,
|
143 |
+
ZERO,
|
144 |
+
ZERO,
|
145 |
+
ZERO,
|
146 |
+
ZERO,
|
147 |
+
]
|
148 |
+
)
|
149 |
+
|
150 |
+
def test_freeze_layers_with_range_pattern_end_omitted(self):
|
151 |
+
freeze_layers_except(self.model, ["features.layer[4:]"])
|
152 |
+
self.assertTrue(
|
153 |
+
self.model.features.layer.weight.requires_grad,
|
154 |
+
"model.features.layer should be trainable.",
|
155 |
+
)
|
156 |
+
self.assertFalse(
|
157 |
+
self.model.classifier.weight.requires_grad,
|
158 |
+
"model.classifier should be frozen.",
|
159 |
+
)
|
160 |
+
|
161 |
+
self._assert_gradient_output(
|
162 |
+
[
|
163 |
+
ZERO,
|
164 |
+
ZERO,
|
165 |
+
ZERO,
|
166 |
+
ZERO,
|
167 |
+
ONE_TO_TEN,
|
168 |
+
ONE_TO_TEN,
|
169 |
+
ONE_TO_TEN,
|
170 |
+
ONE_TO_TEN,
|
171 |
+
ONE_TO_TEN,
|
172 |
+
ONE_TO_TEN,
|
173 |
+
]
|
174 |
+
)
|
175 |
+
|
176 |
+
def test_freeze_layers_with_range_pattern_merge_included(self):
|
177 |
+
freeze_layers_except(self.model, ["features.layer[4:]", "features.layer[5:6]"])
|
178 |
+
self.assertTrue(
|
179 |
+
self.model.features.layer.weight.requires_grad,
|
180 |
+
"model.features.layer should be trainable.",
|
181 |
+
)
|
182 |
+
self.assertFalse(
|
183 |
+
self.model.classifier.weight.requires_grad,
|
184 |
+
"model.classifier should be frozen.",
|
185 |
+
)
|
186 |
+
|
187 |
+
self._assert_gradient_output(
|
188 |
+
[
|
189 |
+
ZERO,
|
190 |
+
ZERO,
|
191 |
+
ZERO,
|
192 |
+
ZERO,
|
193 |
+
ONE_TO_TEN,
|
194 |
+
ONE_TO_TEN,
|
195 |
+
ONE_TO_TEN,
|
196 |
+
ONE_TO_TEN,
|
197 |
+
ONE_TO_TEN,
|
198 |
+
ONE_TO_TEN,
|
199 |
+
]
|
200 |
+
)
|
201 |
+
|
202 |
+
def test_freeze_layers_with_range_pattern_merge_intersect(self):
|
203 |
+
freeze_layers_except(self.model, ["features.layer[4:7]", "features.layer[6:8]"])
|
204 |
+
self.assertTrue(
|
205 |
+
self.model.features.layer.weight.requires_grad,
|
206 |
+
"model.features.layer should be trainable.",
|
207 |
+
)
|
208 |
+
self.assertFalse(
|
209 |
+
self.model.classifier.weight.requires_grad,
|
210 |
+
"model.classifier should be frozen.",
|
211 |
+
)
|
212 |
+
|
213 |
+
self._assert_gradient_output(
|
214 |
+
[
|
215 |
+
ZERO,
|
216 |
+
ZERO,
|
217 |
+
ZERO,
|
218 |
+
ZERO,
|
219 |
+
ONE_TO_TEN,
|
220 |
+
ONE_TO_TEN,
|
221 |
+
ONE_TO_TEN,
|
222 |
+
ONE_TO_TEN,
|
223 |
+
ZERO,
|
224 |
+
ZERO,
|
225 |
+
]
|
226 |
+
)
|
227 |
+
|
228 |
+
def test_freeze_layers_with_range_pattern_merge_separate(self):
|
229 |
+
freeze_layers_except(
|
230 |
+
self.model,
|
231 |
+
["features.layer[1:2]", "features.layer[3:4]", "features.layer[5:6]"],
|
232 |
+
)
|
233 |
+
self.assertTrue(
|
234 |
+
self.model.features.layer.weight.requires_grad,
|
235 |
+
"model.features.layer should be trainable.",
|
236 |
+
)
|
237 |
+
self.assertFalse(
|
238 |
+
self.model.classifier.weight.requires_grad,
|
239 |
+
"model.classifier should be frozen.",
|
240 |
+
)
|
241 |
+
|
242 |
+
self._assert_gradient_output(
|
243 |
+
[
|
244 |
+
ZERO,
|
245 |
+
ONE_TO_TEN,
|
246 |
+
ZERO,
|
247 |
+
ONE_TO_TEN,
|
248 |
+
ZERO,
|
249 |
+
ONE_TO_TEN,
|
250 |
+
ZERO,
|
251 |
+
ZERO,
|
252 |
+
ZERO,
|
253 |
+
ZERO,
|
254 |
+
]
|
255 |
+
)
|
256 |
+
|
257 |
+
def _assert_gradient_output(self, expected):
|
258 |
+
input_tensor = torch.tensor([ONE_TO_TEN], dtype=torch.float32)
|
259 |
+
|
260 |
+
self.model.features.layer.weight.grad = None # Reset gradients
|
261 |
+
output = self.model.features.layer(input_tensor)
|
262 |
+
loss = output.sum()
|
263 |
+
loss.backward()
|
264 |
+
|
265 |
+
expected_grads = torch.tensor(expected)
|
266 |
+
torch.testing.assert_close(
|
267 |
+
self.model.features.layer.weight.grad, expected_grads
|
268 |
+
)
|
269 |
+
|
270 |
+
|
271 |
+
class _SubLayerModule(nn.Module):
|
272 |
+
def __init__(self):
|
273 |
+
super().__init__()
|
274 |
+
self.layer = nn.Linear(10, 10)
|
275 |
+
|
276 |
+
|
277 |
+
class _TestModel(nn.Module):
|
278 |
+
def __init__(self):
|
279 |
+
super().__init__()
|
280 |
+
self.features = _SubLayerModule()
|
281 |
+
self.classifier = nn.Linear(10, 2)
|
282 |
+
|
283 |
+
|
284 |
+
if __name__ == "__main__":
|
285 |
+
unittest.main()
|