File size: 5,034 Bytes
b0e4fff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import optuna
import torch
import random
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer
import time
# Set random seed for reproducibility
random_seed = 42
torch.manual_seed(random_seed)
random.seed(random_seed)
# Load dataset
dataset = load_dataset("tatsu-lab/alpaca", split="train")
def chatml_format(example):
"""Format the dataset for training, accounting for empty columns."""
return {
"instruction": example['instruction'] if 'instruction' in example else " \n",
"input": example['input'] if 'input' in example else " \n",
"system": example['system'] if 'system' in example else " \n",
"output": example['output'] if 'output' in example else " \n",
}
# Format dataset
dataset = dataset.map(chatml_format, remove_columns=dataset.column_names)
# Define the model initialization function
def model_init(trial=None):
original = False
params = {}
if trial is not None:
n_ahead = 1
n_ahead_talk = 1
n_passes = 1
gumbel_temperature = 1
use_start_thought_token = True
use_end_thought_token = True
include_policy_loss = True
gumbel_detach = True
merged_talk_heads = True
residual_think_head = False
optimize_lm_head_only_at_start = False
model_id = "Crystalcareai/Quiet-Star-Custom"
tokenizer_id = model_id
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
max_thoughts=n_ahead + n_ahead_talk + 1,
merged_talk_heads=merged_talk_heads,
merged_lm_and_talk_heads=False,
merged_lm_and_think_heads=True,
use_concat_talk_head=True,
use_shallow_think=True,
use_shallow_talk=False,
use_complex_think_head=False,
use_complex_talk_head=True,
use_weighted_talk_head=True,
trust_remote_code=True,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, truncation=True, padding="left")
tokenizer.pad_token_id = tokenizer.eos_token_id
special_tokens_to_add = []
if model.use_start_thought_token:
special_tokens_to_add.append("<|startthought|>")
if model.use_end_thought_token:
special_tokens_to_add.append("<|endthought|>")
if special_tokens_to_add:
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
model.resize_token_embeddings(len(tokenizer))
model.tokenizer = tokenizer
for name, module in model.named_modules():
if "embed" in name:
print(module, flush=True)
model.gumbel_detach = gumbel_detach
model.include_policy_loss = include_policy_loss
model.use_end_thought_token = use_end_thought_token
model.use_start_thought_token = use_start_thought_token
model.n_ahead = n_ahead
model.n_ahead_talk = n_ahead_talk
model.n_passes = n_passes
model.residual_think_head = residual_think_head
model.gumbel_temperature = gumbel_temperature
model.original_mode = original
model.config_params = params
model.run_start = int(time.time())
model.train()
return model
# Define the objective function for Optuna
# Define the objective function for Optuna
def objective(trial):
# Hyperparameters to be optimized
learning_rate = trial.suggest_float("learning_rate", 1e-07, 1e-06, log=True)
max_grad_norm = trial.suggest_float("max_grad_norm", 0.3, 1.0)
warmup_steps = trial.suggest_int("warmup_steps", 0, 20)
gradient_accumulation_steps = trial.suggest_int("gradient_accumulation_steps", 4, 8)
model = model_init(trial)
training_args = TrainingArguments(
output_dir="./out",
num_train_epochs=3,
max_steps=30,
per_device_train_batch_size=1,
logging_steps=1,
optim="lion_32bit",
save_strategy="steps",
save_steps=3000,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
max_grad_norm=max_grad_norm,
warmup_steps=warmup_steps,
lr_scheduler_type="cosine",
report_to="none" # Disable reporting to avoid errors related to WandB in this context
)
trainer = SFTTrainer(
args=training_args,
train_dataset=dataset,
model=model,
tokenizer=model.tokenizer,
max_seq_length=1024,
dataset_text_field="output",
)
# Train the model and get the training loss
train_result = trainer.train()
loss = train_result.training_loss
return loss
# Create a study and optimize
study = optuna.create_study(storage="sqlite:///db.sqlite3")
study.optimize(objective, n_trials=100)
# Print the best trial
print("Best trial:")
trial = study.best_trial
print(f" Loss: {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")
|