Spaces:
Runtime error
Runtime error
import torch | |
import random | |
import numpy as np | |
from collections import defaultdict | |
from torch.distributions import Categorical | |
from torch.nn.utils.rnn import pad_sequence | |
from scrl.model import labels_to_summary | |
from nltk import word_tokenize | |
from pprint import pprint | |
def sample_from_policy( | |
input_ids, | |
probs, | |
device="cuda", | |
force_diff=True, | |
diff_trials=1000, | |
): | |
m = Categorical(probs) | |
argmax_labels = torch.argmax(probs, dim=2) | |
sample_labels = m.sample() | |
if force_diff: | |
for _ in range(diff_trials): | |
if (argmax_labels == sample_labels).all(): | |
sample_labels = m.sample() | |
else: | |
break | |
sample_probs = m.log_prob(sample_labels) | |
return sample_probs, sample_labels | |
def best_of_k_samples( | |
args, | |
manager, | |
tokenizer, | |
reward_generator, | |
input_ids, | |
batch, | |
probs, | |
k_samples=50, | |
return_all=False | |
): | |
batch_size = probs.size(0) | |
prob_batches = [] | |
summary_batches = [] | |
reward_batches = [] | |
detail_batches = [] | |
label_batches = [] | |
for _ in range(k_samples): | |
sample_probs, sample_labels = sample_from_policy( | |
input_ids, | |
probs, | |
device=args.device | |
) | |
sample_summaries = labels_to_summary( | |
input_ids, sample_labels, tokenizer | |
) | |
sample_rewards, sample_details = reward_generator( | |
batch["document"], sample_summaries | |
) | |
prob_batches.append(sample_probs) | |
summary_batches.append(sample_summaries) | |
reward_batches.append(sample_rewards) | |
detail_batches.append(sample_details) | |
label_batches.append(sample_labels) | |
best_indices = [] | |
for i in range(batch_size): | |
rewards = [reward_batches[j][i] for j in range(k_samples)] | |
scored = sorted(enumerate(rewards), key=lambda x: x[1], reverse=True) | |
best_idx = scored[0][0] | |
best_indices.append(best_idx) | |
sample_probs = torch.stack([prob_batches[j][i] for i, j in enumerate(best_indices)]) | |
sample_summaries = [summary_batches[j][i] for i, j in enumerate(best_indices)] | |
sample_rewards = [reward_batches[j][i] for i, j in enumerate(best_indices)] | |
sample_labels = torch.stack([label_batches[j][i] for i, j in enumerate(best_indices)]) | |
sample_details = [] | |
for i, j in enumerate(best_indices): | |
detail_keys = sorted(detail_batches[0].keys()) | |
details = defaultdict(list) | |
for k in detail_keys: | |
details[k].append(detail_batches[j][k][i]) | |
sample_details.append(details) | |
sample_data = { | |
"probs": prob_batches, | |
"rewards": reward_batches, | |
"summaries": summary_batches, | |
"details": detail_batches, | |
"labels": label_batches, | |
} | |
return sample_probs, sample_summaries, sample_rewards, sample_details, sample_labels, sample_data | |