File size: 6,641 Bytes
97e4014 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import os
import sys
import argparse
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
from transformers import (
Seq2SeqTrainer,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
from peft import get_peft_model, prepare_model_for_kbit_training
path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, path)
from utils import *
# from model.models import load_model
from model.model import load_model
from data.preprocessing import preprocessing_data
from data.ingest_data import ingest_data
import evaluate
def training_pipeline(args: argparse.Namespace):
try:
print("=========================================")
print('\n'.join(f' + {k}={v}' for k, v in vars(args).items()))
print("=========================================")
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(args.checkpoint)
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
print(type(tokenizer))
if (args.lora == False):
print("lora=Fasle, quantize=False")
base_model = AutoModelForSeq2SeqLM.from_pretrained(args.checkpoint).to(device)
# model.base_model = model.get_model()
# model.base_model.to(device)
else:
from peft import LoraConfig, TaskType
from transformers import BitsAndBytesConfig
import torch
# Define LoRA Config
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=args.target_modules.split(","),
lora_dropout=args.lora_dropout,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM
)
if (args.quantize == True):
print("Quantize=True, lora=True")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = AutoModelForSeq2SeqLM.from_pretrained(args.checkpoint,
quantization_config=bnb_config,
device_map={"":0},
trust_remote_code=True)
base_model = prepare_model_for_kbit_training(base_model)
if (args.quantize==False):
print("Quantize=False, lora=True")
base_model = AutoModelForSeq2SeqLM.from_pretrained(args.checkpoint).to(device)
# add LoRA adaptor
print("Base model:", model.base_model)
base_model = get_peft_model(base_model, lora_config)
base_model.print_trainable_parameters()
# Load data from datapath
data = ingest_data(args.datapath)
print("\033[92m[+] Complete loading dataset!\033[00m")
# Pre-processing data
data = preprocessing_data(data, tokenizer, use_contrastive_loss=args.use_contrastive_loss, tokenizing_strategy=args.tokenizing_strategy)
print("\033[92m[+] Complete pre-processing dataset!\033[00m")
# Load training arguments
training_args = load_training_arguments(args)
print("\033[92m[+] Complete loading training arguments!\033[00m")
# Load metric
metric = evaluate.load("rouge")
nltk.download("punkt")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [label.strip() for label in labels]
preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
labels = ["\n".join(sent_tokenize(label)) for label in labels]
return preds, labels
def compute_metric(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
# metric = evaluate.load("rouge")
rouge_results = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
rouge_results = {k: round(v * 100, 4) for k, v in rouge_results.items()}
results = {
"rouge1": rouge_results["rouge1"],
"rouge2": rouge_results["rouge2"],
"rougeL": rouge_results["rougeL"],
"rougeLsum": rouge_results["rougeLsum"],
"gen_len": np.mean([np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds])
}
return results
# Load trainer
if args.use_contrastive_loss==True:
trainer = ContrastiveLearningTrainer(model=base_model,
train_dataset=data["train"],
eval_dataset=data["validation"],
tokenizer=tokenizer,
compute_metrics=compute_metric)
if args.use_contrastive_loss==False:
trainer = Seq2SeqTrainer(model=base_model,
args=training_args,
train_dataset=data["train"],
eval_dataset=data["validation"],
tokenizer=tokenizer,
compute_metrics=compute_metric)
print("\033[92m[+] Complete loading trainer!\033[00m")
# Train model
trainer.train()
print("\033[92m[+] Complete training!\033[00m")
# Push to Huggingface Hub
trainer.push_to_hub()
print("\033[92m [+] Complete pushing model to hub!\033[00m")
except Exception as e:
print(f"\033[31m\nError while training: {e}\033[00m")
raise e
|