mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 12:07:43 +00:00
Add gather_statistics command that allows gathering statistics from a .bin or .binpack file. Initially only support position count.
This commit is contained in:
@@ -0,0 +1,15 @@
|
|||||||
|
# Stats
|
||||||
|
|
||||||
|
`gather_statistics` command allows gathering various statistics from a .bin or a .binpack file. The syntax is `gather_statistics (GROUP)* input_file FILENAME`. There can be many groups specified. Any statistic gatherer that belongs to at least one of the specified groups will be used.
|
||||||
|
|
||||||
|
Simplest usage: `stockfish.exe gather_statistics all input_file a.binpack`
|
||||||
|
|
||||||
|
## Groups
|
||||||
|
|
||||||
|
`all`
|
||||||
|
|
||||||
|
- A special group designating all statistics gatherers available.
|
||||||
|
|
||||||
|
`position_count`
|
||||||
|
|
||||||
|
- `struct PositionCounter` - the total number of positions in the file.
|
||||||
+2
-1
@@ -66,7 +66,8 @@ SRCS = benchmark.cpp bitbase.cpp bitboard.cpp endgame.cpp evaluate.cpp main.cpp
|
|||||||
learn/gensfen_nonpv.cpp \
|
learn/gensfen_nonpv.cpp \
|
||||||
learn/opening_book.cpp \
|
learn/opening_book.cpp \
|
||||||
learn/convert.cpp \
|
learn/convert.cpp \
|
||||||
learn/transform.cpp
|
learn/transform.cpp \
|
||||||
|
learn/stats.cpp
|
||||||
|
|
||||||
OBJS = $(notdir $(SRCS:.cpp=.o))
|
OBJS = $(notdir $(SRCS:.cpp=.o))
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,209 @@
|
|||||||
|
#include "stats.h"
|
||||||
|
|
||||||
|
#include "sfen_stream.h"
|
||||||
|
#include "packed_sfen.h"
|
||||||
|
#include "sfen_writer.h"
|
||||||
|
|
||||||
|
#include "thread.h"
|
||||||
|
#include "position.h"
|
||||||
|
#include "evaluate.h"
|
||||||
|
#include "search.h"
|
||||||
|
|
||||||
|
#include "nnue/evaluate_nnue.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <iostream>
|
||||||
|
#include <cmath>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <limits>
|
||||||
|
#include <mutex>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
namespace Learner::Stats
|
||||||
|
{
|
||||||
|
struct StatisticGathererBase
|
||||||
|
{
|
||||||
|
virtual void on_position(const Position&) {}
|
||||||
|
virtual void on_move(const Move&) {}
|
||||||
|
virtual void reset() = 0;
|
||||||
|
[[nodiscard]] virtual std::map<std::string, std::string> get_formatted_stats() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PositionCounter : StatisticGathererBase
|
||||||
|
{
|
||||||
|
PositionCounter() :
|
||||||
|
m_num_positions(0)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_position(const Position&) override
|
||||||
|
{
|
||||||
|
m_num_positions += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() override
|
||||||
|
{
|
||||||
|
m_num_positions = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::map<std::string, std::string> get_formatted_stats() const override
|
||||||
|
{
|
||||||
|
return {
|
||||||
|
{ "Number of positions", std::to_string(m_num_positions) }
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::uint64_t m_num_positions;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct StatisticGathererFactoryBase
|
||||||
|
{
|
||||||
|
[[nodiscard]] virtual std::unique_ptr<StatisticGathererBase> create() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct StatisticGathererFactory : StatisticGathererFactoryBase
|
||||||
|
{
|
||||||
|
[[nodiscard]] std::unique_ptr<StatisticGathererBase> create() const override
|
||||||
|
{
|
||||||
|
return std::make_unique<T>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct StatisticGathererRegistry
|
||||||
|
{
|
||||||
|
void add_statistic_gatherers_by_group(
|
||||||
|
std::vector<std::unique_ptr<StatisticGathererBase>>& gatherers,
|
||||||
|
const std::string& group) const
|
||||||
|
{
|
||||||
|
auto it = m_gatherers_by_group.find(group);
|
||||||
|
if (it != m_gatherers_by_group.end())
|
||||||
|
{
|
||||||
|
for (auto& factory : it->second)
|
||||||
|
{
|
||||||
|
gatherers.emplace_back(factory->create());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void add(const std::string& group)
|
||||||
|
{
|
||||||
|
m_gatherers_by_group[group].emplace_back(std::make_unique<StatisticGathererFactory<T>>());
|
||||||
|
|
||||||
|
// Always add to the special group "all".
|
||||||
|
m_gatherers_by_group["all"].emplace_back(std::make_unique<StatisticGathererFactory<T>>());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::map<std::string, std::vector<std::unique_ptr<StatisticGathererFactoryBase>>> m_gatherers_by_group;
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto& get_statistics_gatherers_registry()
|
||||||
|
{
|
||||||
|
static StatisticGathererRegistry s_reg = [](){
|
||||||
|
StatisticGathererRegistry reg;
|
||||||
|
|
||||||
|
reg.add<PositionCounter>("position_count");
|
||||||
|
|
||||||
|
return reg;
|
||||||
|
}();
|
||||||
|
|
||||||
|
return s_reg;
|
||||||
|
}
|
||||||
|
|
||||||
|
void do_gather_statistics(
|
||||||
|
const std::string& filename,
|
||||||
|
std::vector<std::unique_ptr<StatisticGathererBase>>& statistic_gatherers)
|
||||||
|
{
|
||||||
|
Thread* th = Threads.main();
|
||||||
|
Position& pos = th->rootPos;
|
||||||
|
StateInfo si;
|
||||||
|
|
||||||
|
auto in = Learner::open_sfen_input_file(filename);
|
||||||
|
|
||||||
|
auto on_move = [&](Move move) {
|
||||||
|
for (auto&& s : statistic_gatherers)
|
||||||
|
{
|
||||||
|
s->on_move(move);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto on_position = [&](const Position& position) {
|
||||||
|
for (auto&& s : statistic_gatherers)
|
||||||
|
{
|
||||||
|
s->on_position(position);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (in == nullptr)
|
||||||
|
{
|
||||||
|
std::cerr << "Invalid input file type.\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t num_processed = 0;
|
||||||
|
for (;;)
|
||||||
|
{
|
||||||
|
auto v = in->next();
|
||||||
|
if (!v.has_value())
|
||||||
|
break;
|
||||||
|
|
||||||
|
auto& ps = v.value();
|
||||||
|
|
||||||
|
pos.set_from_packed_sfen(ps.sfen, &si, th);
|
||||||
|
|
||||||
|
on_position(pos);
|
||||||
|
on_move((Move)ps.move);
|
||||||
|
|
||||||
|
num_processed += 1;
|
||||||
|
if (num_processed % 1'000'000 == 0)
|
||||||
|
{
|
||||||
|
std::cout << "Processed " << num_processed << " positions.\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "Finished gathering statistics.\n\n";
|
||||||
|
std::cout << "Results:\n\n";
|
||||||
|
|
||||||
|
for (auto&& s : statistic_gatherers)
|
||||||
|
{
|
||||||
|
for (auto&& [name, value] : s->get_formatted_stats())
|
||||||
|
{
|
||||||
|
std::cout << name << ": " << value << '\n';
|
||||||
|
}
|
||||||
|
std::cout << '\n';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void gather_statistics(std::istringstream& is)
|
||||||
|
{
|
||||||
|
Eval::NNUE::init();
|
||||||
|
|
||||||
|
auto& registry = get_statistics_gatherers_registry();
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<StatisticGathererBase>> statistic_gatherers;
|
||||||
|
|
||||||
|
std::string input_file;
|
||||||
|
|
||||||
|
while(true)
|
||||||
|
{
|
||||||
|
std::string token;
|
||||||
|
is >> token;
|
||||||
|
|
||||||
|
if (token == "")
|
||||||
|
break;
|
||||||
|
|
||||||
|
if (token == "input_file")
|
||||||
|
is >> input_file;
|
||||||
|
else
|
||||||
|
registry.add_statistic_gatherers_by_group(statistic_gatherers, token);
|
||||||
|
}
|
||||||
|
|
||||||
|
do_gather_statistics(input_file, statistic_gatherers);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
#ifndef _STATS_H_
|
||||||
|
#define _STATS_H_
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace Learner::Stats {
|
||||||
|
|
||||||
|
void gather_statistics(std::istringstream& is);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -40,6 +40,7 @@
|
|||||||
#include "learn/learn.h"
|
#include "learn/learn.h"
|
||||||
#include "learn/convert.h"
|
#include "learn/convert.h"
|
||||||
#include "learn/transform.h"
|
#include "learn/transform.h"
|
||||||
|
#include "learn/stats.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
@@ -349,6 +350,7 @@ void UCI::loop(int argc, char* argv[]) {
|
|||||||
else if (token == "convert_plain") Learner::convert_plain(is);
|
else if (token == "convert_plain") Learner::convert_plain(is);
|
||||||
else if (token == "convert_bin_from_pgn_extract") Learner::convert_bin_from_pgn_extract(is);
|
else if (token == "convert_bin_from_pgn_extract") Learner::convert_bin_from_pgn_extract(is);
|
||||||
else if (token == "transform") Learner::transform(is);
|
else if (token == "transform") Learner::transform(is);
|
||||||
|
else if (token == "gather_statistics") Learner::Stats::gather_statistics(is);
|
||||||
|
|
||||||
// Command to call qsearch(),search() directly for testing
|
// Command to call qsearch(),search() directly for testing
|
||||||
else if (token == "qsearch") qsearch_cmd(pos);
|
else if (token == "qsearch") qsearch_cmd(pos);
|
||||||
|
|||||||
Reference in New Issue
Block a user