""" module for LISA Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl Arxiv: https://arxiv.org/abs/2403.17919 License: Apache 2.0 """ import logging from functools import reduce from typing import TYPE_CHECKING import numpy as np from transformers import TrainerCallback if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainer LOG = logging.getLogger("axolotl.callbacks.lisa") def lisa_callback_factory(trainer: "AxolotlTrainer"): class LISACallback(TrainerCallback): """trainer callback for lisa layer switching""" def __init__( self, n_layers, step_interval, trainer, layers_attribute="model.layers" ): super().__init__() self.n_layers = n_layers self.step_interval = step_interval self.layers_attribute = layers_attribute self.trainer = trainer reduce(getattr, self.layers_attribute.split("."), self.trainer.model) self.total_layers = len( reduce(getattr, self.layers_attribute.split("."), self.trainer.model) ) self.active_layers_indices = [] layers = reduce( getattr, self.layers_attribute.split("."), self.trainer.model ) LOG.info( f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps" ) def freeze_all_layers(self): layers = reduce( getattr, self.layers_attribute.split("."), self.trainer.model ) for layer in layers: for param in layer.parameters(): param.requires_grad = False def on_step_begin( self, args, state, control, **kwargs ): # pylint: disable=unused-argument # Check if it's time to switch active layers, including at step 0 if state.global_step % self.step_interval == 0 or state.global_step == 1: self.switch_active_layers() def switch_active_layers(self): # First, disable gradients for all layers self.freeze_all_layers() # Randomly select n_layers to activate layers = reduce( getattr, self.layers_attribute.split("."), self.trainer.model ) self.active_layers_indices = np.random.choice( range(self.total_layers), self.n_layers, replace=False ) LOG.info( f"Activating layers at indices: {self.active_layers_indices} for the next steps." ) # Enable gradients only for the selected layers for idx in self.active_layers_indices: for param in layers[idx].parameters(): param.requires_grad = True lisa_callback = LISACallback( n_layers=trainer.args.lisa_n_layers, step_interval=trainer.args.lisa_step_interval, trainer=trainer, layers_attribute=trainer.args.lisa_layers_attribute, ) return lisa_callback