|
import optuna |
|
from transformers import ( |
|
AutoTokenizer, AutoModelForCausalLM, TrainingArguments, |
|
Trainer, DataCollatorForLanguageModeling |
|
) |
|
import torch |
|
from datasets import load_dataset |
|
import numpy as np |
|
import gc |
|
from sklearn.gaussian_process import GaussianProcessRegressor |
|
from sklearn.gaussian_process.kernels import ConstantKernel, Matern |
|
import matplotlib.pyplot as plt |
|
from scipy.stats import norm |
|
import warnings |
|
warnings.filterwarnings('ignore', category=UserWarning) |
|
|
|
from transformers import TrainerCallback |
|
|
|
import argparse |
|
|
|
|
|
num_trials = 10 |
|
DATASET = load_dataset("BramVanroy/dolly-15k-dutch", split="train_sft[:1000]") |
|
CONTEXT_WINDOW = 1024 |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B") |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "left" |
|
|
|
def prepare_chat_format(examples): |
|
chats = [] |
|
for messages in examples['messages']: |
|
try: |
|
chat = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=True, |
|
max_length=CONTEXT_WINDOW, |
|
truncation=True, |
|
return_tensors=None |
|
) |
|
chats.append(chat) |
|
except Exception as e: |
|
print(f"Error applying chat template: {e}") |
|
print("Fallback format if chat template fails") |
|
text = "" |
|
for message in messages: |
|
role = message["role"] |
|
content = message["content"] |
|
text += f"<|{role}|>\n{content}</s>\n" |
|
|
|
chat = tokenizer( |
|
text, |
|
max_length=CONTEXT_WINDOW, |
|
truncation=True, |
|
return_tensors=None |
|
)["input_ids"] |
|
|
|
chats.append(chat) |
|
return {"input_ids": chats} |
|
|
|
|
|
tokenized_dataset = DATASET.map( |
|
prepare_chat_format, |
|
batched=True, |
|
remove_columns=DATASET.column_names |
|
) |
|
|
|
def clear_memory(): |
|
"""Clear GPU memory between trials""" |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
class LossCallback(TrainerCallback): |
|
def __init__(self): |
|
self.losses = [] |
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
if logs is not None and "loss" in logs: |
|
self.losses.append(logs["loss"]) |
|
|
|
def objective(trial): |
|
|
|
clear_memory() |
|
|
|
lr = trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True) |
|
|
|
|
|
torch.manual_seed(42) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"Zyphra/Zamba2-1.2B", |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto" |
|
) |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
|
|
batch_size = 4 |
|
grad_accum_steps = 8 |
|
effective_batch_size = batch_size * grad_accum_steps |
|
total_steps = len(tokenized_dataset) // effective_batch_size |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=f"./optuna_runs/trial_{trial.number}", |
|
num_train_epochs=1, |
|
per_device_train_batch_size=batch_size, |
|
gradient_accumulation_steps=grad_accum_steps, |
|
logging_steps=max(total_steps // 20, 1), |
|
learning_rate=lr, |
|
weight_decay=0.01, |
|
fp16=False, |
|
bf16=True, |
|
warmup_steps=total_steps // 10, |
|
save_steps=1000000, |
|
save_total_limit=None, |
|
report_to="none", |
|
seed=42, |
|
dataloader_num_workers=4, |
|
gradient_checkpointing=True, |
|
max_grad_norm=1.0 |
|
) |
|
|
|
print(f"\nTrial {trial.number}:") |
|
print(f"Learning rate: {lr}") |
|
print(f"Total steps: {total_steps}") |
|
print(f"Logging every {training_args.logging_steps} steps") |
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
tokenizer=tokenizer, |
|
mlm=False |
|
) |
|
|
|
class CustomTrainer(Trainer): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.model = model |
|
|
|
def _move_model_to_device(self, model, device): |
|
pass |
|
|
|
|
|
loss_callback = LossCallback() |
|
|
|
trainer = CustomTrainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset, |
|
data_collator=data_collator, |
|
callbacks=[loss_callback] |
|
) |
|
|
|
try: |
|
train_result = trainer.train() |
|
|
|
|
|
losses = loss_callback.losses |
|
n_losses = max(len(losses) // 5, 1) |
|
final_losses = losses[-n_losses:] |
|
mean_loss = np.mean(final_losses) if final_losses else float('inf') |
|
|
|
|
|
del model |
|
del trainer |
|
clear_memory() |
|
|
|
return mean_loss |
|
|
|
except Exception as e: |
|
print(f"Trial failed with error: {e}") |
|
|
|
del model |
|
del trainer |
|
clear_memory() |
|
return float('inf') |
|
|
|
|
|
study = optuna.create_study( |
|
direction="minimize", |
|
sampler=optuna.samplers.TPESampler(seed=42), |
|
study_name="learning_rate_optimization" |
|
) |
|
|
|
study.optimize(objective, n_trials=num_trials) |
|
|
|
|
|
print(f"\nOptimization Results ({num_trials} trials):") |
|
print("Best learning rate:", study.best_params["learning_rate"]) |
|
print("Best loss:", study.best_value) |
|
print("\nAll trials:") |
|
for trial in study.trials: |
|
print(f"Learning rate: {trial.params['learning_rate']:.2e}, Loss: {trial.value:.4f}") |
|
|
|
|
|
import json |
|
results = { |
|
"best_learning_rate": study.best_params["learning_rate"], |
|
"best_loss": study.best_value, |
|
"all_trials": [(trial.params["learning_rate"], trial.value) for trial in study.trials] |
|
} |
|
with open("lr_optimization_results.json", "w") as f: |
|
json.dump(results, f, indent=4) |
|
|
|
|
|
try: |
|
fig = optuna.visualization.plot_optimization_history(study) |
|
fig.show() |
|
except Exception as e: |
|
print(f"Could not create visualization: {e}") |
|
|
|
|
|
def optimize_final_lr(study): |
|
try: |
|
|
|
X = np.array([[trial.params['learning_rate']] for trial in study.trials]) |
|
y = np.array([trial.value for trial in study.trials]) |
|
|
|
|
|
valid_mask = np.isfinite(y) |
|
if not np.any(valid_mask): |
|
print("No valid trials found. Returning default learning rate.") |
|
return { |
|
'gpr_optimal_lr': 2e-5, |
|
'ei_optimal_lr': 2e-5, |
|
'predicted_loss': float('inf'), |
|
'uncertainty': float('inf') |
|
} |
|
|
|
|
|
X = X[valid_mask] |
|
y = y[valid_mask] |
|
|
|
|
|
if len(X) < 2: |
|
print("Not enough valid trials for GPR. Returning best observed value.") |
|
best_idx = np.argmin(y) |
|
return { |
|
'gpr_optimal_lr': float(X[best_idx][0]), |
|
'ei_optimal_lr': float(X[best_idx][0]), |
|
'predicted_loss': float(y[best_idx]), |
|
'uncertainty': float('inf') |
|
} |
|
|
|
|
|
X_log = np.log10(X) |
|
|
|
|
|
y_mean = np.mean(y) |
|
y_std = np.std(y) |
|
if y_std == 0: |
|
y_std = 1 |
|
y_normalized = (y - y_mean) / y_std |
|
|
|
|
|
kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5) |
|
|
|
|
|
gpr = GaussianProcessRegressor( |
|
kernel=kernel, |
|
n_restarts_optimizer=10, |
|
random_state=42, |
|
normalize_y=False |
|
) |
|
|
|
try: |
|
gpr.fit(X_log, y_normalized) |
|
except np.linalg.LinAlgError: |
|
print("GPR fitting failed. Returning best observed value.") |
|
best_idx = np.argmin(y) |
|
return { |
|
'gpr_optimal_lr': float(X[best_idx][0]), |
|
'ei_optimal_lr': float(X[best_idx][0]), |
|
'predicted_loss': float(y[best_idx]), |
|
'uncertainty': float('inf') |
|
} |
|
|
|
|
|
X_pred_log = np.linspace(np.log10(X.min()), np.log10(X.max()), 1000).reshape(-1, 1) |
|
|
|
|
|
y_pred_normalized, sigma = gpr.predict(X_pred_log, return_std=True) |
|
|
|
|
|
y_pred = y_pred_normalized * y_std + y_mean |
|
sigma = sigma * y_std |
|
|
|
|
|
best_idx = np.argmin(y_pred) |
|
optimal_lr = 10 ** X_pred_log[best_idx, 0] |
|
|
|
|
|
best_f = np.min(y) |
|
Z = (best_f - y_pred) / (sigma + 1e-9) |
|
ei = sigma * (Z * norm.cdf(Z) + norm.pdf(Z)) |
|
|
|
|
|
ei_best_idx = np.argmax(ei) |
|
ei_optimal_lr = 10 ** X_pred_log[ei_best_idx, 0] |
|
|
|
return { |
|
'gpr_optimal_lr': float(optimal_lr), |
|
'ei_optimal_lr': float(ei_optimal_lr), |
|
'predicted_loss': float(y_pred[best_idx]), |
|
'uncertainty': float(sigma[best_idx]) |
|
} |
|
|
|
except Exception as e: |
|
print(f"Optimization failed with error: {e}") |
|
return { |
|
'gpr_optimal_lr': 2e-5, |
|
'ei_optimal_lr': 2e-5, |
|
'predicted_loss': float('inf'), |
|
'uncertainty': float('inf') |
|
} |
|
|
|
|
|
try: |
|
final_optimization = optimize_final_lr(study) |
|
print("\nAdvanced Optimization Results:") |
|
print(f"GPR Optimal Learning Rate: {final_optimization['gpr_optimal_lr']:.2e}") |
|
print(f"Expected Improvement Optimal Learning Rate: {final_optimization['ei_optimal_lr']:.2e}") |
|
print(f"Predicted Loss: {final_optimization['predicted_loss']:.4f}") |
|
print(f"Uncertainty: {final_optimization['uncertainty']:.4f}") |
|
except Exception as e: |
|
print(f"Final optimization failed: {e}") |
|
final_optimization = { |
|
'gpr_optimal_lr': 2e-5, |
|
'ei_optimal_lr': 2e-5, |
|
'predicted_loss': float('inf'), |
|
'uncertainty': float('inf') |
|
} |
|
|
|
|
|
results.update({ |
|
"gpr_optimal_lr": float(final_optimization['gpr_optimal_lr']), |
|
"ei_optimal_lr": float(final_optimization['ei_optimal_lr']), |
|
"predicted_loss": float(final_optimization['predicted_loss']), |
|
"uncertainty": float(final_optimization['uncertainty']) |
|
}) |
|
|
|
|
|
def plot_gpr_results(study, final_optimization): |
|
|
|
X = np.array([[trial.params['learning_rate']] for trial in study.trials]) |
|
y = np.array([trial.value for trial in study.trials]) |
|
|
|
|
|
finite_mask = np.isfinite(y) |
|
X = X[finite_mask] |
|
y = y[finite_mask] |
|
|
|
|
|
if len(X) < 2: |
|
print("Not enough valid points for GPR visualization") |
|
return |
|
|
|
|
|
X_pred = np.logspace(np.log10(X.min()), np.log10(X.max()), 100).reshape(-1, 1) |
|
X_pred_log = np.log10(X_pred) |
|
|
|
|
|
kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5) |
|
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, random_state=42) |
|
gpr.fit(np.log10(X), y) |
|
|
|
|
|
y_pred, sigma = gpr.predict(X_pred_log, return_std=True) |
|
|
|
plt.figure(figsize=(12, 6)) |
|
plt.semilogx(X, y, 'ko', label='Valid Trials', markersize=8) |
|
plt.semilogx(X_pred, y_pred, 'b-', label='GPR Mean') |
|
plt.fill_between(X_pred.ravel(), |
|
y_pred - 2*sigma, |
|
y_pred + 2*sigma, |
|
color='blue', |
|
alpha=0.2, |
|
label='95% Confidence') |
|
|
|
|
|
if np.isfinite(final_optimization['gpr_optimal_lr']): |
|
plt.axvline(final_optimization['gpr_optimal_lr'], color='r', linestyle='--', |
|
label='GPR Optimal LR') |
|
if np.isfinite(final_optimization['ei_optimal_lr']): |
|
plt.axvline(final_optimization['ei_optimal_lr'], color='g', linestyle='--', |
|
label='EI Optimal LR') |
|
|
|
plt.xlabel('Learning Rate') |
|
plt.ylabel('Loss') |
|
plt.title('Learning Rate Optimization Results with GPR') |
|
plt.legend() |
|
plt.grid(True) |
|
plt.savefig('lr_optimization_plot.png', dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
plot_gpr_results(study, final_optimization) |
|
|
|
|
|
with open("lr_optimization_results.json", "w") as f: |
|
json.dump(results, f, indent=4) |