mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 12:07:43 +00:00
@@ -4,6 +4,14 @@
|
|||||||
|
|
||||||
Simplest usage: `stockfish.exe gather_statistics all input_file a.binpack`
|
Simplest usage: `stockfish.exe gather_statistics all input_file a.binpack`
|
||||||
|
|
||||||
|
Any name that doesn't designate an argument name or is not an argument will be interpreted as a group name.
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
|
||||||
|
`input_file` - the path to the .bin or .binpack input file to read
|
||||||
|
|
||||||
|
`max_count` - the maximum number of positions to process. Default: no limit.
|
||||||
|
|
||||||
## Groups
|
## Groups
|
||||||
|
|
||||||
`all`
|
`all`
|
||||||
@@ -13,3 +21,25 @@ Simplest usage: `stockfish.exe gather_statistics all input_file a.binpack`
|
|||||||
`position_count`
|
`position_count`
|
||||||
|
|
||||||
- `struct PositionCounter` - the total number of positions in the file.
|
- `struct PositionCounter` - the total number of positions in the file.
|
||||||
|
|
||||||
|
|
||||||
|
reg.add<KingSquareCounter>("king", "king_square_count");
|
||||||
|
|
||||||
|
reg.add<MoveFromCounter>("move", "move_from_count");
|
||||||
|
reg.add<MoveToCounter>("move", "move_to_count");
|
||||||
|
reg.add<MoveTypeCounter>("move", "move_type");
|
||||||
|
reg.add<MovedPieceTypeCounter>("move", "moved_piece_type");
|
||||||
|
|
||||||
|
reg.add<PieceCountCounter>("piece_count");
|
||||||
|
|
||||||
|
`king`, `king_square_count` - the number of times a king was on each square. Output is layed out as a chessboard, with the 8th rank being the topmost. Separate values for white and black kings.
|
||||||
|
|
||||||
|
`move`, `move_from_count` - same as `king_square_count` but for from_sq(move)
|
||||||
|
|
||||||
|
`move`, `move_to_count` - same as `king_square_count` but for to_sq(move)
|
||||||
|
|
||||||
|
`move`, `move_type` - the number of moves with each type. Includes normal, captures, castling, promotions, enpassant. The groups are not disjoint.
|
||||||
|
|
||||||
|
`move`, `moved_piece_type` - the number of times a piece of each type was moved
|
||||||
|
|
||||||
|
`piece_count` - the histogram of the number of pieces on the board
|
||||||
|
|||||||
+489
-53
@@ -11,12 +11,16 @@
|
|||||||
|
|
||||||
#include "nnue/evaluate_nnue.h"
|
#include "nnue/evaluate_nnue.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <set>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <sstream>
|
||||||
|
#include <iomanip>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
@@ -26,13 +30,190 @@ namespace Learner::Stats
|
|||||||
struct StatisticGathererBase
|
struct StatisticGathererBase
|
||||||
{
|
{
|
||||||
virtual void on_position(const Position&) {}
|
virtual void on_position(const Position&) {}
|
||||||
virtual void on_move(const Move&) {}
|
virtual void on_move(const Position&, const Move&) {}
|
||||||
virtual void reset() = 0;
|
virtual void reset() = 0;
|
||||||
[[nodiscard]] virtual std::map<std::string, std::string> get_formatted_stats() const = 0;
|
[[nodiscard]] virtual const std::string& get_name() const = 0;
|
||||||
|
[[nodiscard]] virtual std::vector<std::pair<std::string, std::string>> get_formatted_stats() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct StatisticGathererFactoryBase
|
||||||
|
{
|
||||||
|
[[nodiscard]] virtual std::unique_ptr<StatisticGathererBase> create() const = 0;
|
||||||
|
[[nodiscard]] virtual const std::string& get_name() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct StatisticGathererFactory : StatisticGathererFactoryBase
|
||||||
|
{
|
||||||
|
static inline std::string name = T::name;
|
||||||
|
|
||||||
|
[[nodiscard]] std::unique_ptr<StatisticGathererBase> create() const override
|
||||||
|
{
|
||||||
|
return std::make_unique<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] const std::string& get_name() const override
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct StatisticGathererSet : StatisticGathererBase
|
||||||
|
{
|
||||||
|
void add(const StatisticGathererFactoryBase& factory)
|
||||||
|
{
|
||||||
|
const std::string name = factory.get_name();
|
||||||
|
if (m_gatherers_names.count(name) == 0)
|
||||||
|
{
|
||||||
|
m_gatherers_names.insert(name);
|
||||||
|
m_gatherers.emplace_back(factory.create());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(std::unique_ptr<StatisticGathererBase>&& gatherer)
|
||||||
|
{
|
||||||
|
const std::string name = gatherer->get_name();
|
||||||
|
if (m_gatherers_names.count(name) == 0)
|
||||||
|
{
|
||||||
|
m_gatherers_names.insert(name);
|
||||||
|
m_gatherers.emplace_back(std::move(gatherer));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_position(const Position& position) override
|
||||||
|
{
|
||||||
|
for (auto& g : m_gatherers)
|
||||||
|
{
|
||||||
|
g->on_position(position);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_move(const Position& pos, const Move& move) override
|
||||||
|
{
|
||||||
|
for (auto& g : m_gatherers)
|
||||||
|
{
|
||||||
|
g->on_move(pos, move);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() override
|
||||||
|
{
|
||||||
|
for (auto& g : m_gatherers)
|
||||||
|
{
|
||||||
|
g->reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] virtual const std::string& get_name() const override
|
||||||
|
{
|
||||||
|
static std::string name = "SET";
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] virtual std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||||
|
{
|
||||||
|
std::vector<std::pair<std::string, std::string>> parts;
|
||||||
|
for (auto&& s : m_gatherers)
|
||||||
|
{
|
||||||
|
auto part = s->get_formatted_stats();
|
||||||
|
parts.insert(parts.end(), part.begin(), part.end());
|
||||||
|
}
|
||||||
|
return parts;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<std::unique_ptr<StatisticGathererBase>> m_gatherers;
|
||||||
|
std::set<std::string> m_gatherers_names;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct StatisticGathererRegistry
|
||||||
|
{
|
||||||
|
void add_statistic_gatherers_by_group(
|
||||||
|
StatisticGathererSet& gatherers,
|
||||||
|
const std::string& group) const
|
||||||
|
{
|
||||||
|
auto it = m_gatherers_by_group.find(group);
|
||||||
|
if (it != m_gatherers_by_group.end())
|
||||||
|
{
|
||||||
|
for (auto& factory : it->second)
|
||||||
|
{
|
||||||
|
gatherers.add(*factory);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename... ArgsTs>
|
||||||
|
void add(const ArgsTs&... group)
|
||||||
|
{
|
||||||
|
auto dummy = {(add_single<T>(group), 0)...};
|
||||||
|
(void)dummy;
|
||||||
|
add_single<T>("all");
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::map<std::string, std::vector<std::unique_ptr<StatisticGathererFactoryBase>>> m_gatherers_by_group;
|
||||||
|
std::map<std::string, std::set<std::string>> m_gatherers_names_by_group;
|
||||||
|
|
||||||
|
template <typename T, typename ArgT>
|
||||||
|
void add_single(const ArgT& group)
|
||||||
|
{
|
||||||
|
using FactoryT = StatisticGathererFactory<T>;
|
||||||
|
|
||||||
|
if (m_gatherers_names_by_group[group].count(FactoryT::name) == 0)
|
||||||
|
{
|
||||||
|
m_gatherers_by_group[group].emplace_back(std::make_unique<FactoryT>());
|
||||||
|
m_gatherers_names_by_group[group].insert(FactoryT::name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
Statistic gatherer helpers
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct StatPerSquare
|
||||||
|
{
|
||||||
|
StatPerSquare()
|
||||||
|
{
|
||||||
|
for (int i = 0; i < SQUARE_NB; ++i)
|
||||||
|
m_squares[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] T& operator[](Square sq)
|
||||||
|
{
|
||||||
|
return m_squares[sq];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] const T& operator[](Square sq) const
|
||||||
|
{
|
||||||
|
return m_squares[sq];
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::string get_formatted_stats() const
|
||||||
|
{
|
||||||
|
std::stringstream ss;
|
||||||
|
for (int i = 0; i < SQUARE_NB; ++i)
|
||||||
|
{
|
||||||
|
ss << std::setw(8) << m_squares[i ^ (int)SQ_A8] << ' ';
|
||||||
|
if ((i + 1) % 8 == 0)
|
||||||
|
ss << '\n';
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::array<T, SQUARE_NB> m_squares;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
Definitions for specific statistic gatherers follow:
|
||||||
|
*/
|
||||||
|
|
||||||
struct PositionCounter : StatisticGathererBase
|
struct PositionCounter : StatisticGathererBase
|
||||||
{
|
{
|
||||||
|
static inline std::string name = "PositionCounter";
|
||||||
|
|
||||||
PositionCounter() :
|
PositionCounter() :
|
||||||
m_num_positions(0)
|
m_num_positions(0)
|
||||||
{
|
{
|
||||||
@@ -48,7 +229,12 @@ namespace Learner::Stats
|
|||||||
m_num_positions = 0;
|
m_num_positions = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] std::map<std::string, std::string> get_formatted_stats() const override
|
[[nodiscard]] const std::string& get_name() const override
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||||
{
|
{
|
||||||
return {
|
return {
|
||||||
{ "Number of positions", std::to_string(m_num_positions) }
|
{ "Number of positions", std::to_string(m_num_positions) }
|
||||||
@@ -59,49 +245,296 @@ namespace Learner::Stats
|
|||||||
std::uint64_t m_num_positions;
|
std::uint64_t m_num_positions;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct StatisticGathererFactoryBase
|
struct KingSquareCounter : StatisticGathererBase
|
||||||
{
|
{
|
||||||
[[nodiscard]] virtual std::unique_ptr<StatisticGathererBase> create() const = 0;
|
static inline std::string name = "KingSquareCounter";
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
KingSquareCounter() :
|
||||||
struct StatisticGathererFactory : StatisticGathererFactoryBase
|
m_white{},
|
||||||
{
|
m_black{}
|
||||||
[[nodiscard]] std::unique_ptr<StatisticGathererBase> create() const override
|
|
||||||
{
|
{
|
||||||
return std::make_unique<T>();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct StatisticGathererRegistry
|
|
||||||
{
|
|
||||||
void add_statistic_gatherers_by_group(
|
|
||||||
std::vector<std::unique_ptr<StatisticGathererBase>>& gatherers,
|
|
||||||
const std::string& group) const
|
|
||||||
{
|
|
||||||
auto it = m_gatherers_by_group.find(group);
|
|
||||||
if (it != m_gatherers_by_group.end())
|
|
||||||
{
|
|
||||||
for (auto& factory : it->second)
|
|
||||||
{
|
|
||||||
gatherers.emplace_back(factory->create());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
void on_position(const Position& pos) override
|
||||||
void add(const std::string& group)
|
|
||||||
{
|
{
|
||||||
m_gatherers_by_group[group].emplace_back(std::make_unique<StatisticGathererFactory<T>>());
|
m_white[pos.square<KING>(WHITE)] += 1;
|
||||||
|
m_black[pos.square<KING>(BLACK)] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Always add to the special group "all".
|
void reset() override
|
||||||
m_gatherers_by_group["all"].emplace_back(std::make_unique<StatisticGathererFactory<T>>());
|
{
|
||||||
|
m_white = StatPerSquare<std::uint64_t>{};
|
||||||
|
m_black = StatPerSquare<std::uint64_t>{};
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] const std::string& get_name() const override
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||||
|
{
|
||||||
|
return {
|
||||||
|
{ "White king squares", '\n' + m_white.get_formatted_stats() },
|
||||||
|
{ "Black king squares", '\n' + m_black.get_formatted_stats() }
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<std::string, std::vector<std::unique_ptr<StatisticGathererFactoryBase>>> m_gatherers_by_group;
|
StatPerSquare<std::uint64_t> m_white;
|
||||||
|
StatPerSquare<std::uint64_t> m_black;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MoveFromCounter : StatisticGathererBase
|
||||||
|
{
|
||||||
|
static inline std::string name = "MoveFromCounter";
|
||||||
|
|
||||||
|
MoveFromCounter() :
|
||||||
|
m_white{},
|
||||||
|
m_black{}
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_move(const Position& pos, const Move& move) override
|
||||||
|
{
|
||||||
|
if (pos.side_to_move() == WHITE)
|
||||||
|
m_white[from_sq(move)] += 1;
|
||||||
|
else
|
||||||
|
m_black[from_sq(move)] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() override
|
||||||
|
{
|
||||||
|
m_white = StatPerSquare<std::uint64_t>{};
|
||||||
|
m_black = StatPerSquare<std::uint64_t>{};
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] const std::string& get_name() const override
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||||
|
{
|
||||||
|
return {
|
||||||
|
{ "White move from squares", '\n' + m_white.get_formatted_stats() },
|
||||||
|
{ "Black move from squares", '\n' + m_black.get_formatted_stats() }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
StatPerSquare<std::uint64_t> m_white;
|
||||||
|
StatPerSquare<std::uint64_t> m_black;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MoveToCounter : StatisticGathererBase
|
||||||
|
{
|
||||||
|
static inline std::string name = "MoveToCounter";
|
||||||
|
|
||||||
|
MoveToCounter() :
|
||||||
|
m_white{},
|
||||||
|
m_black{}
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_move(const Position& pos, const Move& move) override
|
||||||
|
{
|
||||||
|
if (pos.side_to_move() == WHITE)
|
||||||
|
m_white[to_sq(move)] += 1;
|
||||||
|
else
|
||||||
|
m_black[to_sq(move)] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() override
|
||||||
|
{
|
||||||
|
m_white = StatPerSquare<std::uint64_t>{};
|
||||||
|
m_black = StatPerSquare<std::uint64_t>{};
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] const std::string& get_name() const override
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||||
|
{
|
||||||
|
return {
|
||||||
|
{ "White move to squares", '\n' + m_white.get_formatted_stats() },
|
||||||
|
{ "Black move to squares", '\n' + m_black.get_formatted_stats() }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
StatPerSquare<std::uint64_t> m_white;
|
||||||
|
StatPerSquare<std::uint64_t> m_black;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MoveTypeCounter : StatisticGathererBase
|
||||||
|
{
|
||||||
|
static inline std::string name = "MoveTypeCounter";
|
||||||
|
|
||||||
|
MoveTypeCounter() :
|
||||||
|
m_total(0),
|
||||||
|
m_normal(0),
|
||||||
|
m_capture(0),
|
||||||
|
m_promotion(0),
|
||||||
|
m_castling(0),
|
||||||
|
m_enpassant(0)
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_move(const Position& pos, const Move& move) override
|
||||||
|
{
|
||||||
|
m_total += 1;
|
||||||
|
|
||||||
|
if (!pos.empty(to_sq(move)))
|
||||||
|
m_capture += 1;
|
||||||
|
|
||||||
|
if (type_of(move) == CASTLING)
|
||||||
|
m_castling += 1;
|
||||||
|
else if (type_of(move) == PROMOTION)
|
||||||
|
m_promotion += 1;
|
||||||
|
else if (type_of(move) == ENPASSANT)
|
||||||
|
m_enpassant += 1;
|
||||||
|
else if (type_of(move) == NORMAL)
|
||||||
|
m_normal += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() override
|
||||||
|
{
|
||||||
|
m_total = 0;
|
||||||
|
m_normal = 0;
|
||||||
|
m_capture = 0;
|
||||||
|
m_promotion = 0;
|
||||||
|
m_castling = 0;
|
||||||
|
m_enpassant = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] const std::string& get_name() const override
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||||
|
{
|
||||||
|
return {
|
||||||
|
{ "Total moves", std::to_string(m_total) },
|
||||||
|
{ "Normal moves", std::to_string(m_normal) },
|
||||||
|
{ "Capture moves", std::to_string(m_capture) },
|
||||||
|
{ "Promotion moves", std::to_string(m_promotion) },
|
||||||
|
{ "Castling moves", std::to_string(m_castling) },
|
||||||
|
{ "En-passant moves", std::to_string(m_enpassant) }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::uint64_t m_total;
|
||||||
|
std::uint64_t m_normal;
|
||||||
|
std::uint64_t m_capture;
|
||||||
|
std::uint64_t m_promotion;
|
||||||
|
std::uint64_t m_castling;
|
||||||
|
std::uint64_t m_enpassant;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PieceCountCounter : StatisticGathererBase
|
||||||
|
{
|
||||||
|
static inline std::string name = "PieceCountCounter";
|
||||||
|
|
||||||
|
PieceCountCounter()
|
||||||
|
{
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_position(const Position& pos) override
|
||||||
|
{
|
||||||
|
m_piece_count_hist[popcount(pos.pieces())] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() override
|
||||||
|
{
|
||||||
|
for (int i = 0; i < SQUARE_NB; ++i)
|
||||||
|
m_piece_count_hist[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] const std::string& get_name() const override
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||||
|
{
|
||||||
|
std::vector<std::pair<std::string, std::string>> result;
|
||||||
|
bool do_write = false;
|
||||||
|
for (int i = SQUARE_NB - 1; i >= 0; --i)
|
||||||
|
{
|
||||||
|
if (m_piece_count_hist[i] != 0)
|
||||||
|
do_write = true;
|
||||||
|
|
||||||
|
// Start writing when the first non-zero number pops up.
|
||||||
|
if (do_write)
|
||||||
|
{
|
||||||
|
result.emplace_back(
|
||||||
|
std::string("Number of positions with ") + std::to_string(i) + " pieces",
|
||||||
|
std::to_string(m_piece_count_hist[i])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::uint64_t m_piece_count_hist[SQUARE_NB];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MovedPieceTypeCounter : StatisticGathererBase
|
||||||
|
{
|
||||||
|
static inline std::string name = "MovedPieceTypeCounter";
|
||||||
|
|
||||||
|
MovedPieceTypeCounter()
|
||||||
|
{
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_move(const Position& pos, const Move& move) override
|
||||||
|
{
|
||||||
|
m_moved_piece_type_hist[type_of(pos.piece_on(from_sq(move)))] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() override
|
||||||
|
{
|
||||||
|
for (int i = 0; i < PIECE_TYPE_NB; ++i)
|
||||||
|
m_moved_piece_type_hist[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] const std::string& get_name() const override
|
||||||
|
{
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::vector<std::pair<std::string, std::string>> get_formatted_stats() const override
|
||||||
|
{
|
||||||
|
return {
|
||||||
|
{ "Pawn moves", std::to_string(m_moved_piece_type_hist[PAWN]) },
|
||||||
|
{ "Knight moves", std::to_string(m_moved_piece_type_hist[KNIGHT]) },
|
||||||
|
{ "Bishop moves", std::to_string(m_moved_piece_type_hist[BISHOP]) },
|
||||||
|
{ "Rook moves", std::to_string(m_moved_piece_type_hist[ROOK]) },
|
||||||
|
{ "Queen moves", std::to_string(m_moved_piece_type_hist[QUEEN]) },
|
||||||
|
{ "King moves", std::to_string(m_moved_piece_type_hist[KING]) }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::uint64_t m_moved_piece_type_hist[PIECE_TYPE_NB];
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
This function provides factories for all possible statistic gatherers.
|
||||||
|
Each new statistic gatherer needs to be added there.
|
||||||
|
*/
|
||||||
const auto& get_statistics_gatherers_registry()
|
const auto& get_statistics_gatherers_registry()
|
||||||
{
|
{
|
||||||
static StatisticGathererRegistry s_reg = [](){
|
static StatisticGathererRegistry s_reg = [](){
|
||||||
@@ -109,6 +542,15 @@ namespace Learner::Stats
|
|||||||
|
|
||||||
reg.add<PositionCounter>("position_count");
|
reg.add<PositionCounter>("position_count");
|
||||||
|
|
||||||
|
reg.add<KingSquareCounter>("king", "king_square_count");
|
||||||
|
|
||||||
|
reg.add<MoveFromCounter>("move", "move_from_count");
|
||||||
|
reg.add<MoveToCounter>("move", "move_to_count");
|
||||||
|
reg.add<MoveTypeCounter>("move", "move_type");
|
||||||
|
reg.add<MovedPieceTypeCounter>("move", "moved_piece_type");
|
||||||
|
|
||||||
|
reg.add<PieceCountCounter>("piece_count");
|
||||||
|
|
||||||
return reg;
|
return reg;
|
||||||
}();
|
}();
|
||||||
|
|
||||||
@@ -117,7 +559,8 @@ namespace Learner::Stats
|
|||||||
|
|
||||||
void do_gather_statistics(
|
void do_gather_statistics(
|
||||||
const std::string& filename,
|
const std::string& filename,
|
||||||
std::vector<std::unique_ptr<StatisticGathererBase>>& statistic_gatherers)
|
StatisticGathererSet& statistic_gatherers,
|
||||||
|
std::uint64_t max_count)
|
||||||
{
|
{
|
||||||
Thread* th = Threads.main();
|
Thread* th = Threads.main();
|
||||||
Position& pos = th->rootPos;
|
Position& pos = th->rootPos;
|
||||||
@@ -125,18 +568,12 @@ namespace Learner::Stats
|
|||||||
|
|
||||||
auto in = Learner::open_sfen_input_file(filename);
|
auto in = Learner::open_sfen_input_file(filename);
|
||||||
|
|
||||||
auto on_move = [&](Move move) {
|
auto on_move = [&](const Position& position, const Move& move) {
|
||||||
for (auto&& s : statistic_gatherers)
|
statistic_gatherers.on_move(position, move);
|
||||||
{
|
|
||||||
s->on_move(move);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
auto on_position = [&](const Position& position) {
|
auto on_position = [&](const Position& position) {
|
||||||
for (auto&& s : statistic_gatherers)
|
statistic_gatherers.on_position(position);
|
||||||
{
|
|
||||||
s->on_position(position);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (in == nullptr)
|
if (in == nullptr)
|
||||||
@@ -146,7 +583,7 @@ namespace Learner::Stats
|
|||||||
}
|
}
|
||||||
|
|
||||||
uint64_t num_processed = 0;
|
uint64_t num_processed = 0;
|
||||||
for (;;)
|
while (num_processed < max_count)
|
||||||
{
|
{
|
||||||
auto v = in->next();
|
auto v = in->next();
|
||||||
if (!v.has_value())
|
if (!v.has_value())
|
||||||
@@ -157,7 +594,7 @@ namespace Learner::Stats
|
|||||||
pos.set_from_packed_sfen(ps.sfen, &si, th);
|
pos.set_from_packed_sfen(ps.sfen, &si, th);
|
||||||
|
|
||||||
on_position(pos);
|
on_position(pos);
|
||||||
on_move((Move)ps.move);
|
on_move(pos, (Move)ps.move);
|
||||||
|
|
||||||
num_processed += 1;
|
num_processed += 1;
|
||||||
if (num_processed % 1'000'000 == 0)
|
if (num_processed % 1'000'000 == 0)
|
||||||
@@ -169,13 +606,9 @@ namespace Learner::Stats
|
|||||||
std::cout << "Finished gathering statistics.\n\n";
|
std::cout << "Finished gathering statistics.\n\n";
|
||||||
std::cout << "Results:\n\n";
|
std::cout << "Results:\n\n";
|
||||||
|
|
||||||
for (auto&& s : statistic_gatherers)
|
for (auto&& [name, value] : statistic_gatherers.get_formatted_stats())
|
||||||
{
|
{
|
||||||
for (auto&& [name, value] : s->get_formatted_stats())
|
std::cout << name << ": " << value << '\n';
|
||||||
{
|
|
||||||
std::cout << name << ": " << value << '\n';
|
|
||||||
}
|
|
||||||
std::cout << '\n';
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,9 +618,10 @@ namespace Learner::Stats
|
|||||||
|
|
||||||
auto& registry = get_statistics_gatherers_registry();
|
auto& registry = get_statistics_gatherers_registry();
|
||||||
|
|
||||||
std::vector<std::unique_ptr<StatisticGathererBase>> statistic_gatherers;
|
StatisticGathererSet statistic_gatherers;
|
||||||
|
|
||||||
std::string input_file;
|
std::string input_file;
|
||||||
|
std::uint64_t max_count = std::numeric_limits<std::uint64_t>::max();
|
||||||
|
|
||||||
while(true)
|
while(true)
|
||||||
{
|
{
|
||||||
@@ -199,11 +633,13 @@ namespace Learner::Stats
|
|||||||
|
|
||||||
if (token == "input_file")
|
if (token == "input_file")
|
||||||
is >> input_file;
|
is >> input_file;
|
||||||
|
else if (token == "max_count")
|
||||||
|
is >> max_count;
|
||||||
else
|
else
|
||||||
registry.add_statistic_gatherers_by_group(statistic_gatherers, token);
|
registry.add_statistic_gatherers_by_group(statistic_gatherers, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
do_gather_statistics(input_file, statistic_gatherers);
|
do_gather_statistics(input_file, statistic_gatherers, max_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user