Spaces:
Runtime error
Runtime error
import argparse | |
import shutil | |
import logging | |
import random | |
import time | |
from pprint import pprint | |
from collections import defaultdict | |
from pathlib import Path | |
from scrl.rewards import load_rewards | |
from scrl.data import load_data_for_training | |
from scrl.config import load_config | |
from scrl.model import load_model, LinearTokenSelector, labels_to_summary | |
import scrl.utils as utils | |
import scrl.sampling as sampling | |
import numpy as np | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from transformers import AutoModel, AutoTokenizer | |
from sklearn import preprocessing | |
from nltk import word_tokenize | |
def print_if(x, do_print=True): | |
if do_print: | |
print(x) | |
class TrainingManager: | |
""" | |
Object for saving/loading model checkpoints and for tracking and saving | |
metrics measured during training, e.g. loss, rewards. | |
The following directory struture is build around one training run: | |
dir/ | |
val_scores.json | |
checkpoints/ | |
latest-model-500/ | |
classifier.bin | |
encoder.bin | |
best-model-200/ | |
[...] | |
series/ | |
loss.npy | |
[...] | |
totals/ | |
loss.npy | |
[...] | |
""" | |
def __init__(self, dir): | |
self.step = 0 | |
self.total_seconds = 0 | |
self.start_time = None | |
self.series = defaultdict(list) | |
self.totals = defaultdict(float) | |
self.dir = dir | |
dir.mkdir(exist_ok=True) | |
for subdir_name in ("checkpoints", "series", "totals"): | |
(dir / subdir_name).mkdir(exist_ok=True) | |
def start_clock(self): | |
self.start_time = time.time() - self.total_seconds | |
def load(self): | |
# load tracked data, e.g. loss, rewards etc. | |
for p in (self.dir / "series").iterdir(): | |
k = p.name.split(".npy")[0] | |
self.series[k] = list(utils.load_numpy(p)) | |
for p in (self.dir / "totals").iterdir(): | |
k = p.name.split(".npy")[0] | |
self.totals[k] = utils.load_numpy(p) | |
# read latest training step | |
latest_model_dir = self.find_old_model("latest-model") | |
self.total_seconds = utils.read_json(self.dir / "time.json")["total_seconds"] | |
last_step = int(latest_model_dir.name.split("-")[-1]) | |
self.step = last_step + 1 | |
def update_metric(self, key, value): | |
self.totals[key] += value | |
self.series[key].append(value) | |
def mean_metric(self, key): | |
return self.totals[key] / (self.step + 1) | |
def save_latest_model(self, model, checkpoint_id): | |
self.save_model(model, checkpoint_id, prefix="latest-model") | |
def save_model(self, model, checkpoint_id, prefix): | |
old_model_dir = self.find_old_model(prefix) | |
model_dir = self.dir / "checkpoints" / f"{prefix}-{checkpoint_id}" | |
model_dir.mkdir() | |
model.save( | |
classifier_path = model_dir / "classifier.bin", | |
encoder_path = model_dir / "encoder.bin" | |
) | |
if old_model_dir: | |
shutil.rmtree(old_model_dir) | |
def find_old_model(self, prefix): | |
model_path = None | |
for p in (self.dir / "checkpoints").iterdir(): | |
if p.name.startswith(f"{prefix}"): | |
model_path = p | |
return model_path | |
def is_empty(self): | |
latest_model_dir = self.find_old_model("latest-model") | |
return latest_model_dir is None | |
def save_data(self): | |
for k, v in self.series.items(): | |
utils.save_numpy(v, self.dir / "series" / f"{k}.npy") | |
for k, v in self.totals.items(): | |
utils.save_numpy(v, self.dir / "totals" / f"{k}.npy") | |
utils.write_json({ | |
"step": self.step, | |
"total_seconds": self.total_seconds | |
}, self.dir / "time.json") | |
def label_variance(probs): | |
# batch, seq, 2 | |
variances = [] | |
for i in range(probs.size(0)): | |
distrib = probs[i, :, 0] | |
var = torch.var(distrib) | |
variances.append(var) | |
return var.mean().item() | |
def check_gradient(model): | |
is_zero = [] | |
is_none = [] | |
for name, param in list(model.named_parameters()): | |
if (param.requires_grad): | |
grad = param.grad | |
if grad is None: | |
is_none.append(name) | |
else: | |
gradsum = param.grad.sum().item() | |
if gradsum == 0: | |
is_zero.append(name) | |
print("zero-grad:", len(is_zero), is_zero) | |
print("none-grad:", len(is_none), is_none) | |
print() | |
def get_mean_max_prob(probs): | |
return probs.max(dim=2).values.mean().item() | |
def print_training_progress(args, manager, model, probs, argmax_summaries, sample_summaries, batch, argmax_details): | |
print(f"[step: {manager.step}] [duration(s): {round(manager.total_seconds)}]") | |
print(f"[example/s: {(args.batch_size * (manager.step + 1)) / manager.total_seconds:.3f}]") | |
print(f"[s/step: {manager.total_seconds / (manager.step+1):.3f}]") | |
print(f"[avg-loss: {manager.mean_metric('loss')}]") | |
print(f"[avg-max-prob: {manager.mean_metric('mean_max_prob'):.3f}]") | |
print(f"[avg-a-reward: {manager.mean_metric('argmax_reward'):.3f}]") | |
print(f"[avg-s-reward: {manager.mean_metric('sample_reward'):.3f}]") | |
print(f"[avg-len: {manager.mean_metric('argmax_len'):.1f}]") | |
print() | |
print(f"[a-reward: {manager.series['argmax_reward'][-1]:.3f}]") | |
print(f"[s-reward: {manager.series['sample_reward'][-1]:.3f}]") | |
print(f"[max-prob: {manager.series['mean_max_prob'][-1]:.3f}]") | |
print() | |
print("[sentences]") | |
print("\n".join(batch["document"])) | |
print("\n[current policy summaries]") | |
print("\n".join(argmax_summaries)) | |
print("\n[sampled summaries]") | |
print("\n".join(sample_summaries)) | |
print() | |
print("Reward Breakdown:") | |
pprint(argmax_details) | |
print() | |
check_gradient(model) | |
print("="*100) | |
def setup_model(args): | |
# setup/load model manager object | |
model_dir = Path(args.model_dir) | |
if args.fresh and model_dir.exists(): | |
utils.ask_rmdir(model_dir) | |
manager = TrainingManager(model_dir) | |
if not manager.is_empty(): | |
manager.load() | |
if not (model_dir / "config.json").exists(): | |
shutil.copy(args.config, model_dir / "config.json") | |
# initialize new or load existing model | |
if manager.step == 0: | |
encoder = AutoModel.from_pretrained(args.encoder_model_id) | |
embedding_size = encoder.state_dict()["embeddings.word_embeddings.weight"].shape[1] | |
model = LinearTokenSelector(encoder, embedding_size).to(args.device) | |
else: | |
print("loading latest model from step", manager.step - 1) | |
model = load_model( | |
model_dir, prefix="latest", device=args.device | |
) | |
return manager, model | |
def setup_dataset_indices(args, step): | |
""" | |
Load pre-built indices that determine in which order we traverse a dataset. | |
If we continue interrupted training state, we move indices accordingly. | |
""" | |
dataset_indices = utils.batchify( | |
utils.load_numpy(args.indices), | |
args.batch_size | |
) | |
if step > 0: | |
utils.move_generator(dataset_indices, step) | |
return dataset_indices | |
def train( | |
args, | |
manager, | |
model, | |
tokenizer, | |
reward_generator, | |
dataset, | |
dataset_indices, | |
eval_func | |
): | |
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) | |
n_train = len(dataset["train"]) | |
device = args.device | |
model.train() | |
manager.start_clock() | |
for indices in dataset_indices: | |
step = manager.step | |
manager.total_seconds = time.time() - manager.start_time | |
if args.max_train_steps and step >= args.max_train_steps + 1: | |
break | |
if args.max_train_seconds and manager.total_seconds >= args.max_train_seconds: | |
break | |
optimizer.zero_grad() | |
batch = dataset["train"][indices] | |
input_ids = pad_sequence( | |
[torch.tensor(ids) for ids in batch["input_ids"]], | |
batch_first=True | |
).to(device) | |
logits = model(input_ids) | |
probs = torch.softmax(logits, dim=2) | |
argmax_labels = torch.argmax(logits, dim=2).to(device) | |
argmax_summaries = labels_to_summary(input_ids, argmax_labels, tokenizer) | |
argmax_rewards, argmax_details = reward_generator(batch["document"], argmax_summaries) | |
a_reward = np.mean(argmax_rewards) | |
(sample_probs, sample_summaries, sample_rewards, sample_details, | |
sample_labels, sample_data) = sampling.best_of_k_samples( | |
args, manager, tokenizer, reward_generator, | |
input_ids, batch, probs, | |
k_samples=args.k_samples, | |
) | |
s_reward = np.mean(sample_rewards) | |
if args.sample_aggregation == "max": | |
loss = (a_reward - s_reward) * sample_probs.sum(1).mean() | |
else: | |
loss = 0. | |
for sample_probs_i, s_rewards_i in zip(sample_data["probs"], sample_data["rewards"]): | |
s_reward_i = np.mean(s_rewards_i) | |
loss_i = (a_reward_i - s_reward_i) * sample_probs_i.sum(1).mean() | |
loss += loss_i | |
loss /= len(sample_data["rewards"]) | |
if args.sample_aggregation == "mean" or a_reward != s_reward: | |
# not updating model if no reward difference, in case of single sample | |
loss.backward() | |
optimizer.step() | |
argmax_len = np.mean([len(word_tokenize(s)) for s in argmax_summaries]) | |
manager.update_metric("time", time.time()) | |
manager.update_metric("loss", loss.item()) | |
manager.update_metric("argmax_reward", a_reward) | |
manager.update_metric("sample_reward", s_reward) | |
manager.update_metric("sample_prob", sample_probs.detach().cpu().numpy().mean()) | |
manager.update_metric("mean_max_prob", get_mean_max_prob(probs)) | |
manager.update_metric("label_variance", label_variance(probs)) | |
manager.update_metric("argmax_len", argmax_len) | |
for rname, rvalues in argmax_details.items(): | |
manager.update_metric(f"reward|{rname}", np.mean(rvalues)) | |
if args.eval_every != None and (step > 0 and step % args.eval_every == 0): | |
eval_func( | |
args, manager, model, tokenizer, reward_generator, | |
dataset["validation"] | |
) | |
model.train() | |
if args.save_every != None and (step % args.save_every == 0): | |
manager.save_latest_model(model, step) | |
manager.save_data() | |
if args.print_every != None and (args.verbose and step % args.print_every == 0): | |
print_training_progress( | |
args, manager, model, probs, | |
argmax_summaries, sample_summaries, batch, | |
argmax_details | |
) | |
manager.step += 1 | |
def setup_and_train(args, eval_func): | |
print_if("loading model", args.verbose) | |
manager, model = setup_model(args) | |
print_if("loading tokenizer", args.verbose) | |
tokenizer = AutoTokenizer.from_pretrained(args.encoder_model_id) | |
print_if("loading rewards", args.verbose) | |
reward_generator = load_rewards(args) | |
print_if("rewards:", reward_generator.reward_names) | |
print_if("loading dataset", args.verbose) | |
dataset = load_data_for_training(tokenizer, args.loader, args.dataset) | |
dataset_indices = setup_dataset_indices(args, manager.step) | |
train( | |
args, | |
manager, | |
model, | |
tokenizer, | |
reward_generator, | |
dataset, | |
dataset_indices, | |
eval_func | |
) | |