Support for binpack format in sfenreader in learner. Automatically detect file extension and choose the correct reader (bin or binpack)

This commit is contained in:
Tomasz Sobczyk
2020-09-10 00:31:38 +02:00
committed by nodchip
parent 020e66d2e6
commit 6b76ebc2ca
3 changed files with 328 additions and 304 deletions
+65 -166
View File
@@ -2745,9 +2745,6 @@ namespace chess
0x000200282410A102ull, 0x000200282410A102ull,
0x4048240043802106ull 0x4048240043802106ull
} }; } };
alignas(64) extern EnumArray<Square, Bitboard> g_rookMasks;
alignas(64) extern EnumArray<Square, std::uint8_t> g_rookShifts;
alignas(64) extern EnumArray<Square, const Bitboard*> g_rookAttacks;
alignas(64) constexpr EnumArray<Square, std::uint64_t> g_bishopMagics{ { alignas(64) constexpr EnumArray<Square, std::uint64_t> g_bishopMagics{ {
0x40106000A1160020ull, 0x40106000A1160020ull,
@@ -2815,9 +2812,17 @@ namespace chess
0x0300404822C08200ull, 0x0300404822C08200ull,
0x48081010008A2A80ull 0x48081010008A2A80ull
} }; } };
alignas(64) extern EnumArray<Square, Bitboard> g_bishopMasks;
alignas(64) extern EnumArray<Square, std::uint8_t> g_bishopShifts; alignas(64) static EnumArray<Square, Bitboard> g_rookMasks;
alignas(64) extern EnumArray<Square, const Bitboard*> g_bishopAttacks; alignas(64) static EnumArray<Square, std::uint8_t> g_rookShifts;
alignas(64) static EnumArray<Square, const Bitboard*> g_rookAttacks;
alignas(64) static EnumArray<Square, Bitboard> g_bishopMasks;
alignas(64) static EnumArray<Square, std::uint8_t> g_bishopShifts;
alignas(64) static EnumArray<Square, const Bitboard*> g_bishopAttacks;
alignas(64) static std::array<Bitboard, 102400> g_allRookAttacks;
alignas(64) static std::array<Bitboard, 5248> g_allBishopAttacks;
inline Bitboard bishopAttacks(Square s, Bitboard occupied) inline Bitboard bishopAttacks(Square s, Bitboard occupied)
{ {
@@ -3402,17 +3407,6 @@ namespace chess
Bishop Bishop
}; };
alignas(64) EnumArray<Square, Bitboard> g_rookMasks;
alignas(64) EnumArray<Square, std::uint8_t> g_rookShifts;
alignas(64) EnumArray<Square, const Bitboard*> g_rookAttacks;
alignas(64) EnumArray<Square, Bitboard> g_bishopMasks;
alignas(64) EnumArray<Square, std::uint8_t> g_bishopShifts;
alignas(64) EnumArray<Square, const Bitboard*> g_bishopAttacks;
alignas(64) static std::array<Bitboard, 102400> g_allRookAttacks;
alignas(64) static std::array<Bitboard, 5248> g_allBishopAttacks;
template <MagicsType TypeV> template <MagicsType TypeV>
[[nodiscard]] inline Bitboard slidingAttacks(Square sq, Bitboard occupied) [[nodiscard]] inline Bitboard slidingAttacks(Square sq, Bitboard occupied)
{ {
@@ -3857,7 +3851,7 @@ namespace chess
return true; return true;
} }
[[nodiscard]] std::string fen() const; [[nodiscard]] inline std::string fen() const;
[[nodiscard]] inline bool trySet(std::string_view boardState) [[nodiscard]] inline bool trySet(std::string_view boardState)
{ {
@@ -4093,7 +4087,7 @@ namespace chess
// returns captured piece // returns captured piece
// doesn't check validity // doesn't check validity
FORCEINLINE constexpr Piece doMove(Move move) inline constexpr Piece doMove(Move move)
{ {
if (move.type == MoveType::Normal) if (move.type == MoveType::Normal)
{ {
@@ -4132,7 +4126,7 @@ namespace chess
return doMoveColdPath(move); return doMoveColdPath(move);
} }
NOINLINE constexpr Piece doMoveColdPath(Move move) inline constexpr Piece doMoveColdPath(Move move)
{ {
if (move.type == MoveType::Promotion) if (move.type == MoveType::Promotion)
{ {
@@ -4333,38 +4327,38 @@ namespace chess
// Returns whether a given square is attacked by any piece // Returns whether a given square is attacked by any piece
// of `attackerColor` side. // of `attackerColor` side.
[[nodiscard]] bool isSquareAttacked(Square sq, Color attackerColor) const; [[nodiscard]] inline bool isSquareAttacked(Square sq, Color attackerColor) const;
// Returns whether a given square is attacked by any piece // Returns whether a given square is attacked by any piece
// of `attackerColor` side after `move` is made. // of `attackerColor` side after `move` is made.
// Move must be pseudo legal. // Move must be pseudo legal.
[[nodiscard]] bool isSquareAttackedAfterMove(Move move, Square sq, Color attackerColor) const; [[nodiscard]] inline bool isSquareAttackedAfterMove(Move move, Square sq, Color attackerColor) const;
// Move must be pseudo legal. // Move must be pseudo legal.
// Must not be a king move. // Must not be a king move.
[[nodiscard]] bool createsDiscoveredAttackOnOwnKing(Move move) const; [[nodiscard]] inline bool createsDiscoveredAttackOnOwnKing(Move move) const;
// Returns whether a piece on a given square is attacked // Returns whether a piece on a given square is attacked
// by any enemy piece. False if square is empty. // by any enemy piece. False if square is empty.
[[nodiscard]] bool isPieceAttacked(Square sq) const; [[nodiscard]] inline bool isPieceAttacked(Square sq) const;
// Returns whether a piece on a given square is attacked // Returns whether a piece on a given square is attacked
// by any enemy piece after `move` is made. False if square is empty. // by any enemy piece after `move` is made. False if square is empty.
// Move must be pseudo legal. // Move must be pseudo legal.
[[nodiscard]] bool isPieceAttackedAfterMove(Move move, Square sq) const; [[nodiscard]] inline bool isPieceAttackedAfterMove(Move move, Square sq) const;
// Returns whether the king of the moving side is attacked // Returns whether the king of the moving side is attacked
// by any enemy piece after a move is made. // by any enemy piece after a move is made.
// Move must be pseudo legal. // Move must be pseudo legal.
[[nodiscard]] bool isOwnKingAttackedAfterMove(Move move) const; [[nodiscard]] inline bool isOwnKingAttackedAfterMove(Move move) const;
// Return a bitboard with all (pseudo legal) attacks by the piece on // Return a bitboard with all (pseudo legal) attacks by the piece on
// the given square. Empty if no piece on the square. // the given square. Empty if no piece on the square.
[[nodiscard]] Bitboard attacks(Square sq) const; [[nodiscard]] inline Bitboard attacks(Square sq) const;
// Returns a bitboard with all squared that have pieces // Returns a bitboard with all squared that have pieces
// that attack a given square (pseudo legally) // that attack a given square (pseudo legally)
[[nodiscard]] Bitboard attackers(Square sq, Color attackerColor) const; [[nodiscard]] inline Bitboard attackers(Square sq, Color attackerColor) const;
[[nodiscard]] constexpr Piece pieceAt(Square sq) const [[nodiscard]] constexpr Piece pieceAt(Square sq) const
{ {
@@ -4438,20 +4432,6 @@ namespace chess
struct Position; struct Position;
struct MoveLegalityChecker
{
MoveLegalityChecker(const Position& position);
[[nodiscard]] bool isPseudoLegalMoveLegal(const Move& move) const;
private:
const Position* m_position;
Bitboard m_checkers;
Bitboard m_ourBlockersForKing;
Bitboard m_potentialCheckRemovals;
Square m_ksq;
};
struct CompressedPosition; struct CompressedPosition;
struct PositionHash128 struct PositionHash128
@@ -4484,20 +4464,20 @@ namespace chess
{ {
} }
void set(std::string_view fen); inline void set(std::string_view fen);
// Returns false if the fen was not valid // Returns false if the fen was not valid
// If the returned value was false the position // If the returned value was false the position
// is in unspecified state. // is in unspecified state.
[[nodiscard]] bool trySet(std::string_view fen); [[nodiscard]] inline bool trySet(std::string_view fen);
[[nodiscard]] static Position fromFen(std::string_view fen); [[nodiscard]] static inline Position fromFen(std::string_view fen);
[[nodiscard]] static std::optional<Position> tryFromFen(std::string_view fen); [[nodiscard]] static inline std::optional<Position> tryFromFen(std::string_view fen);
[[nodiscard]] static Position startPosition(); [[nodiscard]] static inline Position startPosition();
[[nodiscard]] std::string fen() const; [[nodiscard]] inline std::string fen() const;
constexpr void setEpSquareUnchecked(Square sq) constexpr void setEpSquareUnchecked(Square sq)
{ {
@@ -4535,7 +4515,7 @@ namespace chess
m_ply = ply; m_ply = ply;
} }
ReverseMove doMove(const Move& move); inline ReverseMove doMove(const Move& move);
constexpr void undoMove(const ReverseMove& reverseMove) constexpr void undoMove(const ReverseMove& reverseMove)
{ {
@@ -4559,49 +4539,44 @@ namespace chess
return m_sideToMove; return m_sideToMove;
} }
[[nodiscard]] std::uint8_t rule50Counter() const [[nodiscard]] inline std::uint8_t rule50Counter() const
{ {
return m_rule50Counter; return m_rule50Counter;
} }
[[nodiscard]] std::uint16_t ply() const [[nodiscard]] inline std::uint16_t ply() const
{ {
return m_ply; return m_ply;
} }
[[nodiscard]] std::uint16_t halfMove() const [[nodiscard]] inline std::uint16_t halfMove() const
{ {
return (m_ply + 1) / 2; return (m_ply + 1) / 2;
} }
void setHalfMove(std::uint16_t hm) inline void setHalfMove(std::uint16_t hm)
{ {
m_ply = 2 * hm - 1 + (m_sideToMove == Color::Black); m_ply = 2 * hm - 1 + (m_sideToMove == Color::Black);
} }
[[nodiscard]] bool isCheck() const; [[nodiscard]] inline bool isCheck() const;
[[nodiscard]] Bitboard checkers() const; [[nodiscard]] inline Bitboard checkers() const;
[[nodiscard]] bool isCheckAfterMove(Move move) const; [[nodiscard]] inline bool isCheckAfterMove(Move move) const;
// Checks whether ANY `move` is legal. // Checks whether ANY `move` is legal.
[[nodiscard]] bool isMoveLegal(Move move) const; [[nodiscard]] inline bool isMoveLegal(Move move) const;
[[nodiscard]] bool isPseudoLegalMoveLegal(Move move) const; [[nodiscard]] inline bool isPseudoLegalMoveLegal(Move move) const;
[[nodiscard]] bool isMovePseudoLegal(Move move) const; [[nodiscard]] inline bool isMovePseudoLegal(Move move) const;
// Returns all pieces that block a slider // Returns all pieces that block a slider
// from attacking our king. When two or more // from attacking our king. When two or more
// pieces block a single slider then none // pieces block a single slider then none
// of these pieces are included. // of these pieces are included.
[[nodiscard]] Bitboard blockersForKing(Color color) const; [[nodiscard]] inline Bitboard blockersForKing(Color color) const;
[[nodiscard]] MoveLegalityChecker moveLegalityChecker() const
{
return { *this };
}
[[nodiscard]] constexpr Square epSquare() const [[nodiscard]] constexpr Square epSquare() const
{ {
@@ -4637,7 +4612,7 @@ namespace chess
return cpy; return cpy;
} }
[[nodiscard]] Position afterMove(Move move) const; [[nodiscard]] inline Position afterMove(Move move) const;
[[nodiscard]] constexpr bool isEpPossible() const [[nodiscard]] constexpr bool isEpPossible() const
{ {
@@ -4655,11 +4630,11 @@ namespace chess
static_assert(sizeof(Color) + sizeof(Square) + sizeof(CastlingRights) + sizeof(std::uint8_t) == 4); static_assert(sizeof(Color) + sizeof(Square) + sizeof(CastlingRights) + sizeof(std::uint8_t) == 4);
[[nodiscard]] FORCEINLINE bool isEpPossible(Square epSquare, Color sideToMove) const; [[nodiscard]] inline bool isEpPossible(Square epSquare, Color sideToMove) const;
[[nodiscard]] NOINLINE bool isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const; [[nodiscard]] inline bool isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const;
void nullifyEpSquareIfNotPossible(); inline void nullifyEpSquareIfNotPossible();
}; };
struct CompressedPosition struct CompressedPosition
@@ -5302,7 +5277,7 @@ namespace chess
return allAttackers; return allAttackers;
} }
const Piece* Board::piecesRaw() const inline const Piece* Board::piecesRaw() const
{ {
return m_pieces.data(); return m_pieces.data();
} }
@@ -5330,7 +5305,7 @@ namespace chess
}(); }();
} }
[[nodiscard]] std::string Board::fen() const [[nodiscard]] inline std::string Board::fen() const
{ {
std::string fen; std::string fen;
fen.reserve(96); // longest fen is probably in range of around 88 fen.reserve(96); // longest fen is probably in range of around 88
@@ -5382,79 +5357,6 @@ namespace chess
return fen; return fen;
} }
MoveLegalityChecker::MoveLegalityChecker(const Position& position) :
m_position(&position),
m_checkers(position.checkers()),
m_ourBlockersForKing(
position.blockersForKing(position.sideToMove())
& position.piecesBB(position.sideToMove())
),
m_ksq(position.kingSquare(position.sideToMove()))
{
if (m_checkers.exactlyOne())
{
const Bitboard knightCheckers = m_checkers & bb::pseudoAttacks<PieceType::Knight>(m_ksq);
if (knightCheckers.any())
{
// We're checked by a knight, we have to remove it or move the king.
m_potentialCheckRemovals = knightCheckers;
}
else
{
// If we're not checked by a knight we can block it.
m_potentialCheckRemovals = bb::between(m_ksq, m_checkers.first()) | m_checkers;
}
}
else
{
// Double check, king has to move.
m_potentialCheckRemovals = Bitboard::none();
}
}
[[nodiscard]] bool MoveLegalityChecker::isPseudoLegalMoveLegal(const Move& move) const
{
const Piece movedPiece = m_position->pieceAt(move.from);
if (m_checkers.any())
{
if (move.from == m_ksq || move.type == MoveType::EnPassant)
{
return m_position->isPseudoLegalMoveLegal(move);
}
else
{
// This means there's only one check and we either
// blocked it or removed the piece that attacked
// our king. So the only threat is if it's a discovered check.
return
m_potentialCheckRemovals.isSet(move.to)
&& !m_ourBlockersForKing.isSet(move.from);
}
}
else
{
if (move.from == m_ksq)
{
return m_position->isPseudoLegalMoveLegal(move);
}
else if (move.type == MoveType::EnPassant)
{
return !m_position->createsDiscoveredAttackOnOwnKing(move);
}
else if (m_ourBlockersForKing.isSet(move.from))
{
// If it was a blocker it may have only moved in line with our king.
// Otherwise it's a discovered check.
return bb::line(m_ksq, move.from).isSet(move.to);
}
else
{
return true;
}
}
}
void Position::set(std::string_view fen) void Position::set(std::string_view fen)
{ {
(void)trySet(fen); (void)trySet(fen);
@@ -5611,7 +5513,7 @@ namespace chess
}(); }();
} }
ReverseMove Position::doMove(const Move& move) inline ReverseMove Position::doMove(const Move& move)
{ {
assert(move.from.isOk() && move.to.isOk()); assert(move.from.isOk() && move.to.isOk());
@@ -5649,12 +5551,12 @@ namespace chess
return { move, captured, oldEpSquare, oldCastlingRights }; return { move, captured, oldEpSquare, oldCastlingRights };
} }
[[nodiscard]] bool Position::isCheck() const [[nodiscard]] inline bool Position::isCheck() const
{ {
return BaseType::isSquareAttacked(kingSquare(m_sideToMove), !m_sideToMove); return BaseType::isSquareAttacked(kingSquare(m_sideToMove), !m_sideToMove);
} }
[[nodiscard]] Bitboard Position::checkers() const [[nodiscard]] inline Bitboard Position::checkers() const
{ {
return BaseType::attackers(kingSquare(m_sideToMove), !m_sideToMove); return BaseType::attackers(kingSquare(m_sideToMove), !m_sideToMove);
} }
@@ -5664,7 +5566,7 @@ namespace chess
return BaseType::isSquareAttackedAfterMove(move, kingSquare(!m_sideToMove), m_sideToMove); return BaseType::isSquareAttackedAfterMove(move, kingSquare(!m_sideToMove), m_sideToMove);
} }
[[nodiscard]] Bitboard Position::blockersForKing(Color color) const [[nodiscard]] inline Bitboard Position::blockersForKing(Color color) const
{ {
const Color attackerColor = !color; const Color attackerColor = !color;
@@ -5700,7 +5602,7 @@ namespace chess
return allBlockers; return allBlockers;
} }
[[nodiscard]] Position Position::afterMove(Move move) const [[nodiscard]] inline Position Position::afterMove(Move move) const
{ {
Position cpy(*this); Position cpy(*this);
auto pc = cpy.doMove(move); auto pc = cpy.doMove(move);
@@ -5711,7 +5613,7 @@ namespace chess
return cpy; return cpy;
} }
[[nodiscard]] FORCEINLINE bool Position::isEpPossible(Square epSquare, Color sideToMove) const [[nodiscard]] inline bool Position::isEpPossible(Square epSquare, Color sideToMove) const
{ {
const Bitboard pawnsAttackingEpSquare = const Bitboard pawnsAttackingEpSquare =
bb::pawnAttacks(Bitboard::square(epSquare), !sideToMove) bb::pawnAttacks(Bitboard::square(epSquare), !sideToMove)
@@ -5725,7 +5627,7 @@ namespace chess
return isEpPossibleColdPath(epSquare, pawnsAttackingEpSquare, sideToMove); return isEpPossibleColdPath(epSquare, pawnsAttackingEpSquare, sideToMove);
} }
[[nodiscard]] NOINLINE bool Position::isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const [[nodiscard]] inline bool Position::isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const
{ {
// only set m_epSquare when it matters, ie. when // only set m_epSquare when it matters, ie. when
// the opposite side can actually capture // the opposite side can actually capture
@@ -5772,7 +5674,7 @@ namespace chess
return false; return false;
} }
void Position::nullifyEpSquareIfNotPossible() inline void Position::nullifyEpSquareIfNotPossible()
{ {
if (m_epSquare != Square::none() && !isEpPossible(m_epSquare, m_sideToMove)) if (m_epSquare != Square::none() && !isEpPossible(m_epSquare, m_sideToMove))
{ {
@@ -5782,12 +5684,12 @@ namespace chess
namespace uci namespace uci
{ {
[[nodiscard]] std::string moveToUci(const Position& pos, const Move& move); [[nodiscard]] inline std::string moveToUci(const Position& pos, const Move& move);
[[nodiscard]] Move uciToMove(const Position& pos, std::string_view sv); [[nodiscard]] inline Move uciToMove(const Position& pos, std::string_view sv);
[[nodiscard]] std::optional<Move> tryUciToMove(const Position& pos, std::string_view sv); [[nodiscard]] inline std::optional<Move> tryUciToMove(const Position& pos, std::string_view sv);
[[nodiscard]] std::string moveToUci(const Position& pos, const Move& move) [[nodiscard]] inline std::string moveToUci(const Position& pos, const Move& move)
{ {
std::string s; std::string s;
@@ -5814,7 +5716,7 @@ namespace chess
return s; return s;
} }
[[nodiscard]] Move uciToMove(const Position& pos, std::string_view sv) [[nodiscard]] inline Move uciToMove(const Position& pos, std::string_view sv)
{ {
const Square from = parser_bits::parseSquare(sv.data()); const Square from = parser_bits::parseSquare(sv.data());
const Square to = parser_bits::parseSquare(sv.data() + 2); const Square to = parser_bits::parseSquare(sv.data() + 2);
@@ -5850,7 +5752,7 @@ namespace chess
} }
} }
[[nodiscard]] std::optional<Move> tryUciToMove(const Position& pos, std::string_view sv) [[nodiscard]] inline std::optional<Move> tryUciToMove(const Position& pos, std::string_view sv)
{ {
if (sv.size() < 4 || sv.size() > 5) if (sv.size() < 4 || sv.size() > 5)
{ {
@@ -6300,7 +6202,7 @@ namespace binpack
}; };
[[nodiscard]] chess::Position pos_from_packed_sfen(const PackedSfen& sfen) [[nodiscard]] inline chess::Position pos_from_packed_sfen(const PackedSfen& sfen)
{ {
SfenPacker packer; SfenPacker packer;
auto& stream = packer.stream; auto& stream = packer.stream;
@@ -6655,14 +6557,12 @@ namespace binpack
if (!occupied.isSet(sqForward)) if (!occupied.isSet(sqForward))
{ {
destinations |= sqForward; destinations |= sqForward;
const chess::Square sqForward2 = sqForward + forward;
if ( if (
from.rank() == startRank from.rank() == startRank
&& !occupied.isSet(sqForward2) && !occupied.isSet(sqForward + forward)
) )
{ {
destinations |= sqForward2; destinations |= sqForward + forward;
} }
} }
@@ -6845,13 +6745,12 @@ namespace binpack
{ {
destinations |= sqForward; destinations |= sqForward;
const chess::Square sqForward2 = sqForward + forward;
if ( if (
move.from.rank() == startRank move.from.rank() == startRank
&& !occupied.isSet(sqForward2) && !occupied.isSet(sqForward + forward)
) )
{ {
destinations |= sqForward2; destinations |= sqForward + forward;
} }
} }
+4 -1
View File
@@ -133,9 +133,12 @@ namespace Learner
{ {
case SfenOutputType::Bin: case SfenOutputType::Bin:
return std::make_unique<BinSfenOutputStream>(filename); return std::make_unique<BinSfenOutputStream>(filename);
default: case SfenOutputType::Binpack:
return std::make_unique<BinpackSfenOutputStream>(filename); return std::make_unique<BinpackSfenOutputStream>(filename);
} }
assert(false);
return nullptr;
} }
// Helper class for exporting Sfen // Helper class for exporting Sfen
+259 -137
View File
@@ -30,6 +30,8 @@
#include "learn.h" #include "learn.h"
#include "multi_think.h" #include "multi_think.h"
#include "../extra/nnue_data_binpack_format.h"
#include <chrono> #include <chrono>
#include <climits> #include <climits>
#include <cmath> // std::exp(),std::pow(),std::log() #include <cmath> // std::exp(),std::pow(),std::log()
@@ -85,8 +87,8 @@ namespace Learner
static double dest_score_min_value = 0.0; static double dest_score_min_value = 0.0;
static double dest_score_max_value = 1.0; static double dest_score_max_value = 1.0;
// Assume teacher signals are the scores of deep searches, // Assume teacher signals are the scores of deep searches,
// and convert them into winning probabilities in the trainer. // and convert them into winning probabilities in the trainer.
// Sometimes we want to use the winning probabilities in the training // Sometimes we want to use the winning probabilities in the training
// data directly. In those cases, we set false to this variable. // data directly. In those cases, we set false to this variable.
static bool convert_teacher_signal_to_winning_probability = true; static bool convert_teacher_signal_to_winning_probability = true;
@@ -125,19 +127,19 @@ namespace Learner
// A function that converts the evaluation value to the winning rate [0,1] // A function that converts the evaluation value to the winning rate [0,1]
double winning_percentage(double value, int ply) double winning_percentage(double value, int ply)
{ {
if (use_wdl) if (use_wdl)
{ {
return winning_percentage_wdl(value, ply); return winning_percentage_wdl(value, ply);
} }
else else
{ {
return winning_percentage(value); return winning_percentage(value);
} }
} }
double calc_cross_entropy_of_winning_percentage( double calc_cross_entropy_of_winning_percentage(
double deep_win_rate, double deep_win_rate,
double shallow_eval, double shallow_eval,
int ply) int ply)
{ {
const double p = deep_win_rate; const double p = deep_win_rate;
@@ -146,8 +148,8 @@ namespace Learner
} }
double calc_d_cross_entropy_of_winning_percentage( double calc_d_cross_entropy_of_winning_percentage(
double deep_win_rate, double deep_win_rate,
double shallow_eval, double shallow_eval,
int ply) int ply)
{ {
constexpr double epsilon = 0.000001; constexpr double epsilon = 0.000001;
@@ -158,7 +160,7 @@ namespace Learner
const double y2 = calc_cross_entropy_of_winning_percentage( const double y2 = calc_cross_entropy_of_winning_percentage(
deep_win_rate, shallow_eval + epsilon, ply); deep_win_rate, shallow_eval + epsilon, ply);
// Divide by the winning_probability_coefficient to // Divide by the winning_probability_coefficient to
// match scale with the sigmoidal win rate // match scale with the sigmoidal win rate
return ((y2 - y1) / epsilon) / winning_probability_coefficient; return ((y2 - y1) / epsilon) / winning_probability_coefficient;
} }
@@ -195,7 +197,7 @@ namespace Learner
const double scaled_teacher_signal = get_scaled_signal(teacher_signal); const double scaled_teacher_signal = get_scaled_signal(teacher_signal);
double p = scaled_teacher_signal; double p = scaled_teacher_signal;
if (convert_teacher_signal_to_winning_probability) if (convert_teacher_signal_to_winning_probability)
{ {
p = winning_percentage(scaled_teacher_signal, ply); p = winning_percentage(scaled_teacher_signal, ply);
} }
@@ -217,7 +219,7 @@ namespace Learner
double calculate_t(int game_result) double calculate_t(int game_result)
{ {
// Use 1 as the correction term if the expected win rate is 1, // Use 1 as the correction term if the expected win rate is 1,
// 0 if you lose, and 0.5 if you draw. // 0 if you lose, and 0.5 if you draw.
// game_result = 1,0,-1 so add 1 and divide by 2. // game_result = 1,0,-1 so add 1 and divide by 2.
const double t = double(game_result + 1) * 0.5; const double t = double(game_result + 1) * 0.5;
@@ -235,13 +237,13 @@ namespace Learner
const double lambda = calculate_lambda(teacher_signal); const double lambda = calculate_lambda(teacher_signal);
double grad; double grad;
if (use_wdl) if (use_wdl)
{ {
const double dce_p = calc_d_cross_entropy_of_winning_percentage(p, shallow, psv.gamePly); const double dce_p = calc_d_cross_entropy_of_winning_percentage(p, shallow, psv.gamePly);
const double dce_t = calc_d_cross_entropy_of_winning_percentage(t, shallow, psv.gamePly); const double dce_t = calc_d_cross_entropy_of_winning_percentage(t, shallow, psv.gamePly);
grad = lambda * dce_p + (1.0 - lambda) * dce_t; grad = lambda * dce_p + (1.0 - lambda) * dce_t;
} }
else else
{ {
// Use the actual win rate as a correction term. // Use the actual win rate as a correction term.
// This is the idea of elmo (WCSC27), modern O-parts. // This is the idea of elmo (WCSC27), modern O-parts.
@@ -252,18 +254,18 @@ namespace Learner
} }
// Calculate cross entropy during learning // Calculate cross entropy during learning
// The individual cross entropy of the win/loss term and win // The individual cross entropy of the win/loss term and win
// rate term of the elmo expression is returned // rate term of the elmo expression is returned
// to the arguments cross_entropy_eval and cross_entropy_win. // to the arguments cross_entropy_eval and cross_entropy_win.
void calc_cross_entropy( void calc_cross_entropy(
Value teacher_signal, Value teacher_signal,
Value shallow, Value shallow,
const PackedSfenValue& psv, const PackedSfenValue& psv,
double& cross_entropy_eval, double& cross_entropy_eval,
double& cross_entropy_win, double& cross_entropy_win,
double& cross_entropy, double& cross_entropy,
double& entropy_eval, double& entropy_eval,
double& entropy_win, double& entropy_win,
double& entropy) double& entropy)
{ {
// Teacher winning probability. // Teacher winning probability.
@@ -292,24 +294,133 @@ namespace Learner
} }
// Other objective functions may be considered in the future... // Other objective functions may be considered in the future...
double calc_grad(Value shallow, const PackedSfenValue& psv) double calc_grad(Value shallow, const PackedSfenValue& psv)
{ {
return calc_grad((Value)psv.score, shallow, psv); return calc_grad((Value)psv.score, shallow, psv);
} }
struct BasicSfenInputStream
{
virtual std::optional<PackedSfenValue> next() = 0;
virtual bool eof() const = 0;
virtual ~BasicSfenInputStream() {}
};
struct BinSfenInputStream : BasicSfenInputStream
{
static constexpr auto openmode = ios::in | ios::binary;
static inline const std::string extension = "bin";
BinSfenInputStream(std::string filename) :
m_stream(filename, openmode),
m_eof(!m_stream)
{
}
std::optional<PackedSfenValue> next() override
{
PackedSfenValue e;
if(m_stream.read(reinterpret_cast<char*>(&e), sizeof(PackedSfenValue)))
{
return e;
}
else
{
m_eof = true;
return std::nullopt;
}
}
bool eof() const override
{
return m_eof;
}
~BinSfenInputStream() override {}
private:
fstream m_stream;
bool m_eof;
};
struct BinpackSfenInputStream : BasicSfenInputStream
{
static constexpr auto openmode = ios::in | ios::binary;
static inline const std::string extension = "binpack";
BinpackSfenInputStream(std::string filename) :
m_stream(filename, openmode),
m_eof(!m_stream.hasNext())
{
}
std::optional<PackedSfenValue> next() override
{
static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue));
if (!m_stream.hasNext())
{
m_eof = true;
return std::nullopt;
}
auto training_data_entry = m_stream.next();
auto v = binpack::trainingDataEntryToPackedSfenValue(training_data_entry);
PackedSfenValue psv;
// same layout, different types. One is from generic library.
std::memcpy(&psv, &v, sizeof(PackedSfenValue));
return psv;
}
bool eof() const override
{
return m_eof;
}
~BinpackSfenInputStream() override {}
private:
binpack::CompressedTrainingDataEntryReader m_stream;
bool m_eof;
};
static bool ends_with(const std::string& lhs, const std::string& end)
{
if (end.size() > lhs.size()) return false;
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
}
static bool has_extension(const std::string& filename, const std::string& extension)
{
return ends_with(filename, "." + extension);
}
static std::unique_ptr<BasicSfenInputStream> open_sfen_input_file(const std::string& filename)
{
if (has_extension(filename, BinSfenInputStream::extension))
return std::make_unique<BinSfenInputStream>(filename);
else if (has_extension(filename, BinpackSfenInputStream::extension))
return std::make_unique<BinpackSfenInputStream>(filename);
assert(false);
return nullptr;
}
// Sfen reader // Sfen reader
struct SfenReader struct SfenReader
{ {
// Number of phases used for calculation such as mse // Number of phases used for calculation such as mse
// mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time. // mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time.
// Since search() is performed with depth = 1 in calculation of // Since search() is performed with depth = 1 in calculation of
// move match rate, simple comparison is not possible... // move match rate, simple comparison is not possible...
static constexpr uint64_t sfen_for_mse_size = 2000; static constexpr uint64_t sfen_for_mse_size = 2000;
// Number of phases buffered by each thread 0.1M phases. 4M phase at 40HT // Number of phases buffered by each thread 0.1M phases. 4M phase at 40HT
static constexpr size_t THREAD_BUFFER_SIZE = 10 * 1000; static constexpr size_t THREAD_BUFFER_SIZE = 10 * 1000;
// Buffer for reading files (If this is made larger, // Buffer for reading files (If this is made larger,
// the shuffle becomes larger and the phases may vary. // the shuffle becomes larger and the phases may vary.
// If it is too large, the memory consumption will increase. // If it is too large, the memory consumption will increase.
// SFEN_READ_SIZE is a multiple of THREAD_BUFFER_SIZE. // SFEN_READ_SIZE is a multiple of THREAD_BUFFER_SIZE.
@@ -322,7 +433,7 @@ namespace Learner
// Do not use std::random_device(). // Do not use std::random_device().
// Because it always the same integers on MinGW. // Because it always the same integers on MinGW.
SfenReader(int thread_num) : SfenReader(int thread_num) :
prng(std::chrono::system_clock::now().time_since_epoch().count()) prng(std::chrono::system_clock::now().time_since_epoch().count())
{ {
packed_sfens.resize(thread_num); packed_sfens.resize(thread_num);
@@ -369,13 +480,15 @@ namespace Learner
void read_validation_set(const string& file_name, int eval_limit) void read_validation_set(const string& file_name, int eval_limit)
{ {
ifstream input(file_name, ios::binary); auto input = open_sfen_input_file(file_name);
while (input) while(!input->eof())
{ {
PackedSfenValue p; std::optional<PackedSfenValue> p_opt = input->next();
if (input.read(reinterpret_cast<char*>(&p), sizeof(PackedSfenValue))) if (p_opt.has_value())
{ {
auto& p = *p_opt;
if (eval_limit < abs(p.score)) if (eval_limit < abs(p.score))
continue; continue;
@@ -398,7 +511,7 @@ namespace Learner
// then retrieve one and return it. // then retrieve one and return it.
auto& thread_ps = packed_sfens[thread_id]; auto& thread_ps = packed_sfens[thread_id];
// Fill the read buffer if there is no remaining buffer, // Fill the read buffer if there is no remaining buffer,
// but if it doesn't even exist, finish. // but if it doesn't even exist, finish.
// If the buffer is empty, fill it. // If the buffer is empty, fill it.
if ((thread_ps == nullptr || thread_ps->empty()) if ((thread_ps == nullptr || thread_ps->empty())
@@ -406,7 +519,7 @@ namespace Learner
return false; return false;
// read_to_thread_buffer_impl() returned true, // read_to_thread_buffer_impl() returned true,
// Since the filling of the thread buffer with the // Since the filling of the thread buffer with the
// phase has been completed successfully // phase has been completed successfully
// thread_ps->rbegin() is alive. // thread_ps->rbegin() is alive.
@@ -458,33 +571,42 @@ namespace Learner
// Start a thread that loads the phase file in the background. // Start a thread that loads the phase file in the background.
void start_file_read_worker() void start_file_read_worker()
{ {
file_worker_thread = std::thread([&] { file_worker_thread = std::thread([&] {
this->file_read_worker(); this->file_read_worker();
}); });
} }
void file_read_worker() void file_read_worker()
{ {
auto open_next_file = [&]() { auto open_next_file = [&]() {
if (fs.is_open())
fs.close();
// no more // no more
if (filenames.empty()) for(;;)
return false; {
sfen_input_stream.reset();
// Get the next file name. if (filenames.empty())
string filename = filenames.back(); return false;
filenames.pop_back();
fs.open(filename, ios::in | ios::binary); // Get the next file name.
cout << "open filename = " << filename << endl; string filename = filenames.back();
filenames.pop_back();
assert(fs); sfen_input_stream = open_sfen_input_file(filename);
cout << "open filename = " << filename << endl;
return true; // in case the file is empty or was deleted.
if (!sfen_input_stream->eof())
return true;
}
}; };
if (sfen_input_stream == nullptr && !open_next_file())
{
cout << "..end of files." << endl;
end_of_files = true;
return;
}
while (true) while (true)
{ {
// Wait for the buffer to run out. // Wait for the buffer to run out.
@@ -501,10 +623,10 @@ namespace Learner
// Read from the file into the file buffer. // Read from the file into the file buffer.
while (sfens.size() < SFEN_READ_SIZE) while (sfens.size() < SFEN_READ_SIZE)
{ {
PackedSfenValue p; std::optional<PackedSfenValue> p = sfen_input_stream->next();
if (fs.read(reinterpret_cast<char*>(&p), sizeof(PackedSfenValue))) if (p.has_value())
{ {
sfens.push_back(p); sfens.push_back(*p);
} }
else if(!open_next_file()) else if(!open_next_file())
{ {
@@ -535,8 +657,8 @@ namespace Learner
auto buf = std::make_unique<PSVector>(); auto buf = std::make_unique<PSVector>();
buf->resize(THREAD_BUFFER_SIZE); buf->resize(THREAD_BUFFER_SIZE);
memcpy( memcpy(
buf->data(), buf->data(),
&sfens[i * THREAD_BUFFER_SIZE], &sfens[i * THREAD_BUFFER_SIZE],
sizeof(PackedSfenValue) * THREAD_BUFFER_SIZE); sizeof(PackedSfenValue) * THREAD_BUFFER_SIZE);
buffers.emplace_back(std::move(buf)); buffers.emplace_back(std::move(buf));
@@ -545,7 +667,7 @@ namespace Learner
{ {
std::unique_lock<std::mutex> lk(mutex); std::unique_lock<std::mutex> lk(mutex);
// The mutex lock is required because the // The mutex lock is required because the%
// contents of packed_sfens_pool are changed. // contents of packed_sfens_pool are changed.
for (auto& buf : buffers) for (auto& buf : buffers)
@@ -600,7 +722,7 @@ namespace Learner
atomic<bool> end_of_files; atomic<bool> end_of_files;
// handle of sfen file // handle of sfen file
std::fstream fs; std::unique_ptr<BasicSfenInputStream> sfen_input_stream;
// sfen for each thread // sfen for each thread
// (When the thread is used up, the thread should call delete to release it.) // (When the thread is used up, the thread should call delete to release it.)
@@ -621,9 +743,9 @@ namespace Learner
// Class to generate sfen with multiple threads // Class to generate sfen with multiple threads
struct LearnerThink : public MultiThink struct LearnerThink : public MultiThink
{ {
LearnerThink(SfenReader& sr_) : LearnerThink(SfenReader& sr_) :
sr(sr_), sr(sr_),
stop_flag(false), stop_flag(false),
save_only_once(false) save_only_once(false)
{ {
learn_sum_cross_entropy_eval = 0.0; learn_sum_cross_entropy_eval = 0.0;
@@ -644,9 +766,9 @@ namespace Learner
virtual void thread_worker(size_t thread_id); virtual void thread_worker(size_t thread_id);
// Start a thread that loads the phase file in the background. // Start a thread that loads the phase file in the background.
void start_file_read_worker() void start_file_read_worker()
{ {
sr.start_file_read_worker(); sr.start_file_read_worker();
} }
Value get_shallow_value(Position& task_pos); Value get_shallow_value(Position& task_pos);
@@ -674,7 +796,7 @@ namespace Learner
// Option not to learn kk/kkp/kpp/kppp // Option not to learn kk/kkp/kpp/kppp
std::array<bool, 4> freeze; std::array<bool, 4> freeze;
// If the absolute value of the evaluation value of the deep search // If the absolute value of the evaluation value of the deep search
// of the teacher phase exceeds this value, discard the teacher phase. // of the teacher phase exceeds this value, discard the teacher phase.
int eval_limit; int eval_limit;
@@ -742,7 +864,7 @@ namespace Learner
void LearnerThink::calc_loss(size_t thread_id, uint64_t done) void LearnerThink::calc_loss(size_t thread_id, uint64_t done)
{ {
// There is no point in hitting the replacement table, // There is no point in hitting the replacement table,
// so at this timing the generation of the replacement table is updated. // so at this timing the generation of the replacement table is updated.
// It doesn't matter if you have disabled the substitution table. // It doesn't matter if you have disabled the substitution table.
TT.new_search(); TT.new_search();
@@ -766,7 +888,7 @@ namespace Learner
atomic<double> sum_norm; atomic<double> sum_norm;
sum_norm = 0; sum_norm = 0;
// The number of times the pv first move of deep // The number of times the pv first move of deep
// search matches the pv first move of search(1). // search matches the pv first move of search(1).
atomic<int> move_accord_count; atomic<int> move_accord_count;
move_accord_count = 0; move_accord_count = 0;
@@ -778,7 +900,7 @@ namespace Learner
pos.set(StartFEN, false, &si, th); pos.set(StartFEN, false, &si, th);
std::cout << "hirate eval = " << Eval::evaluate(pos); std::cout << "hirate eval = " << Eval::evaluate(pos);
// It's better to parallelize here, but it's a bit // It's better to parallelize here, but it's a bit
// troublesome because the search before slave has not finished. // troublesome because the search before slave has not finished.
// I created a mechanism to call task, so I will use it. // I created a mechanism to call task, so I will use it.
@@ -792,7 +914,7 @@ namespace Learner
{ {
// Assign work to each thread using TaskDispatcher. // Assign work to each thread using TaskDispatcher.
// A task definition for that. // A task definition for that.
// It is not possible to capture pos used in ↑, // It is not possible to capture pos used in ↑,
// so specify the variables you want to capture one by one. // so specify the variables you want to capture one by one.
auto task = auto task =
[ [
@@ -823,7 +945,7 @@ namespace Learner
// Evaluation value of deep search // Evaluation value of deep search
auto deep_value = (Value)ps.score; auto deep_value = (Value)ps.score;
// Note) This code does not consider when // Note) This code does not consider when
// eval_limit is specified in the learn command. // eval_limit is specified in the learn command.
// --- calculation of cross entropy // --- calculation of cross entropy
@@ -834,14 +956,14 @@ namespace Learner
double test_cross_entropy_eval, test_cross_entropy_win, test_cross_entropy; double test_cross_entropy_eval, test_cross_entropy_win, test_cross_entropy;
double test_entropy_eval, test_entropy_win, test_entropy; double test_entropy_eval, test_entropy_win, test_entropy;
calc_cross_entropy( calc_cross_entropy(
deep_value, deep_value,
shallow_value, shallow_value,
ps, ps,
test_cross_entropy_eval, test_cross_entropy_eval,
test_cross_entropy_win, test_cross_entropy_win,
test_cross_entropy, test_cross_entropy,
test_entropy_eval, test_entropy_eval,
test_entropy_win, test_entropy_win,
test_entropy); test_entropy);
// The total cross entropy need not be abs() by definition. // The total cross entropy need not be abs() by definition.
@@ -878,9 +1000,9 @@ namespace Learner
latest_loss_sum += test_sum_cross_entropy - test_sum_entropy; latest_loss_sum += test_sum_cross_entropy - test_sum_entropy;
latest_loss_count += sr.sfen_for_mse.size(); latest_loss_count += sr.sfen_for_mse.size();
// learn_cross_entropy may be called train cross // learn_cross_entropy may be called train cross
// entropy in the world of machine learning, // entropy in the world of machine learning,
// When omitting the acronym, it is nice to be able to // When omitting the acronym, it is nice to be able to
// distinguish it from test cross entropy(tce) by writing it as lce. // distinguish it from test cross entropy(tce) by writing it as lce.
if (sr.sfen_for_mse.size() && done) if (sr.sfen_for_mse.size() && done)
@@ -907,7 +1029,7 @@ namespace Learner
} }
cout << endl; cout << endl;
} }
else else
{ {
cout << "Error! : sr.sfen_for_mse.size() = " << sr.sfen_for_mse.size() << " , done = " << done << endl; cout << "Error! : sr.sfen_for_mse.size() = " << sr.sfen_for_mse.size() << " , done = " << done << endl;
} }
@@ -977,7 +1099,7 @@ namespace Learner
{ {
sr.save_count = 0; sr.save_count = 0;
// During this time, as the gradient calculation proceeds, // During this time, as the gradient calculation proceeds,
// the value becomes too large and I feel annoyed, so stop other threads. // the value becomes too large and I feel annoyed, so stop other threads.
const bool converged = save(); const bool converged = save();
if (converged) if (converged)
@@ -1007,11 +1129,11 @@ namespace Learner
sr.last_done = sr.total_done; sr.last_done = sr.total_done;
} }
// Next time, I want you to do this series of // Next time, I want you to do this series of
// processing again when you process only mini_batch_size. // processing again when you process only mini_batch_size.
sr.next_update_weights += mini_batch_size; sr.next_update_weights += mini_batch_size;
// Since I was waiting for the update of this // Since I was waiting for the update of this
// sr.next_update_weights except the main thread, // sr.next_update_weights except the main thread,
// Once this value is updated, it will start moving again. // Once this value is updated, it will start moving again.
} }
@@ -1048,16 +1170,16 @@ namespace Learner
if (pos.set_from_packed_sfen(ps.sfen, &si, th, mirror) != 0) if (pos.set_from_packed_sfen(ps.sfen, &si, th, mirror) != 0)
{ {
// I got a strange sfen. Should be debugged! // I got a strange sfen. Should be debugged!
// Since it is an illegal sfen, it may not be // Since it is an illegal sfen, it may not be
// displayed with pos.sfen(), but it is better than not. // displayed with pos.sfen(), but it is better than not.
cout << "Error! : illigal packed sfen = " << pos.fen() << endl; cout << "Error! : illigal packed sfen = " << pos.fen() << endl;
goto RETRY_READ; goto RETRY_READ;
} }
// There is a possibility that all the pieces are blocked and stuck. // There is a possibility that all the pieces are blocked and stuck.
// Also, the declaration win phase is excluded from // Also, the declaration win phase is excluded from
// learning because you cannot go to leaf with PV moves. // learning because you cannot go to leaf with PV moves.
// (shouldn't write out such teacher aspect itself, // (shouldn't write out such teacher aspect itself,
// but may have written it out with an old generation routine) // but may have written it out with an old generation routine)
// Skip the position if there are no legal moves (=checkmated or stalemate). // Skip the position if there are no legal moves (=checkmated or stalemate).
if (MoveList<LEGAL>(pos).size() == 0) if (MoveList<LEGAL>(pos).size() == 0)
@@ -1073,7 +1195,7 @@ namespace Learner
const auto deep_value = (Value)ps.score; const auto deep_value = (Value)ps.score;
// I feel that the mini batch has a better gradient. // I feel that the mini batch has a better gradient.
// Go to the leaf node as it is, add only to the gradient array, // Go to the leaf node as it is, add only to the gradient array,
// and later try AdaGrad at the time of rmse aggregation. // and later try AdaGrad at the time of rmse aggregation.
const auto rootColor = pos.side_to_move(); const auto rootColor = pos.side_to_move();
@@ -1088,30 +1210,30 @@ namespace Learner
auto pos_add_grad = [&]() { auto pos_add_grad = [&]() {
// Use the value of evaluate in leaf as shallow_value. // Use the value of evaluate in leaf as shallow_value.
// Using the return value of qsearch() as shallow_value, // Using the return value of qsearch() as shallow_value,
// If PV is interrupted in the middle, the phase where // If PV is interrupted in the middle, the phase where
// evaluate() is called to calculate the gradient, // evaluate() is called to calculate the gradient,
// and I don't think this is a very desirable property, // and I don't think this is a very desirable property,
// as the aspect that gives that gradient will be different. // as the aspect that gives that gradient will be different.
// I have turned off the substitution table, but since // I have turned off the substitution table, but since
// the pv array has not been updated due to one stumbling block etc... // the pv array has not been updated due to one stumbling block etc...
const Value shallow_value = const Value shallow_value =
(rootColor == pos.side_to_move()) (rootColor == pos.side_to_move())
? Eval::evaluate(pos) ? Eval::evaluate(pos)
: -Eval::evaluate(pos); : -Eval::evaluate(pos);
// Calculate loss for training data // Calculate loss for training data
double learn_cross_entropy_eval, learn_cross_entropy_win, learn_cross_entropy; double learn_cross_entropy_eval, learn_cross_entropy_win, learn_cross_entropy;
double learn_entropy_eval, learn_entropy_win, learn_entropy; double learn_entropy_eval, learn_entropy_win, learn_entropy;
calc_cross_entropy( calc_cross_entropy(
deep_value, deep_value,
shallow_value, shallow_value,
ps, ps,
learn_cross_entropy_eval, learn_cross_entropy_eval,
learn_cross_entropy_win, learn_cross_entropy_win,
learn_cross_entropy, learn_cross_entropy,
learn_entropy_eval, learn_entropy_eval,
learn_entropy_win, learn_entropy_win,
learn_entropy); learn_entropy);
learn_sum_cross_entropy_eval += learn_cross_entropy_eval; learn_sum_cross_entropy_eval += learn_cross_entropy_eval;
@@ -1154,7 +1276,7 @@ namespace Learner
Eval::NNUE::update_eval(pos); Eval::NNUE::update_eval(pos);
} }
if (illegal_move) if (illegal_move)
{ {
sync_cout << "An illegal move was detected... Excluded the position from the learning data..." << sync_endl; sync_cout << "An illegal move was detected... Excluded the position from the learning data..." << sync_endl;
continue; continue;
@@ -1182,12 +1304,12 @@ namespace Learner
// Do not dig a subfolder because I want to save it only once. // Do not dig a subfolder because I want to save it only once.
Eval::save_eval(""); Eval::save_eval("");
} }
else if (is_final) else if (is_final)
{ {
Eval::save_eval("final"); Eval::save_eval("final");
return true; return true;
} }
else else
{ {
static int dir_number = 0; static int dir_number = 0;
const std::string dir_name = std::to_string(dir_number++); const std::string dir_name = std::to_string(dir_number++);
@@ -1199,27 +1321,27 @@ namespace Learner
latest_loss_sum = 0.0; latest_loss_sum = 0.0;
latest_loss_count = 0; latest_loss_count = 0;
cout << "loss: " << latest_loss; cout << "loss: " << latest_loss;
if (latest_loss < best_loss) if (latest_loss < best_loss)
{ {
cout << " < best (" << best_loss << "), accepted" << endl; cout << " < best (" << best_loss << "), accepted" << endl;
best_loss = latest_loss; best_loss = latest_loss;
best_nn_directory = Path::Combine((std::string)Options["EvalSaveDir"], dir_name); best_nn_directory = Path::Combine((std::string)Options["EvalSaveDir"], dir_name);
trials = newbob_num_trials; trials = newbob_num_trials;
} }
else else
{ {
cout << " >= best (" << best_loss << "), rejected" << endl; cout << " >= best (" << best_loss << "), rejected" << endl;
if (best_nn_directory.empty()) if (best_nn_directory.empty())
{ {
cout << "WARNING: no improvement from initial model" << endl; cout << "WARNING: no improvement from initial model" << endl;
} }
else else
{ {
cout << "restoring parameters from " << best_nn_directory << endl; cout << "restoring parameters from " << best_nn_directory << endl;
Eval::NNUE::RestoreParameters(best_nn_directory); Eval::NNUE::RestoreParameters(best_nn_directory);
} }
if (--trials > 0 && !is_final) if (--trials > 0 && !is_final)
{ {
cout cout
<< "reducing learning rate scale from " << newbob_scale << "reducing learning rate scale from " << newbob_scale
@@ -1230,8 +1352,8 @@ namespace Learner
Eval::NNUE::SetGlobalLearningRateScale(newbob_scale); Eval::NNUE::SetGlobalLearningRateScale(newbob_scale);
} }
} }
if (trials == 0) if (trials == 0)
{ {
cout << "converged" << endl; cout << "converged" << endl;
return true; return true;
@@ -1247,9 +1369,9 @@ namespace Learner
// sfen_file_streams: fstream of each teacher phase file // sfen_file_streams: fstream of each teacher phase file
// sfen_count_in_file: The number of teacher positions present in each file. // sfen_count_in_file: The number of teacher positions present in each file.
void shuffle_write( void shuffle_write(
const string& output_file_name, const string& output_file_name,
PRNG& prng, PRNG& prng,
vector<fstream>& sfen_file_streams, vector<fstream>& sfen_file_streams,
vector<uint64_t>& sfen_count_in_file) vector<uint64_t>& sfen_count_in_file)
{ {
uint64_t total_sfen_count = 0; uint64_t total_sfen_count = 0;
@@ -1323,7 +1445,7 @@ namespace Learner
// Temporary file is written to tmp/ folder for each buffer_size phase. // Temporary file is written to tmp/ folder for each buffer_size phase.
// For example, if buffer_size = 20M, you need a buffer of 20M*40bytes = 800MB. // For example, if buffer_size = 20M, you need a buffer of 20M*40bytes = 800MB.
// In a PC with a small memory, it would be better to reduce this. // In a PC with a small memory, it would be better to reduce this.
// However, if the number of files increases too much, // However, if the number of files increases too much,
// it will not be possible to open at the same time due to OS restrictions. // it will not be possible to open at the same time due to OS restrictions.
// There should have been a limit of 512 per process on Windows, so you can open here as 500, // There should have been a limit of 512 per process on Windows, so you can open here as 500,
// The current setting is 500 files x 20M = 10G = 10 billion phases. // The current setting is 500 files x 20M = 10G = 10 billion phases.
@@ -1377,7 +1499,7 @@ namespace Learner
// Read in units of sizeof(PackedSfenValue), // Read in units of sizeof(PackedSfenValue),
// Ignore the last remaining fraction. (Fails in fs.read, so exit while) // Ignore the last remaining fraction. (Fails in fs.read, so exit while)
// (The remaining fraction seems to be half-finished data // (The remaining fraction seems to be half-finished data
// that was created because it was stopped halfway during teacher generation.) // that was created because it was stopped halfway during teacher generation.)
} }
@@ -1385,14 +1507,14 @@ namespace Learner
write_buffer(buf_write_marker); write_buffer(buf_write_marker);
// Only shuffled files have been written write_file_count. // Only shuffled files have been written write_file_count.
// As a second pass, if you open all of them at the same time, // As a second pass, if you open all of them at the same time,
// select one at random and load one phase at a time // select one at random and load one phase at a time
// Now you have shuffled. // Now you have shuffled.
// Original file for shirt full + tmp file + file to write // Original file for shirt full + tmp file + file to write
// requires 3 times the storage capacity of the original file. // requires 3 times the storage capacity of the original file.
// 1 billion SSD is not enough for shuffling because it is 400GB for 10 billion phases. // 1 billion SSD is not enough for shuffling because it is 400GB for 10 billion phases.
// If you want to delete (or delete by hand) the // If you want to delete (or delete by hand) the
// original file at this point after writing to tmp, // original file at this point after writing to tmp,
// The storage capacity is about twice that of the original file. // The storage capacity is about twice that of the original file.
// So, maybe we should have an option to delete the original file. // So, maybe we should have an option to delete the original file.
@@ -1477,11 +1599,11 @@ namespace Learner
std::cout << "write : " << output_file_name << endl; std::cout << "write : " << output_file_name << endl;
// If the file to be written exceeds 2GB, it cannot be // If the file to be written exceeds 2GB, it cannot be
// written in one shot with fstream::write, so use wrapper. // written in one shot with fstream::write, so use wrapper.
write_memory_to_file( write_memory_to_file(
output_file_name, output_file_name,
(void*)&buf[0], (void*)&buf[0],
sizeof(PackedSfenValue) * buf.size()); sizeof(PackedSfenValue) * buf.size());
std::cout << "..shuffle_on_memory done." << std::endl; std::cout << "..shuffle_on_memory done." << std::endl;
@@ -1521,10 +1643,10 @@ namespace Learner
uint64_t buffer_size = 20000000; uint64_t buffer_size = 20000000;
// fast shuffling assuming each file is shuffled // fast shuffling assuming each file is shuffled
bool shuffle_quick = false; bool shuffle_quick = false;
// A function to read the entire file in memory and shuffle it. // A function to read the entire file in memory and shuffle it.
// (Requires file size memory) // (Requires file size memory)
bool shuffle_on_memory = false; bool shuffle_on_memory = false;
// Conversion of packed sfen. In plain, it consists of sfen(string), // Conversion of packed sfen. In plain, it consists of sfen(string),
// evaluation value (integer), move (eg 7g7f, string), result (loss-1, win 1, draw 0) // evaluation value (integer), move (eg 7g7f, string), result (loss-1, win 1, draw 0)
bool use_convert_plain = false; bool use_convert_plain = false;
// convert plain format teacher to Yaneura King's bin // convert plain format teacher to Yaneura King's bin
@@ -1541,15 +1663,15 @@ namespace Learner
// File name to write in those cases (default is "shuffled_sfen.bin") // File name to write in those cases (default is "shuffled_sfen.bin")
string output_file_name = "shuffled_sfen.bin"; string output_file_name = "shuffled_sfen.bin";
// If the absolute value of the evaluation value // If the absolute value of the evaluation value
// in the deep search of the teacher phase exceeds this value, // in the deep search of the teacher phase exceeds this value,
// that phase is discarded. // that phase is discarded.
int eval_limit = 32000; int eval_limit = 32000;
// Flag to save the evaluation function file only once near the end. // Flag to save the evaluation function file only once near the end.
bool save_only_once = false; bool save_only_once = false;
// Shuffle about what you are pre-reading on the teacher aspect. // Shuffle about what you are pre-reading on the teacher aspect.
// (Shuffle of about 10 million phases) // (Shuffle of about 10 million phases)
// Turn on if you want to pass a pre-shuffled file. // Turn on if you want to pass a pre-shuffled file.
bool no_shuffle = false; bool no_shuffle = false;
@@ -1559,8 +1681,8 @@ namespace Learner
ELMO_LAMBDA2 = 0.33; ELMO_LAMBDA2 = 0.33;
ELMO_LAMBDA_LIMIT = 32000; ELMO_LAMBDA_LIMIT = 32000;
// Discount rate. If this is set to a value other than 0, // Discount rate. If this is set to a value other than 0,
// the slope will be added even at other than the PV termination. // the slope will be added even at other than the PV termination.
// (At that time, apply this discount rate) // (At that time, apply this discount rate)
double discount_rate = 0; double discount_rate = 0;
@@ -1620,18 +1742,18 @@ namespace Learner
else if (option == "eta2_epoch") is >> eta2_epoch; else if (option == "eta2_epoch") is >> eta2_epoch;
// Accept also the old option name. // Accept also the old option name.
else if (option == "use_draw_in_training" else if (option == "use_draw_in_training"
|| option == "use_draw_games_in_training") || option == "use_draw_games_in_training")
is >> use_draw_games_in_training; is >> use_draw_games_in_training;
// Accept also the old option name. // Accept also the old option name.
else if (option == "use_draw_in_validation" else if (option == "use_draw_in_validation"
|| option == "use_draw_games_in_validation") || option == "use_draw_games_in_validation")
is >> use_draw_games_in_validation; is >> use_draw_games_in_validation;
// Accept also the old option name. // Accept also the old option name.
else if (option == "use_hash_in_training" else if (option == "use_hash_in_training"
|| option == "skip_duplicated_positions_in_training") || option == "skip_duplicated_positions_in_training")
is >> skip_duplicated_positions_in_training; is >> skip_duplicated_positions_in_training;
else if (option == "winning_probability_coefficient") is >> winning_probability_coefficient; else if (option == "winning_probability_coefficient") is >> winning_probability_coefficient;
@@ -1792,9 +1914,9 @@ namespace Learner
Eval::init_NNUE(); Eval::init_NNUE();
cout << "convert_bin_from_pgn-extract.." << endl; cout << "convert_bin_from_pgn-extract.." << endl;
convert_bin_from_pgn_extract( convert_bin_from_pgn_extract(
filenames, filenames,
output_file_name, output_file_name,
pgn_eval_side_to_move, pgn_eval_side_to_move,
convert_no_eval_fens_as_score_zero); convert_no_eval_fens_as_score_zero);
return; return;
@@ -1808,7 +1930,7 @@ namespace Learner
// Insert the file name for the number of loops. // Insert the file name for the number of loops.
for (int i = 0; i < loop; ++i) for (int i = 0; i < loop; ++i)
{ {
// sfen reader, I'll read it in reverse // sfen reader, I'll read it in reverse
// order so I'll reverse it here. I'm sorry. // order so I'll reverse it here. I'm sorry.
for (auto it = filenames.rbegin(); it != filenames.rend(); ++it) for (auto it = filenames.rbegin(); it != filenames.rend(); ++it)
{ {
@@ -1891,12 +2013,12 @@ namespace Learner
learn_think.mini_batch_size = mini_batch_size; learn_think.mini_batch_size = mini_batch_size;
if (validation_set_file_name.empty()) if (validation_set_file_name.empty())
{ {
// Get about 10,000 data for mse calculation. // Get about 10,000 data for mse calculation.
sr.read_for_mse(); sr.read_for_mse();
} }
else else
{ {
sr.read_validation_set(validation_set_file_name, eval_limit); sr.read_validation_set(validation_set_file_name, eval_limit);
} }