|
--- |
|
license: llama3.2 |
|
datasets: |
|
- WillHeld/top_v2 |
|
language: |
|
- en |
|
base_model: |
|
- meta-llama/Llama-3.2-1B |
|
pipeline_tag: text-generation |
|
library_name: transformers |
|
tags: |
|
- trl |
|
- sft |
|
--- |
|
|
|
SFT with Layer Skip. |
|
|
|
``` |
|
class LayerSkipSFTTrainer(SFTTrainer): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.early_exit_layer = 1 |
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
labels = inputs.pop("labels") |
|
outputs = model(**inputs, output_hidden_states=True) |
|
|
|
hidden_state = outputs["hidden_states"][self.early_exit_layer] |
|
logits = model.lm_head(hidden_state) |
|
loss = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size) |
|
|
|
self.early_exit_layer = (self.early_exit_layer + 1) % model.config.num_hidden_layers |
|
|
|
return loss |
|
``` |