jasonmcaffee commited on
Commit
e6cff95
1 Parent(s): 2dc1457

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +135 -0
README.md CHANGED
@@ -1,3 +1,138 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+ # Overview
5
+ A LoRA adapter created by fine tuning the flan-t5-large model using the [SAMsum training dataset](https://huggingface.co/datasets/samsum).
6
+
7
+ SAMsum is a corpus comprised of 16k dialogues and corresponding summaries.
8
+
9
+ Example entry:
10
+ - Dialogue - "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)"
11
+ - Summary - "Amanda baked cookies and will bring Jerry some tomorrow."
12
+
13
+ [LoRA](https://github.com/microsoft/LoRA) is a performant mechanism for fine tuning models to become better at tasks.
14
+ > An important paradigm of natural language processing consists of large-scale pre-training on general domain data and adaptation to particular tasks or domains. As we pre-train larger models, full fine-tuning, which retrains all model parameters, becomes less feasible. Using GPT-3 175B as an example -- deploying independent instances of fine-tuned models, each with 175B parameters, is prohibitively expensive. We propose Low-Rank Adaptation, or LoRA, which freezes the pre-trained model weights and injects trainable rank decomposition matrices into each layer of the Transformer architecture, greatly reducing the number of trainable parameters for downstream tasks. Compared to GPT-3 175B fine-tuned with Adam, LoRA can reduce the number of trainable parameters by 10,000 times and the GPU memory requirement by 3 times. LoRA performs on-par or better than fine-tuning in model quality on RoBERTa, DeBERTa, GPT-2, and GPT-3, despite having fewer trainable parameters, a higher training throughput, and, unlike adapters, no additional inference latency.
15
+
16
+ In this case we are training the flan-t5 on the SAMsum dataset in order to create a model that is better at dialog summary.
17
+
18
+ # Code
19
+
20
+ ## Notebook Source
21
+ [Notebook used to create LoRA adapter](https://colab.research.google.com/drive/1z_mZL6CIRRA4AeF6GXe-zpfEGqqdMk-f?usp=sharing)
22
+
23
+ ## Load the samsum dataset that we will use to finetune the flan-t5-large model with.
24
+ ```
25
+ from datasets import load_dataset
26
+ dataset = load_dataset("samsum")
27
+ ```
28
+
29
+ ## Prepare the dataset
30
+ ```
31
+ ... see notebook
32
+ # save datasets to disk for later easy loading
33
+ tokenized_dataset["train"].save_to_disk("data/train")
34
+ tokenized_dataset["test"].save_to_disk("data/eval")
35
+ ```
36
+
37
+ ## Load the flan-t5-large model
38
+ Loading in 8bit greatly reduces the amount of GPU memory required.
39
+
40
+ When combined with the accelerate library, device_map="auto" will use all available gpus for training.
41
+ ```
42
+ from transformers import AutoModelForSeq2SeqLM
43
+ model_id = "google/flan-t5-large"
44
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16)
45
+ ```
46
+
47
+ ## Define LoRA config and prepare the model for training
48
+ ```
49
+ from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
50
+ lora_config = LoraConfig(
51
+ r=16,
52
+ lora_alpha=32,
53
+ target_modules=["q", "v"],
54
+ lora_dropout=0.05,
55
+ bias="none",
56
+ task_type=TaskType.SEQ_2_SEQ_LM
57
+ )
58
+ # prepare int-8 model for training
59
+ model = prepare_model_for_int8_training(model)
60
+
61
+ # add LoRA adaptor
62
+ model = get_peft_model(model, lora_config)
63
+ model.print_trainable_parameters()
64
+ ```
65
+ ## Create data collator
66
+ Data collators are objects that will form a batch by using a list of dataset elements as input.
67
+ ```
68
+ from transformers import DataCollatorForSeq2Seq
69
+
70
+ # we want to ignore tokenizer pad token in the loss
71
+ label_pad_token_id = -100
72
+ # Data collator
73
+ data_collator = DataCollatorForSeq2Seq(
74
+ tokenizer,
75
+ model=model,
76
+ label_pad_token_id=label_pad_token_id,
77
+ pad_to_multiple_of=8
78
+ )
79
+ ```
80
+
81
+ ## Create the training arguments and trainer
82
+ ```
83
+ from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
84
+
85
+ output_dir="lora-flan-t5-large"
86
+
87
+ # Define training args
88
+ training_args = Seq2SeqTrainingArguments(
89
+ output_dir=output_dir,
90
+ auto_find_batch_size=True,
91
+ learning_rate=1e-3, # higher learning rate
92
+ num_train_epochs=5,
93
+ logging_dir=f"{output_dir}/logs",
94
+ logging_strategy="steps",
95
+ logging_steps=500,
96
+ save_strategy="no",
97
+ report_to="tensorboard",
98
+ )
99
+
100
+ # Create Trainer instance
101
+ trainer = Seq2SeqTrainer(
102
+ model=model,
103
+ args=training_args,
104
+ data_collator=data_collator,
105
+ train_dataset=tokenized_dataset["train"],
106
+ )
107
+ model.config.use_cache = False # re-enable for inference!
108
+ ```
109
+
110
+ ## Train the model!
111
+ This will take about 5-6 hours on a singe T4 GPU
112
+ ```
113
+ trainer.train()
114
+ ```
115
+ | Step | Training Loss |
116
+ |------|---------------|
117
+ | 500 | 1.302200 |
118
+ | 1000 | 1.306300 |
119
+ | 1500 | 1.341500 |
120
+ | 2000 | 1.278500 |
121
+ | 2500 | 1.237000 |
122
+ | 3000 | 1.239200 |
123
+ | 3500 | 1.250900 |
124
+ | 4000 | 1.202100 |
125
+ | 4500 | 1.165300 |
126
+ | 5000 | 1.178900 |
127
+ | 5500 | 1.181700 |
128
+ | 6000 | 1.100600 |
129
+ | 6500 | 1.119800 |
130
+ | 7000 | 1.105700 |
131
+ | 7500 | 1.097900 |
132
+ | 8000 | 1.059500 |
133
+ | 8500 | 1.047400 |
134
+ | 9000 | 1.046100 |
135
+
136
+ TrainOutput(global_step=9210, training_loss=1.1780610539108094, metrics={'train_runtime': 19217.7668, 'train_samples_per_second': 3.833, 'train_steps_per_second': 0.479, 'total_flos': 8.541847343333376e+16, 'train_loss': 1.1780610539108094, 'epoch': 5.0})
137
+
138
+