Spaces:
Runtime error
Runtime error
File size: 2,966 Bytes
10b912d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import torch
import random
import numpy as np
from collections import defaultdict
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from scrl.model import labels_to_summary
from nltk import word_tokenize
from pprint import pprint
def sample_from_policy(
input_ids,
probs,
device="cuda",
force_diff=True,
diff_trials=1000,
):
m = Categorical(probs)
argmax_labels = torch.argmax(probs, dim=2)
sample_labels = m.sample()
if force_diff:
for _ in range(diff_trials):
if (argmax_labels == sample_labels).all():
sample_labels = m.sample()
else:
break
sample_probs = m.log_prob(sample_labels)
return sample_probs, sample_labels
def best_of_k_samples(
args,
manager,
tokenizer,
reward_generator,
input_ids,
batch,
probs,
k_samples=50,
return_all=False
):
batch_size = probs.size(0)
prob_batches = []
summary_batches = []
reward_batches = []
detail_batches = []
label_batches = []
for _ in range(k_samples):
sample_probs, sample_labels = sample_from_policy(
input_ids,
probs,
device=args.device
)
sample_summaries = labels_to_summary(
input_ids, sample_labels, tokenizer
)
sample_rewards, sample_details = reward_generator(
batch["document"], sample_summaries
)
prob_batches.append(sample_probs)
summary_batches.append(sample_summaries)
reward_batches.append(sample_rewards)
detail_batches.append(sample_details)
label_batches.append(sample_labels)
best_indices = []
for i in range(batch_size):
rewards = [reward_batches[j][i] for j in range(k_samples)]
scored = sorted(enumerate(rewards), key=lambda x: x[1], reverse=True)
best_idx = scored[0][0]
best_indices.append(best_idx)
sample_probs = torch.stack([prob_batches[j][i] for i, j in enumerate(best_indices)])
sample_summaries = [summary_batches[j][i] for i, j in enumerate(best_indices)]
sample_rewards = [reward_batches[j][i] for i, j in enumerate(best_indices)]
sample_labels = torch.stack([label_batches[j][i] for i, j in enumerate(best_indices)])
sample_details = []
for i, j in enumerate(best_indices):
detail_keys = sorted(detail_batches[0].keys())
details = defaultdict(list)
for k in detail_keys:
details[k].append(detail_batches[j][k][i])
sample_details.append(details)
sample_data = {
"probs": prob_batches,
"rewards": reward_batches,
"summaries": summary_batches,
"details": detail_batches,
"labels": label_batches,
}
return sample_probs, sample_summaries, sample_rewards, sample_details, sample_labels, sample_data
|