|
|
|
|
|
#include <iostream> |
|
#include "cnode.h" |
|
#include <algorithm> |
|
#include <map> |
|
#include <random> |
|
#include <chrono> |
|
#include <iostream> |
|
#include <vector> |
|
#include <stack> |
|
#include <math.h> |
|
|
|
#include <stdlib.h> |
|
#include <time.h> |
|
#include <cmath> |
|
#include <sys/timeb.h> |
|
#include <time.h> |
|
#include <cassert> |
|
|
|
#ifdef _WIN32 |
|
#include "..\..\common_lib\utils.cpp" |
|
#else |
|
#include "../../common_lib/utils.cpp" |
|
#endif |
|
|
|
|
|
|
|
template <class T> |
|
size_t hash_combine(std::size_t &seed, const T &val) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::hash<T> hasher; |
|
seed ^= hasher(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2); |
|
return seed; |
|
} |
|
|
|
|
|
bool cmp(std::pair<int, double> x, std::pair<int, double> y) |
|
{ |
|
return x.second > y.second; |
|
} |
|
|
|
namespace tree |
|
{ |
|
|
|
|
|
CAction::CAction() |
|
{ |
|
|
|
|
|
|
|
|
|
this->is_root_action = 0; |
|
} |
|
|
|
CAction::CAction(std::vector<float> value, int is_root_action) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this->value = value; |
|
this->is_root_action = is_root_action; |
|
} |
|
|
|
CAction::~CAction() {} |
|
|
|
std::vector<size_t> CAction::get_hash(void) |
|
{ |
|
|
|
|
|
|
|
|
|
std::vector<size_t> hash; |
|
for (int i = 0; i < this->value.size(); ++i) |
|
{ |
|
std::size_t hash_i = std::hash<std::string>()(std::to_string(this->value[i])); |
|
hash.push_back(hash_i); |
|
} |
|
return hash; |
|
} |
|
size_t CAction::get_combined_hash(void) |
|
{ |
|
|
|
|
|
|
|
|
|
std::vector<size_t> hash = this->get_hash(); |
|
size_t combined_hash = hash[0]; |
|
|
|
if (hash.size() >= 1) |
|
{ |
|
for (int i = 1; i < hash.size(); ++i) |
|
{ |
|
combined_hash = hash_combine(combined_hash, hash[i]); |
|
} |
|
} |
|
|
|
return combined_hash; |
|
} |
|
|
|
|
|
|
|
CSearchResults::CSearchResults() |
|
{ |
|
|
|
|
|
|
|
|
|
this->num = 0; |
|
} |
|
|
|
CSearchResults::CSearchResults(int num) |
|
{ |
|
|
|
|
|
|
|
|
|
this->num = num; |
|
for (int i = 0; i < num; ++i) |
|
{ |
|
this->search_paths.push_back(std::vector<CNode *>()); |
|
} |
|
} |
|
|
|
CSearchResults::~CSearchResults() {} |
|
|
|
|
|
|
|
CNode::CNode() |
|
{ |
|
|
|
|
|
|
|
|
|
this->prior = 0; |
|
this->action_space_size = 9; |
|
this->num_of_sampled_actions = 20; |
|
this->continuous_action_space = false; |
|
|
|
this->is_reset = 0; |
|
this->visit_count = 0; |
|
this->value_sum = 0; |
|
CAction best_action; |
|
this->best_action = best_action; |
|
|
|
this->to_play = 0; |
|
this->value_prefix = 0.0; |
|
this->parent_value_prefix = 0.0; |
|
} |
|
|
|
CNode::CNode(float prior, std::vector<CAction> &legal_actions, int action_space_size, int num_of_sampled_actions, bool continuous_action_space) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this->prior = prior; |
|
this->legal_actions = legal_actions; |
|
|
|
this->action_space_size = action_space_size; |
|
this->num_of_sampled_actions = num_of_sampled_actions; |
|
this->continuous_action_space = continuous_action_space; |
|
this->is_reset = 0; |
|
this->visit_count = 0; |
|
this->value_sum = 0; |
|
this->to_play = 0; |
|
this->value_prefix = 0.0; |
|
this->parent_value_prefix = 0.0; |
|
this->current_latent_state_index = -1; |
|
this->batch_index = -1; |
|
} |
|
|
|
CNode::~CNode() {} |
|
|
|
|
|
void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector<float> &policy_logits) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this->to_play = to_play; |
|
this->current_latent_state_index = current_latent_state_index; |
|
this->batch_index = batch_index; |
|
this->value_prefix = value_prefix; |
|
int action_num = policy_logits.size(); |
|
|
|
#ifdef _WIN32 |
|
|
|
float* policy = new float[action_num]; |
|
#else |
|
float policy[action_num]; |
|
#endif |
|
|
|
std::vector<int> all_actions; |
|
for (int i = 0; i < action_num; ++i) |
|
{ |
|
all_actions.push_back(i); |
|
} |
|
std::vector<std::vector<float> > sampled_actions_after_tanh; |
|
std::vector<float> sampled_actions_log_probs_after_tanh; |
|
|
|
std::vector<int> sampled_actions; |
|
std::vector<float> sampled_actions_log_probs; |
|
std::vector<float> sampled_actions_probs; |
|
std::vector<float> probs; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (this->continuous_action_space == true) |
|
{ |
|
|
|
this->action_space_size = policy_logits.size() / 2; |
|
std::vector<float> mu; |
|
std::vector<float> sigma; |
|
for (int i = 0; i < this->action_space_size; ++i) |
|
{ |
|
mu.push_back(policy_logits[i]); |
|
sigma.push_back(policy_logits[this->action_space_size + i]); |
|
} |
|
|
|
|
|
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); |
|
|
|
|
|
std::vector<std::vector<float> > sampled_actions_before_tanh; |
|
|
|
float sampled_action_one_dim_before_tanh; |
|
std::vector<float> sampled_actions_log_probs_before_tanh; |
|
|
|
std::default_random_engine generator(seed); |
|
for (int i = 0; i < this->num_of_sampled_actions; ++i) |
|
{ |
|
float sampled_action_prob_before_tanh = 1; |
|
|
|
std::vector<float> sampled_action_before_tanh; |
|
std::vector<float> sampled_action_after_tanh; |
|
std::vector<float> y; |
|
|
|
for (int j = 0; j < this->action_space_size; ++j) |
|
{ |
|
std::normal_distribution<float> distribution(mu[j], sigma[j]); |
|
sampled_action_one_dim_before_tanh = distribution(generator); |
|
|
|
sampled_action_prob_before_tanh *= exp(-pow((sampled_action_one_dim_before_tanh - mu[j]), 2) / (2 * pow(sigma[j], 2)) - log(sigma[j]) - log(sqrt(2 * M_PI))); |
|
sampled_action_before_tanh.push_back(sampled_action_one_dim_before_tanh); |
|
sampled_action_after_tanh.push_back(tanh(sampled_action_one_dim_before_tanh)); |
|
y.push_back(1 - pow(tanh(sampled_action_one_dim_before_tanh), 2) + 1e-6); |
|
} |
|
sampled_actions_before_tanh.push_back(sampled_action_before_tanh); |
|
sampled_actions_after_tanh.push_back(sampled_action_after_tanh); |
|
sampled_actions_log_probs_before_tanh.push_back(log(sampled_action_prob_before_tanh)); |
|
float y_sum = std::accumulate(y.begin(), y.end(), 0.); |
|
sampled_actions_log_probs_after_tanh.push_back(log(sampled_action_prob_before_tanh) - log(y_sum)); |
|
} |
|
} |
|
else |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<CAction> legal_actions; |
|
|
|
|
|
float logits_exp_sum = 0; |
|
for (int i = 0; i < policy_logits.size(); ++i) |
|
{ |
|
logits_exp_sum += exp(policy_logits[i]); |
|
} |
|
for (int i = 0; i < policy_logits.size(); ++i) |
|
{ |
|
probs.push_back(exp(policy_logits[i]) / (logits_exp_sum + 1e-6)); |
|
} |
|
|
|
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
std::default_random_engine generator(seed); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::uniform_real_distribution<double> uniform_distribution(0.0, 1.0); |
|
|
|
std::vector<double> disturbed_probs; |
|
std::vector<std::pair<int, double> > disc_action_with_probs; |
|
|
|
|
|
|
|
for (auto prob : probs) |
|
{ |
|
disturbed_probs.push_back(std::pow(uniform_distribution(generator), 1. / prob)); |
|
} |
|
|
|
|
|
|
|
for (size_t iter = 0; iter < disturbed_probs.size(); iter++) |
|
{ |
|
|
|
#ifdef __GNUC__ |
|
|
|
disc_action_with_probs.push_back(std::make_pair(iter, disturbed_probs[iter])); |
|
#else |
|
|
|
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter])); |
|
#endif |
|
} |
|
|
|
std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp); |
|
|
|
|
|
for (int k = 0; k < num_of_sampled_actions; ++k) |
|
{ |
|
sampled_actions.push_back(disc_action_with_probs[k].first); |
|
|
|
|
|
sampled_actions_probs.push_back(probs[disc_action_with_probs[k].first]); |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
disturbed_probs.clear(); |
|
disc_action_with_probs.clear(); |
|
} |
|
|
|
float prior; |
|
for (int i = 0; i < this->num_of_sampled_actions; ++i) |
|
{ |
|
|
|
if (this->continuous_action_space == true) |
|
{ |
|
CAction action = CAction(sampled_actions_after_tanh[i], 0); |
|
std::vector<CAction> legal_actions; |
|
this->children[action.get_combined_hash()] = CNode(sampled_actions_log_probs_after_tanh[i], legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space); |
|
this->legal_actions.push_back(action); |
|
} |
|
else |
|
{ |
|
std::vector<float> sampled_action_tmp; |
|
for (size_t iter = 0; iter < 1; iter++) |
|
{ |
|
sampled_action_tmp.push_back(float(sampled_actions[i])); |
|
} |
|
CAction action = CAction(sampled_action_tmp, 0); |
|
std::vector<CAction> legal_actions; |
|
this->children[action.get_combined_hash()] = CNode(sampled_actions_probs[i], legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space); |
|
this->legal_actions.push_back(action); |
|
} |
|
} |
|
|
|
#ifdef _WIN32 |
|
|
|
delete[] policy; |
|
#else |
|
#endif |
|
} |
|
|
|
void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float noise, prior; |
|
for (int i = 0; i < this->num_of_sampled_actions; ++i) |
|
{ |
|
|
|
noise = noises[i]; |
|
CNode *child = this->get_child(this->legal_actions[i]); |
|
prior = child->prior; |
|
if (this->continuous_action_space == true) |
|
{ |
|
|
|
child->prior = log(exp(prior) * (1 - exploration_fraction) + noise * exploration_fraction + 1e-6); |
|
} |
|
else |
|
{ |
|
|
|
child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction; |
|
} |
|
} |
|
} |
|
|
|
float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float total_unsigned_q = 0.0; |
|
int total_visits = 0; |
|
float parent_value_prefix = this->value_prefix; |
|
for (auto a : this->legal_actions) |
|
{ |
|
CNode *child = this->get_child(a); |
|
if (child->visit_count > 0) |
|
{ |
|
float true_reward = child->value_prefix - parent_value_prefix; |
|
if (this->is_reset == 1) |
|
{ |
|
true_reward = child->value_prefix; |
|
} |
|
float qsa = true_reward + discount_factor * child->value(); |
|
total_unsigned_q += qsa; |
|
total_visits += 1; |
|
} |
|
} |
|
|
|
float mean_q = 0.0; |
|
if (isRoot && total_visits > 0) |
|
{ |
|
mean_q = (total_unsigned_q) / (total_visits); |
|
} |
|
else |
|
{ |
|
mean_q = (parent_q + total_unsigned_q) / (total_visits + 1); |
|
} |
|
return mean_q; |
|
} |
|
|
|
void CNode::print_out() |
|
{ |
|
return; |
|
} |
|
|
|
int CNode::expanded() |
|
{ |
|
|
|
|
|
|
|
|
|
return this->children.size() > 0; |
|
} |
|
|
|
float CNode::value() |
|
{ |
|
|
|
|
|
|
|
|
|
float true_value = 0.0; |
|
if (this->visit_count == 0) |
|
{ |
|
return true_value; |
|
} |
|
else |
|
{ |
|
true_value = this->value_sum / this->visit_count; |
|
return true_value; |
|
} |
|
} |
|
|
|
std::vector<std::vector<float> > CNode::get_trajectory() |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<CAction> traj; |
|
|
|
CNode *node = this; |
|
CAction best_action = node->best_action; |
|
while (best_action.is_root_action != 1) |
|
{ |
|
traj.push_back(best_action); |
|
node = node->get_child(best_action); |
|
best_action = node->best_action; |
|
} |
|
|
|
std::vector<std::vector<float> > traj_return; |
|
for (int i = 0; i < traj.size(); ++i) |
|
{ |
|
traj_return.push_back(traj[i].value); |
|
} |
|
return traj_return; |
|
} |
|
|
|
std::vector<int> CNode::get_children_distribution() |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int> distribution; |
|
if (this->expanded()) |
|
{ |
|
for (auto a : this->legal_actions) |
|
{ |
|
CNode *child = this->get_child(a); |
|
distribution.push_back(child->visit_count); |
|
} |
|
} |
|
return distribution; |
|
} |
|
|
|
CNode *CNode::get_child(CAction action) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
return &(this->children[action.get_combined_hash()]); |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
CRoots::CRoots() |
|
{ |
|
this->root_num = 0; |
|
this->num_of_sampled_actions = 20; |
|
} |
|
|
|
CRoots::CRoots(int root_num, std::vector<std::vector<float> > legal_actions_list, int action_space_size, int num_of_sampled_actions, bool continuous_action_space) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
this->root_num = root_num; |
|
this->legal_actions_list = legal_actions_list; |
|
this->continuous_action_space = continuous_action_space; |
|
|
|
|
|
this->num_of_sampled_actions = num_of_sampled_actions; |
|
this->action_space_size = action_space_size; |
|
|
|
for (int i = 0; i < this->root_num; ++i) |
|
{ |
|
if (this->continuous_action_space == true and this->legal_actions_list[0][0] == -1) |
|
{ |
|
|
|
std::vector<CAction> legal_actions; |
|
this->roots.push_back(CNode(0, legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space)); |
|
} |
|
else if (this->continuous_action_space == false or this->legal_actions_list[0][0] == -1) |
|
{ |
|
|
|
|
|
std::vector<CAction> legal_actions; |
|
this->roots.push_back(CNode(0, legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space)); |
|
} |
|
|
|
else |
|
{ |
|
|
|
std::vector<CAction> c_legal_actions; |
|
for (int i = 0; i < this->legal_actions_list.size(); ++i) |
|
{ |
|
CAction c_legal_action = CAction(legal_actions_list[i], 0); |
|
c_legal_actions.push_back(c_legal_action); |
|
} |
|
this->roots.push_back(CNode(0, c_legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space)); |
|
} |
|
} |
|
} |
|
|
|
CRoots::~CRoots() {} |
|
|
|
void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < this->root_num; ++i) |
|
{ |
|
this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]); |
|
this->roots[i].add_exploration_noise(root_noise_weight, noises[i]); |
|
this->roots[i].visit_count += 1; |
|
} |
|
} |
|
|
|
void CRoots::prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < this->root_num; ++i) |
|
{ |
|
this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]); |
|
|
|
this->roots[i].visit_count += 1; |
|
} |
|
} |
|
|
|
void CRoots::clear() |
|
{ |
|
this->roots.clear(); |
|
} |
|
|
|
std::vector<std::vector<std::vector<float> > > CRoots::get_trajectories() |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<std::vector<float> > > trajs; |
|
trajs.reserve(this->root_num); |
|
|
|
for (int i = 0; i < this->root_num; ++i) |
|
{ |
|
trajs.push_back(this->roots[i].get_trajectory()); |
|
} |
|
return trajs; |
|
} |
|
|
|
std::vector<std::vector<int> > CRoots::get_distributions() |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<int> > distributions; |
|
distributions.reserve(this->root_num); |
|
|
|
for (int i = 0; i < this->root_num; ++i) |
|
{ |
|
distributions.push_back(this->roots[i].get_children_distribution()); |
|
} |
|
return distributions; |
|
} |
|
|
|
|
|
std::vector<std::vector<std::vector<float> > > CRoots::get_sampled_actions() |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<CAction> > sampled_actions; |
|
std::vector<std::vector<std::vector<float> > > python_sampled_actions; |
|
|
|
|
|
|
|
for (int i = 0; i < this->root_num; ++i) |
|
{ |
|
std::vector<CAction> sampled_action; |
|
sampled_action = this->roots[i].legal_actions; |
|
std::vector<std::vector<float> > python_sampled_action; |
|
|
|
for (int j = 0; j < this->roots[i].legal_actions.size(); ++j) |
|
{ |
|
python_sampled_action.push_back(sampled_action[j].value); |
|
} |
|
python_sampled_actions.push_back(python_sampled_action); |
|
} |
|
|
|
return python_sampled_actions; |
|
} |
|
|
|
std::vector<float> CRoots::get_values() |
|
{ |
|
|
|
|
|
|
|
|
|
std::vector<float> values; |
|
for (int i = 0; i < this->root_num; ++i) |
|
{ |
|
values.push_back(this->roots[i].value()); |
|
} |
|
return values; |
|
} |
|
|
|
|
|
|
|
void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::stack<CNode *> node_stack; |
|
node_stack.push(root); |
|
float parent_value_prefix = 0.0; |
|
int is_reset = 0; |
|
while (node_stack.size() > 0) |
|
{ |
|
CNode *node = node_stack.top(); |
|
node_stack.pop(); |
|
|
|
if (node != root) |
|
{ |
|
|
|
|
|
|
|
float true_reward = node->value_prefix - node->parent_value_prefix; |
|
|
|
if (is_reset == 1) |
|
{ |
|
true_reward = node->value_prefix; |
|
} |
|
float qsa; |
|
if (players == 1) |
|
qsa = true_reward + discount_factor * node->value(); |
|
else if (players == 2) |
|
|
|
qsa = true_reward + discount_factor * (-1) * node->value(); |
|
|
|
min_max_stats.update(qsa); |
|
} |
|
|
|
for (auto a : node->legal_actions) |
|
{ |
|
CNode *child = node->get_child(a); |
|
if (child->expanded()) |
|
{ |
|
child->parent_value_prefix = node->value_prefix; |
|
node_stack.push(child); |
|
} |
|
} |
|
|
|
is_reset = node->is_reset; |
|
} |
|
} |
|
|
|
void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert(to_play == -1 || to_play == 1 || to_play == 2); |
|
if (to_play == -1) |
|
{ |
|
|
|
float bootstrap_value = value; |
|
int path_len = search_path.size(); |
|
for (int i = path_len - 1; i >= 0; --i) |
|
{ |
|
CNode *node = search_path[i]; |
|
node->value_sum += bootstrap_value; |
|
node->visit_count += 1; |
|
|
|
float parent_value_prefix = 0.0; |
|
int is_reset = 0; |
|
if (i >= 1) |
|
{ |
|
CNode *parent = search_path[i - 1]; |
|
parent_value_prefix = parent->value_prefix; |
|
is_reset = parent->is_reset; |
|
} |
|
|
|
float true_reward = node->value_prefix - parent_value_prefix; |
|
min_max_stats.update(true_reward + discount_factor * node->value()); |
|
|
|
if (is_reset == 1) |
|
{ |
|
|
|
true_reward = node->value_prefix; |
|
} |
|
|
|
bootstrap_value = true_reward + discount_factor * bootstrap_value; |
|
} |
|
} |
|
else |
|
{ |
|
|
|
float bootstrap_value = value; |
|
int path_len = search_path.size(); |
|
for (int i = path_len - 1; i >= 0; --i) |
|
{ |
|
CNode *node = search_path[i]; |
|
if (node->to_play == to_play) |
|
node->value_sum += bootstrap_value; |
|
else |
|
node->value_sum += -bootstrap_value; |
|
node->visit_count += 1; |
|
|
|
float parent_value_prefix = 0.0; |
|
int is_reset = 0; |
|
if (i >= 1) |
|
{ |
|
CNode *parent = search_path[i - 1]; |
|
parent_value_prefix = parent->value_prefix; |
|
is_reset = parent->is_reset; |
|
} |
|
|
|
|
|
|
|
float true_reward = node->value_prefix - parent_value_prefix; |
|
|
|
min_max_stats.update(true_reward + discount_factor * node->value()); |
|
|
|
if (is_reset == 1) |
|
{ |
|
|
|
true_reward = node->value_prefix; |
|
} |
|
if (node->to_play == to_play) |
|
bootstrap_value = -true_reward + discount_factor * bootstrap_value; |
|
else |
|
bootstrap_value = true_reward + discount_factor * bootstrap_value; |
|
} |
|
} |
|
} |
|
|
|
void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_list, std::vector<int> &to_play_batch) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < results.num; ++i) |
|
{ |
|
results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i]); |
|
|
|
results.nodes[i]->is_reset = is_reset_list[i]; |
|
|
|
cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor); |
|
} |
|
} |
|
|
|
CAction cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, bool continuous_action_space) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float max_score = FLOAT_MIN; |
|
const float epsilon = 0.000001; |
|
std::vector<CAction> max_index_lst; |
|
for (auto a : root->legal_actions) |
|
{ |
|
|
|
CNode *child = root->get_child(a); |
|
|
|
float temp_score = cucb_score(root, child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players, continuous_action_space); |
|
|
|
if (max_score < temp_score) |
|
{ |
|
max_score = temp_score; |
|
|
|
max_index_lst.clear(); |
|
max_index_lst.push_back(a); |
|
} |
|
else if (temp_score >= max_score - epsilon) |
|
{ |
|
max_index_lst.push_back(a); |
|
} |
|
} |
|
|
|
|
|
CAction action; |
|
if (max_index_lst.size() > 0) |
|
{ |
|
int rand_index = rand() % max_index_lst.size(); |
|
action = max_index_lst[rand_index]; |
|
} |
|
return action; |
|
} |
|
|
|
|
|
float cucb_score(CNode *parent, CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players, bool continuous_action_space) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; |
|
pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init; |
|
pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1)); |
|
|
|
|
|
|
|
|
|
|
|
std::string empirical_distribution_type = "density"; |
|
if (empirical_distribution_type.compare("density")) |
|
{ |
|
if (continuous_action_space == true) |
|
{ |
|
float empirical_prob_sum = 0; |
|
for (int i = 0; i < parent->children.size(); ++i) |
|
{ |
|
empirical_prob_sum += exp(parent->get_child(parent->legal_actions[i])->prior); |
|
} |
|
prior_score = pb_c * exp(child->prior) / (empirical_prob_sum + 1e-6); |
|
} |
|
else |
|
{ |
|
float empirical_prob_sum = 0; |
|
for (int i = 0; i < parent->children.size(); ++i) |
|
{ |
|
empirical_prob_sum += parent->get_child(parent->legal_actions[i])->prior; |
|
} |
|
prior_score = pb_c * child->prior / (empirical_prob_sum + 1e-6); |
|
} |
|
} |
|
else if (empirical_distribution_type.compare("uniform")) |
|
{ |
|
prior_score = pb_c * 1 / parent->children.size(); |
|
} |
|
|
|
if (child->visit_count == 0) |
|
{ |
|
value_score = parent_mean_q; |
|
} |
|
else |
|
{ |
|
float true_reward = child->value_prefix - parent_value_prefix; |
|
if (is_reset == 1) |
|
{ |
|
true_reward = child->value_prefix; |
|
} |
|
|
|
if (players == 1) |
|
value_score = true_reward + discount_factor * child->value(); |
|
else if (players == 2) |
|
value_score = true_reward + discount_factor * (-child->value()); |
|
} |
|
|
|
value_score = min_max_stats.normalize(value_score); |
|
|
|
if (value_score < 0) |
|
value_score = 0; |
|
if (value_score > 1) |
|
value_score = 1; |
|
|
|
float ucb_value = prior_score + value_score; |
|
return ucb_value; |
|
} |
|
|
|
void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch, bool continuous_action_space) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
get_time_and_set_rand_seed(); |
|
|
|
std::vector<float> null_value; |
|
for (int i = 0; i < 1; ++i) |
|
{ |
|
null_value.push_back(i + 0.1); |
|
} |
|
|
|
std::vector<float> last_action; |
|
float parent_q = 0.0; |
|
results.search_lens = std::vector<int>(); |
|
|
|
int players = 0; |
|
int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); |
|
if (largest_element == -1) |
|
players = 1; |
|
else |
|
players = 2; |
|
|
|
for (int i = 0; i < results.num; ++i) |
|
{ |
|
CNode *node = &(roots->roots[i]); |
|
int is_root = 1; |
|
int search_len = 0; |
|
results.search_paths[i].push_back(node); |
|
|
|
while (node->expanded()) |
|
{ |
|
float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor); |
|
is_root = 0; |
|
parent_q = mean_q; |
|
|
|
CAction action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players, continuous_action_space); |
|
if (players > 1) |
|
{ |
|
assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2); |
|
if (virtual_to_play_batch[i] == 1) |
|
virtual_to_play_batch[i] = 2; |
|
else |
|
virtual_to_play_batch[i] = 1; |
|
} |
|
|
|
node->best_action = action; |
|
|
|
node = node->get_child(action); |
|
last_action = action.value; |
|
|
|
results.search_paths[i].push_back(node); |
|
search_len += 1; |
|
} |
|
|
|
CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2]; |
|
|
|
results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index); |
|
results.latent_state_index_in_batch.push_back(parent->batch_index); |
|
|
|
results.last_actions.push_back(last_action); |
|
results.search_lens.push_back(search_len); |
|
results.nodes.push_back(node); |
|
results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); |
|
} |
|
} |
|
|
|
} |
|
|