import argparse import shutil import logging import random import time from pprint import pprint from collections import defaultdict from pathlib import Path from scrl.rewards import load_rewards from scrl.data import load_data_for_training from scrl.config import load_config from scrl.model import load_model, LinearTokenSelector, labels_to_summary import scrl.utils as utils import scrl.sampling as sampling import numpy as np import torch from torch.nn.utils.rnn import pad_sequence from transformers import AutoModel, AutoTokenizer from sklearn import preprocessing from nltk import word_tokenize def print_if(x, do_print=True): if do_print: print(x) class TrainingManager: """ Object for saving/loading model checkpoints and for tracking and saving metrics measured during training, e.g. loss, rewards. The following directory struture is build around one training run: dir/ val_scores.json checkpoints/ latest-model-500/ classifier.bin encoder.bin best-model-200/ [...] series/ loss.npy [...] totals/ loss.npy [...] """ def __init__(self, dir): self.step = 0 self.total_seconds = 0 self.start_time = None self.series = defaultdict(list) self.totals = defaultdict(float) self.dir = dir dir.mkdir(exist_ok=True) for subdir_name in ("checkpoints", "series", "totals"): (dir / subdir_name).mkdir(exist_ok=True) def start_clock(self): self.start_time = time.time() - self.total_seconds def load(self): # load tracked data, e.g. loss, rewards etc. for p in (self.dir / "series").iterdir(): k = p.name.split(".npy")[0] self.series[k] = list(utils.load_numpy(p)) for p in (self.dir / "totals").iterdir(): k = p.name.split(".npy")[0] self.totals[k] = utils.load_numpy(p) # read latest training step latest_model_dir = self.find_old_model("latest-model") self.total_seconds = utils.read_json(self.dir / "time.json")["total_seconds"] last_step = int(latest_model_dir.name.split("-")[-1]) self.step = last_step + 1 def update_metric(self, key, value): self.totals[key] += value self.series[key].append(value) def mean_metric(self, key): return self.totals[key] / (self.step + 1) def save_latest_model(self, model, checkpoint_id): self.save_model(model, checkpoint_id, prefix="latest-model") def save_model(self, model, checkpoint_id, prefix): old_model_dir = self.find_old_model(prefix) model_dir = self.dir / "checkpoints" / f"{prefix}-{checkpoint_id}" model_dir.mkdir() model.save( classifier_path = model_dir / "classifier.bin", encoder_path = model_dir / "encoder.bin" ) if old_model_dir: shutil.rmtree(old_model_dir) def find_old_model(self, prefix): model_path = None for p in (self.dir / "checkpoints").iterdir(): if p.name.startswith(f"{prefix}"): model_path = p return model_path def is_empty(self): latest_model_dir = self.find_old_model("latest-model") return latest_model_dir is None def save_data(self): for k, v in self.series.items(): utils.save_numpy(v, self.dir / "series" / f"{k}.npy") for k, v in self.totals.items(): utils.save_numpy(v, self.dir / "totals" / f"{k}.npy") utils.write_json({ "step": self.step, "total_seconds": self.total_seconds }, self.dir / "time.json") def label_variance(probs): # batch, seq, 2 variances = [] for i in range(probs.size(0)): distrib = probs[i, :, 0] var = torch.var(distrib) variances.append(var) return var.mean().item() def check_gradient(model): is_zero = [] is_none = [] for name, param in list(model.named_parameters()): if (param.requires_grad): grad = param.grad if grad is None: is_none.append(name) else: gradsum = param.grad.sum().item() if gradsum == 0: is_zero.append(name) print("zero-grad:", len(is_zero), is_zero) print("none-grad:", len(is_none), is_none) print() def get_mean_max_prob(probs): return probs.max(dim=2).values.mean().item() def print_training_progress(args, manager, model, probs, argmax_summaries, sample_summaries, batch, argmax_details): print(f"[step: {manager.step}] [duration(s): {round(manager.total_seconds)}]") print(f"[example/s: {(args.batch_size * (manager.step + 1)) / manager.total_seconds:.3f}]") print(f"[s/step: {manager.total_seconds / (manager.step+1):.3f}]") print(f"[avg-loss: {manager.mean_metric('loss')}]") print(f"[avg-max-prob: {manager.mean_metric('mean_max_prob'):.3f}]") print(f"[avg-a-reward: {manager.mean_metric('argmax_reward'):.3f}]") print(f"[avg-s-reward: {manager.mean_metric('sample_reward'):.3f}]") print(f"[avg-len: {manager.mean_metric('argmax_len'):.1f}]") print() print(f"[a-reward: {manager.series['argmax_reward'][-1]:.3f}]") print(f"[s-reward: {manager.series['sample_reward'][-1]:.3f}]") print(f"[max-prob: {manager.series['mean_max_prob'][-1]:.3f}]") print() print("[sentences]") print("\n".join(batch["document"])) print("\n[current policy summaries]") print("\n".join(argmax_summaries)) print("\n[sampled summaries]") print("\n".join(sample_summaries)) print() print("Reward Breakdown:") pprint(argmax_details) print() check_gradient(model) print("="*100) def setup_model(args): # setup/load model manager object model_dir = Path(args.model_dir) if args.fresh and model_dir.exists(): utils.ask_rmdir(model_dir) manager = TrainingManager(model_dir) if not manager.is_empty(): manager.load() if not (model_dir / "config.json").exists(): shutil.copy(args.config, model_dir / "config.json") # initialize new or load existing model if manager.step == 0: encoder = AutoModel.from_pretrained(args.encoder_model_id) embedding_size = encoder.state_dict()["embeddings.word_embeddings.weight"].shape[1] model = LinearTokenSelector(encoder, embedding_size).to(args.device) else: print("loading latest model from step", manager.step - 1) model = load_model( model_dir, prefix="latest", device=args.device ) return manager, model def setup_dataset_indices(args, step): """ Load pre-built indices that determine in which order we traverse a dataset. If we continue interrupted training state, we move indices accordingly. """ dataset_indices = utils.batchify( utils.load_numpy(args.indices), args.batch_size ) if step > 0: utils.move_generator(dataset_indices, step) return dataset_indices def train( args, manager, model, tokenizer, reward_generator, dataset, dataset_indices, eval_func ): optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) n_train = len(dataset["train"]) device = args.device model.train() manager.start_clock() for indices in dataset_indices: step = manager.step manager.total_seconds = time.time() - manager.start_time if args.max_train_steps and step >= args.max_train_steps + 1: break if args.max_train_seconds and manager.total_seconds >= args.max_train_seconds: break optimizer.zero_grad() batch = dataset["train"][indices] input_ids = pad_sequence( [torch.tensor(ids) for ids in batch["input_ids"]], batch_first=True ).to(device) logits = model(input_ids) probs = torch.softmax(logits, dim=2) argmax_labels = torch.argmax(logits, dim=2).to(device) argmax_summaries = labels_to_summary(input_ids, argmax_labels, tokenizer) argmax_rewards, argmax_details = reward_generator(batch["document"], argmax_summaries) a_reward = np.mean(argmax_rewards) (sample_probs, sample_summaries, sample_rewards, sample_details, sample_labels, sample_data) = sampling.best_of_k_samples( args, manager, tokenizer, reward_generator, input_ids, batch, probs, k_samples=args.k_samples, ) s_reward = np.mean(sample_rewards) if args.sample_aggregation == "max": loss = (a_reward - s_reward) * sample_probs.sum(1).mean() else: loss = 0. for sample_probs_i, s_rewards_i in zip(sample_data["probs"], sample_data["rewards"]): s_reward_i = np.mean(s_rewards_i) loss_i = (a_reward_i - s_reward_i) * sample_probs_i.sum(1).mean() loss += loss_i loss /= len(sample_data["rewards"]) if args.sample_aggregation == "mean" or a_reward != s_reward: # not updating model if no reward difference, in case of single sample loss.backward() optimizer.step() argmax_len = np.mean([len(word_tokenize(s)) for s in argmax_summaries]) manager.update_metric("time", time.time()) manager.update_metric("loss", loss.item()) manager.update_metric("argmax_reward", a_reward) manager.update_metric("sample_reward", s_reward) manager.update_metric("sample_prob", sample_probs.detach().cpu().numpy().mean()) manager.update_metric("mean_max_prob", get_mean_max_prob(probs)) manager.update_metric("label_variance", label_variance(probs)) manager.update_metric("argmax_len", argmax_len) for rname, rvalues in argmax_details.items(): manager.update_metric(f"reward|{rname}", np.mean(rvalues)) if args.eval_every != None and (step > 0 and step % args.eval_every == 0): eval_func( args, manager, model, tokenizer, reward_generator, dataset["validation"] ) model.train() if args.save_every != None and (step % args.save_every == 0): manager.save_latest_model(model, step) manager.save_data() if args.print_every != None and (args.verbose and step % args.print_every == 0): print_training_progress( args, manager, model, probs, argmax_summaries, sample_summaries, batch, argmax_details ) manager.step += 1 def setup_and_train(args, eval_func): print_if("loading model", args.verbose) manager, model = setup_model(args) print_if("loading tokenizer", args.verbose) tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_id) print_if("loading rewards", args.verbose) reward_generator = load_rewards(args) print_if("rewards:", reward_generator.reward_names) print_if("loading dataset", args.verbose) dataset = load_data_for_training(tokenizer, args.loader, args.dataset) dataset_indices = setup_dataset_indices(args, manager.step) train( args, manager, model, tokenizer, reward_generator, dataset, dataset_indices, eval_func )