Skip to content

Commit

Permalink
polish(pu): polish alphazero ctree unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Jan 20, 2025
1 parent aa122d0 commit 5dbe736
Show file tree
Hide file tree
Showing 10 changed files with 558 additions and 314 deletions.
5 changes: 3 additions & 2 deletions lzero/mcts/ctree/ctree_alphazero/make.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
# navigating into it, running cmake to generate build files suitable for the arm64 architecture,
# and running make to compile the project.

# Navigate to the project directory
cd /Users/puyuan/code/LightZero/lzero/mcts/ctree/ctree_alphazero/ || exit
# Navigate to the project directory.
# ========= NOTE: PLEASE MODIFY THE FOLLOWING DIRECTORY TO YOUR OWN. =========
cd /YOUR_LightZero_DIR/LightZero/lzero/mcts/ctree/ctree_alphazero/ || exit

# Create a new directory named "build." The build directory is where the compiled files will be stored.
mkdir -p build
Expand Down
80 changes: 54 additions & 26 deletions lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// mcts_alphazero.cpp

#include "node_alphazero.h"
#include <cmath>
#include <map>
Expand All @@ -14,20 +12,20 @@

namespace py = pybind11;

// 使用 shared_ptr 管理节点
// The MCTS class implements Monte Carlo Tree Search (MCTS) for AlphaZero-like algorithms.
class MCTS {

private:
int max_moves;
int num_simulations;
double pb_c_base;
double pb_c_init;
double root_dirichlet_alpha;
double root_noise_weight;
py::object simulate_env;

int max_moves; // Maximum allowed moves in a game
int num_simulations; // Number of MCTS simulations
double pb_c_base; // Coefficient for UCB exploration term (base)
double pb_c_init; // Coefficient for UCB exploration term (initial value)
double root_dirichlet_alpha; // Alpha parameter for Dirichlet noise
double root_noise_weight; // Weight for exploration noise added to root node
py::object simulate_env; // Python object representing the simulation environment

public:
// Constructor to initialize MCTS with optional parameters
MCTS(int max_moves=512, int num_simulations=800,
double pb_c_base=19652, double pb_c_init=1.25,
double root_dirichlet_alpha=0.3, double root_noise_weight=0.25, py::object simulate_env=py::none())
Expand All @@ -37,35 +35,43 @@ class MCTS {
root_noise_weight(root_noise_weight),
simulate_env(simulate_env) {}

// Getter for simulate_env
// Getter for the simulation environment (Python object)
py::object get_simulate_env() const {
return simulate_env;
}

// Setter for simulate_env
// Setter for the simulation environment
void set_simulate_env(py::object env) {
simulate_env = env;
}

// 添加 pb_c_base 和 pb_c_init 的 getter 方法
// Getter for pb_c_base
double get_pb_c_base() const { return pb_c_base; }

// Getter for pb_c_init
double get_pb_c_init() const { return pb_c_init; }

// Calculate the Upper Confidence Bound (UCB) score for child nodes
double _ucb_score(std::shared_ptr<Node> parent, std::shared_ptr<Node> child) {
// Calculate PB-C component of UCB
double pb_c = std::log((parent->visit_count + pb_c_base + 1) / pb_c_base) + pb_c_init;
pb_c *= std::sqrt(parent->visit_count) / (child->visit_count + 1);

// Combine prior probability and value score
double prior_score = pb_c * child->prior_p;
double value_score = child->get_value();
return prior_score + value_score;
}

// Add Dirichlet noise to the root node for exploration
void _add_exploration_noise(std::shared_ptr<Node> node) {
std::vector<int> actions;
// Collect all child actions of the root node
for (const auto& kv : node->children) {
actions.push_back(kv.first);
}

// Generate Dirichlet noise
std::default_random_engine generator;
std::gamma_distribution<double> distribution(root_dirichlet_alpha, 1.0);

Expand All @@ -77,31 +83,38 @@ class MCTS {
sum += sample;
}

// Normalize the noise
for (size_t i = 0; i < noise.size(); ++i) {
noise[i] /= sum;
}

// Mix noise with prior probabilities
double frac = root_noise_weight;
for (size_t i = 0; i < actions.size(); ++i) {
node->children[actions[i]]->prior_p = node->children[actions[i]]->prior_p * (1 - frac) + noise[i] * frac;
}
}

// Select the best child node based on UCB score
std::pair<int, std::shared_ptr<Node>> _select_child(std::shared_ptr<Node> node, py::object simulate_env) {
int action = -1;
std::shared_ptr<Node> child = nullptr;
double best_score = -9999999;

// Iterate through all children
for (const auto& kv : node->children) {
int action_tmp = kv.first;
std::shared_ptr<Node> child_tmp = kv.second;

// Get legal actions from the simulation environment
py::list legal_actions_py = simulate_env.attr("legal_actions").cast<py::list>();

std::vector<int> legal_actions;
for (py::handle h : legal_actions_py) {
legal_actions.push_back(h.cast<int>());
}

// Check if the action is legal and calculate UCB score
if (std::find(legal_actions.begin(), legal_actions.end(), action_tmp) != legal_actions.end()) {
double score = _ucb_score(node, child_tmp);
if (score > best_score) {
Expand All @@ -111,23 +124,28 @@ class MCTS {
}
}
}
// If no valid child is found, return the current node
if (child == nullptr) {
child = node;
}
return std::make_pair(action, child);
}

// Expand a leaf node by adding its children based on policy probabilities
double _expand_leaf_node(std::shared_ptr<Node> node, py::object simulate_env, py::object policy_value_func) {
std::map<int, double> action_probs_dict;
double leaf_value;
py::tuple result = policy_value_func(simulate_env);

// Call the policy-value function to get action probabilities and leaf value
py::tuple result = policy_value_func(simulate_env);
action_probs_dict = result[0].cast<std::map<int, double>>();
leaf_value = result[1].cast<double>();

// Get the legal actions from the simulation environment
py::list legal_actions_list = simulate_env.attr("legal_actions").cast<py::list>();
std::vector<int> legal_actions = legal_actions_list.cast<std::vector<int>>();

// Add child nodes for legal actions
for (const auto& kv : action_probs_dict) {
int action = kv.first;
double prior_p = kv.second;
Expand All @@ -139,9 +157,11 @@ class MCTS {
return leaf_value;
}

std::pair<int, std::vector<double>> get_next_action(py::object state_config_for_env_reset, py::object policy_value_func, double temperature, bool sample) {
// Main function to get the next action from MCTS
std::tuple<int, std::vector<double>, std::shared_ptr<Node>> get_next_action(py::object state_config_for_env_reset, py::object policy_value_func, double temperature, bool sample) {
std::shared_ptr<Node> root = std::make_shared<Node>();

// Configure initial environment state
py::object init_state = state_config_for_env_reset["init_state"];
if (!init_state.is_none()) {
init_state = py::bytes(init_state.attr("tobytes")());
Expand All @@ -157,10 +177,13 @@ class MCTS {
katago_game_state
);

// Expand the root node
_expand_leaf_node(root, simulate_env, policy_value_func);
if (sample) {
_add_exploration_noise(root);
}

// Run MCTS simulations
for (int n = 0; n < num_simulations; ++n) {
simulate_env.attr("reset")(
state_config_for_env_reset["start_player_index"].cast<int>(),
Expand All @@ -172,6 +195,7 @@ class MCTS {
_simulate(root, simulate_env, policy_value_func);
}

// Collect visit counts from the root's children
std::vector<std::pair<int, int>> action_visits;
for (int action = 0; action < simulate_env.attr("action_space").attr("n").cast<int>(); ++action) {
if (root->children.count(action)) {
Expand Down Expand Up @@ -200,9 +224,11 @@ class MCTS {
action_selected = actions[std::distance(action_probs.begin(), std::max_element(action_probs.begin(), action_probs.end()))];
}

return std::make_pair(action_selected, action_probs);
// Return the selected action, action probabilities, and root node
return std::make_tuple(action_selected, action_probs, root);
}

// Simulate a game starting from a given node
void _simulate(std::shared_ptr<Node> node, py::object simulate_env, py::object policy_value_func) {
while (!node->is_leaf()) {
int action;
Expand Down Expand Up @@ -257,6 +283,7 @@ class MCTS {
}

private:
// Helper: Convert visit counts to action probabilities using temperature
static std::vector<double> visit_count_to_action_distribution(const std::vector<double>& visits, double temperature) {
if (temperature == 0) {
throw std::invalid_argument("Temperature cannot be 0");
Expand All @@ -281,6 +308,7 @@ class MCTS {
return normalized_visits;
}

// Helper: Softmax function to normalize values
static std::vector<double> softmax(const std::vector<double>& values, double temperature) {
std::vector<double> exps;
double sum = 0.0;
Expand All @@ -299,6 +327,7 @@ class MCTS {
return exps;
}

// Helper: Randomly choose an action based on probabilities
static int random_choice(const std::vector<int>& actions, const std::vector<double>& probs) {
std::random_device rd;
std::mt19937 gen(rd());
Expand All @@ -307,8 +336,9 @@ class MCTS {
}
};

// 绑定 Node MCTS 到同一个 pybind11 模块
// Bind Node and MCTS to the same pybind11 module
PYBIND11_MODULE(mcts_alphazero, m) {
// Bind the Node class
py::class_<Node, std::shared_ptr<Node>>(m, "Node")
.def(py::init<std::shared_ptr<Node>, float>(),
py::arg("parent")=nullptr, py::arg("prior_p")=1.0)
Expand All @@ -317,13 +347,13 @@ PYBIND11_MODULE(mcts_alphazero, m) {
.def("update_recursive", &Node::update_recursive)
.def("is_leaf", &Node::is_leaf)
.def("is_root", &Node::is_root)
// 绑定 parent 和 children 为只读属性
.def_property_readonly("parent", &Node::get_parent)
.def_property_readonly("children", &Node::get_children)
.def("add_child", &Node::add_child)
.def_property_readonly("visit_count", &Node::get_visit_count)
.def_readwrite("prior_p", &Node::prior_p);

// Bind the MCTS class
py::class_<MCTS>(m, "MCTS")
.def(py::init<int, int, double, double, double, double, py::object>(),
py::arg("max_moves")=512, py::arg("num_simulations")=800,
Expand All @@ -335,14 +365,12 @@ PYBIND11_MODULE(mcts_alphazero, m) {
.def("_expand_leaf_node", &MCTS::_expand_leaf_node)
.def("get_next_action", &MCTS::get_next_action)
.def("_simulate", &MCTS::_simulate)
// 绑定 simulate_env 通过 getter 和 setter
.def_property("simulate_env", &MCTS::get_simulate_env, &MCTS::set_simulate_env)
// 绑定 pb_c_base 和 pb_c_init 为只读属性
.def_property_readonly("pb_c_base", &MCTS::get_pb_c_base)
.def_property_readonly("pb_c_init", &MCTS::get_pb_c_init)
.def("get_next_action", &MCTS::get_next_action,
py::arg("state_config_for_env_reset"),
py::arg("policy_value_func"),
py::arg("temperature"),
py::arg("sample"));
.def("get_next_action", &MCTS::get_next_action,
py::arg("state_config_for_env_reset"),
py::arg("policy_value_func"),
py::arg("temperature"),
py::arg("sample"));
}
3 changes: 0 additions & 3 deletions lzero/mcts/ctree/ctree_alphazero/node_alphazero.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// node_alphazero.cpp

#include "node_alphazero.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand All @@ -24,7 +22,6 @@ PYBIND11_MODULE(node_alphazero, m) {
})
.def("add_child", &Node::add_child)
.def_property_readonly("visit_count", &Node::get_visit_count)
// 添加 prior_p 属性的绑定
.def_readwrite("prior_p", &Node::prior_p)
.def("get_child", [](const Node &self, int action) -> std::shared_ptr<Node> {
auto it = self.children.find(action);
Expand Down
Loading

0 comments on commit 5dbe736

Please sign in to comment.