mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 13:17:44 +00:00
Add primitive MCTS search.
This commit is contained in:
committed by
Stéphane Nicolet
parent
a44b1115c4
commit
9094255f50
+800
@@ -2200,6 +2200,806 @@ namespace Search
|
||||
return ValueAndPV(bestValue, pvs);
|
||||
}
|
||||
|
||||
namespace MCTS {
|
||||
/*
|
||||
The implementation of the MCTS is heavily based on Stephane Nicolet's work here
|
||||
https://github.com/snicolet/Stockfish/commit/28501872a1e7ce84dd1f38ab9e59c5adb0d24b41
|
||||
and the adjusted implementation of it in ShashChess https://github.com/amchess/ShashChess
|
||||
*/
|
||||
static constexpr float sigmoidScale = 600.0f;
|
||||
|
||||
static inline float fast_sigmoid(float x) {
|
||||
bool negative = x < 0.0f;
|
||||
if (negative)
|
||||
x = -x;
|
||||
const float xx = x*x;
|
||||
const float v = 1.0f / (1.0f + 1.0f / (1.0f + x + xx*(0.555f + xx*0.143f)));
|
||||
if (negative)
|
||||
return 1.0f - v;
|
||||
else
|
||||
return v;
|
||||
}
|
||||
|
||||
static inline Value reward_to_value(float r) {
|
||||
if (r > 0.99f) return VALUE_KNOWN_WIN;
|
||||
if (r < 0.01f) return -VALUE_KNOWN_WIN;
|
||||
|
||||
return Value(-sigmoidScale * std::log(1.0f/r - 1.0f));
|
||||
}
|
||||
|
||||
static inline float value_to_reward(Value v) {
|
||||
return fast_sigmoid(static_cast<float>(v) * (1.0f / sigmoidScale));
|
||||
}
|
||||
|
||||
struct MctsNode {
|
||||
Key posKey = 0; // for consistency checks.
|
||||
MctsNode* parent = nullptr; // only nullptr for the root node
|
||||
std::unique_ptr<MctsNode[]> children = nullptr; // only nullptr for nodes that have not been expanded
|
||||
std::uint64_t numVisits = 0; // the number of playouts for this node and all descendants
|
||||
Value leafSearchEval = VALUE_NONE; // the evaluation from AB playout
|
||||
float prior = 0.0f; // the policy, currently a rough estimation based on the playout of the parent
|
||||
float actionValue = 0.0f; // the accumulated rewards
|
||||
float actionValueWeight = 0.0f; // the maximum value for the accumulater rewards
|
||||
Move prevMove = MOVE_NONE; // the move on the edge from the parent
|
||||
std::uint16_t numChildren = 0; // the number of legal moves, filled on expansion
|
||||
std::uint16_t childId = 0; // the index of this node in the parent's children array
|
||||
Depth leafSearchDepth = DEPTH_NONE; // the depth with which the AB playout was done
|
||||
bool isTerminal = false; // whether the node is terminal. Terminal nodes are always "expanded" immediately.
|
||||
|
||||
// The upper confidence bound. When searching for the node to expand/playout we take
|
||||
// one with the highest ucb.
|
||||
float ucb_value(MctsNode& child, float explorationFactor, bool flipPerspective = false) const {
|
||||
assert(explorationFactor >= 0.0f);
|
||||
assert(child.actionValue >= 0.0f);
|
||||
assert(child.actionValueWeight >= 0.0f);
|
||||
assert(child.actionValue <= child.actionValueWeight);
|
||||
assert(child.prior >= 0.0f);
|
||||
assert(child.prior <= 1.0f);
|
||||
|
||||
// For the nodes which have not been played-out we use the prior.
|
||||
// Otherwise we have some averaged score or the eval already.
|
||||
float reward =
|
||||
child.numVisits == 0
|
||||
? child.prior
|
||||
: child.actionValue / child.actionValueWeight;
|
||||
|
||||
if (flipPerspective)
|
||||
reward = 1.0f - reward;
|
||||
|
||||
// The exploration factor.
|
||||
// In theory unplayed nodes should have priority, but we
|
||||
// add 1 to avoid div by 0 so they might not always be prioritized.
|
||||
if (explorationFactor != 0.0f)
|
||||
reward +=
|
||||
explorationFactor
|
||||
* std::sqrt(std::log(static_cast<float>(1 + numVisits))
|
||||
/ static_cast<float>(1 + child.numVisits));
|
||||
|
||||
assert(!std::isnan(reward));
|
||||
assert(reward >= 0.0f);
|
||||
|
||||
return reward;
|
||||
}
|
||||
|
||||
// Returns the reference to the best child node.
|
||||
const MctsNode& get_best_child(float explorationFactor) const {
|
||||
assert(!is_leaf());
|
||||
assert(numChildren > 0);
|
||||
|
||||
if (numChildren == 1) {
|
||||
assert(children[0].childId == 0);
|
||||
return children[0];
|
||||
}
|
||||
|
||||
int bestIdx = -1;
|
||||
float bestValue = std::numeric_limits<float>::lowest();
|
||||
for (int i = 0; i < static_cast<int>(numChildren); ++i) {
|
||||
MctsNode& child = children[i];
|
||||
// The "best" is the one with the best UCB.
|
||||
// Child values are with opposite signs.
|
||||
const float r = ucb_value(child, explorationFactor, true);
|
||||
if (r > bestValue) {
|
||||
bestIdx = i;
|
||||
bestValue = r;
|
||||
}
|
||||
}
|
||||
|
||||
assert(bestIdx >= 0);
|
||||
assert(bestIdx < numChildren);
|
||||
assert(children[bestIdx].childId == bestIdx);
|
||||
|
||||
return children[bestIdx];
|
||||
}
|
||||
|
||||
MctsNode& get_best_child(float explorationFactor) {
|
||||
return const_cast<MctsNode&>(static_cast<const MctsNode*>(this)->get_best_child(explorationFactor));
|
||||
}
|
||||
|
||||
std::pair<Move, Value> get_best_move() const {
|
||||
assert(!is_leaf());
|
||||
assert(numChildren > 0);
|
||||
|
||||
int bestIdx = -1;
|
||||
float bestValue = std::numeric_limits<float>::lowest();
|
||||
for (int i = 0; i < static_cast<int>(numChildren); ++i) {
|
||||
MctsNode& child = children[i];
|
||||
// The "best" is the one with the best UCB.
|
||||
// Child values are with opposite signs.
|
||||
const float r = 1.0f - (child.actionValue / child.actionValueWeight);
|
||||
if (r > bestValue) {
|
||||
bestIdx = i;
|
||||
bestValue = r;
|
||||
}
|
||||
}
|
||||
|
||||
assert(bestIdx >= 0);
|
||||
assert(bestIdx < numChildren);
|
||||
assert(children[bestIdx].childId == bestIdx);
|
||||
|
||||
return { children[bestIdx].prevMove, reward_to_value(bestValue) };
|
||||
}
|
||||
|
||||
const MctsNode* get_child_by_move(Move move) const {
|
||||
for (int i = 0; i < static_cast<int>(numChildren); ++i) {
|
||||
MctsNode& child = children[i];
|
||||
if (child.prevMove == move)
|
||||
return &child;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MctsNode* get_child_by_move(Move move) {
|
||||
return const_cast<MctsNode*>(static_cast<const MctsNode*>(this)->get_child_by_move(move));
|
||||
}
|
||||
|
||||
bool is_root() const {
|
||||
return parent == nullptr;
|
||||
}
|
||||
|
||||
bool is_leaf() const {
|
||||
return children == nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// The stuff that needs to be backpropagated down the tree.
|
||||
struct BackpropValues {
|
||||
std::uint64_t numVisits = 0;
|
||||
float actionValue = 0.0f;
|
||||
float actionValueWeight = 0.0f;
|
||||
|
||||
// We always keep everything for the side to move perspective.
|
||||
// When changing the side the score flips.
|
||||
void flip_side() {
|
||||
assert(actionValueWeight >= actionValue);
|
||||
assert(actionValue >= 0.0f);
|
||||
assert(actionValueWeight >= 0.0f);
|
||||
|
||||
actionValue = actionValueWeight - actionValue;
|
||||
}
|
||||
};
|
||||
|
||||
struct MonteCarloTreeSearch {
|
||||
static constexpr Depth terminalEvalDepth = Depth(255);
|
||||
|
||||
// We add a lot of stuff to the actionValue, but the weights differ.
|
||||
// The prior is currently bad so low weight,
|
||||
static constexpr float priorWeight = 0.01f;
|
||||
static constexpr float terminalWeight = 1.0f; // could be increased? Different for wins/draws?
|
||||
static constexpr float normalWeight = 1.0f;
|
||||
|
||||
static_assert(priorWeight > 0.0f);
|
||||
static_assert(terminalWeight > 0.0f);
|
||||
static_assert(normalWeight > 0.0f);
|
||||
|
||||
MonteCarloTreeSearch() {}
|
||||
|
||||
MonteCarloTreeSearch(const MonteCarloTreeSearch&) = delete;
|
||||
|
||||
ValueAndPV search_new(
|
||||
Position& pos,
|
||||
std::uint64_t maxPlayouts,
|
||||
Depth leafDepth,
|
||||
float explorationFactor = 0.25f) {
|
||||
|
||||
init_for_mcts_search(pos);
|
||||
return search_continue(pos, maxPlayouts, leafDepth, explorationFactor);
|
||||
}
|
||||
|
||||
// Continue after a move and reuse the relevant part of the tree.
|
||||
// The prevMove is the move that lead to pos
|
||||
// TODO: Make the node limit be the total.
|
||||
ValueAndPV search_continue_after_move(
|
||||
Position& pos,
|
||||
Move prevMove,
|
||||
std::uint64_t maxPlayouts,
|
||||
Depth leafDepth,
|
||||
float explorationFactor = 0.25f) {
|
||||
|
||||
do_move_at_root(pos, prevMove);
|
||||
return search_continue(pos, maxPlayouts, leafDepth, explorationFactor);
|
||||
}
|
||||
|
||||
std::vector<MctsContinuation> get_all_continuations() const {
|
||||
std::vector<MctsContinuation> continuations;
|
||||
continuations.resize(rootNode.numChildren);
|
||||
|
||||
for (int i = 0; i < rootNode.numChildren; ++i) {
|
||||
MctsNode& child = rootNode.children[i];
|
||||
|
||||
auto& cont = continuations[i];
|
||||
cont.numVisits = child.numVisits;
|
||||
// child value is with opposite sign
|
||||
cont.actionValue = 1.0f - (child.actionValue / child.actionValueWeight);
|
||||
cont.value = reward_to_value(cont.actionValue);
|
||||
cont.pv = get_pv(child);
|
||||
}
|
||||
|
||||
std::stable_sort(
|
||||
continuations.begin(),
|
||||
continuations.end(),
|
||||
[](const auto& lhs, const auto& rhs) { return lhs.value > rhs.value; }
|
||||
);
|
||||
|
||||
return continuations;
|
||||
}
|
||||
|
||||
// Continues with the same tree.
|
||||
ValueAndPV search_continue(
|
||||
Position& pos,
|
||||
std::uint64_t maxPlayouts,
|
||||
Depth leafDepth,
|
||||
float explorationFactor = 0.25f) {
|
||||
|
||||
if (rootNode.leafSearchDepth == DEPTH_NONE)
|
||||
do_playout(pos, rootNode, leafDepth);
|
||||
|
||||
while (numPlayouts < maxPlayouts) {
|
||||
debug_print("Starting iteration ", numPlayouts, '\n');
|
||||
do_search_iteration(pos, leafDepth, explorationFactor);
|
||||
}
|
||||
|
||||
if (rootNode.is_leaf())
|
||||
return {};
|
||||
else {
|
||||
return { rootNode.get_best_move().second, get_pv() };
|
||||
}
|
||||
}
|
||||
|
||||
Stack stackBuffer[MAX_PLY + 10];
|
||||
StateInfo statesBuffer[MAX_PLY + 10];
|
||||
|
||||
Stack* stack = stackBuffer + 7;
|
||||
StateInfo* states = statesBuffer + 7;
|
||||
|
||||
MctsNode rootNode;
|
||||
|
||||
int ply = 1;
|
||||
int maximumPly = ply; // Effectively the selective depth.
|
||||
std::uint64_t numPlayouts = 0;
|
||||
|
||||
private:
|
||||
// IMPORTANT:
|
||||
// The position is stateful so we always have one.
|
||||
// It has to match certain expectations in different functions.
|
||||
// For example when looking for the node to expand the pos must correspond
|
||||
// to the root mcts node. When expanding the node it must correspond to the
|
||||
// node being expanded, etc.
|
||||
|
||||
void reset_stats() {
|
||||
numPlayouts = 0;
|
||||
}
|
||||
|
||||
void accumulate_stats_recursively(MctsNode& node) {
|
||||
if (node.leafSearchDepth != DEPTH_NONE)
|
||||
numPlayouts += 1;
|
||||
|
||||
if (!node.is_leaf())
|
||||
for (int i = 0; i < node.numChildren; ++i)
|
||||
accumulate_stats_recursively(node.children[i]);
|
||||
}
|
||||
|
||||
void recalculate_stats() {
|
||||
reset_stats();
|
||||
accumulate_stats_recursively(rootNode);
|
||||
}
|
||||
|
||||
// Tree reuse.
|
||||
// pos is the position after move.
|
||||
void do_move_at_root(Position& pos, Move move) {
|
||||
MctsNode* child = rootNode.get_child_by_move(move);
|
||||
if (child == nullptr)
|
||||
create_new_root(pos);
|
||||
else {
|
||||
rootNode = std::move(*child);
|
||||
rootNode.parent = nullptr;
|
||||
rootNode.childId = 0;
|
||||
// keep rootNode.prevMove for move ordering heuristics
|
||||
}
|
||||
|
||||
recalculate_stats();
|
||||
|
||||
assert(rootNode.posKey == pos.key());
|
||||
}
|
||||
|
||||
// One iteration of the search. Basically:
|
||||
// 1. find a node to expand/playout
|
||||
// 2. if the node is a terminal then we just get the stuff and backprop
|
||||
// 3. if we only have prior for the node then do a playout
|
||||
// 4. otherwise we expand the children and do at least one playout from the best child (chosen by prior)
|
||||
// 4.1. a terminal node counts as a playout. All terminal nodes are played out.
|
||||
// 5. Backpropagate all changes down the tree.
|
||||
void do_search_iteration(Position& pos, Depth leafDepth, float explorationFactor) {
|
||||
MctsNode& node = find_node_to_expand_or_playout(pos, explorationFactor);
|
||||
BackpropValues backprops{};
|
||||
if (node.isTerminal) {
|
||||
debug_print("Root is terminal\n");
|
||||
backprops.numVisits = 1;
|
||||
backprops.actionValue += node.actionValue;
|
||||
backprops.actionValueWeight += node.actionValueWeight;
|
||||
|
||||
numPlayouts += 1;
|
||||
}
|
||||
else if (node.leafSearchDepth == DEPTH_NONE) {
|
||||
// The node is considered the best but it only has a prior value.
|
||||
// We don't really want to expand nodes based just on the prior, so
|
||||
// first do a playout to get a better estimate, and expand only in the
|
||||
// next iteration.
|
||||
// Normally we playout immediately the move with the best prior, but that
|
||||
// playout can put it below another move.
|
||||
backprops = do_playout(pos, node, leafDepth);
|
||||
}
|
||||
else {
|
||||
// We have done leaf evaluation with AB search so we know that
|
||||
// this node is *actually good* and not just *prior good*, so we
|
||||
// can now expand it and do an immediate playout for the node with the best prior.
|
||||
backprops = expand_node_and_do_playout(pos, node, leafDepth, explorationFactor);
|
||||
}
|
||||
|
||||
backpropagate(pos, node, backprops);
|
||||
}
|
||||
|
||||
// Backpropagates the changes after an expand/playout all the way to the root.
|
||||
// pos is expected to be at the node from which we start backpropagating.
|
||||
void backpropagate(Position& pos, MctsNode& node, BackpropValues backprops) {
|
||||
assert(node.posKey == pos.key());
|
||||
assert(ply >= 1);
|
||||
|
||||
debug_print("Backpropagating: ", pos.fen(), '\n');
|
||||
|
||||
MctsNode* currentNode = &node;
|
||||
while (!currentNode->is_root()) {
|
||||
// On each descent we switch the side to move.
|
||||
undo_move(pos);
|
||||
currentNode = currentNode->parent;
|
||||
backprops.flip_side();
|
||||
|
||||
debug_print("Backprop step: ", pos.fen(), '\n');
|
||||
assert(currentNode->posKey == pos.key());
|
||||
|
||||
currentNode->numVisits += backprops.numVisits;
|
||||
currentNode->actionValue += backprops.actionValue;
|
||||
currentNode->actionValueWeight += backprops.actionValueWeight;
|
||||
}
|
||||
|
||||
// At the end we must be at the root.
|
||||
assert(currentNode == &rootNode);
|
||||
assert(rootNode.posKey == pos.key());
|
||||
}
|
||||
|
||||
// Navigate with pos to the node expand/playout
|
||||
MctsNode& find_node_to_expand_or_playout(Position& pos, float explorationFactor) {
|
||||
assert(rootNode.posKey == pos.key());
|
||||
|
||||
// Find a node that has not yet been expanded
|
||||
MctsNode* currentNode = &rootNode;
|
||||
while (!currentNode->is_leaf()) {
|
||||
MctsNode& bestChild = currentNode->get_best_child(explorationFactor);
|
||||
|
||||
do_move(pos, *currentNode, bestChild);
|
||||
|
||||
currentNode = &bestChild;
|
||||
}
|
||||
|
||||
return *currentNode;
|
||||
}
|
||||
|
||||
int generate_moves_unordered(Position& pos, Move* out) const {
|
||||
int moveCount = 0;
|
||||
for (auto move : MoveList<LEGAL>(pos))
|
||||
out[moveCount++] = move;
|
||||
|
||||
return moveCount;
|
||||
}
|
||||
|
||||
// Generate moves with some reasonable ordering. Using this we can assume some reasonable priors.
|
||||
int generate_moves_ordered(Position& pos, MctsNode& node, Depth leafDepth, Move* out) const {
|
||||
assert(ply >= 1);
|
||||
|
||||
debug_print("Generating moves: ", pos.fen(), '\n');
|
||||
|
||||
Thread* const thread = pos.this_thread();
|
||||
const Square prevSq = to_sq(node.prevMove);
|
||||
const Move countermove = thread->counterMoves[pos.piece_on(prevSq)][prevSq];
|
||||
const Move ttMove = MOVE_NONE; // TODO: retrieve tt move
|
||||
const Move* const killers = stack[ply].killers;
|
||||
const Depth depth = leafDepth + 1;
|
||||
|
||||
const PieceToHistory* contHist[] = {
|
||||
stack[ply-1].continuationHistory, stack[ply-2].continuationHistory,
|
||||
nullptr , stack[ply-4].continuationHistory,
|
||||
nullptr , stack[ply-6].continuationHistory
|
||||
};
|
||||
|
||||
assert(contHist[0] != nullptr);
|
||||
assert(contHist[1] != nullptr);
|
||||
assert(contHist[3] != nullptr);
|
||||
assert(contHist[5] != nullptr);
|
||||
|
||||
MovePicker mp(
|
||||
pos,
|
||||
ttMove,
|
||||
depth,
|
||||
&(thread->mainHistory),
|
||||
&(thread->lowPlyHistory),
|
||||
&(thread->captureHistory),
|
||||
contHist,
|
||||
countermove,
|
||||
killers,
|
||||
ply
|
||||
);
|
||||
|
||||
int moveCount = 0;
|
||||
for (;;) {
|
||||
const Move move = mp.next_move();
|
||||
debug_print("Generated move ", UCI::move(move, false), ": ", pos.fen(), '\n');
|
||||
|
||||
if (move == MOVE_NONE)
|
||||
break;
|
||||
|
||||
if (pos.legal(move))
|
||||
out[moveCount++] = move;
|
||||
}
|
||||
|
||||
debug_print("Generated ", moveCount, " legal moves: ", pos.fen(), '\n');
|
||||
|
||||
return moveCount;
|
||||
}
|
||||
|
||||
void init_for_leaf_search(Position& pos) {
|
||||
auto th = pos.this_thread();
|
||||
|
||||
th->completedDepth = 0;
|
||||
th->selDepth = 0;
|
||||
th->rootDepth = 0;
|
||||
th->nmpMinPly = th->bestMoveChanges = th->failedHighCnt = 0;
|
||||
th->ttHitAverage = TtHitAverageWindow * TtHitAverageResolution / 2;
|
||||
th->nodes = 0;
|
||||
}
|
||||
|
||||
// Checks whether the position is terminal and return the right value if it is.
|
||||
// Otherwise returns VALUE_NONE
|
||||
Value terminal_value(Position& pos) const {
|
||||
if (MoveList<LEGAL>(pos).size() == 0)
|
||||
return pos.checkers() ? VALUE_MATE : -VALUE_MATE;;
|
||||
|
||||
if (ply >= MAX_PLY - 2 || pos.is_draw(ply - 1))
|
||||
return VALUE_DRAW;
|
||||
|
||||
return VALUE_NONE;
|
||||
}
|
||||
|
||||
// Does AB search on the position to get the value.
|
||||
Value evaluate_leaf(Position& pos, MctsNode& node, Depth leafDepth) {
|
||||
assert(node.posKey == pos.key());
|
||||
assert(node.leafSearchDepth == DEPTH_NONE);
|
||||
assert(node.leafSearchEval == VALUE_NONE);
|
||||
|
||||
debug_print("Evaluating leaf: ", pos.fen(), '\n');
|
||||
|
||||
init_for_leaf_search(pos);
|
||||
|
||||
Move pv[MAX_PLY + 1];
|
||||
stack[ply].pv = pv;
|
||||
stack[ply].currentMove = MOVE_NONE;
|
||||
stack[ply].excludedMove = MOVE_NONE;
|
||||
|
||||
if (!node.is_root() && node.parent->leafSearchEval != VALUE_NONE) {
|
||||
// If we have some parent score then use an aspiration window.
|
||||
// We know what to expect.
|
||||
Value delta = Value(18);
|
||||
Value alpha = std::max(node.parent->leafSearchEval - delta, -VALUE_INFINITE);
|
||||
Value beta = std::min(node.parent->leafSearchEval + delta, VALUE_INFINITE);
|
||||
for (;;) {
|
||||
const Value value = Stockfish::search<PV>(pos, stack + ply, alpha, beta, leafDepth, false);
|
||||
if (value <= alpha) {
|
||||
beta = (alpha + beta) / 2;
|
||||
alpha = std::max(value - delta, -VALUE_INFINITE);
|
||||
}
|
||||
else if (value >= beta)
|
||||
beta = std::min(value + delta, VALUE_INFINITE);
|
||||
else
|
||||
return value;
|
||||
|
||||
delta += delta / 4 + 5;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// If no parent score then do infinite aspiration window.
|
||||
return Stockfish::search<PV>(pos, stack + ply, -VALUE_INFINITE, VALUE_INFINITE, leafDepth, false);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Move> get_pv(const MctsNode& node) const {
|
||||
std::vector<Move> pv;
|
||||
|
||||
const MctsNode* currentNode = &node;
|
||||
if (!currentNode->is_root())
|
||||
pv.emplace_back(currentNode->prevMove);
|
||||
|
||||
while (!currentNode->is_leaf()) {
|
||||
// No exploration factor for choosing the PV.
|
||||
const MctsNode& bestChild = currentNode->get_best_child(0.0f);
|
||||
pv.emplace_back(bestChild.prevMove);
|
||||
currentNode = &bestChild;
|
||||
}
|
||||
|
||||
return pv;
|
||||
}
|
||||
|
||||
std::vector<Move> get_pv() const {
|
||||
return get_pv(rootNode);
|
||||
}
|
||||
|
||||
// Does a single playout and returns what is needed to backprop.
|
||||
BackpropValues do_playout(Position& pos, MctsNode& node, Depth leafDepth) {
|
||||
assert(node.posKey == pos.key());
|
||||
assert(node.numVisits == 0);
|
||||
assert(node.is_leaf());
|
||||
assert(node.numChildren == 0);
|
||||
assert(node.leafSearchDepth == DEPTH_NONE);
|
||||
assert(!node.isTerminal);
|
||||
|
||||
debug_print("Doing playout ", numPlayouts, ": ", pos.fen(), '\n');
|
||||
|
||||
numPlayouts += 1;
|
||||
|
||||
const Value v = evaluate_leaf(pos, node, leafDepth);
|
||||
|
||||
BackpropValues backprops{};
|
||||
backprops.numVisits = 1; // playout counts as a visit
|
||||
backprops.actionValue += value_to_reward(v);
|
||||
backprops.actionValueWeight += normalWeight;
|
||||
|
||||
// Bookkeep raw eval.
|
||||
node.leafSearchEval = v;
|
||||
node.leafSearchDepth = leafDepth;
|
||||
|
||||
// Local backprop because normal backprop handles only the
|
||||
// nodes starting from the parent of this one
|
||||
node.numVisits = backprops.numVisits;
|
||||
node.actionValue += backprops.actionValue;
|
||||
node.actionValueWeight += backprops.actionValueWeight;
|
||||
|
||||
return backprops;
|
||||
}
|
||||
|
||||
// Expand a node and do at least one playout.
|
||||
// May do more "playouts" if there are terminals as those are "played out" immediately.
|
||||
// Returns what needs to be backpropagated.
|
||||
BackpropValues expand_node_and_do_playout(
|
||||
Position& pos,
|
||||
MctsNode& node,
|
||||
Depth leafDepth,
|
||||
float explorationFactor) {
|
||||
|
||||
assert(node.posKey == pos.key()); // node must match the position
|
||||
assert(node.is_leaf()); // otherwise already expanded
|
||||
assert(node.numChildren == 0);
|
||||
assert(!node.isTerminal); // terminals cannot be expanded
|
||||
assert(node.numVisits == 1); // we expect it to have the "playout visit". Fake visit for the root.
|
||||
assert(node.leafSearchDepth != DEPTH_NONE);
|
||||
assert(node.leafSearchEval != VALUE_NONE);
|
||||
|
||||
debug_print("Expanding and playing out: ", pos.fen(), '\n');
|
||||
|
||||
Move moves[MAX_MOVES];
|
||||
const int moveCount = generate_moves_ordered(pos, node, leafDepth, moves);
|
||||
|
||||
assert(moveCount > 0);
|
||||
|
||||
node.children = std::make_unique<MctsNode[]>(moveCount);
|
||||
node.numChildren = moveCount;
|
||||
|
||||
int numTerminals = 0;
|
||||
BackpropValues backprops{};
|
||||
|
||||
float prior = value_to_reward(node.leafSearchEval);
|
||||
// Prior is attenuated for later moves - we rely on move ordering.
|
||||
// Attenuate more at higher plies where we have more move ordering.
|
||||
const float priorAttenuation = 1.0f - std::min((ply - 1) / 100.0f, 0.05f);
|
||||
for (int i = 0; i < moveCount; ++i) {
|
||||
// setup the child
|
||||
MctsNode& child = node.children[i];
|
||||
child.prevMove = moves[i];
|
||||
child.childId = i;
|
||||
child.parent = &node;
|
||||
|
||||
{
|
||||
debug_print("Expanding move ", i+1, " out of ", moveCount, ": ", pos.fen(), '\n');
|
||||
|
||||
// we enter the child's position
|
||||
do_move(pos, node, child);
|
||||
child.posKey = pos.key();
|
||||
|
||||
const Value terminalValue = terminal_value(pos);
|
||||
if (terminalValue != VALUE_NONE) {
|
||||
// if it's a terminal then "play it out"
|
||||
child.isTerminal = true;
|
||||
child.prior = value_to_reward(terminalValue);
|
||||
child.numVisits = 1;
|
||||
child.actionValue = child.prior * terminalWeight;
|
||||
child.actionValueWeight = terminalWeight;
|
||||
child.leafSearchEval = terminalValue;
|
||||
child.leafSearchDepth = terminalEvalDepth;
|
||||
|
||||
numTerminals += 1;
|
||||
numPlayouts += 1;
|
||||
}
|
||||
else {
|
||||
// otherwise we just note the prior (policy)
|
||||
child.prior = 1.0f - prior;
|
||||
child.actionValue = child.prior * priorWeight;
|
||||
child.actionValueWeight = priorWeight;
|
||||
}
|
||||
|
||||
undo_move(pos);
|
||||
}
|
||||
|
||||
// accumulate the policies to backprop.
|
||||
backprops.actionValue += child.actionValue;
|
||||
backprops.actionValueWeight += child.actionValueWeight;
|
||||
|
||||
// Reduce the prior for the next move.
|
||||
prior *= priorAttenuation;
|
||||
}
|
||||
|
||||
if (numTerminals == 0) {
|
||||
// If no terminals then we do one playout on the best child.
|
||||
MctsNode& bestChild = node.get_best_child(explorationFactor);
|
||||
do_move(pos, node, bestChild);
|
||||
|
||||
backprops.numVisits += 1;
|
||||
|
||||
auto playoutBackprops = do_playout(pos, bestChild, leafDepth);
|
||||
backprops.numVisits += playoutBackprops.numVisits;
|
||||
backprops.actionValue += playoutBackprops.actionValue;
|
||||
backprops.actionValueWeight += playoutBackprops.actionValueWeight;
|
||||
|
||||
undo_move(pos);
|
||||
}
|
||||
else {
|
||||
// If there are any terminals we don't do more playouts
|
||||
backprops.numVisits += numTerminals;
|
||||
}
|
||||
|
||||
// Local backprop because normal backprop handles only the
|
||||
// nodes starting from the parent of this one
|
||||
backprops.flip_side();
|
||||
|
||||
node.actionValue = backprops.actionValue;
|
||||
node.actionValueWeight = backprops.actionValueWeight;
|
||||
node.numVisits = backprops.numVisits;
|
||||
|
||||
return backprops;
|
||||
}
|
||||
|
||||
// Does a move and updates the stack.
|
||||
void do_move(Position& pos, MctsNode& parentNode, MctsNode& childNode) {
|
||||
assert(ply < MAX_PLY);
|
||||
assert(!parentNode.is_leaf());
|
||||
assert(&parentNode.children[childNode.childId] == &childNode);
|
||||
assert(parentNode.posKey == pos.key());
|
||||
|
||||
const Move move = childNode.prevMove;
|
||||
|
||||
stack[ply].currentMove = move;
|
||||
stack[ply].inCheck = pos.checkers();
|
||||
stack[ply].continuationHistory =
|
||||
&(
|
||||
pos.this_thread()->continuationHistory
|
||||
[stack[ply].inCheck]
|
||||
[pos.capture_or_promotion(move)]
|
||||
[pos.moved_piece(move)]
|
||||
[to_sq(move)]
|
||||
);
|
||||
stack[ply].staticEval = parentNode.leafSearchEval;
|
||||
stack[ply].moveCount = childNode.childId + 1;
|
||||
|
||||
pos.do_move(move, states[ply]);
|
||||
|
||||
// The first time around we don't have posKey set yet,
|
||||
// because we need to do the move first.
|
||||
assert(childNode.posKey == 0 || childNode.posKey == pos.key());
|
||||
|
||||
ply += 1;
|
||||
if (ply > maximumPly)
|
||||
maximumPly = ply;
|
||||
}
|
||||
|
||||
// Undoes a move, pops the stack.
|
||||
void undo_move(Position& pos) {
|
||||
assert(ply > 1);
|
||||
|
||||
ply -= 1;
|
||||
|
||||
pos.undo_move(stack[ply].currentMove);
|
||||
}
|
||||
|
||||
void create_new_root(Position& pos) {
|
||||
rootNode = MctsNode{};
|
||||
rootNode.posKey = pos.key();
|
||||
rootNode.isTerminal = MoveList<LEGAL>(pos).size() == 0;
|
||||
}
|
||||
|
||||
void init_for_mcts_search(Position& pos) {
|
||||
std::memset(stack - 7, 0, 10 * sizeof(Stack));
|
||||
|
||||
auto th = pos.this_thread();
|
||||
|
||||
// stack + 0 also needs to be initialized because we start from ply=1
|
||||
for (int i = 7; i >= 0; --i)
|
||||
(stack - i)->continuationHistory = &th->continuationHistory[0][0][NO_PIECE][0]; // Use as a sentinel
|
||||
|
||||
for (int i = 1; i <= MAX_PLY; ++i)
|
||||
(stack + i)->ply = i;
|
||||
|
||||
int ct = int(Options["Contempt"]) * PawnValueEg / 100; // From centipawns
|
||||
Color us = pos.side_to_move();
|
||||
|
||||
// In analysis mode, adjust contempt in accordance with user preference
|
||||
if (Limits.infinite || Options["UCI_AnalyseMode"])
|
||||
ct = Options["Analysis Contempt"] == "Off" ? 0
|
||||
: Options["Analysis Contempt"] == "Both" ? ct
|
||||
: Options["Analysis Contempt"] == "White" && us == BLACK ? -ct
|
||||
: Options["Analysis Contempt"] == "Black" && us == WHITE ? -ct
|
||||
: ct;
|
||||
|
||||
// Evaluation score is from the white point of view
|
||||
th->contempt =
|
||||
us == WHITE
|
||||
? make_score(ct, ct / 2)
|
||||
: -make_score(ct, ct / 2);
|
||||
|
||||
create_new_root(pos);
|
||||
|
||||
ply = 1;
|
||||
maximumPly = ply;
|
||||
numPlayouts = 0;
|
||||
}
|
||||
};
|
||||
|
||||
ValueAndPV search_mcts(
|
||||
Position& pos,
|
||||
std::uint64_t numPlayouts,
|
||||
Depth leafDepth,
|
||||
float explorationFactor) {
|
||||
|
||||
MonteCarloTreeSearch mcts{};
|
||||
return mcts.search_new(pos, numPlayouts, leafDepth, explorationFactor);
|
||||
}
|
||||
|
||||
std::vector<MctsContinuation> search_mcts_multipv(
|
||||
Position& pos,
|
||||
std::uint64_t numPlayouts,
|
||||
Depth leafDepth,
|
||||
float explorationFactor) {
|
||||
|
||||
MonteCarloTreeSearch mcts{};
|
||||
mcts.search_new(pos, numPlayouts, leafDepth, explorationFactor);
|
||||
return mcts.get_all_continuations();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace Stockfish
|
||||
|
||||
Reference in New Issue
Block a user