amigov1 / medalpaca /train.py
asach's picture
Upload folder using huggingface_hub
d727a17
raw
history blame
10.8 kB
import os
import sys
from typing import Tuple, Union
import fire
import torch
from datasets import load_dataset
from handler import DataHandler
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
)
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
LlamaForCausalLM,
LlamaTokenizer,
Trainer,
TrainingArguments,
)
def main(
model: str, # e.g. "decapoda-research/llama-7b-hf"
val_set_size: Union[int, float] = 0.1,
prompt_template: str = "/home/ubuntu/LLM/.conda/om/medAlpaca/medalpaca/prompts/medalpaca.json",
model_max_length: int = 256, # should not exceed 2048, as LLaMA is trained with this
train_on_inputs: bool = True, # if False, masks out inputs in loss
data_path: str = "medical_meadow_small.json",
train_in_8bit: bool = True,
use_lora: bool = True,
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.1,
lora_target_modules: Tuple[str] = ("q_proj", "v_proj"),
per_device_batch_size: int = 8,
num_epochs: int = 2,
learning_rate: float = 2e-5,
global_batch_size: int = 128,
output_dir: str = "./output",
save_total_limit: int = 3,
eval_steps: int = 200,
device_map: str = "auto",
group_by_length: bool = False,
wandb_run_name: str = "test",
use_wandb: bool = False,
wandb_project: str = "medalpaca",
optim: str = "adamw_torch",
lr_scheduler_type: str = "cosine",
fp16: bool = True,
bf16: bool = False,
gradient_checkpointing: bool = False,
warmup_steps: int = 100,
fsdp: str = "full_shard auto_wrap",
fsdp_transformer_layer_cls_to_wrap: str = "LlamaDecoderLayer",
**kwargs
):
"""
Trains a large language model using HuggingFace Transformers with custom configuration options.
Args:
model (str, optional):
The model identifier on HuggingFace Model Hub.
val_set_size (Union[int, float], optional):
The proportion or number of samples to use for validation. Default is 0.1.
prompt_template (str, optional):
The path to the JSON file containing prompt templates. Default is "/home/ubuntu/LLM/.conda/om/medAlpaca/medalpaca/prompts/medalpaca.json".
model_max_length (int, optional):
The maximum length for model inputs. Default is 256.
train_on_inputs (bool, optional):
Whether to train on input tokens. Default is True.
data_path (str, optional):
The path to the dataset file. Default is "medical_meadow_small.json".
train_in_8bit (bool, optional):
Whether to use 8-bit training. Default is True.
use_lora (bool, optional):
Whether to use the Lora method. Default is True.
lora_r (int, optional):
The Lora method's reduction factor. Default is 8.
lora_alpha (int, optional):
The Lora method's alpha parameter. Default is 16.
lora_dropout (float, optional):
The dropout rate for Lora. Default is 0.1.
lora_target_modules (List[str], optional):
The target modules for Lora. Default is ["q_proj","v_proj"].
per_device_batch_size (int, optional):
The batch size per device. Default is 2.
num_epochs (int, optional):
The number of epochs for training. Default is 3.
learning_rate (float, optional):
The learning rate for the optimizer. Default is 2e-5.
global_batch_size (int, optional):
The number of samples the model needs to see until the weights get updated.
Default is 128.
output_dir (str, optional):
The directory to save the model and outputs. Default is "./output".
save_total_limit (int, optional):
The maximum number of saved checkpoints. Default is 3.
eval_steps (int, optional):
The number of steps between evaluations. Default is 200.
device_map (str, optional):
The device placement strategy. Default is "auto".
group_by_length (bool, optional):
Whether to group samples by length for batch construction. Default is False.
wandb_run_name (str, optional):
The run name for Weights & Biases logging. Default is "test".
use_wandb (bool, optional):
Whether to use Weights & Biases for logging. Default is False.
wandb_project (str, optional):
The Weights & Biases project name. Default is "medalpaca".
optim (str, optional):
The optimizer to use. Default is "adamw_torch".
lr_scheduler_type (str, optional):
The learning rate scheduler type. Default is "cosine".
fp16 (bool, optional):
Whether to use mixed precision training (FP16). Default is True.
bf16 (bool, optional):
Whether to use mixed precision training (BF16). Default is False.
gradient_checkpointing (bool, optional):
Whether to use gradient checkpointing during training to reduce memory footprint
warmup_steps (int, optional):
The number of steps for warmup. Default is 200.
fsdp (str, optional):
Fully Sharded Data Parallel strategy. Only active with distributed training.
Default is "full_shard auto_wrap"
fsdp_transformer_layer_cls_to_wrap (optiona, str):
The model layer to wrap for fsdp. Default is "LlamaDecoderLayer".
**kwargs:
additional arguments passed to the transformers.TrainingArguments"""
# adapt arguments
model_name = "decapoda-research/llama-7b-hf"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
gradient_accumulation_steps = global_batch_size // per_device_batch_size
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size
if use_lora:
# integer and mixed dtypes are not supported with fsdp
fsdp, fsdp_transformer_layer_cls_to_wrap = "", None
else:
fsdp, fsdp_transformer_layer_cls_to_wrap = "", None
if len(wandb_project) > 0:
os.environ["WANDB_PROJECT"] = wandb_project
# perform some checks, to raise errors early
if fp16 and bf16:
raise ValueError("At most one of fp16 and bf16 can be True, but not both.")
if train_in_8bit and not use_lora:
raise ValueError("8bit training without LoRA is not supported")
if use_lora and gradient_checkpointing:
raise ValueError("gradient_checkpointing with LoRA training is not implemented")
# init model
if "llama" in model_name:
# The LLaMA config on HF is not up to date with the library,
# leading to errors when using AutoModelForCausalLM
load_model = LlamaForCausalLM
else:
load_model = AutoModelForCausalLM
# loading the model with torch_dtype=torch.float16 with only fp16 and no LoRA leads
# to `ValueError: Attempting to unscale FP16 gradients.`
model = load_model.from_pretrained(
model_name,
load_in_8bit=train_in_8bit,
torch_dtype=torch.float16 if any([use_lora, bf16]) else torch.float32,
device_map=device_map,
)
if train_in_8bit:
model = prepare_model_for_int8_training(model)
if use_lora:
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# init tokenizer and tokenize function
if "llama" in model_name.lower():
tokenizer = LlamaTokenizer.from_pretrained(model_name)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
# load and tokenize data
data_handler = DataHandler(
tokenizer=tokenizer,
prompt_template=prompt_template,
model_max_length=model_max_length,
train_on_inputs=train_on_inputs,
)
data = load_dataset("json", data_files=data_path)
if val_set_size > 0:
data = (
data["train"]
.train_test_split(test_size=val_set_size, shuffle=True, seed=42)
.map(data_handler.generate_and_tokenize_prompt)
)
else:
data = data.shuffle(seed=42).map(data_handler.generate_and_tokenize_prompt)
if not ddp and torch.cuda.device_count() > 1:
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
model.is_parallelizable = True
model.model_parallel = True
# init trainer
training_args = TrainingArguments(
per_device_train_batch_size=per_device_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=gradient_checkpointing,
warmup_steps=warmup_steps,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=fp16,
bf16=bf16,
logging_steps=10,
optim=optim,
lr_scheduler_type=lr_scheduler_type,
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=eval_steps if val_set_size > 0 else None,
save_steps=eval_steps,
output_dir=output_dir,
save_total_limit=save_total_limit,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
fsdp=fsdp,
fsdp_transformer_layer_cls_to_wrap=fsdp_transformer_layer_cls_to_wrap,
**kwargs
)
trainer = Trainer(
model=model,
train_dataset=data["train"],
eval_dataset=data["test"], # if val_set_size > 0 else None
args=training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
# for whatever reason, it is important that this is executed after trainer
# is initialized. Otherwise you run into data indexing error, as the
# trainer drops all columns in the dataset
model.config.use_cache = False
if use_lora:
old_state_dict = model.state_dict
model.state_dict = (
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
# finally, train
trainer.train()
model.save_pretrained(output_dir)
if __name__ == "__main__":
fire.Fire(main)