# distutils: language=c++ # cython:language_level=3 from libcpp.vector cimport vector cdef class MinMaxStatsList: cdef CMinMaxStatsList *cmin_max_stats_lst def __cinit__(self, int num): self.cmin_max_stats_lst = new CMinMaxStatsList(num) def set_delta(self, float value_delta_max): self.cmin_max_stats_lst[0].set_delta(value_delta_max) def __dealloc__(self): del self.cmin_max_stats_lst cdef class ResultsWrapper: cdef CSearchResults cresults def __cinit__(self, int num): self.cresults = CSearchResults(num) def get_search_len(self): return self.cresults.search_lens cdef class Roots: cdef int root_num cdef CRoots *roots def __cinit__(self, int root_num, vector[vector[int]] legal_actions_list): self.root_num = root_num self.roots = new CRoots(root_num, legal_actions_list) def prepare(self, float root_noise_weight, list noises, list value_prefix_pool, list value_pool, list policy_logits_pool, vector[int] &to_play_batch): self.roots[0].prepare(root_noise_weight, noises, value_prefix_pool, value_pool, policy_logits_pool, to_play_batch) def prepare_no_noise(self, list value_prefix_pool, list value_pool, list policy_logits_pool, vector[int] &to_play_batch): self.roots[0].prepare_no_noise(value_prefix_pool, value_pool, policy_logits_pool, to_play_batch) def get_trajectories(self): return self.roots[0].get_trajectories() def get_distributions(self): return self.roots[0].get_distributions() def get_children_values(self, float discount, int action_space_size): return self.roots[0].get_children_values(discount, action_space_size) def get_policies(self, float discount, int action_space_size): return self.roots[0].get_policies(discount, action_space_size) def get_values(self): return self.roots[0].get_values() def clear(self): self.roots[0].clear() def __dealloc__(self): del self.roots @property def num(self): return self.root_num cdef class Node: cdef CNode cnode def __cinit__(self): pass def __cinit__(self, float prior, vector[int] &legal_actions): pass def expand(self, int to_play, int current_latent_state_index, int batch_index, float value_prefix, float value, list policy_logits): cdef vector[float] cpolicy = policy_logits self.cnode.expand(to_play, current_latent_state_index, batch_index, value_prefix, value, cpolicy) def batch_back_propagate(int current_latent_state_index, float discount, list value_prefixs, list values, list policies, MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list to_play_batch): cdef int i cdef vector[float] cvalue_prefixs = value_prefixs cdef vector[float] cvalues = values cdef vector[vector[float]] cpolicies = policies cbatch_back_propagate(current_latent_state_index, discount, cvalue_prefixs, cvalues, cpolicies, min_max_stats_lst.cmin_max_stats_lst, results.cresults, to_play_batch) def batch_traverse(Roots roots, int num_simulations, int max_num_considered_actions, float discount, ResultsWrapper results, list virtual_to_play_batch): cbatch_traverse(roots.roots, num_simulations, max_num_considered_actions, discount, results.cresults, virtual_to_play_batch) return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs def select_root_child(Node roots, float discount, int num_simulations, int max_num_considered_actions): return cselect_root_child(&roots.cnode, discount, num_simulations, max_num_considered_actions) def select_interior_child(Node roots, float discount): return cselect_interior_child(&roots.cnode, discount) def softmax(list py_num_list): cdef vector[float] cnum_list = py_num_list; cdef int clength = len(py_num_list) csoftmax(cnum_list, clength) for i in range(len(py_num_list)): py_num_list[i] = cnum_list[i] return py_num_list def pcompute_mixed_value(float raw_value, list py_q_values, list py_child_visit, list py_child_prior): cdef vector[float] cq_values = py_q_values cdef vector[int] cchild_visit = py_child_visit cdef vector[float] cchild_prior = py_child_prior return compute_mixed_value(raw_value, cq_values, cchild_visit, cchild_prior) def prescale_qvalues(list py_value, float epsilon): cdef vector[float] cvalue = py_value rescale_qvalues(cvalue, epsilon) for i in range(len(py_value)): py_value[i] = cvalue[i] return py_value def pqtransform_completed_by_mix_value(Node roots, list py_child_visit, list py_child_prior, float discount, int maxvisit_init, float value_scale, bool rescale_values, float epsilon): cdef vector[int] cchild_visit=py_child_visit cdef vector[float] cchild_prior=py_child_prior cdef vector[float] cmix_value = qtransform_completed_by_mix_value(&roots.cnode, cchild_visit, cchild_prior, discount, maxvisit_init, value_scale, rescale_values, epsilon) py_mix_value = [] for i in range(len(py_child_visit)): py_mix_value.append(cmix_value[i]) return py_mix_value def pget_sequence_of_considered_visits(int max_num_considered_actions, int num_simulations): return get_sequence_of_considered_visits(max_num_considered_actions, num_simulations) def pget_table_of_considered_visits(int max_num_considered_actions, int num_simulations): cdef vector[vector[int]] table = get_table_of_considered_visits(max_num_considered_actions, num_simulations) result = [] for i in range(max_num_considered_actions+1): result.append(table[i]) return result def pscore_considered(int considered_visit, list py_gumbel, list py_logits, list py_normalized_qvalues, list py_visit_counts): cdef vector[float] cgumbel=py_gumbel cdef vector[float] clogits=py_logits cdef vector[float] cnormalized_qvalues=py_normalized_qvalues cdef vector[int] cvisit_counts=py_visit_counts return score_considered(considered_visit, cgumbel, clogits, cnormalized_qvalues, cvisit_counts) def pgenerate_gumbel(float gumbel_scale, float gumbel_rng, int shape): return generate_gumbel(gumbel_scale, gumbel_rng, shape)