JerryLiJinyi's picture
Upload 127 files
10b912d verified
raw
history blame
2.97 kB
import torch
import random
import numpy as np
from collections import defaultdict
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from scrl.model import labels_to_summary
from nltk import word_tokenize
from pprint import pprint
def sample_from_policy(
input_ids,
probs,
device="cuda",
force_diff=True,
diff_trials=1000,
):
m = Categorical(probs)
argmax_labels = torch.argmax(probs, dim=2)
sample_labels = m.sample()
if force_diff:
for _ in range(diff_trials):
if (argmax_labels == sample_labels).all():
sample_labels = m.sample()
else:
break
sample_probs = m.log_prob(sample_labels)
return sample_probs, sample_labels
def best_of_k_samples(
args,
manager,
tokenizer,
reward_generator,
input_ids,
batch,
probs,
k_samples=50,
return_all=False
):
batch_size = probs.size(0)
prob_batches = []
summary_batches = []
reward_batches = []
detail_batches = []
label_batches = []
for _ in range(k_samples):
sample_probs, sample_labels = sample_from_policy(
input_ids,
probs,
device=args.device
)
sample_summaries = labels_to_summary(
input_ids, sample_labels, tokenizer
)
sample_rewards, sample_details = reward_generator(
batch["document"], sample_summaries
)
prob_batches.append(sample_probs)
summary_batches.append(sample_summaries)
reward_batches.append(sample_rewards)
detail_batches.append(sample_details)
label_batches.append(sample_labels)
best_indices = []
for i in range(batch_size):
rewards = [reward_batches[j][i] for j in range(k_samples)]
scored = sorted(enumerate(rewards), key=lambda x: x[1], reverse=True)
best_idx = scored[0][0]
best_indices.append(best_idx)
sample_probs = torch.stack([prob_batches[j][i] for i, j in enumerate(best_indices)])
sample_summaries = [summary_batches[j][i] for i, j in enumerate(best_indices)]
sample_rewards = [reward_batches[j][i] for i, j in enumerate(best_indices)]
sample_labels = torch.stack([label_batches[j][i] for i, j in enumerate(best_indices)])
sample_details = []
for i, j in enumerate(best_indices):
detail_keys = sorted(detail_batches[0].keys())
details = defaultdict(list)
for k in detail_keys:
details[k].append(detail_batches[j][k][i])
sample_details.append(details)
sample_data = {
"probs": prob_batches,
"rewards": reward_batches,
"summaries": summary_batches,
"details": detail_batches,
"labels": label_batches,
}
return sample_probs, sample_summaries, sample_rewards, sample_details, sample_labels, sample_data