zjowowen commited on
Commit
c1806e4
·
1 Parent(s): e11ce00

init space

Browse files
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM zjowowen/dev:cnf as base
2
 
3
  ENV DEBIAN_FRONTEND=noninteractive
4
  ENV LANG en_US.UTF-8
@@ -7,9 +7,9 @@ ENV LC_ALL en_US.UTF-8
7
 
8
  RUN apt update -y \
9
  && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils zip unzip cmake nginx -y \
10
- && curl -fsSL https://deb.nodesource.com/setup_16.x | bash - \
11
  && apt-get install -y nodejs \
12
- && npm install -g npm@9.6.5 \
13
  && npm install -g create-react-app \
14
  && npm install typescript -g \
15
  && npm install -g vite \
 
1
+ FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime as base
2
 
3
  ENV DEBIAN_FRONTEND=noninteractive
4
  ENV LANG en_US.UTF-8
 
7
 
8
  RUN apt update -y \
9
  && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils zip unzip cmake nginx -y \
10
+ && curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
11
  && apt-get install -y nodejs \
12
+ && npm install -g npm@10.3.0 \
13
  && npm install -g create-react-app \
14
  && npm install typescript -g \
15
  && npm install -g vite \
LightZero/.gitignore CHANGED
@@ -1432,16 +1432,3 @@ collect_demo_data_config.py
1432
  events.*
1433
  /test_*
1434
  # LightZero special key
1435
- /zoo/board_games/**/*.c
1436
- /zoo/board_games/**/*.cpp
1437
- /lzero/mcts/**/*.cpp
1438
- /zoo/**/*.c
1439
- /lzero/mcts/**/*.so
1440
- /lzero/mcts/**/*.h
1441
- !/lzero/mcts/**/lib
1442
- !/lzero/mcts/**/lib/*.cpp
1443
- !/lzero/mcts/**/lib/*.hpp
1444
- !/lzero/mcts/**/lib/*.h
1445
- **/tb/*
1446
- **/mcts/ctree/tests_cpp/*
1447
- **/*tmp*
 
1432
  events.*
1433
  /test_*
1434
  # LightZero special key
 
 
 
 
 
 
 
 
 
 
 
 
 
