mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 02:47:45 +00:00
Clean up previous patch
This commit is contained in:
+10
-6
@@ -53,16 +53,20 @@ void dbg_hit_on(bool c, bool b);
|
|||||||
void dbg_mean_of(int v);
|
void dbg_mean_of(int v);
|
||||||
void dbg_print();
|
void dbg_print();
|
||||||
|
|
||||||
|
/// Debug macro to write to std::err if NDEBUG flag is set, and do nothing otherwise
|
||||||
#if defined(NDEBUG)
|
#if defined(NDEBUG)
|
||||||
template <typename... Ts>
|
#define debug 1 && std::cerr
|
||||||
void debug_print(const Ts&...) {}
|
|
||||||
#else
|
#else
|
||||||
template <typename... Ts>
|
#define debug 0 && std::cerr
|
||||||
void debug_print(const Ts&... v) {
|
|
||||||
((std::cerr << v), ...);
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
inline void hit_any_key() {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
debug << "Hit any key to continue..." << std::endl << std::flush;
|
||||||
|
system("read"); // on Windows, should be system("pause");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
typedef std::chrono::milliseconds::rep TimePoint; // A value in milliseconds
|
typedef std::chrono::milliseconds::rep TimePoint; // A value in milliseconds
|
||||||
static_assert(sizeof(TimePoint) == sizeof(int64_t), "TimePoint should be 64 bits");
|
static_assert(sizeof(TimePoint) == sizeof(int64_t), "TimePoint should be 64 bits");
|
||||||
inline TimePoint now() {
|
inline TimePoint now() {
|
||||||
|
|||||||
+305
-163
@@ -47,6 +47,7 @@ namespace Search {
|
|||||||
using std::string;
|
using std::string;
|
||||||
using Eval::evaluate;
|
using Eval::evaluate;
|
||||||
using namespace Search;
|
using namespace Search;
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
bool Search::prune_at_shallow_depth = true;
|
bool Search::prune_at_shallow_depth = true;
|
||||||
|
|
||||||
@@ -1970,7 +1971,10 @@ namespace Search
|
|||||||
// Zero initialization of the number of search nodes
|
// Zero initialization of the number of search nodes
|
||||||
th->nodes = 0;
|
th->nodes = 0;
|
||||||
|
|
||||||
// Clear all history types. This initialization takes a little time, and the accuracy of the search is rather low, so the good and bad are not well understood.
|
// Clear all history types. This initialization takes a little time, and
|
||||||
|
// the accuracy of the search is rather low, so the good and bad are
|
||||||
|
// not well understood.
|
||||||
|
|
||||||
// th->clear();
|
// th->clear();
|
||||||
|
|
||||||
int ct = int(Options["Contempt"]) * PawnValueEg / 100; // From centipawns
|
int ct = int(Options["Contempt"]) * PawnValueEg / 100; // From centipawns
|
||||||
@@ -2200,12 +2204,14 @@ namespace Search
|
|||||||
return ValueAndPV(bestValue, pvs);
|
return ValueAndPV(bestValue, pvs);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace MCTS {
|
|
||||||
/*
|
|
||||||
The implementation of the MCTS is heavily based on Stephane Nicolet's work here
|
// This implementation of the MCTS is heavily based on Stephane Nicolet's work here
|
||||||
https://github.com/snicolet/Stockfish/commit/28501872a1e7ce84dd1f38ab9e59c5adb0d24b41
|
// https://github.com/snicolet/Stockfish/commit/28501872a1e7ce84dd1f38ab9e59c5adb0d24b41
|
||||||
and the adjusted implementation of it in ShashChess https://github.com/amchess/ShashChess
|
// and the adjusted implementation of it in ShashChess https://github.com/amchess/ShashChess
|
||||||
*/
|
|
||||||
|
namespace MCTS
|
||||||
|
{
|
||||||
static constexpr float sigmoidScale = 600.0f;
|
static constexpr float sigmoidScale = 600.0f;
|
||||||
|
|
||||||
static inline float fast_sigmoid(float x) {
|
static inline float fast_sigmoid(float x) {
|
||||||
@@ -2231,24 +2237,31 @@ namespace Search
|
|||||||
return fast_sigmoid(static_cast<float>(v) * (1.0f / sigmoidScale));
|
return fast_sigmoid(static_cast<float>(v) * (1.0f / sigmoidScale));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MctsNode {
|
// struct MCTSNode : store info at one node of the MCTS algorithm
|
||||||
Key posKey = 0; // for consistency checks.
|
|
||||||
MctsNode* parent = nullptr; // only nullptr for the root node
|
struct MCTSNode {
|
||||||
std::unique_ptr<MctsNode[]> children = nullptr; // only nullptr for nodes that have not been expanded
|
|
||||||
std::uint64_t numVisits = 0; // the number of playouts for this node and all descendants
|
Key posKey = 0; // for consistency checks
|
||||||
|
MCTSNode* parent = nullptr; // only nullptr for the root node
|
||||||
|
unique_ptr<MCTSNode[]> children = nullptr; // only nullptr for nodes that have not been expanded
|
||||||
|
uint64_t numVisits = 0; // the number of playouts for this node and all descendants
|
||||||
Value leafSearchEval = VALUE_NONE; // the evaluation from AB playout
|
Value leafSearchEval = VALUE_NONE; // the evaluation from AB playout
|
||||||
float prior = 0.0f; // the policy, currently a rough estimation based on the playout of the parent
|
float prior = 0.0f; // the policy, currently a rough estimation based on the playout of the parent
|
||||||
float actionValue = 0.0f; // the accumulated rewards
|
float actionValue = 0.0f; // the accumulated rewards
|
||||||
float actionValueWeight = 0.0f; // the maximum value for the accumulater rewards
|
float actionValueWeight = 0.0f; // the maximum value for the accumulater rewards
|
||||||
Move prevMove = MOVE_NONE; // the move on the edge from the parent
|
Move prevMove = MOVE_NONE; // the move on the edge from the parent
|
||||||
std::uint16_t numChildren = 0; // the number of legal moves, filled on expansion
|
int numChildren = 0; // the number of legal moves, filled on expansion
|
||||||
std::uint16_t childId = 0; // the index of this node in the parent's children array
|
int childId = 0; // the index of this node in the parent's children array
|
||||||
Depth leafSearchDepth = DEPTH_NONE; // the depth with which the AB playout was done
|
Depth leafSearchDepth = DEPTH_NONE; // the depth with which the AB playout was done
|
||||||
bool isTerminal = false; // whether the node is terminal. Terminal nodes are always "expanded" immediately.
|
bool isTerminal = false; // whether the node is terminal. Terminal nodes are always "expanded" immediately.
|
||||||
|
|
||||||
// The upper confidence bound. When searching for the node to expand/playout we take
|
|
||||||
// one with the highest ucb.
|
// ucb_value() calculates the upper confidence bound of a child.
|
||||||
float ucb_value(MctsNode& child, float explorationFactor, bool flipPerspective = false) const {
|
// When searching for the node to expand/playout we take one with the highest ucb.
|
||||||
|
|
||||||
|
float ucb_value(MCTSNode& child, float explorationFactor, bool flipPerspective = false) const
|
||||||
|
{
|
||||||
|
|
||||||
assert(explorationFactor >= 0.0f);
|
assert(explorationFactor >= 0.0f);
|
||||||
assert(child.actionValue >= 0.0f);
|
assert(child.actionValue >= 0.0f);
|
||||||
assert(child.actionValueWeight >= 0.0f);
|
assert(child.actionValueWeight >= 0.0f);
|
||||||
@@ -2258,9 +2271,7 @@ namespace Search
|
|||||||
|
|
||||||
// For the nodes which have not been played-out we use the prior.
|
// For the nodes which have not been played-out we use the prior.
|
||||||
// Otherwise we have some averaged score or the eval already.
|
// Otherwise we have some averaged score or the eval already.
|
||||||
float reward =
|
float reward = child.numVisits == 0 ? child.prior
|
||||||
child.numVisits == 0
|
|
||||||
? child.prior
|
|
||||||
: child.actionValue / child.actionValueWeight;
|
: child.actionValue / child.actionValueWeight;
|
||||||
|
|
||||||
if (flipPerspective)
|
if (flipPerspective)
|
||||||
@@ -2269,11 +2280,11 @@ namespace Search
|
|||||||
// The exploration factor.
|
// The exploration factor.
|
||||||
// In theory unplayed nodes should have priority, but we
|
// In theory unplayed nodes should have priority, but we
|
||||||
// add 1 to avoid div by 0 so they might not always be prioritized.
|
// add 1 to avoid div by 0 so they might not always be prioritized.
|
||||||
|
|
||||||
if (explorationFactor != 0.0f)
|
if (explorationFactor != 0.0f)
|
||||||
reward +=
|
reward +=
|
||||||
explorationFactor
|
explorationFactor
|
||||||
* std::sqrt(std::log(static_cast<float>(1 + numVisits))
|
* std::sqrt(std::log(1.0 + numVisits) / (1.0 + child.numVisits));
|
||||||
/ static_cast<float>(1 + child.numVisits));
|
|
||||||
|
|
||||||
assert(!std::isnan(reward));
|
assert(!std::isnan(reward));
|
||||||
assert(reward >= 0.0f);
|
assert(reward >= 0.0f);
|
||||||
@@ -2281,24 +2292,32 @@ namespace Search
|
|||||||
return reward;
|
return reward;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the reference to the best child node.
|
|
||||||
const MctsNode& get_best_child(float explorationFactor) const {
|
// get_best_child() returns a const reference to the best child node,
|
||||||
|
// according to the UCB value.
|
||||||
|
|
||||||
|
const MCTSNode& get_best_child(float explorationFactor) const
|
||||||
|
{
|
||||||
|
|
||||||
assert(!is_leaf());
|
assert(!is_leaf());
|
||||||
assert(numChildren > 0);
|
assert(numChildren > 0);
|
||||||
|
|
||||||
if (numChildren == 1) {
|
if (numChildren == 1)
|
||||||
|
{
|
||||||
assert(children[0].childId == 0);
|
assert(children[0].childId == 0);
|
||||||
return children[0];
|
return children[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
int bestIdx = -1;
|
int bestIdx = -1;
|
||||||
float bestValue = std::numeric_limits<float>::lowest();
|
float bestValue = std::numeric_limits<float>::lowest();
|
||||||
for (int i = 0; i < static_cast<int>(numChildren); ++i) {
|
for (int i = 0 ; i < numChildren ; ++i)
|
||||||
MctsNode& child = children[i];
|
{
|
||||||
|
MCTSNode& child = children[i];
|
||||||
// The "best" is the one with the best UCB.
|
// The "best" is the one with the best UCB.
|
||||||
// Child values are with opposite signs.
|
// Child values are with opposite signs.
|
||||||
const float r = ucb_value(child, explorationFactor, true);
|
const float r = ucb_value(child, explorationFactor, true);
|
||||||
if (r > bestValue) {
|
if (r > bestValue)
|
||||||
|
{
|
||||||
bestIdx = i;
|
bestIdx = i;
|
||||||
bestValue = r;
|
bestValue = r;
|
||||||
}
|
}
|
||||||
@@ -2311,22 +2330,32 @@ namespace Search
|
|||||||
return children[bestIdx];
|
return children[bestIdx];
|
||||||
}
|
}
|
||||||
|
|
||||||
MctsNode& get_best_child(float explorationFactor) {
|
|
||||||
return const_cast<MctsNode&>(static_cast<const MctsNode*>(this)->get_best_child(explorationFactor));
|
// get_best_child() : like the previous one, but does not return a const reference
|
||||||
|
|
||||||
|
MCTSNode& get_best_child(float explorationFactor) {
|
||||||
|
return const_cast<MCTSNode&>(static_cast<const MCTSNode*>(this)->get_best_child(explorationFactor));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// get_best_move() returns a pair (move,value) leading to the best child,
|
||||||
|
// according to the action value heuristic.
|
||||||
|
|
||||||
std::pair<Move, Value> get_best_move() const {
|
std::pair<Move, Value> get_best_move() const {
|
||||||
|
|
||||||
assert(!is_leaf());
|
assert(!is_leaf());
|
||||||
assert(numChildren > 0);
|
assert(numChildren > 0);
|
||||||
|
|
||||||
int bestIdx = -1;
|
int bestIdx = -1;
|
||||||
float bestValue = std::numeric_limits<float>::lowest();
|
float bestValue = std::numeric_limits<float>::lowest();
|
||||||
for (int i = 0; i < static_cast<int>(numChildren); ++i) {
|
for (int i = 0; i < numChildren ; ++i)
|
||||||
MctsNode& child = children[i];
|
{
|
||||||
// The "best" is the one with the best UCB.
|
MCTSNode& child = children[i];
|
||||||
|
// The "best" is the one with the best action value.
|
||||||
// Child values are with opposite signs.
|
// Child values are with opposite signs.
|
||||||
const float r = 1.0f - (child.actionValue / child.actionValueWeight);
|
const float r = 1.0f - (child.actionValue / child.actionValueWeight);
|
||||||
if (r > bestValue) {
|
if (r > bestValue)
|
||||||
|
{
|
||||||
bestIdx = i;
|
bestIdx = i;
|
||||||
bestValue = r;
|
bestValue = r;
|
||||||
}
|
}
|
||||||
@@ -2339,9 +2368,13 @@ namespace Search
|
|||||||
return { children[bestIdx].prevMove, reward_to_value(bestValue) };
|
return { children[bestIdx].prevMove, reward_to_value(bestValue) };
|
||||||
}
|
}
|
||||||
|
|
||||||
const MctsNode* get_child_by_move(Move move) const {
|
|
||||||
for (int i = 0; i < static_cast<int>(numChildren); ++i) {
|
// get_child_by_move() finds a child, given the move that leads to it
|
||||||
MctsNode& child = children[i];
|
|
||||||
|
const MCTSNode* get_child_by_move(Move move) const {
|
||||||
|
for (int i = 0; i < numChildren ; ++i)
|
||||||
|
{
|
||||||
|
MCTSNode& child = children[i];
|
||||||
if (child.prevMove == move)
|
if (child.prevMove == move)
|
||||||
return &child;
|
return &child;
|
||||||
}
|
}
|
||||||
@@ -2349,22 +2382,34 @@ namespace Search
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
MctsNode* get_child_by_move(Move move) {
|
|
||||||
return const_cast<MctsNode*>(static_cast<const MctsNode*>(this)->get_child_by_move(move));
|
// get_child_by_move() : like the previous one, but does not return a const
|
||||||
|
|
||||||
|
MCTSNode* get_child_by_move(Move move) {
|
||||||
|
return const_cast<MCTSNode*>(static_cast<const MCTSNode*>(this)->get_child_by_move(move));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// is_root() returns true when node is the root
|
||||||
|
|
||||||
bool is_root() const {
|
bool is_root() const {
|
||||||
return parent == nullptr;
|
return parent == nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// is_leaf() returns true when node is a leaf
|
||||||
|
|
||||||
bool is_leaf() const {
|
bool is_leaf() const {
|
||||||
return children == nullptr;
|
return children == nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// The stuff that needs to be backpropagated down the tree.
|
|
||||||
|
// struct BackpropValues is a structure to manipulate the kind of stuff
|
||||||
|
// that needs to be back-propagated down and up the tree by MCTS.
|
||||||
|
|
||||||
struct BackpropValues {
|
struct BackpropValues {
|
||||||
std::uint64_t numVisits = 0;
|
|
||||||
|
uint64_t numVisits = 0;
|
||||||
float actionValue = 0.0f;
|
float actionValue = 0.0f;
|
||||||
float actionValueWeight = 0.0f;
|
float actionValueWeight = 0.0f;
|
||||||
|
|
||||||
@@ -2379,7 +2424,18 @@ namespace Search
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// struct MonteCarloTreeSearch implements the methods for the MCTS algorithm
|
||||||
|
|
||||||
struct MonteCarloTreeSearch {
|
struct MonteCarloTreeSearch {
|
||||||
|
|
||||||
|
// IMPORTANT:
|
||||||
|
// The position is stateful so we always have one.
|
||||||
|
// It has to match certain expectations in different functions.
|
||||||
|
// For example when looking for the node to expand the pos must correspond
|
||||||
|
// to the root mcts node. When expanding the node it must correspond to the
|
||||||
|
// node being expanded, etc.
|
||||||
|
|
||||||
static constexpr Depth terminalEvalDepth = Depth(255);
|
static constexpr Depth terminalEvalDepth = Depth(255);
|
||||||
|
|
||||||
// We add a lot of stuff to the actionValue, but the weights differ.
|
// We add a lot of stuff to the actionValue, but the weights differ.
|
||||||
@@ -2393,9 +2449,11 @@ namespace Search
|
|||||||
static_assert(normalWeight > 0.0f);
|
static_assert(normalWeight > 0.0f);
|
||||||
|
|
||||||
MonteCarloTreeSearch() {}
|
MonteCarloTreeSearch() {}
|
||||||
|
|
||||||
MonteCarloTreeSearch(const MonteCarloTreeSearch&) = delete;
|
MonteCarloTreeSearch(const MonteCarloTreeSearch&) = delete;
|
||||||
|
|
||||||
|
|
||||||
|
// search_new() : let's start the search !
|
||||||
|
|
||||||
ValueAndPV search_new(
|
ValueAndPV search_new(
|
||||||
Position& pos,
|
Position& pos,
|
||||||
std::uint64_t maxPlayouts,
|
std::uint64_t maxPlayouts,
|
||||||
@@ -2406,37 +2464,42 @@ namespace Search
|
|||||||
return search_continue(pos, maxPlayouts, leafDepth, explorationFactor);
|
return search_continue(pos, maxPlayouts, leafDepth, explorationFactor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Continue after a move and reuse the relevant part of the tree.
|
|
||||||
// The prevMove is the move that lead to pos
|
// search_continue_after_move() : continue after a move and reuse the relevant
|
||||||
// TODO: Make the node limit be the total.
|
// part of the tree. The prevMove is the move that lead to position 'pos'.
|
||||||
ValueAndPV search_continue_after_move(
|
//
|
||||||
Position& pos,
|
// TODO: make the node limit be the total.
|
||||||
|
|
||||||
|
ValueAndPV search_continue_after_move( Position& pos,
|
||||||
Move prevMove,
|
Move prevMove,
|
||||||
std::uint64_t maxPlayouts,
|
std::uint64_t maxPlayouts,
|
||||||
Depth leafDepth,
|
Depth leafDepth,
|
||||||
float explorationFactor = 0.25f) {
|
float explorationFactor = 0.25f) {
|
||||||
|
|
||||||
do_move_at_root(pos, prevMove);
|
do_move_at_root(pos, prevMove);
|
||||||
return search_continue(pos, maxPlayouts, leafDepth, explorationFactor);
|
return search_continue(pos, maxPlayouts, leafDepth, explorationFactor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// get_all_continuations() is missing description
|
||||||
|
|
||||||
std::vector<MctsContinuation> get_all_continuations() const {
|
std::vector<MctsContinuation> get_all_continuations() const {
|
||||||
|
|
||||||
std::vector<MctsContinuation> continuations;
|
std::vector<MctsContinuation> continuations;
|
||||||
continuations.resize(rootNode.numChildren);
|
continuations.resize(rootNode.numChildren);
|
||||||
|
|
||||||
for (int i = 0; i < rootNode.numChildren; ++i) {
|
for (int i = 0; i < rootNode.numChildren; ++i)
|
||||||
MctsNode& child = rootNode.children[i];
|
{
|
||||||
|
MCTSNode& child = rootNode.children[i];
|
||||||
|
|
||||||
auto& cont = continuations[i];
|
auto& cont = continuations[i];
|
||||||
|
|
||||||
cont.numVisits = child.numVisits;
|
cont.numVisits = child.numVisits;
|
||||||
// child value is with opposite sign
|
|
||||||
cont.actionValue = 1.0f - (child.actionValue / child.actionValueWeight);
|
|
||||||
cont.value = reward_to_value(cont.actionValue);
|
cont.value = reward_to_value(cont.actionValue);
|
||||||
cont.pv = get_pv(child);
|
cont.pv = get_pv(child);
|
||||||
|
cont.actionValue = 1.0f - (child.actionValue / child.actionValueWeight); // child value is with opposite sign
|
||||||
}
|
}
|
||||||
|
|
||||||
std::stable_sort(
|
std::stable_sort( continuations.begin(),
|
||||||
continuations.begin(),
|
|
||||||
continuations.end(),
|
continuations.end(),
|
||||||
[](const auto& lhs, const auto& rhs) { return lhs.value > rhs.value; }
|
[](const auto& lhs, const auto& rhs) { return lhs.value > rhs.value; }
|
||||||
);
|
);
|
||||||
@@ -2444,9 +2507,10 @@ namespace Search
|
|||||||
return continuations;
|
return continuations;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Continues with the same tree.
|
|
||||||
ValueAndPV search_continue(
|
// search_continue() : continues with the same tree
|
||||||
Position& pos,
|
|
||||||
|
ValueAndPV search_continue( Position& pos,
|
||||||
std::uint64_t maxPlayouts,
|
std::uint64_t maxPlayouts,
|
||||||
Depth leafDepth,
|
Depth leafDepth,
|
||||||
float explorationFactor = 0.25f) {
|
float explorationFactor = 0.25f) {
|
||||||
@@ -2454,17 +2518,18 @@ namespace Search
|
|||||||
if (rootNode.leafSearchDepth == DEPTH_NONE)
|
if (rootNode.leafSearchDepth == DEPTH_NONE)
|
||||||
do_playout(pos, rootNode, leafDepth);
|
do_playout(pos, rootNode, leafDepth);
|
||||||
|
|
||||||
while (numPlayouts < maxPlayouts) {
|
while (numPlayouts < maxPlayouts)
|
||||||
debug_print("Starting iteration ", numPlayouts, '\n');
|
{
|
||||||
|
debug << "Starting iteration " << numPlayouts << endl;
|
||||||
do_search_iteration(pos, leafDepth, explorationFactor);
|
do_search_iteration(pos, leafDepth, explorationFactor);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (rootNode.is_leaf())
|
if (rootNode.is_leaf())
|
||||||
return {};
|
return {};
|
||||||
else {
|
else
|
||||||
return { rootNode.get_best_move().second, get_pv() };
|
return { rootNode.get_best_move().second, get_pv() };
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Stack stackBuffer [MAX_PLY + 10];
|
Stack stackBuffer [MAX_PLY + 10];
|
||||||
StateInfo statesBuffer[MAX_PLY + 10];
|
StateInfo statesBuffer[MAX_PLY + 10];
|
||||||
@@ -2472,25 +2537,25 @@ namespace Search
|
|||||||
Stack* stack = stackBuffer + 7;
|
Stack* stack = stackBuffer + 7;
|
||||||
StateInfo* states = statesBuffer + 7;
|
StateInfo* states = statesBuffer + 7;
|
||||||
|
|
||||||
MctsNode rootNode;
|
MCTSNode rootNode;
|
||||||
|
|
||||||
int ply = 1;
|
int ply = 1;
|
||||||
int maximumPly = ply; // Effectively the selective depth.
|
int maximumPly = ply; // Effectively the selective depth.
|
||||||
std::uint64_t numPlayouts = 0;
|
std::uint64_t numPlayouts = 0;
|
||||||
|
|
||||||
private :
|
private :
|
||||||
// IMPORTANT:
|
|
||||||
// The position is stateful so we always have one.
|
// reset_stats(), recalculate_stats() and accumulate_stats_recursively()
|
||||||
// It has to match certain expectations in different functions.
|
// are used to recalculate the number of playouts in our MCTS tree. Note
|
||||||
// For example when looking for the node to expand the pos must correspond
|
// that at the moment we call recalculate_stats() each time we play a move
|
||||||
// to the root mcts node. When expanding the node it must correspond to the
|
// at root, to recalculate the stats in the subtree.
|
||||||
// node being expanded, etc.
|
|
||||||
|
|
||||||
void reset_stats() {
|
void reset_stats() {
|
||||||
numPlayouts = 0;
|
numPlayouts = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void accumulate_stats_recursively(MctsNode& node) {
|
void accumulate_stats_recursively(MCTSNode& node) {
|
||||||
|
|
||||||
if (node.leafSearchDepth != DEPTH_NONE)
|
if (node.leafSearchDepth != DEPTH_NONE)
|
||||||
numPlayouts += 1;
|
numPlayouts += 1;
|
||||||
|
|
||||||
@@ -2504,13 +2569,18 @@ namespace Search
|
|||||||
accumulate_stats_recursively(rootNode);
|
accumulate_stats_recursively(rootNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tree reuse.
|
|
||||||
|
// do_move_at_root() is missing description
|
||||||
|
// Tree reuse (?)
|
||||||
// pos is the position after move.
|
// pos is the position after move.
|
||||||
|
|
||||||
void do_move_at_root(Position& pos, Move move) {
|
void do_move_at_root(Position& pos, Move move) {
|
||||||
MctsNode* child = rootNode.get_child_by_move(move);
|
|
||||||
|
MCTSNode* child = rootNode.get_child_by_move(move);
|
||||||
if (child == nullptr)
|
if (child == nullptr)
|
||||||
create_new_root(pos);
|
create_new_root(pos);
|
||||||
else {
|
else
|
||||||
|
{
|
||||||
rootNode = std::move(*child);
|
rootNode = std::move(*child);
|
||||||
rootNode.parent = nullptr;
|
rootNode.parent = nullptr;
|
||||||
rootNode.childId = 0;
|
rootNode.childId = 0;
|
||||||
@@ -2522,25 +2592,32 @@ namespace Search
|
|||||||
assert(rootNode.posKey == pos.key());
|
assert(rootNode.posKey == pos.key());
|
||||||
}
|
}
|
||||||
|
|
||||||
// One iteration of the search. Basically:
|
|
||||||
|
// do_search_iteration() does one iteration of the search.
|
||||||
|
//
|
||||||
|
// Basically:
|
||||||
// 1. find a node to expand/playout
|
// 1. find a node to expand/playout
|
||||||
// 2. if the node is a terminal then we just get the stuff and backprop
|
// 2. if the node is a terminal then we just get the stuff and backprop
|
||||||
// 3. if we only have prior for the node then do a playout
|
// 3. if we only have prior for the node then do a playout
|
||||||
// 4. otherwise we expand the children and do at least one playout from the best child (chosen by prior)
|
// 4. otherwise we expand the children and do at least one playout from the best child (chosen by prior)
|
||||||
// 4.1. a terminal node counts as a playout. All terminal nodes are played out.
|
// 4.1. a terminal node counts as a playout. All terminal nodes are played out.
|
||||||
// 5. Backpropagate all changes down the tree.
|
// 5. Backpropagate all changes down the tree.
|
||||||
|
|
||||||
void do_search_iteration(Position& pos, Depth leafDepth, float explorationFactor) {
|
void do_search_iteration(Position& pos, Depth leafDepth, float explorationFactor) {
|
||||||
MctsNode& node = find_node_to_expand_or_playout(pos, explorationFactor);
|
|
||||||
|
MCTSNode& node = find_node_to_expand_or_playout(pos, explorationFactor);
|
||||||
BackpropValues backprops{};
|
BackpropValues backprops{};
|
||||||
if (node.isTerminal) {
|
if (node.isTerminal)
|
||||||
debug_print("Root is terminal\n");
|
{
|
||||||
|
debug << "Root is terminal" << endl;
|
||||||
backprops.numVisits = 1;
|
backprops.numVisits = 1;
|
||||||
backprops.actionValue += node.actionValue;
|
backprops.actionValue += node.actionValue;
|
||||||
backprops.actionValueWeight += node.actionValueWeight;
|
backprops.actionValueWeight += node.actionValueWeight;
|
||||||
|
|
||||||
numPlayouts += 1;
|
numPlayouts += 1;
|
||||||
}
|
}
|
||||||
else if (node.leafSearchDepth == DEPTH_NONE) {
|
else if (node.leafSearchDepth == DEPTH_NONE)
|
||||||
|
{
|
||||||
// The node is considered the best but it only has a prior value.
|
// The node is considered the best but it only has a prior value.
|
||||||
// We don't really want to expand nodes based just on the prior, so
|
// We don't really want to expand nodes based just on the prior, so
|
||||||
// first do a playout to get a better estimate, and expand only in the
|
// first do a playout to get a better estimate, and expand only in the
|
||||||
@@ -2549,7 +2626,8 @@ namespace Search
|
|||||||
// playout can put it below another move.
|
// playout can put it below another move.
|
||||||
backprops = do_playout(pos, node, leafDepth);
|
backprops = do_playout(pos, node, leafDepth);
|
||||||
}
|
}
|
||||||
else {
|
else
|
||||||
|
{
|
||||||
// We have done leaf evaluation with AB search so we know that
|
// We have done leaf evaluation with AB search so we know that
|
||||||
// this node is *actually good* and not just *prior good*, so we
|
// this node is *actually good* and not just *prior good*, so we
|
||||||
// can now expand it and do an immediate playout for the node with the best prior.
|
// can now expand it and do an immediate playout for the node with the best prior.
|
||||||
@@ -2559,22 +2637,28 @@ namespace Search
|
|||||||
backpropagate(pos, node, backprops);
|
backpropagate(pos, node, backprops);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Backpropagates the changes after an expand/playout all the way to the root.
|
|
||||||
// pos is expected to be at the node from which we start backpropagating.
|
// Backpropagates() is the function we use to back-propagate the changes
|
||||||
void backpropagate(Position& pos, MctsNode& node, BackpropValues backprops) {
|
// after an expand/playout, all the way to the root. The position 'pos'
|
||||||
|
// is expected to be at the node from which we start backpropagating.
|
||||||
|
|
||||||
|
void backpropagate(Position& pos, MCTSNode& node, BackpropValues backprops) {
|
||||||
|
|
||||||
assert(node.posKey == pos.key());
|
assert(node.posKey == pos.key());
|
||||||
assert(ply >= 1);
|
assert(ply >= 1);
|
||||||
|
|
||||||
debug_print("Backpropagating: ", pos.fen(), '\n');
|
debug << "Backpropagating: " << pos.fen() << endl;
|
||||||
|
|
||||||
|
MCTSNode* currentNode = &node;
|
||||||
|
while (!currentNode->is_root())
|
||||||
|
{
|
||||||
|
// On each descent we switch the side to move
|
||||||
|
|
||||||
MctsNode* currentNode = &node;
|
|
||||||
while (!currentNode->is_root()) {
|
|
||||||
// On each descent we switch the side to move.
|
|
||||||
undo_move(pos);
|
undo_move(pos);
|
||||||
currentNode = currentNode->parent;
|
currentNode = currentNode->parent;
|
||||||
backprops.flip_side();
|
backprops.flip_side();
|
||||||
|
|
||||||
debug_print("Backprop step: ", pos.fen(), '\n');
|
debug << "Backprop step: " << pos.fen() << endl;
|
||||||
assert(currentNode->posKey == pos.key());
|
assert(currentNode->posKey == pos.key());
|
||||||
|
|
||||||
currentNode->numVisits += backprops.numVisits;
|
currentNode->numVisits += backprops.numVisits;
|
||||||
@@ -2582,19 +2666,24 @@ namespace Search
|
|||||||
currentNode->actionValueWeight += backprops.actionValueWeight;
|
currentNode->actionValueWeight += backprops.actionValueWeight;
|
||||||
}
|
}
|
||||||
|
|
||||||
// At the end we must be at the root.
|
// At the end we must be at the root
|
||||||
|
|
||||||
assert(currentNode == &rootNode);
|
assert(currentNode == &rootNode);
|
||||||
assert(rootNode.posKey == pos.key());
|
assert(rootNode.posKey == pos.key());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Navigate with pos to the node expand/playout
|
|
||||||
MctsNode& find_node_to_expand_or_playout(Position& pos, float explorationFactor) {
|
// find_node_to_expand_or_playout() navigates from pos to the node to expand/playout,
|
||||||
|
// according to the get_best_child() heuristics.
|
||||||
|
|
||||||
|
MCTSNode& find_node_to_expand_or_playout(Position& pos, float explorationFactor) {
|
||||||
assert(rootNode.posKey == pos.key());
|
assert(rootNode.posKey == pos.key());
|
||||||
|
|
||||||
// Find a node that has not yet been expanded
|
// Find a node that has not yet been expanded
|
||||||
MctsNode* currentNode = &rootNode;
|
MCTSNode* currentNode = &rootNode;
|
||||||
while (!currentNode->is_leaf()) {
|
while (!currentNode->is_leaf())
|
||||||
MctsNode& bestChild = currentNode->get_best_child(explorationFactor);
|
{
|
||||||
|
MCTSNode& bestChild = currentNode->get_best_child(explorationFactor);
|
||||||
|
|
||||||
do_move(pos, *currentNode, bestChild);
|
do_move(pos, *currentNode, bestChild);
|
||||||
|
|
||||||
@@ -2604,6 +2693,9 @@ namespace Search
|
|||||||
return *currentNode;
|
return *currentNode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// generate_moves_unordered() generates moves in a random order
|
||||||
|
|
||||||
int generate_moves_unordered(Position& pos, Move* out) const {
|
int generate_moves_unordered(Position& pos, Move* out) const {
|
||||||
int moveCount = 0;
|
int moveCount = 0;
|
||||||
for (auto move : MoveList<LEGAL>(pos))
|
for (auto move : MoveList<LEGAL>(pos))
|
||||||
@@ -2612,11 +2704,14 @@ namespace Search
|
|||||||
return moveCount;
|
return moveCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate moves with some reasonable ordering. Using this we can assume some reasonable priors.
|
|
||||||
int generate_moves_ordered(Position& pos, MctsNode& node, Depth leafDepth, Move* out) const {
|
// generate_moves_ordered() generates moves with some reasonable ordering.
|
||||||
|
// Using this function, we can assume some reasonable priors.
|
||||||
|
|
||||||
|
int generate_moves_ordered(Position& pos, MCTSNode& node, Depth leafDepth, Move* out) const {
|
||||||
assert(ply >= 1);
|
assert(ply >= 1);
|
||||||
|
|
||||||
debug_print("Generating moves: ", pos.fen(), '\n');
|
debug << "Generating moves: " << pos.fen() << endl;
|
||||||
|
|
||||||
Thread* const thread = pos.this_thread();
|
Thread* const thread = pos.this_thread();
|
||||||
const Square prevSq = to_sq(node.prevMove);
|
const Square prevSq = to_sq(node.prevMove);
|
||||||
@@ -2650,9 +2745,10 @@ namespace Search
|
|||||||
);
|
);
|
||||||
|
|
||||||
int moveCount = 0;
|
int moveCount = 0;
|
||||||
for (;;) {
|
while (true)
|
||||||
|
{
|
||||||
const Move move = mp.next_move();
|
const Move move = mp.next_move();
|
||||||
debug_print("Generated move ", UCI::move(move, false), ": ", pos.fen(), '\n');
|
debug << "Generated move " << UCI::move(move, false) << ": " << pos.fen() << endl;
|
||||||
|
|
||||||
if (move == MOVE_NONE)
|
if (move == MOVE_NONE)
|
||||||
break;
|
break;
|
||||||
@@ -2661,12 +2757,19 @@ namespace Search
|
|||||||
out[moveCount++] = move;
|
out[moveCount++] = move;
|
||||||
}
|
}
|
||||||
|
|
||||||
debug_print("Generated ", moveCount, " legal moves: ", pos.fen(), '\n');
|
debug << "Generated " << moveCount << " legal moves: " << pos.fen() << endl;
|
||||||
|
|
||||||
return moveCount;
|
return moveCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// init_for_leaf_search() prepares some global variables in the thread of the
|
||||||
|
// given position, for compatibility with the normal AB search of Stockfish.
|
||||||
|
// This allows us to use that AB search to get an estimated value of the leaf,
|
||||||
|
// if necessary.
|
||||||
|
|
||||||
void init_for_leaf_search(Position& pos) {
|
void init_for_leaf_search(Position& pos) {
|
||||||
|
|
||||||
auto th = pos.this_thread();
|
auto th = pos.this_thread();
|
||||||
|
|
||||||
th->completedDepth = 0;
|
th->completedDepth = 0;
|
||||||
@@ -2677,9 +2780,12 @@ namespace Search
|
|||||||
th->nodes = 0;
|
th->nodes = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks whether the position is terminal and return the right value if it is.
|
|
||||||
// Otherwise returns VALUE_NONE
|
// terminal_value() checks whether the position is terminal. We return
|
||||||
|
// the right value if position is terminal, otherwise we return VALUE_NONE.
|
||||||
|
|
||||||
Value terminal_value(Position& pos) const {
|
Value terminal_value(Position& pos) const {
|
||||||
|
|
||||||
if (MoveList<LEGAL>(pos).size() == 0)
|
if (MoveList<LEGAL>(pos).size() == 0)
|
||||||
return pos.checkers() ? VALUE_MATE : -VALUE_MATE;;
|
return pos.checkers() ? VALUE_MATE : -VALUE_MATE;;
|
||||||
|
|
||||||
@@ -2689,13 +2795,16 @@ namespace Search
|
|||||||
return VALUE_NONE;
|
return VALUE_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Does AB search on the position to get the value.
|
|
||||||
Value evaluate_leaf(Position& pos, MctsNode& node, Depth leafDepth) {
|
// evaluate_leaf() does AB search on the position to get its value
|
||||||
|
|
||||||
|
Value evaluate_leaf(Position& pos, MCTSNode& node, Depth leafDepth) {
|
||||||
|
|
||||||
assert(node.posKey == pos.key());
|
assert(node.posKey == pos.key());
|
||||||
assert(node.leafSearchDepth == DEPTH_NONE);
|
assert(node.leafSearchDepth == DEPTH_NONE);
|
||||||
assert(node.leafSearchEval == VALUE_NONE);
|
assert(node.leafSearchEval == VALUE_NONE);
|
||||||
|
|
||||||
debug_print("Evaluating leaf: ", pos.fen(), '\n');
|
debug << "Evaluating leaf: " << pos.fen() << endl;
|
||||||
|
|
||||||
init_for_leaf_search(pos);
|
init_for_leaf_search(pos);
|
||||||
|
|
||||||
@@ -2704,19 +2813,23 @@ namespace Search
|
|||||||
stack[ply].currentMove = MOVE_NONE;
|
stack[ply].currentMove = MOVE_NONE;
|
||||||
stack[ply].excludedMove = MOVE_NONE;
|
stack[ply].excludedMove = MOVE_NONE;
|
||||||
|
|
||||||
if (!node.is_root() && node.parent->leafSearchEval != VALUE_NONE) {
|
if (!node.is_root() && node.parent->leafSearchEval != VALUE_NONE)
|
||||||
|
{
|
||||||
// If we have some parent score then use an aspiration window.
|
// If we have some parent score then use an aspiration window.
|
||||||
// We know what to expect.
|
// We know what to expect.
|
||||||
Value delta = Value(18);
|
Value delta = Value(18);
|
||||||
Value alpha = std::max(node.parent->leafSearchEval - delta, -VALUE_INFINITE);
|
Value alpha = std::max(node.parent->leafSearchEval - delta, -VALUE_INFINITE);
|
||||||
Value beta = std::min(node.parent->leafSearchEval + delta, VALUE_INFINITE);
|
Value beta = std::min(node.parent->leafSearchEval + delta, VALUE_INFINITE);
|
||||||
for (;;) {
|
while (true)
|
||||||
|
{
|
||||||
const Value value = Stockfish::search<PV>(pos, stack + ply, alpha, beta, leafDepth, false);
|
const Value value = Stockfish::search<PV>(pos, stack + ply, alpha, beta, leafDepth, false);
|
||||||
if (value <= alpha) {
|
if (value <= alpha)
|
||||||
|
{
|
||||||
beta = (alpha + beta) / 2;
|
beta = (alpha + beta) / 2;
|
||||||
alpha = std::max(value - delta, -VALUE_INFINITE);
|
alpha = std::max(value - delta, -VALUE_INFINITE);
|
||||||
}
|
}
|
||||||
else if (value >= beta)
|
else
|
||||||
|
if (value >= beta)
|
||||||
beta = std::min(value + delta, VALUE_INFINITE);
|
beta = std::min(value + delta, VALUE_INFINITE);
|
||||||
else
|
else
|
||||||
return value;
|
return value;
|
||||||
@@ -2724,22 +2837,26 @@ namespace Search
|
|||||||
delta += delta / 4 + 5;
|
delta += delta / 4 + 5;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
|
else
|
||||||
// If no parent score then do infinite aspiration window.
|
// If no parent score then do infinite aspiration window.
|
||||||
return Stockfish::search<PV>(pos, stack + ply, -VALUE_INFINITE, VALUE_INFINITE, leafDepth, false);
|
return Stockfish::search<PV>(pos, stack + ply, -VALUE_INFINITE, VALUE_INFINITE, leafDepth, false);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Move> get_pv(const MctsNode& node) const {
|
|
||||||
|
// get_pv(node) tries to get a pv, starting from the given node
|
||||||
|
|
||||||
|
std::vector<Move> get_pv(const MCTSNode& node) const {
|
||||||
std::vector<Move> pv;
|
std::vector<Move> pv;
|
||||||
|
|
||||||
const MctsNode* currentNode = &node;
|
const MCTSNode* currentNode = &node;
|
||||||
if (!currentNode->is_root())
|
if (!currentNode->is_root())
|
||||||
pv.emplace_back(currentNode->prevMove);
|
pv.emplace_back(currentNode->prevMove);
|
||||||
|
|
||||||
while (!currentNode->is_leaf()) {
|
while (!currentNode->is_leaf())
|
||||||
|
{
|
||||||
// No exploration factor for choosing the PV.
|
// No exploration factor for choosing the PV.
|
||||||
const MctsNode& bestChild = currentNode->get_best_child(0.0f);
|
const MCTSNode& bestChild = currentNode->get_best_child(0.0f);
|
||||||
pv.emplace_back(bestChild.prevMove);
|
pv.emplace_back(bestChild.prevMove);
|
||||||
currentNode = &bestChild;
|
currentNode = &bestChild;
|
||||||
}
|
}
|
||||||
@@ -2747,12 +2864,18 @@ namespace Search
|
|||||||
return pv;
|
return pv;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// get_pv() tries to get the pv, starting from the root
|
||||||
|
|
||||||
std::vector<Move> get_pv() const {
|
std::vector<Move> get_pv() const {
|
||||||
return get_pv(rootNode);
|
return get_pv(rootNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Does a single playout and returns what is needed to backprop.
|
|
||||||
BackpropValues do_playout(Position& pos, MctsNode& node, Depth leafDepth) {
|
// do_playout() does a single playout and returns what is needed to backprop
|
||||||
|
|
||||||
|
BackpropValues do_playout(Position& pos, MCTSNode& node, Depth leafDepth) {
|
||||||
|
|
||||||
assert(node.posKey == pos.key());
|
assert(node.posKey == pos.key());
|
||||||
assert(node.numVisits == 0);
|
assert(node.numVisits == 0);
|
||||||
assert(node.is_leaf());
|
assert(node.is_leaf());
|
||||||
@@ -2760,7 +2883,7 @@ namespace Search
|
|||||||
assert(node.leafSearchDepth == DEPTH_NONE);
|
assert(node.leafSearchDepth == DEPTH_NONE);
|
||||||
assert(!node.isTerminal);
|
assert(!node.isTerminal);
|
||||||
|
|
||||||
debug_print("Doing playout ", numPlayouts, ": ", pos.fen(), '\n');
|
debug << "Doing playout " << numPlayouts << ": " << pos.fen() << endl;
|
||||||
|
|
||||||
numPlayouts += 1;
|
numPlayouts += 1;
|
||||||
|
|
||||||
@@ -2771,12 +2894,12 @@ namespace Search
|
|||||||
backprops.actionValue += value_to_reward(v);
|
backprops.actionValue += value_to_reward(v);
|
||||||
backprops.actionValueWeight += normalWeight;
|
backprops.actionValueWeight += normalWeight;
|
||||||
|
|
||||||
// Bookkeep raw eval.
|
// Bookkeeping for raw eval
|
||||||
node.leafSearchEval = v;
|
node.leafSearchEval = v;
|
||||||
node.leafSearchDepth = leafDepth;
|
node.leafSearchDepth = leafDepth;
|
||||||
|
|
||||||
// Local backprop because normal backprop handles only the
|
// Local backprop because normal backprop handles only the
|
||||||
// nodes starting from the parent of this one
|
// nodes starting from the parent of this one.
|
||||||
node.numVisits = backprops.numVisits;
|
node.numVisits = backprops.numVisits;
|
||||||
node.actionValue += backprops.actionValue;
|
node.actionValue += backprops.actionValue;
|
||||||
node.actionValueWeight += backprops.actionValueWeight;
|
node.actionValueWeight += backprops.actionValueWeight;
|
||||||
@@ -2784,56 +2907,60 @@ namespace Search
|
|||||||
return backprops;
|
return backprops;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expand a node and do at least one playout.
|
|
||||||
|
// expand_node_and_do_playout() : expand a node and do at least one playout.
|
||||||
// May do more "playouts" if there are terminals as those are "played out" immediately.
|
// May do more "playouts" if there are terminals as those are "played out" immediately.
|
||||||
// Returns what needs to be backpropagated.
|
// Returns what needs to be backpropagated.
|
||||||
BackpropValues expand_node_and_do_playout(
|
|
||||||
Position& pos,
|
|
||||||
MctsNode& node,
|
|
||||||
Depth leafDepth,
|
|
||||||
float explorationFactor) {
|
|
||||||
|
|
||||||
|
BackpropValues expand_node_and_do_playout( Position& pos,
|
||||||
|
MCTSNode& node,
|
||||||
|
Depth leafDepth,
|
||||||
|
float explorationFactor)
|
||||||
|
{
|
||||||
assert(node.posKey == pos.key()); // node must match the position
|
assert(node.posKey == pos.key()); // node must match the position
|
||||||
assert(node.is_leaf()); // otherwise already expanded
|
assert(node.is_leaf()); // otherwise already expanded
|
||||||
assert(node.numChildren == 0);
|
assert(node.numChildren == 0); // leafs have no children
|
||||||
assert(!node.isTerminal); // terminals cannot be expanded
|
assert(!node.isTerminal); // terminals cannot be expanded
|
||||||
assert(node.numVisits == 1); // we expect it to have the "playout visit". Fake visit for the root.
|
assert(node.numVisits == 1); // we expect it to have the "playout visit". Fake visit for the root.
|
||||||
assert(node.leafSearchDepth != DEPTH_NONE);
|
assert(node.leafSearchDepth != DEPTH_NONE);
|
||||||
assert(node.leafSearchEval != VALUE_NONE);
|
assert(node.leafSearchEval != VALUE_NONE);
|
||||||
|
|
||||||
debug_print("Expanding and playing out: ", pos.fen(), '\n');
|
debug << "Expanding and playing out: " << pos.fen() << endl;
|
||||||
|
|
||||||
Move moves[MAX_MOVES];
|
Move moves[MAX_MOVES];
|
||||||
const int moveCount = generate_moves_ordered(pos, node, leafDepth, moves);
|
const int moveCount = generate_moves_ordered(pos, node, leafDepth, moves);
|
||||||
|
|
||||||
assert(moveCount > 0);
|
assert(moveCount > 0);
|
||||||
|
|
||||||
node.children = std::make_unique<MctsNode[]>(moveCount);
|
node.children = std::make_unique<MCTSNode[]>(moveCount);
|
||||||
node.numChildren = moveCount;
|
node.numChildren = moveCount;
|
||||||
|
|
||||||
int numTerminals = 0;
|
int numTerminals = 0;
|
||||||
BackpropValues backprops{};
|
BackpropValues backprops{};
|
||||||
|
|
||||||
float prior = value_to_reward(node.leafSearchEval);
|
float prior = value_to_reward(node.leafSearchEval);
|
||||||
// Prior is attenuated for later moves - we rely on move ordering.
|
|
||||||
// Attenuate more at higher plies where we have more move ordering.
|
// Note that prior is attenuated for later moves - we rely on move ordering.
|
||||||
|
// Attenuate more at higher plies, where we have better move ordering.
|
||||||
|
|
||||||
const float priorAttenuation = 1.0f - std::min((ply - 1) / 100.0f, 0.05f);
|
const float priorAttenuation = 1.0f - std::min((ply - 1) / 100.0f, 0.05f);
|
||||||
for (int i = 0; i < moveCount; ++i) {
|
for (int i = 0; i < moveCount; ++i)
|
||||||
// setup the child
|
{
|
||||||
MctsNode& child = node.children[i];
|
// Setup the child
|
||||||
|
MCTSNode& child = node.children[i];
|
||||||
child.prevMove = moves[i];
|
child.prevMove = moves[i];
|
||||||
child.childId = i;
|
child.childId = i;
|
||||||
child.parent = &node;
|
child.parent = &node;
|
||||||
|
|
||||||
{
|
debug << "Expanding move " << i+1 << " out of " << moveCount << ": " << pos.fen() << endl;
|
||||||
debug_print("Expanding move ", i+1, " out of ", moveCount, ": ", pos.fen(), '\n');
|
|
||||||
|
|
||||||
// we enter the child's position
|
// We enter the child's position
|
||||||
do_move(pos, node, child);
|
do_move(pos, node, child);
|
||||||
child.posKey = pos.key();
|
child.posKey = pos.key();
|
||||||
|
|
||||||
const Value terminalValue = terminal_value(pos);
|
const Value terminalValue = terminal_value(pos);
|
||||||
if (terminalValue != VALUE_NONE) {
|
if (terminalValue != VALUE_NONE)
|
||||||
|
{
|
||||||
// if it's a terminal then "play it out"
|
// if it's a terminal then "play it out"
|
||||||
child.isTerminal = true;
|
child.isTerminal = true;
|
||||||
child.prior = value_to_reward(terminalValue);
|
child.prior = value_to_reward(terminalValue);
|
||||||
@@ -2846,27 +2973,28 @@ namespace Search
|
|||||||
numTerminals += 1;
|
numTerminals += 1;
|
||||||
numPlayouts += 1;
|
numPlayouts += 1;
|
||||||
}
|
}
|
||||||
else {
|
else
|
||||||
// otherwise we just note the prior (policy)
|
{
|
||||||
|
// Otherwise we just note the prior (policy)
|
||||||
child.prior = 1.0f - prior;
|
child.prior = 1.0f - prior;
|
||||||
child.actionValue = child.prior * priorWeight;
|
child.actionValue = child.prior * priorWeight;
|
||||||
child.actionValueWeight = priorWeight;
|
child.actionValueWeight = priorWeight;
|
||||||
}
|
}
|
||||||
|
|
||||||
undo_move(pos);
|
undo_move(pos);
|
||||||
}
|
|
||||||
|
|
||||||
// accumulate the policies to backprop.
|
// Accumulate the policies to backprop
|
||||||
backprops.actionValue += child.actionValue;
|
backprops.actionValue += child.actionValue;
|
||||||
backprops.actionValueWeight += child.actionValueWeight;
|
backprops.actionValueWeight += child.actionValueWeight;
|
||||||
|
|
||||||
// Reduce the prior for the next move.
|
// Reduce the prior for the next move
|
||||||
prior *= priorAttenuation;
|
prior *= priorAttenuation;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (numTerminals == 0) {
|
if (numTerminals == 0)
|
||||||
// If no terminals then we do one playout on the best child.
|
{
|
||||||
MctsNode& bestChild = node.get_best_child(explorationFactor);
|
// If no terminals then we do one playout on the best child
|
||||||
|
MCTSNode& bestChild = node.get_best_child(explorationFactor);
|
||||||
do_move(pos, node, bestChild);
|
do_move(pos, node, bestChild);
|
||||||
|
|
||||||
backprops.numVisits += 1;
|
backprops.numVisits += 1;
|
||||||
@@ -2878,7 +3006,8 @@ namespace Search
|
|||||||
|
|
||||||
undo_move(pos);
|
undo_move(pos);
|
||||||
}
|
}
|
||||||
else {
|
else
|
||||||
|
{
|
||||||
// If there are any terminals we don't do more playouts
|
// If there are any terminals we don't do more playouts
|
||||||
backprops.numVisits += numTerminals;
|
backprops.numVisits += numTerminals;
|
||||||
}
|
}
|
||||||
@@ -2894,8 +3023,11 @@ namespace Search
|
|||||||
return backprops;
|
return backprops;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Does a move and updates the stack.
|
|
||||||
void do_move(Position& pos, MctsNode& parentNode, MctsNode& childNode) {
|
// do_move() does a move and updates the stack
|
||||||
|
|
||||||
|
void do_move(Position& pos, MCTSNode& parentNode, MCTSNode& childNode) {
|
||||||
|
|
||||||
assert(ply < MAX_PLY);
|
assert(ply < MAX_PLY);
|
||||||
assert(!parentNode.is_leaf());
|
assert(!parentNode.is_leaf());
|
||||||
assert(&parentNode.children[childNode.childId] == &childNode);
|
assert(&parentNode.children[childNode.childId] == &childNode);
|
||||||
@@ -2923,11 +3055,14 @@ namespace Search
|
|||||||
assert(childNode.posKey == 0 || childNode.posKey == pos.key());
|
assert(childNode.posKey == 0 || childNode.posKey == pos.key());
|
||||||
|
|
||||||
ply += 1;
|
ply += 1;
|
||||||
|
|
||||||
if (ply > maximumPly)
|
if (ply > maximumPly)
|
||||||
maximumPly = ply;
|
maximumPly = ply;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Undoes a move, pops the stack.
|
|
||||||
|
// undo_move() undoes a move and pops the stack
|
||||||
|
|
||||||
void undo_move(Position& pos) {
|
void undo_move(Position& pos) {
|
||||||
assert(ply > 1);
|
assert(ply > 1);
|
||||||
|
|
||||||
@@ -2936,12 +3071,16 @@ namespace Search
|
|||||||
pos.undo_move(stack[ply].currentMove);
|
pos.undo_move(stack[ply].currentMove);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// create_new_root() inits a root from the given position
|
||||||
|
|
||||||
void create_new_root(Position& pos) {
|
void create_new_root(Position& pos) {
|
||||||
rootNode = MctsNode{};
|
rootNode = MCTSNode{};
|
||||||
rootNode.posKey = pos.key();
|
rootNode.posKey = pos.key();
|
||||||
rootNode.isTerminal = MoveList<LEGAL>(pos).size() == 0;
|
rootNode.isTerminal = MoveList<LEGAL>(pos).size() == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void init_for_mcts_search(Position& pos) {
|
void init_for_mcts_search(Position& pos) {
|
||||||
std::memset(stack - 7, 0, 10 * sizeof(Stack));
|
std::memset(stack - 7, 0, 10 * sizeof(Stack));
|
||||||
|
|
||||||
@@ -2966,10 +3105,8 @@ namespace Search
|
|||||||
: ct;
|
: ct;
|
||||||
|
|
||||||
// Evaluation score is from the white point of view
|
// Evaluation score is from the white point of view
|
||||||
th->contempt =
|
th->contempt = (us == WHITE ? make_score(ct, ct / 2)
|
||||||
us == WHITE
|
: -make_score(ct, ct / 2));
|
||||||
? make_score(ct, ct / 2)
|
|
||||||
: -make_score(ct, ct / 2);
|
|
||||||
|
|
||||||
create_new_root(pos);
|
create_new_root(pos);
|
||||||
|
|
||||||
@@ -2979,24 +3116,29 @@ namespace Search
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
ValueAndPV search_mcts(
|
|
||||||
Position& pos,
|
|
||||||
std::uint64_t numPlayouts,
|
|
||||||
Depth leafDepth,
|
|
||||||
float explorationFactor) {
|
|
||||||
|
|
||||||
|
// search_mcts() : this is the main function of the MonteCarloTreeSearch class
|
||||||
|
|
||||||
|
ValueAndPV search_mcts( Position& pos,
|
||||||
|
uint64_t numPlayouts,
|
||||||
|
Depth leafDepth,
|
||||||
|
float explorationFactor)
|
||||||
|
{
|
||||||
MonteCarloTreeSearch mcts{};
|
MonteCarloTreeSearch mcts{};
|
||||||
return mcts.search_new(pos, numPlayouts, leafDepth, explorationFactor);
|
return mcts.search_new(pos, numPlayouts, leafDepth, explorationFactor);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<MctsContinuation> search_mcts_multipv(
|
|
||||||
Position& pos,
|
|
||||||
std::uint64_t numPlayouts,
|
|
||||||
Depth leafDepth,
|
|
||||||
float explorationFactor) {
|
|
||||||
|
|
||||||
|
// search_mcts_multipv() : use this for multiPV
|
||||||
|
|
||||||
|
std::vector<MctsContinuation> search_mcts_multipv( Position& pos,
|
||||||
|
uint64_t numPlayouts,
|
||||||
|
Depth leafDepth,
|
||||||
|
float explorationFactor)
|
||||||
|
{
|
||||||
MonteCarloTreeSearch mcts{};
|
MonteCarloTreeSearch mcts{};
|
||||||
mcts.search_new(pos, numPlayouts, leafDepth, explorationFactor);
|
mcts.search_new(pos, numPlayouts, leafDepth, explorationFactor);
|
||||||
|
|
||||||
return mcts.get_all_continuations();
|
return mcts.get_all_continuations();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user