// C++11 #ifndef CNODE_H #define CNODE_H #include "../../common_lib/cminimax.h" #include #include #include #include #include #include #include #include #include const int DEBUG_MODE = 0; namespace tree { // sampled related core code class CAction { public: std::vector value; std::vector hash; int is_root_action; CAction(); CAction(std::vector value, int is_root_action); ~CAction(); std::vector get_hash(void); std::size_t get_combined_hash(void); }; class CNode { public: int visit_count, to_play, current_latent_state_index, batch_index, is_reset, action_space_size; // sampled related core code CAction best_action; int num_of_sampled_actions; float value_prefix, prior, value_sum; float parent_value_prefix; bool continuous_action_space; std::vector children_index; std::map children; std::vector legal_actions; CNode(); // sampled related core code CNode(float prior, std::vector &legal_actions, int action_space_size, int num_of_sampled_actions, bool continuous_action_space); ~CNode(); void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector &policy_logits); void add_exploration_noise(float exploration_fraction, const std::vector &noises); float compute_mean_q(int isRoot, float parent_q, float discount_factor); void print_out(); int expanded(); float value(); // sampled related core code std::vector > get_trajectory(); std::vector get_children_distribution(); CNode *get_child(CAction action); }; class CRoots { public: int root_num; int num_of_sampled_actions; int action_space_size; std::vector roots; std::vector > legal_actions_list; bool continuous_action_space; CRoots(); CRoots(int root_num, std::vector > legal_actions_list, int action_space_size, int num_of_sampled_actions, bool continuous_action_space); ~CRoots(); void prepare(float root_noise_weight, const std::vector > &noises, const std::vector &value_prefixs, const std::vector > &policies, std::vector &to_play_batch); void prepare_no_noise(const std::vector &value_prefixs, const std::vector > &policies, std::vector &to_play_batch); void clear(); // sampled related core code std::vector > > get_trajectories(); std::vector > > get_sampled_actions(); std::vector > get_distributions(); std::vector get_values(); }; class CSearchResults { public: int num; std::vector latent_state_index_in_search_path, latent_state_index_in_batch, search_lens; std::vector virtual_to_play_batchs; std::vector > last_actions; std::vector nodes; std::vector > search_paths; CSearchResults(); CSearchResults(int num); ~CSearchResults(); }; //********************************************************* void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players); void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor); void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch); CAction cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, bool continuous_action_space); float cucb_score(CNode *parent, CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players, bool continuous_action_space); void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, bool continuous_action_space); } #endif