File size: 3,675 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# distutils:language=c++
# cython:language_level=3
from libcpp.vector cimport vector


cdef extern from "../common_lib/cminimax.cpp":
    pass


cdef extern from "../common_lib/cminimax.h" namespace "tools":
    cdef cppclass CMinMaxStats:
        CMinMaxStats() except +
        float maximum, minimum, value_delta_max

        void set_delta(float value_delta_max)
        void update(float value)
        void clear()
        float normalize(float value)

    cdef cppclass CMinMaxStatsList:
        CMinMaxStatsList() except +
        CMinMaxStatsList(int num) except +
        int num
        vector[CMinMaxStats] stats_lst

        void set_delta(float value_delta_max)

cdef extern from "lib/cnode.cpp":
    pass


cdef extern from "lib/cnode.h" namespace "tree":
    cdef cppclass CNode:
        CNode() except +
        CNode(float prior, vector[int] & legal_actions) except +
        int visit_count, to_play, current_latent_state_index, batch_index, best_action
        float value_prefixs, prior, value_sum, parent_value_prefix

        void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefixs,
                    vector[float] policy_logits)
        void add_exploration_noise(float exploration_fraction, vector[float] noises)
        float compute_mean_q(int isRoot, float parent_q, float discount_factor)

        int expanded()
        float value()
        vector[int] get_trajectory()
        vector[int] get_children_distribution()
        CNode * get_child(int action)

    cdef cppclass CRoots:
        CRoots() except +
        CRoots(int root_num, vector[vector[int]] legal_actions_list) except +
        int root_num
        vector[CNode] roots

        void prepare(float root_noise_weight, const vector[vector[float]] & noises,
                     const vector[float] & value_prefixs, const vector[vector[float]] & policies,
                     vector[int] to_play_batch)
        void prepare_no_noise(const vector[float] & value_prefixs, const vector[vector[float]] & policies,
                              vector[int] to_play_batch)
        void clear()
        vector[vector[int]] get_trajectories()
        vector[vector[int]] get_distributions()
        vector[float] get_values()
        # visualize related code
        # CNode* get_root(int index)

    cdef cppclass CSearchResults:
        CSearchResults() except +
        CSearchResults(int num) except +
        int num
        vector[int] latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens
        vector[int] virtual_to_play_batchs
        vector[CNode *] nodes

    cdef void cbackpropagate(vector[CNode *] & search_path, CMinMaxStats & min_max_stats,
                              int to_play, float value, float discount_factor)
    void cbatch_backpropagate(int current_latent_state_index, float discount_factor, vector[float] value_prefixs,
                               vector[float] values, vector[vector[float]] policies,
                               CMinMaxStatsList *min_max_stats_lst, CSearchResults & results,
                               vector[int] is_reset_list, vector[int] & to_play_batch)
    void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor,
                         CMinMaxStatsList *min_max_stats_lst, CSearchResults & results,
                         vector[int] & virtual_to_play_batch)

cdef class MinMaxStatsList:
    cdef CMinMaxStatsList *cmin_max_stats_lst

cdef class ResultsWrapper:
    cdef CSearchResults cresults

cdef class Roots:
    cdef readonly int root_num
    cdef CRoots *roots

cdef class Node:
    cdef CNode cnode