Deduplicate statistic gatherers. Fix King square counter compilation errors.

This commit is contained in:
Tomasz Sobczyk
2021-04-05 16:36:27 +02:00
parent eda51f19a2
commit b2a5bf4171
+106 -18
View File
@@ -14,6 +14,7 @@
#include <array> #include <array>
#include <string> #include <string>
#include <map> #include <map>
#include <set>
#include <iostream> #include <iostream>
#include <cmath> #include <cmath>
#include <algorithm> #include <algorithm>
@@ -31,12 +32,14 @@ namespace Learner::Stats
virtual void on_position(const Position&) {} virtual void on_position(const Position&) {}
virtual void on_move(const Move&) {} virtual void on_move(const Move&) {}
virtual void reset() = 0; virtual void reset() = 0;
[[nodiscard]] virtual const std::string& get_name() const = 0;
[[nodiscard]] virtual std::map<std::string, std::string> get_formatted_stats() const = 0; [[nodiscard]] virtual std::map<std::string, std::string> get_formatted_stats() const = 0;
}; };
struct StatisticGathererFactoryBase struct StatisticGathererFactoryBase
{ {
[[nodiscard]] virtual std::unique_ptr<StatisticGathererBase> create() const = 0; [[nodiscard]] virtual std::unique_ptr<StatisticGathererBase> create() const = 0;
[[nodiscard]] virtual const std::string& get_name() const = 0;
}; };
template <typename T> template <typename T>
@@ -46,12 +49,84 @@ namespace Learner::Stats
{ {
return std::make_unique<T>(); return std::make_unique<T>();
} }
[[nodiscard]] const std::string& get_name() const override
{
return T::name;
}
};
struct StatisticGathererSet : StatisticGathererBase
{
void add(const StatisticGathererFactoryBase& factory)
{
const std::string name = factory.get_name();
if (m_gatherers_names.count(name) == 0)
{
m_gatherers_names.insert(name);
m_gatherers.emplace_back(factory.create());
}
}
void add(std::unique_ptr<StatisticGathererBase>&& gatherer)
{
const std::string name = gatherer->get_name();
if (m_gatherers_names.count(name) == 0)
{
m_gatherers_names.insert(name);
m_gatherers.emplace_back(std::move(gatherer));
}
}
void on_position(const Position& position) override
{
for (auto& g : m_gatherers)
{
g->on_position(position);
}
}
void on_move(const Move& move) override
{
for (auto& g : m_gatherers)
{
g->on_move(move);
}
}
void reset() override
{
for (auto& g : m_gatherers)
{
g->reset();
}
}
[[nodiscard]] virtual const std::string& get_name() const override
{
static std::string name = "SET";
return name;
}
[[nodiscard]] virtual std::map<std::string, std::string> get_formatted_stats() const override
{
std::map<std::string, std::string> parts;
for (auto&& s : m_gatherers)
{
parts.merge(s->get_formatted_stats());
}
return parts;
}
private:
std::vector<std::unique_ptr<StatisticGathererBase>> m_gatherers;
std::set<std::string> m_gatherers_names;
}; };
struct StatisticGathererRegistry struct StatisticGathererRegistry
{ {
void add_statistic_gatherers_by_group( void add_statistic_gatherers_by_group(
std::vector<std::unique_ptr<StatisticGathererBase>>& gatherers, StatisticGathererSet& gatherers,
const std::string& group) const const std::string& group) const
{ {
auto it = m_gatherers_by_group.find(group); auto it = m_gatherers_by_group.find(group);
@@ -59,7 +134,7 @@ namespace Learner::Stats
{ {
for (auto& factory : it->second) for (auto& factory : it->second)
{ {
gatherers.emplace_back(factory->create()); gatherers.add(*factory);
} }
} }
} }
@@ -109,6 +184,7 @@ namespace Learner::Stats
if ((i + 1) % 8 == 0) if ((i + 1) % 8 == 0)
ss << '\n'; ss << '\n';
} }
return ss.str();
} }
private: private:
@@ -121,6 +197,8 @@ namespace Learner::Stats
struct PositionCounter : StatisticGathererBase struct PositionCounter : StatisticGathererBase
{ {
static inline std::string name = "PositionCounter";
PositionCounter() : PositionCounter() :
m_num_positions(0) m_num_positions(0)
{ {
@@ -136,6 +214,11 @@ namespace Learner::Stats
m_num_positions = 0; m_num_positions = 0;
} }
[[nodiscard]] const std::string& get_name() const override
{
return name;
}
[[nodiscard]] std::map<std::string, std::string> get_formatted_stats() const override [[nodiscard]] std::map<std::string, std::string> get_formatted_stats() const override
{ {
return { return {
@@ -149,6 +232,8 @@ namespace Learner::Stats
struct KingSquareCounter : StatisticGathererBase struct KingSquareCounter : StatisticGathererBase
{ {
static inline std::string name = "KingSquareCounter";
KingSquareCounter() : KingSquareCounter() :
m_white{}, m_white{},
m_black{} m_black{}
@@ -168,6 +253,19 @@ namespace Learner::Stats
m_black = StatPerSquare<std::uint64_t>{}; m_black = StatPerSquare<std::uint64_t>{};
} }
[[nodiscard]] const std::string& get_name() const override
{
return name;
}
[[nodiscard]] std::map<std::string, std::string> get_formatted_stats() const override
{
return {
{ "White king squares", m_white.get_formatted_stats() },
{ "Black king squares", m_black.get_formatted_stats() }
};
}
private: private:
StatPerSquare<std::uint64_t> m_white; StatPerSquare<std::uint64_t> m_white;
StatPerSquare<std::uint64_t> m_black; StatPerSquare<std::uint64_t> m_black;
@@ -195,7 +293,7 @@ namespace Learner::Stats
void do_gather_statistics( void do_gather_statistics(
const std::string& filename, const std::string& filename,
std::vector<std::unique_ptr<StatisticGathererBase>>& statistic_gatherers, StatisticGathererSet& statistic_gatherers,
std::uint64_t max_count) std::uint64_t max_count)
{ {
Thread* th = Threads.main(); Thread* th = Threads.main();
@@ -205,17 +303,11 @@ namespace Learner::Stats
auto in = Learner::open_sfen_input_file(filename); auto in = Learner::open_sfen_input_file(filename);
auto on_move = [&](Move move) { auto on_move = [&](Move move) {
for (auto&& s : statistic_gatherers) statistic_gatherers.on_move(move);
{
s->on_move(move);
}
}; };
auto on_position = [&](const Position& position) { auto on_position = [&](const Position& position) {
for (auto&& s : statistic_gatherers) statistic_gatherers.on_position(position);
{
s->on_position(position);
}
}; };
if (in == nullptr) if (in == nullptr)
@@ -248,13 +340,9 @@ namespace Learner::Stats
std::cout << "Finished gathering statistics.\n\n"; std::cout << "Finished gathering statistics.\n\n";
std::cout << "Results:\n\n"; std::cout << "Results:\n\n";
for (auto&& s : statistic_gatherers) for (auto&& [name, value] : statistic_gatherers.get_formatted_stats())
{ {
for (auto&& [name, value] : s->get_formatted_stats()) std::cout << name << ": " << value << '\n';
{
std::cout << name << ": " << value << '\n';
}
std::cout << '\n';
} }
} }
@@ -264,7 +352,7 @@ namespace Learner::Stats
auto& registry = get_statistics_gatherers_registry(); auto& registry = get_statistics_gatherers_registry();
std::vector<std::unique_ptr<StatisticGathererBase>> statistic_gatherers; StatisticGathererSet statistic_gatherers;
std::string input_file; std::string input_file;
std::uint64_t max_count = std::numeric_limits<std::uint64_t>::max(); std::uint64_t max_count = std::numeric_limits<std::uint64_t>::max();