#ifndef _NNUE_TRAINER_H_ #define _NNUE_TRAINER_H_ #include "nnue/nnue_common.h" #include "nnue/features/index_list.h" #include #if defined(USE_BLAS) static_assert(std::is_same::value, ""); #include #endif // Common header of class template for learning NNUE evaluation function namespace Eval::NNUE { // Ponanza constant used in the relation between evaluation value and winning percentage constexpr double kPonanzaConstant = 600.0; // Class that represents one index of learning feature class TrainingFeature { using StorageType = std::uint32_t; static_assert(std::is_unsigned::value, ""); public: static constexpr std::uint32_t kIndexBits = 24; static_assert(kIndexBits < std::numeric_limits::digits, ""); static constexpr std::uint32_t kCountBits = std::numeric_limits::digits - kIndexBits; explicit TrainingFeature(IndexType index) : index_and_count_((index << kCountBits) | 1) { assert(index < (1 << kIndexBits)); } TrainingFeature& operator+=(const TrainingFeature& other) { assert(other.get_index() == get_index()); assert(other.get_count() + get_count() < (1 << kCountBits)); index_and_count_ += other.get_count(); return *this; } IndexType get_index() const { return static_cast(index_and_count_ >> kCountBits); } void shift_index(IndexType offset) { assert(get_index() + offset < (1 << kIndexBits)); index_and_count_ += offset << kCountBits; } IndexType get_count() const { return static_cast(index_and_count_ & ((1 << kCountBits) - 1)); } bool operator<(const TrainingFeature& other) const { return index_and_count_ < other.index_and_count_; } private: StorageType index_and_count_; }; // Structure that represents one sample of training data struct Example { std::vector training_features[2]; Learner::PackedSfenValue psv; Value discrete_nn_eval; int sign; double weight; }; // Message used for setting hyperparameters struct Message { Message(const std::string& message_name, const std::string& message_value = "") : name(message_name), value(message_value), num_peekers(0), num_receivers(0) { } const std::string name; const std::string value; std::uint32_t num_peekers; std::uint32_t num_receivers; }; // determine whether to accept the message bool receive_message(const std::string& name, Message* message) { const auto subscript = "[" + std::to_string(message->num_peekers) + "]"; if (message->name.substr(0, name.size() + 1) == name + "[") { ++message->num_peekers; } if (message->name == name || message->name == name + subscript) { ++message->num_receivers; return true; } return false; } // round a floating point number to an integer template IntType round(double value) { return static_cast(std::floor(value + 0.5)); } // make_shared with alignment template std::shared_ptr make_aligned_shared_ptr(ArgumentTypes&&... arguments) { const auto ptr = new(std_aligned_alloc(alignof(T), sizeof(T))) T(std::forward(arguments)...); return std::shared_ptr(ptr, AlignedDeleter()); } } // namespace Eval::NNUE #endif