File size: 2,966 Bytes
10b912d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import random
import numpy as np
from collections import defaultdict
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from scrl.model import labels_to_summary
from nltk import word_tokenize
from pprint import pprint


def sample_from_policy(
        input_ids,
        probs,
        device="cuda",
        force_diff=True,
        diff_trials=1000,
    ):
    m = Categorical(probs)
    argmax_labels = torch.argmax(probs, dim=2)
    sample_labels = m.sample()

    if force_diff:
        for _ in range(diff_trials):
            if (argmax_labels == sample_labels).all():
                sample_labels = m.sample()
            else:
                break

    sample_probs = m.log_prob(sample_labels)
    return sample_probs, sample_labels


def best_of_k_samples(
        args,
        manager,
        tokenizer,
        reward_generator,
        input_ids,
        batch,
        probs,
        k_samples=50,
        return_all=False
    ):
    batch_size = probs.size(0)

    prob_batches = []
    summary_batches = []
    reward_batches = []
    detail_batches = []
    label_batches = []
    for _ in range(k_samples):
        sample_probs, sample_labels = sample_from_policy(
            input_ids,
            probs,
            device=args.device
        )
        sample_summaries = labels_to_summary(
            input_ids, sample_labels, tokenizer
        )
        sample_rewards, sample_details = reward_generator(
            batch["document"], sample_summaries
        )

        prob_batches.append(sample_probs)
        summary_batches.append(sample_summaries)
        reward_batches.append(sample_rewards)
        detail_batches.append(sample_details)
        label_batches.append(sample_labels)


    best_indices = []
    for i in range(batch_size):
        rewards = [reward_batches[j][i] for j in range(k_samples)]
        scored = sorted(enumerate(rewards), key=lambda x: x[1], reverse=True)
        best_idx = scored[0][0]
        best_indices.append(best_idx)

    sample_probs = torch.stack([prob_batches[j][i] for i, j in enumerate(best_indices)])
    sample_summaries = [summary_batches[j][i] for i, j in enumerate(best_indices)]
    sample_rewards = [reward_batches[j][i] for i, j in enumerate(best_indices)]
    sample_labels = torch.stack([label_batches[j][i] for i, j in enumerate(best_indices)])

    sample_details = []
    for i, j in enumerate(best_indices):
        detail_keys = sorted(detail_batches[0].keys())
        details = defaultdict(list)
        for k in detail_keys:
            details[k].append(detail_batches[j][k][i])
        sample_details.append(details)

    sample_data = {
        "probs": prob_batches,
        "rewards": reward_batches,
        "summaries": summary_batches,
        "details": detail_batches,
        "labels": label_batches,
    }
    return sample_probs, sample_summaries, sample_rewards, sample_details, sample_labels, sample_data