LightZero/lzero/mcts/ctree/common_lib/cminimax.cpp ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #include "cminimax.h"
4
+
5
+ namespace tools{
6
+
7
+ CMinMaxStats::CMinMaxStats(){
8
+ this->maximum = FLOAT_MIN;
9
+ this->minimum = FLOAT_MAX;
10
+ this->value_delta_max = 0.;
11
+ }
12
+
13
+ CMinMaxStats::~CMinMaxStats(){}
14
+
15
+ void CMinMaxStats::set_delta(float value_delta_max){
16
+ this->value_delta_max = value_delta_max;
17
+ }
18
+
19
+ void CMinMaxStats::update(float value){
20
+ if(value > this->maximum){
21
+ this->maximum = value;
22
+ }
23
+ if(value < this->minimum){
24
+ this->minimum = value;
25
+ }
26
+ }
27
+
28
+ void CMinMaxStats::clear(){
29
+ this->maximum = FLOAT_MIN;
30
+ this->minimum = FLOAT_MAX;
31
+ }
32
+
33
+ float CMinMaxStats::normalize(float value){
34
+ float norm_value = value;
35
+ float delta = this->maximum - this->minimum;
36
+ if(delta > 0){
37
+ if(delta < this->value_delta_max){
38
+ norm_value = (norm_value - this->minimum) / this->value_delta_max;
39
+ }
40
+ else{
41
+ norm_value = (norm_value - this->minimum) / delta;
42
+ }
43
+ }
44
+ return norm_value;
45
+ }
46
+
47
+ //*********************************************************
48
+
49
+ CMinMaxStatsList::CMinMaxStatsList(){
50
+ this->num = 0;
51
+ }
52
+
53
+ CMinMaxStatsList::CMinMaxStatsList(int num){
54
+ this->num = num;
55
+ for(int i = 0; i < num; ++i){
56
+ this->stats_lst.push_back(CMinMaxStats());
57
+ }
58
+ }
59
+
60
+ CMinMaxStatsList::~CMinMaxStatsList(){}
61
+
62
+ void CMinMaxStatsList::set_delta(float value_delta_max){
63
+ for(int i = 0; i < this->num; ++i){
64
+ this->stats_lst[i].set_delta(value_delta_max);
65
+ }
66
+ }
67
+
68
+ }
LightZero/lzero/mcts/ctree/common_lib/cminimax.h ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #ifndef CMINIMAX_H
4
+ #define CMINIMAX_H
5
+
6
+ #include <iostream>
7
+ #include <vector>
8
+
9
+ const float FLOAT_MAX = 1000000.0;
10
+ const float FLOAT_MIN = -FLOAT_MAX;
11
+
12
+ namespace tools {
13
+
14
+ class CMinMaxStats {
15
+ public:
16
+ float maximum, minimum, value_delta_max;
17
+
18
+ CMinMaxStats();
19
+ ~CMinMaxStats();
20
+
21
+ void set_delta(float value_delta_max);
22
+ void update(float value);
23
+ void clear();
24
+ float normalize(float value);
25
+ };
26
+
27
+ class CMinMaxStatsList {
28
+ public:
29
+ int num;
30
+ std::vector<CMinMaxStats> stats_lst;
31
+
32
+ CMinMaxStatsList();
33
+ CMinMaxStatsList(int num);
34
+ ~CMinMaxStatsList();
35
+
36
+ void set_delta(float value_delta_max);
37
+ };
38
+ }
39
+
40
+ #endif
LightZero/lzero/mcts/ctree/common_lib/utils.cpp ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #include <iostream>
4
+ #include <algorithm>
5
+
6
+ #ifdef _WIN32
7
+ #include <Windows.h>
8
+ #else
9
+ #include <sys/time.h>
10
+ #endif
11
+
12
+ void get_time_and_set_rand_seed()
13
+ {
14
+ #ifdef _WIN32
15
+ FILETIME ft;
16
+ GetSystemTimeAsFileTime(&ft);
17
+ ULARGE_INTEGER uli;
18
+ uli.LowPart = ft.dwLowDateTime;
19
+ uli.HighPart = ft.dwHighDateTime;
20
+ uint64_t timestamp = (uli.QuadPart - 116444736000000000ULL) / 10000000ULL;
21
+ srand(timestamp % RAND_MAX);
22
+ #else
23
+ timeval tv;
24
+ gettimeofday(&tv, nullptr);
25
+ srand(tv.tv_usec);
26
+ #endif
27
+ }
LightZero/lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // This code is a Python extension implemented in C++ using the pybind11 library.
2
+ // It's a Monte Carlo Tree Search (MCTS) algorithm with modifications based on Google's AlphaZero paper.
3
+ // MCTS is an algorithm for making optimal decisions in a certain class of combinatorial problems.
4
+ // It's most famously used in board games like chess, Go, and shogi.
5
+
6
+ // The following lines include the necessary headers to facilitate the implementation of the MCTS algorithm.
7
+ #include "node_alphazero.h"
8
+ #include <cmath>
9
+ #include <map>
10
+ #include <random>
11
+ #include <vector>
12
+ #include <pybind11/pybind11.h>
13
+ #include <pybind11/stl.h>
14
+ #include <functional>
15
+ #include <iostream>
16
+ #include <memory>
17
+ #include <numeric>
18
+
19
+ // This line creates an alias for the pybind11 namespace, making it easier to reference in the code.
20
+ namespace py = pybind11;
21
+
22
+ // This part defines the MCTS class and its member variables.
23
+ // The MCTS class implements the MCTS algorithm, and its member variables store configuration values used in the algorithm.
24
+ class MCTS {
25
+ int max_moves;
26
+ int num_simulations;
27
+ double pb_c_base;
28
+ double pb_c_init;
29
+ double root_dirichlet_alpha;
30
+ double root_noise_weight;
31
+ py::object simulate_env;
32
+
33
+ // This part defines the constructor of the MCTS class.
34
+ // The constructor initializes the member variables with the provided arguments or with their default values.
35
+ public:
36
+ MCTS(int max_moves=512, int num_simulations=800,
37
+ double pb_c_base=19652, double pb_c_init=1.25,
38
+ double root_dirichlet_alpha=0.3, double root_noise_weight=0.25, py::object simulate_env=py::none())
39
+ : max_moves(max_moves), num_simulations(num_simulations),
40
+ pb_c_base(pb_c_base), pb_c_init(pb_c_init),
41
+ root_dirichlet_alpha(root_dirichlet_alpha),
42
+ root_noise_weight(root_noise_weight),
43
+ simulate_env(simulate_env) {}
44
+
45
+ // This function calculates the Upper Confidence Bound (UCB) score for a given node in the MCTS tree based on the parent node's visit count,
46
+ // the child node's visit count, and the child node's prior probability.
47
+ double _ucb_score(Node* parent, Node* child) {
48
+ double pb_c = std::log((parent->visit_count + pb_c_base + 1) / pb_c_base) + pb_c_init;
49
+ pb_c *= std::sqrt(parent->visit_count) / (child->visit_count + 1);
50
+
51
+ double prior_score = pb_c * child->prior_p;
52
+ double value_score = child->get_value();
53
+ return prior_score + value_score;
54
+ }
55
+
56
+ // This function adds Dirichlet noise to the prior probabilities of the actions of a given node to encourage exploration.
57
+ void _add_exploration_noise(Node* node) {
58
+ std::vector<int> actions;
59
+ for (const auto& kv : node->children) {
60
+ actions.push_back(kv.first);
61
+ }
62
+
63
+ std::default_random_engine generator;
64
+ std::gamma_distribution<double> distribution(root_dirichlet_alpha, 1.0);
65
+
66
+ std::vector<double> noise;
67
+ double sum = 0;
68
+ for (size_t i = 0; i < actions.size(); ++i) {
69
+ double sample = distribution(generator);
70
+ noise.push_back(sample);
71
+ sum += sample;
72
+ }
73
+
74
+ // Normalize the samples to simulate a Dirichlet distribution
75
+ for (size_t i = 0; i < noise.size(); ++i) {
76
+ noise[i] /= sum;
77
+ }
78
+
79
+ double frac = root_noise_weight;
80
+ for (size_t i = 0; i < actions.size(); ++i) {
81
+ node->children[actions[i]]->prior_p = node->children[actions[i]]->prior_p * (1 - frac) + noise[i] * frac;
82
+ }
83
+ }
84
+ // This function selects the child of a given node that has the highest UCB score among the legal actions.
85
+ std::pair<int, Node*> _select_child(Node* node, py::object simulate_env) {
86
+ int action = -1;
87
+ Node* child = nullptr;
88
+ double best_score = -9999999;
89
+ for (const auto& kv : node->children) {
90
+ int action_tmp = kv.first;
91
+ Node* child_tmp = kv.second;
92
+
93
+ py::list legal_actions_py = simulate_env.attr("legal_actions").cast<py::list>();
94
+
95
+ std::vector<int> legal_actions;
96
+ for (py::handle h : legal_actions_py) {
97
+ legal_actions.push_back(h.cast<int>());
98
+ }
99
+
100
+ if (std::find(legal_actions.begin(), legal_actions.end(), action_tmp) != legal_actions.end()) {
101
+ double score = _ucb_score(node, child_tmp);
102
+ if (score > best_score) {
103
+ best_score = score;
104
+ action = action_tmp;
105
+ child = child_tmp;
106
+ }
107
+ }
108
+
109
+ }
110
+ if (child == nullptr) {
111
+ child = node;
112
+ }
113
+ return std::make_pair(action, child);
114
+ }
115
+
116
+ // This function expands a leaf node by generating its children based on the legal actions and their prior probabilities.
117
+ double _expand_leaf_node(Node* node, py::object simulate_env, py::object policy_value_func) {
118
+
119
+ std::map<int, double> action_probs_dict;
120
+ double leaf_value;
121
+ py::tuple result = policy_value_func(simulate_env);
122
+
123
+ action_probs_dict = result[0].cast<std::map<int, double>>();
124
+ leaf_value = result[1].cast<double>();
125
+
126
+
127
+ py::list legal_actions_list = simulate_env.attr("legal_actions").cast<py::list>();
128
+ std::vector<int> legal_actions = legal_actions_list.cast<std::vector<int>>();
129
+
130
+
131
+ for (const auto& kv : action_probs_dict) {
132
+ int action = kv.first;
133
+ double prior_p = kv.second;
134
+ if (std::find(legal_actions.begin(), legal_actions.end(), action) != legal_actions.end()) {
135
+ node->children[action] = new Node(node, prior_p);
136
+ }
137
+ }
138
+
139
+ return leaf_value;
140
+ }
141
+
142
+ // This function returns the next action to take and the probabilities of each action based on the current state and the policy-value function.
143
+ std::pair<int, std::vector<double>> get_next_action(py::object state_config_for_env_reset, py::object policy_value_func, double temperature, bool sample) {
144
+ Node* root = new Node();
145
+
146
+ py::object init_state = state_config_for_env_reset["init_state"];
147
+ if (!init_state.is_none()) {
148
+ init_state = py::bytes(init_state.attr("tobytes")());
149
+ }
150
+ py::object katago_game_state = state_config_for_env_reset["katago_game_state"];
151
+ if (!katago_game_state.is_none()) {
152
+ // TODO(pu): polish efficiency
153
+ katago_game_state = py::module::import("pickle").attr("dumps")(katago_game_state);
154
+ }
155
+ simulate_env.attr("reset")(
156
+ state_config_for_env_reset["start_player_index"].cast<int>(),
157
+ init_state,
158
+ state_config_for_env_reset["katago_policy_init"].cast<bool>(),
159
+ katago_game_state
160
+ );
161
+
162
+ _expand_leaf_node(root, simulate_env, policy_value_func);
163
+ if (sample) {
164
+ _add_exploration_noise(root);
165
+ }
166
+ for (int n = 0; n < num_simulations; ++n) {
167
+ simulate_env.attr("reset")(
168
+ state_config_for_env_reset["start_player_index"].cast<int>(),
169
+ init_state,
170
+ state_config_for_env_reset["katago_policy_init"].cast<bool>(),
171
+ katago_game_state
172
+ );
173
+ simulate_env.attr("battle_mode") = simulate_env.attr("battle_mode_in_simulation_env");
174
+ _simulate(root, simulate_env, policy_value_func);
175
+ }
176
+
177
+ std::vector<std::pair<int, int>> action_visits;
178
+ for (int action = 0; action < simulate_env.attr("action_space").attr("n").cast<int>(); ++action) {
179
+ if (root->children.count(action)) {
180
+ action_visits.push_back(std::make_pair(action, root->children[action]->visit_count));
181
+ } else {
182
+ action_visits.push_back(std::make_pair(action, 0));
183
+ }
184
+ }
185
+
186
+ // Convert 'action_visits' into two separate arrays.
187
+ std::vector<int> actions;
188
+ std::vector<int> visits;
189
+ for (const auto& av : action_visits) {
190
+ actions.push_back(av.first);
191
+ visits.push_back(av.second);
192
+ }
193
+
194
+
195
+ std::vector<double> visits_d(visits.begin(), visits.end());
196
+ std::vector<double> action_probs = visit_count_to_action_distribution(visits_d, temperature);
197
+
198
+ int action;
199
+ if (sample) {
200
+ action = random_choice(actions, action_probs);
201
+ } else {
202
+ action = actions[std::distance(action_probs.begin(), std::max_element(action_probs.begin(), action_probs.end()))];
203
+ }
204
+
205
+
206
+ return std::make_pair(action, action_probs);
207
+ }
208
+
209
+ // This function performs a simulation from a given node until a leaf node is reached or a terminal state is reached.
210
+ void _simulate(Node* node, py::object simulate_env, py::object policy_value_func) {
211
+ while (!node->is_leaf()) {
212
+ int action;
213
+ std::tie(action, node) = _select_child(node, simulate_env);
214
+ if (action == -1) {
215
+ break;
216
+ }
217
+ simulate_env.attr("step")(action);
218
+ }
219
+
220
+ bool done;
221
+ int winner;
222
+ py::tuple result = simulate_env.attr("get_done_winner")();
223
+ done = result[0].cast<bool>();
224
+ winner = result[1].cast<int>();
225
+
226
+ double leaf_value;
227
+ if (!done) {
228
+ leaf_value = _expand_leaf_node(node, simulate_env, policy_value_func);
229
+ }
230
+ else {
231
+ if (simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>() == "self_play_mode") {
232
+ if (winner == -1) {
233
+ leaf_value = 0;
234
+ } else {
235
+ leaf_value = (simulate_env.attr("current_player").cast<int>() == winner) ? 1 : -1;
236
+ }
237
+ }
238
+ else if (simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>() == "play_with_bot_mode") {
239
+ if (winner == -1) {
240
+ leaf_value = 0;
241
+ } else if (winner == 1) {
242
+ leaf_value = 1;
243
+ } else if (winner == 2) {
244
+ leaf_value = -1;
245
+ }
246
+ }
247
+ }
248
+ if (simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>() == "play_with_bot_mode") {
249
+ node->update_recursive(leaf_value, simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>());
250
+ }
251
+ else if (simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>() == "self_play_mode") {
252
+ node->update_recursive(-leaf_value, simulate_env.attr("battle_mode_in_simulation_env").cast<std::string>());
253
+ }
254
+ }
255
+
256
+
257
+
258
+
259
+
260
+ private:
261
+ static std::vector<double> visit_count_to_action_distribution(const std::vector<double>& visits, double temperature) {
262
+ // Check if temperature is 0
263
+ if (temperature == 0) {
264
+ throw std::invalid_argument("Temperature cannot be 0");
265
+ }
266
+
267
+ // Check if all visit counts are 0
268
+ if (std::all_of(visits.begin(), visits.end(), [](double v){ return v == 0; })) {
269
+ throw std::invalid_argument("All visit counts cannot be 0");
270
+ }
271
+
272
+ std::vector<double> normalized_visits(visits.size());
273
+
274
+ // Divide visit counts by temperature
275
+ for (size_t i = 0; i < visits.size(); i++) {
276
+ normalized_visits[i] = visits[i] / temperature;
277
+ }
278
+
279
+ // Calculate the sum of all normalized visit counts
280
+ double sum = std::accumulate(normalized_visits.begin(), normalized_visits.end(), 0.0);
281
+
282
+ // Normalize the visit counts
283
+ for (double& visit : normalized_visits) {
284
+ visit /= sum;
285
+ }
286
+
287
+ return normalized_visits;
288
+ }
289
+
290
+ static std::vector<double> softmax(const std::vector<double>& values, double temperature) {
291
+ std::vector<double> exps;
292
+ double sum = 0.0;
293
+ // Compute the maximum value
294
+ double max_value = *std::max_element(values.begin(), values.end());
295
+
296
+ // Subtract the maximum value before exponentiation, for numerical stability
297
+ for (double v : values) {
298
+ double exp_v = std::exp((v - max_value) / temperature);
299
+ exps.push_back(exp_v);
300
+ sum += exp_v;
301
+ }
302
+
303
+ for (double& exp_v : exps) {
304
+ exp_v /= sum;
305
+ }
306
+
307
+ return exps;
308
+ }
309
+
310
+ static int random_choice(const std::vector<int>& actions, const std::vector<double>& probs) {
311
+ std::random_device rd;
312
+ std::mt19937 gen(rd());
313
+ std::discrete_distribution<> d(probs.begin(), probs.end());
314
+ return actions[d(gen)];
315
+ }
316
+
317
+ };
318
+
319
+ // This function uses pybind11 to expose the Node and MCTS classes to Python.
320
+ // This allows Python code to create and manipulate instances of these classes.
321
+ PYBIND11_MODULE(mcts_alphazero, m) {
322
+ py::class_<Node>(m, "Node")
323
+ .def(py::init([](Node* parent, float prior_p){
324
+ return new Node(parent ? parent : nullptr, prior_p);
325
+ }), py::arg("parent")=nullptr, py::arg("prior_p")=1.0)
326
+ .def_property_readonly("value", &Node::get_value)
327
+ .def("update", &Node::update)
328
+ .def("update_recursive", &Node::update_recursive)
329
+ .def("is_leaf", &Node::is_leaf)
330
+ .def("is_root", &Node::is_root)
331
+ .def("parent", &Node::get_parent)
332
+ .def_readwrite("prior_p", &Node::prior_p)
333
+ .def_readwrite("children", &Node::children)
334
+ .def("add_child", &Node::add_child)
335
+ .def_readwrite("visit_count", &Node::visit_count);
336
+
337
+ py::class_<MCTS>(m, "MCTS")
338
+ .def(py::init<int, int, double, double, double, double, py::object>(),
339
+ py::arg("max_moves")=512, py::arg("num_simulations")=800,
340
+ py::arg("pb_c_base")=19652, py::arg("pb_c_init")=1.25,
341
+ py::arg("root_dirichlet_alpha")=0.3, py::arg("root_noise_weight")=0.25, py::arg("simulate_env"))
342
+ .def("_ucb_score", &MCTS::_ucb_score)
343
+ .def("_add_exploration_noise", &MCTS::_add_exploration_noise)
344
+ .def("_select_child", &MCTS::_select_child)
345
+ .def("_expand_leaf_node", &MCTS::_expand_leaf_node)
346
+ .def("get_next_action", &MCTS::get_next_action)
347
+ .def("_simulate", &MCTS::_simulate);
348
+ }
LightZero/lzero/mcts/ctree/ctree_alphazero/node_alphazero.cpp ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "node_alphazero.h"
2
+ #include <pybind11/pybind11.h>
3
+ #include <pybind11/stl.h>
4
+
5
+ namespace py = pybind11;
6
+
7
+ PYBIND11_MODULE(node_alphazero, m) {
8
+ py::class_<Node>(m, "Node")
9
+ .def(py::init([](Node* parent, float prior_p){
10
+ return new Node(parent ? parent : nullptr, prior_p);
11
+ }), py::arg("parent")=nullptr, py::arg("prior_p")=1.0)
12
+ .def("value", &Node::get_value)
13
+ .def("update", &Node::update)
14
+ .def("update_recursive", &Node::update_recursive)
15
+ .def("is_leaf", &Node::is_leaf)
16
+ .def("is_root", &Node::is_root)
17
+ .def("parent", &Node::get_parent)
18
+ .def("children", &Node::get_children)
19
+ .def_readwrite("children", &Node::children)
20
+ .def("add_child", &Node::add_child)
21
+ .def("visit_count", &Node::get_visit_count);
22
+ }
LightZero/lzero/mcts/ctree/ctree_alphazero/node_alphazero.h ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <map>
2
+ #include <string>
3
+ #include <iostream>
4
+ #include <memory>
5
+ #include <mutex>
6
+
7
+ class Node {
8
+ public:
9
+ // Constructor, initializes a Node with a parent pointer and a prior probability
10
+ Node(Node* parent = nullptr, float prior_p = 1.0)
11
+ : parent(parent), prior_p(prior_p), visit_count(0), value_sum(0.0) {}
12
+
13
+ // Destructor, deletes all child nodes when a node is deleted to prevent memory leaks
14
+ ~Node() {
15
+ for (auto& pair : children) {
16
+ delete pair.second;
17
+ }
18
+ }
19
+
20
+ // Returns the average value of the node
21
+ float get_value() {
22
+ return visit_count == 0 ? 0.0 : value_sum / visit_count;
23
+ }
24
+
25
+ // Updates the visit count and value sum of the node
26
+ void update(float value) {
27
+ visit_count++;
28
+ value_sum += value;
29
+ }
30
+
31
+ // Recursively updates the value and visit count of the node and its parent nodes
32
+ void update_recursive(float leaf_value, std::string battle_mode_in_simulation_env) {
33
+ // If the mode is "self_play_mode", the leaf_value is subtracted from the parent's value
34
+ if (battle_mode_in_simulation_env == "self_play_mode") {
35
+ update(leaf_value);
36
+ if (!is_root()) {
37
+ parent->update_recursive(-leaf_value, battle_mode_in_simulation_env);
38
+ }
39
+ }
40
+ // If the mode is "play_with_bot_mode", the leaf_value is added to the parent's value
41
+ else if (battle_mode_in_simulation_env == "play_with_bot_mode") {
42
+ update(leaf_value);
43
+ if (!is_root()) {
44
+ parent->update_recursive(leaf_value, battle_mode_in_simulation_env);
45
+ }
46
+ }
47
+ }
48
+
49
+ // Returns true if the node has no children
50
+ bool is_leaf() {
51
+ return children.empty();
52
+ }
53
+
54
+ // Returns true if the node has no parent
55
+ bool is_root() {
56
+ return parent == nullptr;
57
+ }
58
+
59
+ // Returns a pointer to the node's parent
60
+ Node* get_parent() {
61
+ return parent;
62
+ }
63
+
64
+ // Returns a map of the node's children
65
+ std::map<int, Node*> get_children() {
66
+ return children;
67
+ }
68
+
69
+ // Returns the node's visit count
70
+ int get_visit_count() {
71
+ return visit_count;
72
+ }
73
+
74
+ // Adds a child to the node
75
+ void add_child(int action, Node* node) {
76
+ children[action] = node;
77
+ }
78
+
79
+ public:
80
+ Node* parent; // Pointer to the parent node
81
+ float prior_p; // Prior probability of the node
82
+ int visit_count; // Count of visits to the node
83
+ float value_sum; // Sum of values of the node
84
+ std::map<int, Node*> children; // Map of child nodes
85
+ };