Add stats: ply_discontinuities, material_imbalance, results

This commit is contained in:
Tomasz Sobczyk
2021-05-19 12:55:14 +02:00
parent 640ec5706e
commit a4b598060c
+197 -29
View File
@@ -57,6 +57,18 @@ namespace Stockfish::Tools::Stats
return digits; return digits;
} }
[[nodiscard]] std::string left_pad_to_length(const std::string& str, char ch, int length)
{
if (str.size() < length)
{
return std::string(length - static_cast<int>(str.size()), ch) + str;
}
else
{
return str;
}
}
[[nodiscard]] std::string indent_text(const std::string& text, Indentation indent) [[nodiscard]] std::string indent_text(const std::string& text, Indentation indent)
{ {
std::string delimiter = "\n"; std::string delimiter = "\n";
@@ -258,8 +270,7 @@ namespace Stockfish::Tools::Stats
struct StatisticGathererBase struct StatisticGathererBase
{ {
virtual void on_position(const Position&) {} virtual void on_entry(const Position&, const Move&, const PackedSfenValue&) {}
virtual void on_move(const Position&, const Move&) {}
virtual void reset() = 0; virtual void reset() = 0;
[[nodiscard]] virtual const std::string& get_name() const = 0; [[nodiscard]] virtual const std::string& get_name() const = 0;
[[nodiscard]] virtual StatisticOutput get_output() const = 0; [[nodiscard]] virtual StatisticOutput get_output() const = 0;
@@ -309,19 +320,11 @@ namespace Stockfish::Tools::Stats
} }
} }
void on_position(const Position& position) override void on_entry(const Position& pos, const Move& move, const PackedSfenValue& psv) override
{ {
for (auto& g : m_gatherers) for (auto& g : m_gatherers)
{ {
g->on_position(position); g->on_entry(pos, move, psv);
}
}
void on_move(const Position& pos, const Move& move) override
{
for (auto& g : m_gatherers)
{
g->on_move(pos, move);
} }
} }
@@ -458,7 +461,7 @@ namespace Stockfish::Tools::Stats
{ {
} }
void on_position(const Position&) override void on_entry(const Position&, const Move&, const PackedSfenValue&) override
{ {
m_num_positions += 1; m_num_positions += 1;
} }
@@ -495,7 +498,7 @@ namespace Stockfish::Tools::Stats
} }
void on_position(const Position& pos) override void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override
{ {
m_white[pos.square<KING>(WHITE)] += 1; m_white[pos.square<KING>(WHITE)] += 1;
m_black[pos.square<KING>(BLACK)] += 1; m_black[pos.square<KING>(BLACK)] += 1;
@@ -537,7 +540,7 @@ namespace Stockfish::Tools::Stats
} }
void on_move(const Position& pos, const Move& move) override void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override
{ {
if (pos.side_to_move() == WHITE) if (pos.side_to_move() == WHITE)
m_white[from_sq(move)] += 1; m_white[from_sq(move)] += 1;
@@ -581,7 +584,7 @@ namespace Stockfish::Tools::Stats
} }
void on_move(const Position& pos, const Move& move) override void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override
{ {
if (pos.side_to_move() == WHITE) if (pos.side_to_move() == WHITE)
m_white[to_sq(move)] += 1; m_white[to_sq(move)] += 1;
@@ -629,7 +632,7 @@ namespace Stockfish::Tools::Stats
} }
void on_move(const Position& pos, const Move& move) override void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override
{ {
m_total += 1; m_total += 1;
@@ -692,7 +695,7 @@ namespace Stockfish::Tools::Stats
reset(); reset();
} }
void on_position(const Position& pos) override void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override
{ {
m_piece_count_hist[popcount(pos.pieces())] += 1; m_piece_count_hist[popcount(pos.pieces())] += 1;
} }
@@ -740,7 +743,7 @@ namespace Stockfish::Tools::Stats
reset(); reset();
} }
void on_move(const Position& pos, const Move& move) override void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override
{ {
m_moved_piece_type_hist[type_of(pos.piece_on(from_sq(move)))] += 1; m_moved_piece_type_hist[type_of(pos.piece_on(from_sq(move)))] += 1;
} }
@@ -773,6 +776,170 @@ namespace Stockfish::Tools::Stats
std::uint64_t m_moved_piece_type_hist[PIECE_TYPE_NB]; std::uint64_t m_moved_piece_type_hist[PIECE_TYPE_NB];
}; };
struct PlyDiscontinuitiesCounter : StatisticGathererBase
{
static inline std::string name = "PlyDiscontinuitiesCounter";
PlyDiscontinuitiesCounter()
{
reset();
}
void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override
{
const int current_ply = pos.game_ply();
if (m_prev_ply != -1)
{
const bool is_discontinuity = (current_ply != (m_prev_ply + 1));
if (is_discontinuity)
{
m_num_discontinuities += 1;
}
}
m_prev_ply = current_ply;
}
void reset() override
{
m_num_discontinuities = 0;
m_prev_ply = -1;
}
[[nodiscard]] const std::string& get_name() const override
{
return name;
}
[[nodiscard]] StatisticOutput get_output() const override
{
StatisticOutput out;
out.emplace_node<StatisticOutputEntryValue<std::uint64_t>>("Number of ply discontinuities (usually games)", m_num_discontinuities);
return out;
}
private:
std::uint64_t m_num_discontinuities;
int m_prev_ply;
};
struct MaterialImbalanceDistribution : StatisticGathererBase
{
static inline std::string name = "MaterialImbalanceDistribution";
static constexpr int max_imbalance = 64;
MaterialImbalanceDistribution()
{
reset();
}
void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override
{
const int imbalance = get_simple_material(pos, WHITE) - get_simple_material(pos, BLACK);
const int imbalance_idx = std::clamp(imbalance, -max_imbalance, max_imbalance) + max_imbalance;
m_num_imbalances[imbalance_idx] += 1;
}
void reset() override
{
for (auto& imb : m_num_imbalances)
imb = 0;
}
[[nodiscard]] const std::string& get_name() const override
{
return name;
}
[[nodiscard]] StatisticOutput get_output() const override
{
StatisticOutput out;
auto& header = out.emplace_node<StatisticOutputEntryHeader>("Number of \"simple eval\" imbalances for white's perspective:");
const int key_length = get_num_base_10_digits(max_imbalance) + 1;
for (int i = -max_imbalance; i <= max_imbalance; ++i)
{
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>(
left_pad_to_length(std::to_string(i), ' ', key_length),
m_num_imbalances[i + max_imbalance]
);
}
return out;
}
private:
std::uint64_t m_num_imbalances[max_imbalance + 1 + max_imbalance];
[[nodiscard]] int get_simple_material(const Position& pos, Color c)
{
return
9 * pos.count<QUEEN>(c)
+ 5 * pos.count<ROOK>(c)
+ 3 * pos.count<BISHOP>(c)
+ 3 * pos.count<KNIGHT>(c)
+ pos.count<PAWN>(c);
}
};
struct ResultDistribution : StatisticGathererBase
{
static inline std::string name = "ResultDistribution";
ResultDistribution()
{
reset();
}
void on_entry(const Position& pos, const Move&, const PackedSfenValue& psv) override
{
const Color stm = pos.side_to_move();
if (psv.game_result == 0)
{
m_draws += 1;
}
else if (psv.game_result == 1)
{
m_stm_wins += 1;
m_wins[stm] += 1;
}
else
{
m_stm_loses += 1;
m_wins[~stm] += 1;
}
}
void reset() override
{
m_wins[WHITE] = 0;
m_wins[BLACK] = 0;
m_draws = 0;
m_stm_wins = 0;
m_stm_loses = 0;
}
[[nodiscard]] const std::string& get_name() const override
{
return name;
}
[[nodiscard]] StatisticOutput get_output() const override
{
StatisticOutput out;
auto& header = out.emplace_node<StatisticOutputEntryHeader>("Distribution of results:");
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>("White wins", m_wins[WHITE]);
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>("Black wins", m_wins[BLACK]);
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>("Draws", m_draws);
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>("Side to move wins", m_stm_wins);
header.emplace_child<StatisticOutputEntryValue<std::uint64_t>>("Side to move loses", m_stm_loses);
return out;
}
private:
std::uint64_t m_wins[COLOR_NB];
std::uint64_t m_draws;
std::uint64_t m_stm_wins;
std::uint64_t m_stm_loses;
};
/* /*
This function provides factories for all possible statistic gatherers. This function provides factories for all possible statistic gatherers.
Each new statistic gatherer needs to be added there. Each new statistic gatherer needs to be added there.
@@ -791,6 +958,12 @@ namespace Stockfish::Tools::Stats
reg.add<MoveTypeCounter>("move", "move_type"); reg.add<MoveTypeCounter>("move", "move_type");
reg.add<MovedPieceTypeCounter>("move", "moved_piece_type"); reg.add<MovedPieceTypeCounter>("move", "moved_piece_type");
reg.add<PlyDiscontinuitiesCounter>("ply_discontinuities");
reg.add<MaterialImbalanceDistribution>("material_imbalance");
reg.add<ResultDistribution>("results");
reg.add<PieceCountCounter>("piece_count"); reg.add<PieceCountCounter>("piece_count");
return reg; return reg;
@@ -810,12 +983,8 @@ namespace Stockfish::Tools::Stats
auto in = Tools::open_sfen_input_file(filename); auto in = Tools::open_sfen_input_file(filename);
auto on_move = [&](const Position& position, const Move& move) { auto on_entry = [&](const Position& position, const Move& move, const PackedSfenValue& psv) {
statistic_gatherers.on_move(position, move); statistic_gatherers.on_entry(position, move, psv);
};
auto on_position = [&](const Position& position) {
statistic_gatherers.on_position(position);
}; };
if (in == nullptr) if (in == nullptr)
@@ -831,12 +1000,11 @@ namespace Stockfish::Tools::Stats
if (!v.has_value()) if (!v.has_value())
break; break;
auto& ps = v.value(); auto& psv = v.value();
pos.set_from_packed_sfen(ps.sfen, &si, th); pos.set_from_packed_sfen(psv.sfen, &si, th);
on_position(pos); on_entry(pos, (Move)psv.move, psv);
on_move(pos, (Move)ps.move);
num_processed += 1; num_processed += 1;
if (num_processed % 1'000'000 == 0) if (num_processed % 1'000'000 == 0)