""" | |
module to freeze/unfreeze parameters by name | |
""" | |
import logging | |
import re | |
from axolotl.utils.distributed import is_main_process | |
LOG = logging.getLogger("axolotl.utils.freeze") | |
def freeze_parameters_except(model, regex_patterns): | |
""" | |
Freezes all layers of the given model except for the layers that match given regex patterns. | |
Periods in the patterns are treated as literal periods, not as wildcard characters. | |
Parameters: | |
- model (nn.Module): The PyTorch model to be modified. | |
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen. | |
Returns: | |
None; the model is modified in place. | |
""" | |
# Escape periods and compile the regex patterns | |
compiled_patterns = [ | |
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns | |
] | |
# First, freeze all parameters in the model | |
for param in model.parameters(): | |
param.requires_grad = False | |
# Unfreeze layers that match the regex patterns | |
for name, param in model.named_parameters(): | |
if any(pattern.match(name) for pattern in compiled_patterns): | |
if is_main_process(): | |
LOG.debug(f"unfreezing {name}") | |
param.requires_grad = True | |