File size: 6,357 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# distutils: language=c++
# cython:language_level=3
from libcpp.vector cimport vector

cdef class MinMaxStatsList:
    cdef CMinMaxStatsList *cmin_max_stats_lst

    def __cinit__(self, int num):
        self.cmin_max_stats_lst = new CMinMaxStatsList(num)

    def set_delta(self, float value_delta_max):
        self.cmin_max_stats_lst[0].set_delta(value_delta_max)

    def __dealloc__(self):
        del self.cmin_max_stats_lst


cdef class ResultsWrapper:
    cdef CSearchResults cresults

    def __cinit__(self, int num):
        self.cresults = CSearchResults(num)

    def get_search_len(self):
        return self.cresults.search_lens


cdef class Roots:
    cdef int root_num
    cdef CRoots *roots

    def __cinit__(self, int root_num, vector[vector[int]] legal_actions_list):
        self.root_num = root_num
        self.roots = new CRoots(root_num, legal_actions_list)

    def prepare(self, float root_noise_weight, list noises, list value_prefix_pool, list value_pool, list policy_logits_pool, vector[int] &to_play_batch):
        self.roots[0].prepare(root_noise_weight, noises, value_prefix_pool, value_pool, policy_logits_pool, to_play_batch)

    def prepare_no_noise(self, list value_prefix_pool, list value_pool, list policy_logits_pool, vector[int] &to_play_batch):
        self.roots[0].prepare_no_noise(value_prefix_pool, value_pool, policy_logits_pool, to_play_batch)

    def get_trajectories(self):
        return self.roots[0].get_trajectories()

    def get_distributions(self):
        return self.roots[0].get_distributions()

    def get_children_values(self, float discount, int action_space_size):
        return self.roots[0].get_children_values(discount, action_space_size)
    
    def get_policies(self, float discount, int action_space_size):
        return self.roots[0].get_policies(discount, action_space_size)

    def get_values(self):
        return self.roots[0].get_values()

    def clear(self):
        self.roots[0].clear()

    def __dealloc__(self):
        del self.roots

    @property
    def num(self):
        return self.root_num


cdef class Node:
    cdef CNode cnode

    def __cinit__(self):
        pass

    def __cinit__(self, float prior, vector[int] &legal_actions):
        pass

    def expand(self, int to_play, int current_latent_state_index, int batch_index, float value_prefix, float value, list policy_logits):
        cdef vector[float] cpolicy = policy_logits
        self.cnode.expand(to_play, current_latent_state_index, batch_index, value_prefix, value, cpolicy)        

def batch_back_propagate(int current_latent_state_index, float discount, list value_prefixs, list values, list policies, MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list to_play_batch):
    cdef int i
    cdef vector[float] cvalue_prefixs = value_prefixs
    cdef vector[float] cvalues = values
    cdef vector[vector[float]] cpolicies = policies

    cbatch_back_propagate(current_latent_state_index, discount, cvalue_prefixs, cvalues, cpolicies,
                          min_max_stats_lst.cmin_max_stats_lst, results.cresults, to_play_batch)


def batch_traverse(Roots roots, int num_simulations, int max_num_considered_actions, float discount, ResultsWrapper results, list virtual_to_play_batch):

    cbatch_traverse(roots.roots, num_simulations, max_num_considered_actions, discount, results.cresults, virtual_to_play_batch)

    return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs

def select_root_child(Node roots, float discount, int num_simulations, int max_num_considered_actions):

    return cselect_root_child(&roots.cnode, discount, num_simulations, max_num_considered_actions)

def select_interior_child(Node roots, float discount):

    return cselect_interior_child(&roots.cnode, discount)

def softmax(list py_num_list):
    cdef vector[float] cnum_list = py_num_list;
    cdef int clength = len(py_num_list)
    csoftmax(cnum_list, clength)
    for i in range(len(py_num_list)):
        py_num_list[i] = cnum_list[i]
    return py_num_list

def pcompute_mixed_value(float raw_value, list py_q_values, list py_child_visit, list py_child_prior):
    cdef vector[float] cq_values = py_q_values
    cdef vector[int] cchild_visit = py_child_visit
    cdef vector[float] cchild_prior = py_child_prior
    return compute_mixed_value(raw_value, cq_values, cchild_visit, cchild_prior)

def prescale_qvalues(list py_value, float epsilon):
    cdef vector[float] cvalue = py_value
    rescale_qvalues(cvalue, epsilon)
    for i in range(len(py_value)):
        py_value[i] = cvalue[i]
    return py_value

def pqtransform_completed_by_mix_value(Node roots, list py_child_visit, list py_child_prior, float discount, int maxvisit_init, float value_scale, bool rescale_values, float epsilon):
    cdef vector[int] cchild_visit=py_child_visit
    cdef vector[float] cchild_prior=py_child_prior
    cdef vector[float] cmix_value = qtransform_completed_by_mix_value(&roots.cnode, cchild_visit, cchild_prior, discount, maxvisit_init, value_scale, rescale_values, epsilon)
    py_mix_value = []
    for i in range(len(py_child_visit)):
        py_mix_value.append(cmix_value[i])
    return py_mix_value

def pget_sequence_of_considered_visits(int max_num_considered_actions, int num_simulations):
    return get_sequence_of_considered_visits(max_num_considered_actions, num_simulations)

def pget_table_of_considered_visits(int max_num_considered_actions, int num_simulations):
    cdef vector[vector[int]] table = get_table_of_considered_visits(max_num_considered_actions, num_simulations)
    result = []
    for i in range(max_num_considered_actions+1):
        result.append(table[i])
    return result

def pscore_considered(int considered_visit, list py_gumbel, list py_logits, list py_normalized_qvalues, list py_visit_counts):
    cdef vector[float] cgumbel=py_gumbel
    cdef vector[float] clogits=py_logits
    cdef vector[float] cnormalized_qvalues=py_normalized_qvalues
    cdef vector[int] cvisit_counts=py_visit_counts
    return score_considered(considered_visit, cgumbel, clogits, cnormalized_qvalues, cvisit_counts)

def pgenerate_gumbel(float gumbel_scale, float gumbel_rng, int shape):
    return generate_gumbel(gumbel_scale, gumbel_rng, shape)