Support for binpack format in sfenreader in learner. Automatically detect file extension and choose the correct reader (bin or binpack)

This commit is contained in:
Tomasz Sobczyk
2020-09-10 00:31:38 +02:00
committed by nodchip
parent 020e66d2e6
commit 6b76ebc2ca
3 changed files with 328 additions and 304 deletions
+65 -166
View File
@@ -2745,9 +2745,6 @@ namespace chess
0x000200282410A102ull, 0x000200282410A102ull,
0x4048240043802106ull 0x4048240043802106ull
} }; } };
alignas(64) extern EnumArray<Square, Bitboard> g_rookMasks;
alignas(64) extern EnumArray<Square, std::uint8_t> g_rookShifts;
alignas(64) extern EnumArray<Square, const Bitboard*> g_rookAttacks;
alignas(64) constexpr EnumArray<Square, std::uint64_t> g_bishopMagics{ { alignas(64) constexpr EnumArray<Square, std::uint64_t> g_bishopMagics{ {
0x40106000A1160020ull, 0x40106000A1160020ull,
@@ -2815,9 +2812,17 @@ namespace chess
0x0300404822C08200ull, 0x0300404822C08200ull,
0x48081010008A2A80ull 0x48081010008A2A80ull
} }; } };
alignas(64) extern EnumArray<Square, Bitboard> g_bishopMasks;
alignas(64) extern EnumArray<Square, std::uint8_t> g_bishopShifts; alignas(64) static EnumArray<Square, Bitboard> g_rookMasks;
alignas(64) extern EnumArray<Square, const Bitboard*> g_bishopAttacks; alignas(64) static EnumArray<Square, std::uint8_t> g_rookShifts;
alignas(64) static EnumArray<Square, const Bitboard*> g_rookAttacks;
alignas(64) static EnumArray<Square, Bitboard> g_bishopMasks;
alignas(64) static EnumArray<Square, std::uint8_t> g_bishopShifts;
alignas(64) static EnumArray<Square, const Bitboard*> g_bishopAttacks;
alignas(64) static std::array<Bitboard, 102400> g_allRookAttacks;
alignas(64) static std::array<Bitboard, 5248> g_allBishopAttacks;
inline Bitboard bishopAttacks(Square s, Bitboard occupied) inline Bitboard bishopAttacks(Square s, Bitboard occupied)
{ {
@@ -3402,17 +3407,6 @@ namespace chess
Bishop Bishop
}; };
alignas(64) EnumArray<Square, Bitboard> g_rookMasks;
alignas(64) EnumArray<Square, std::uint8_t> g_rookShifts;
alignas(64) EnumArray<Square, const Bitboard*> g_rookAttacks;
alignas(64) EnumArray<Square, Bitboard> g_bishopMasks;
alignas(64) EnumArray<Square, std::uint8_t> g_bishopShifts;
alignas(64) EnumArray<Square, const Bitboard*> g_bishopAttacks;
alignas(64) static std::array<Bitboard, 102400> g_allRookAttacks;
alignas(64) static std::array<Bitboard, 5248> g_allBishopAttacks;
template <MagicsType TypeV> template <MagicsType TypeV>
[[nodiscard]] inline Bitboard slidingAttacks(Square sq, Bitboard occupied) [[nodiscard]] inline Bitboard slidingAttacks(Square sq, Bitboard occupied)
{ {
@@ -3857,7 +3851,7 @@ namespace chess
return true; return true;
} }
[[nodiscard]] std::string fen() const; [[nodiscard]] inline std::string fen() const;
[[nodiscard]] inline bool trySet(std::string_view boardState) [[nodiscard]] inline bool trySet(std::string_view boardState)
{ {
@@ -4093,7 +4087,7 @@ namespace chess
// returns captured piece // returns captured piece
// doesn't check validity // doesn't check validity
FORCEINLINE constexpr Piece doMove(Move move) inline constexpr Piece doMove(Move move)
{ {
if (move.type == MoveType::Normal) if (move.type == MoveType::Normal)
{ {
@@ -4132,7 +4126,7 @@ namespace chess
return doMoveColdPath(move); return doMoveColdPath(move);
} }
NOINLINE constexpr Piece doMoveColdPath(Move move) inline constexpr Piece doMoveColdPath(Move move)
{ {
if (move.type == MoveType::Promotion) if (move.type == MoveType::Promotion)
{ {
@@ -4333,38 +4327,38 @@ namespace chess
// Returns whether a given square is attacked by any piece // Returns whether a given square is attacked by any piece
// of `attackerColor` side. // of `attackerColor` side.
[[nodiscard]] bool isSquareAttacked(Square sq, Color attackerColor) const; [[nodiscard]] inline bool isSquareAttacked(Square sq, Color attackerColor) const;
// Returns whether a given square is attacked by any piece // Returns whether a given square is attacked by any piece
// of `attackerColor` side after `move` is made. // of `attackerColor` side after `move` is made.
// Move must be pseudo legal. // Move must be pseudo legal.
[[nodiscard]] bool isSquareAttackedAfterMove(Move move, Square sq, Color attackerColor) const; [[nodiscard]] inline bool isSquareAttackedAfterMove(Move move, Square sq, Color attackerColor) const;
// Move must be pseudo legal. // Move must be pseudo legal.
// Must not be a king move. // Must not be a king move.
[[nodiscard]] bool createsDiscoveredAttackOnOwnKing(Move move) const; [[nodiscard]] inline bool createsDiscoveredAttackOnOwnKing(Move move) const;
// Returns whether a piece on a given square is attacked // Returns whether a piece on a given square is attacked
// by any enemy piece. False if square is empty. // by any enemy piece. False if square is empty.
[[nodiscard]] bool isPieceAttacked(Square sq) const; [[nodiscard]] inline bool isPieceAttacked(Square sq) const;
// Returns whether a piece on a given square is attacked // Returns whether a piece on a given square is attacked
// by any enemy piece after `move` is made. False if square is empty. // by any enemy piece after `move` is made. False if square is empty.
// Move must be pseudo legal. // Move must be pseudo legal.
[[nodiscard]] bool isPieceAttackedAfterMove(Move move, Square sq) const; [[nodiscard]] inline bool isPieceAttackedAfterMove(Move move, Square sq) const;
// Returns whether the king of the moving side is attacked // Returns whether the king of the moving side is attacked
// by any enemy piece after a move is made. // by any enemy piece after a move is made.
// Move must be pseudo legal. // Move must be pseudo legal.
[[nodiscard]] bool isOwnKingAttackedAfterMove(Move move) const; [[nodiscard]] inline bool isOwnKingAttackedAfterMove(Move move) const;
// Return a bitboard with all (pseudo legal) attacks by the piece on // Return a bitboard with all (pseudo legal) attacks by the piece on
// the given square. Empty if no piece on the square. // the given square. Empty if no piece on the square.
[[nodiscard]] Bitboard attacks(Square sq) const; [[nodiscard]] inline Bitboard attacks(Square sq) const;
// Returns a bitboard with all squared that have pieces // Returns a bitboard with all squared that have pieces
// that attack a given square (pseudo legally) // that attack a given square (pseudo legally)
[[nodiscard]] Bitboard attackers(Square sq, Color attackerColor) const; [[nodiscard]] inline Bitboard attackers(Square sq, Color attackerColor) const;
[[nodiscard]] constexpr Piece pieceAt(Square sq) const [[nodiscard]] constexpr Piece pieceAt(Square sq) const
{ {
@@ -4438,20 +4432,6 @@ namespace chess
struct Position; struct Position;
struct MoveLegalityChecker
{
MoveLegalityChecker(const Position& position);
[[nodiscard]] bool isPseudoLegalMoveLegal(const Move& move) const;
private:
const Position* m_position;
Bitboard m_checkers;
Bitboard m_ourBlockersForKing;
Bitboard m_potentialCheckRemovals;
Square m_ksq;
};
struct CompressedPosition; struct CompressedPosition;
struct PositionHash128 struct PositionHash128
@@ -4484,20 +4464,20 @@ namespace chess
{ {
} }
void set(std::string_view fen); inline void set(std::string_view fen);
// Returns false if the fen was not valid // Returns false if the fen was not valid
// If the returned value was false the position // If the returned value was false the position
// is in unspecified state. // is in unspecified state.
[[nodiscard]] bool trySet(std::string_view fen); [[nodiscard]] inline bool trySet(std::string_view fen);
[[nodiscard]] static Position fromFen(std::string_view fen); [[nodiscard]] static inline Position fromFen(std::string_view fen);
[[nodiscard]] static std::optional<Position> tryFromFen(std::string_view fen); [[nodiscard]] static inline std::optional<Position> tryFromFen(std::string_view fen);
[[nodiscard]] static Position startPosition(); [[nodiscard]] static inline Position startPosition();
[[nodiscard]] std::string fen() const; [[nodiscard]] inline std::string fen() const;
constexpr void setEpSquareUnchecked(Square sq) constexpr void setEpSquareUnchecked(Square sq)
{ {
@@ -4535,7 +4515,7 @@ namespace chess
m_ply = ply; m_ply = ply;
} }
ReverseMove doMove(const Move& move); inline ReverseMove doMove(const Move& move);
constexpr void undoMove(const ReverseMove& reverseMove) constexpr void undoMove(const ReverseMove& reverseMove)
{ {
@@ -4559,49 +4539,44 @@ namespace chess
return m_sideToMove; return m_sideToMove;
} }
[[nodiscard]] std::uint8_t rule50Counter() const [[nodiscard]] inline std::uint8_t rule50Counter() const
{ {
return m_rule50Counter; return m_rule50Counter;
} }
[[nodiscard]] std::uint16_t ply() const [[nodiscard]] inline std::uint16_t ply() const
{ {
return m_ply; return m_ply;
} }
[[nodiscard]] std::uint16_t halfMove() const [[nodiscard]] inline std::uint16_t halfMove() const
{ {
return (m_ply + 1) / 2; return (m_ply + 1) / 2;
} }
void setHalfMove(std::uint16_t hm) inline void setHalfMove(std::uint16_t hm)
{ {
m_ply = 2 * hm - 1 + (m_sideToMove == Color::Black); m_ply = 2 * hm - 1 + (m_sideToMove == Color::Black);
} }
[[nodiscard]] bool isCheck() const; [[nodiscard]] inline bool isCheck() const;
[[nodiscard]] Bitboard checkers() const; [[nodiscard]] inline Bitboard checkers() const;
[[nodiscard]] bool isCheckAfterMove(Move move) const; [[nodiscard]] inline bool isCheckAfterMove(Move move) const;
// Checks whether ANY `move` is legal. // Checks whether ANY `move` is legal.
[[nodiscard]] bool isMoveLegal(Move move) const; [[nodiscard]] inline bool isMoveLegal(Move move) const;
[[nodiscard]] bool isPseudoLegalMoveLegal(Move move) const; [[nodiscard]] inline bool isPseudoLegalMoveLegal(Move move) const;
[[nodiscard]] bool isMovePseudoLegal(Move move) const; [[nodiscard]] inline bool isMovePseudoLegal(Move move) const;
// Returns all pieces that block a slider // Returns all pieces that block a slider
// from attacking our king. When two or more // from attacking our king. When two or more
// pieces block a single slider then none // pieces block a single slider then none
// of these pieces are included. // of these pieces are included.
[[nodiscard]] Bitboard blockersForKing(Color color) const; [[nodiscard]] inline Bitboard blockersForKing(Color color) const;
[[nodiscard]] MoveLegalityChecker moveLegalityChecker() const
{
return { *this };
}
[[nodiscard]] constexpr Square epSquare() const [[nodiscard]] constexpr Square epSquare() const
{ {
@@ -4637,7 +4612,7 @@ namespace chess
return cpy; return cpy;
} }
[[nodiscard]] Position afterMove(Move move) const; [[nodiscard]] inline Position afterMove(Move move) const;
[[nodiscard]] constexpr bool isEpPossible() const [[nodiscard]] constexpr bool isEpPossible() const
{ {
@@ -4655,11 +4630,11 @@ namespace chess
static_assert(sizeof(Color) + sizeof(Square) + sizeof(CastlingRights) + sizeof(std::uint8_t) == 4); static_assert(sizeof(Color) + sizeof(Square) + sizeof(CastlingRights) + sizeof(std::uint8_t) == 4);
[[nodiscard]] FORCEINLINE bool isEpPossible(Square epSquare, Color sideToMove) const; [[nodiscard]] inline bool isEpPossible(Square epSquare, Color sideToMove) const;
[[nodiscard]] NOINLINE bool isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const; [[nodiscard]] inline bool isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const;
void nullifyEpSquareIfNotPossible(); inline void nullifyEpSquareIfNotPossible();
}; };
struct CompressedPosition struct CompressedPosition
@@ -5302,7 +5277,7 @@ namespace chess
return allAttackers; return allAttackers;
} }
const Piece* Board::piecesRaw() const inline const Piece* Board::piecesRaw() const
{ {
return m_pieces.data(); return m_pieces.data();
} }
@@ -5330,7 +5305,7 @@ namespace chess
}(); }();
} }
[[nodiscard]] std::string Board::fen() const [[nodiscard]] inline std::string Board::fen() const
{ {
std::string fen; std::string fen;
fen.reserve(96); // longest fen is probably in range of around 88 fen.reserve(96); // longest fen is probably in range of around 88
@@ -5382,79 +5357,6 @@ namespace chess
return fen; return fen;
} }
MoveLegalityChecker::MoveLegalityChecker(const Position& position) :
m_position(&position),
m_checkers(position.checkers()),
m_ourBlockersForKing(
position.blockersForKing(position.sideToMove())
& position.piecesBB(position.sideToMove())
),
m_ksq(position.kingSquare(position.sideToMove()))
{
if (m_checkers.exactlyOne())
{
const Bitboard knightCheckers = m_checkers & bb::pseudoAttacks<PieceType::Knight>(m_ksq);
if (knightCheckers.any())
{
// We're checked by a knight, we have to remove it or move the king.
m_potentialCheckRemovals = knightCheckers;
}
else
{
// If we're not checked by a knight we can block it.
m_potentialCheckRemovals = bb::between(m_ksq, m_checkers.first()) | m_checkers;
}
}
else
{
// Double check, king has to move.
m_potentialCheckRemovals = Bitboard::none();
}
}
[[nodiscard]] bool MoveLegalityChecker::isPseudoLegalMoveLegal(const Move& move) const
{
const Piece movedPiece = m_position->pieceAt(move.from);
if (m_checkers.any())
{
if (move.from == m_ksq || move.type == MoveType::EnPassant)
{
return m_position->isPseudoLegalMoveLegal(move);
}
else
{
// This means there's only one check and we either
// blocked it or removed the piece that attacked
// our king. So the only threat is if it's a discovered check.
return
m_potentialCheckRemovals.isSet(move.to)
&& !m_ourBlockersForKing.isSet(move.from);
}
}
else
{
if (move.from == m_ksq)
{
return m_position->isPseudoLegalMoveLegal(move);
}
else if (move.type == MoveType::EnPassant)
{
return !m_position->createsDiscoveredAttackOnOwnKing(move);
}
else if (m_ourBlockersForKing.isSet(move.from))
{
// If it was a blocker it may have only moved in line with our king.
// Otherwise it's a discovered check.
return bb::line(m_ksq, move.from).isSet(move.to);
}
else
{
return true;
}
}
}
void Position::set(std::string_view fen) void Position::set(std::string_view fen)
{ {
(void)trySet(fen); (void)trySet(fen);
@@ -5611,7 +5513,7 @@ namespace chess
}(); }();
} }
ReverseMove Position::doMove(const Move& move) inline ReverseMove Position::doMove(const Move& move)
{ {
assert(move.from.isOk() && move.to.isOk()); assert(move.from.isOk() && move.to.isOk());
@@ -5649,12 +5551,12 @@ namespace chess
return { move, captured, oldEpSquare, oldCastlingRights }; return { move, captured, oldEpSquare, oldCastlingRights };
} }
[[nodiscard]] bool Position::isCheck() const [[nodiscard]] inline bool Position::isCheck() const
{ {
return BaseType::isSquareAttacked(kingSquare(m_sideToMove), !m_sideToMove); return BaseType::isSquareAttacked(kingSquare(m_sideToMove), !m_sideToMove);
} }
[[nodiscard]] Bitboard Position::checkers() const [[nodiscard]] inline Bitboard Position::checkers() const
{ {
return BaseType::attackers(kingSquare(m_sideToMove), !m_sideToMove); return BaseType::attackers(kingSquare(m_sideToMove), !m_sideToMove);
} }
@@ -5664,7 +5566,7 @@ namespace chess
return BaseType::isSquareAttackedAfterMove(move, kingSquare(!m_sideToMove), m_sideToMove); return BaseType::isSquareAttackedAfterMove(move, kingSquare(!m_sideToMove), m_sideToMove);
} }
[[nodiscard]] Bitboard Position::blockersForKing(Color color) const [[nodiscard]] inline Bitboard Position::blockersForKing(Color color) const
{ {
const Color attackerColor = !color; const Color attackerColor = !color;
@@ -5700,7 +5602,7 @@ namespace chess
return allBlockers; return allBlockers;
} }
[[nodiscard]] Position Position::afterMove(Move move) const [[nodiscard]] inline Position Position::afterMove(Move move) const
{ {
Position cpy(*this); Position cpy(*this);
auto pc = cpy.doMove(move); auto pc = cpy.doMove(move);
@@ -5711,7 +5613,7 @@ namespace chess
return cpy; return cpy;
} }
[[nodiscard]] FORCEINLINE bool Position::isEpPossible(Square epSquare, Color sideToMove) const [[nodiscard]] inline bool Position::isEpPossible(Square epSquare, Color sideToMove) const
{ {
const Bitboard pawnsAttackingEpSquare = const Bitboard pawnsAttackingEpSquare =
bb::pawnAttacks(Bitboard::square(epSquare), !sideToMove) bb::pawnAttacks(Bitboard::square(epSquare), !sideToMove)
@@ -5725,7 +5627,7 @@ namespace chess
return isEpPossibleColdPath(epSquare, pawnsAttackingEpSquare, sideToMove); return isEpPossibleColdPath(epSquare, pawnsAttackingEpSquare, sideToMove);
} }
[[nodiscard]] NOINLINE bool Position::isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const [[nodiscard]] inline bool Position::isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const
{ {
// only set m_epSquare when it matters, ie. when // only set m_epSquare when it matters, ie. when
// the opposite side can actually capture // the opposite side can actually capture
@@ -5772,7 +5674,7 @@ namespace chess
return false; return false;
} }
void Position::nullifyEpSquareIfNotPossible() inline void Position::nullifyEpSquareIfNotPossible()
{ {
if (m_epSquare != Square::none() && !isEpPossible(m_epSquare, m_sideToMove)) if (m_epSquare != Square::none() && !isEpPossible(m_epSquare, m_sideToMove))
{ {
@@ -5782,12 +5684,12 @@ namespace chess
namespace uci namespace uci
{ {
[[nodiscard]] std::string moveToUci(const Position& pos, const Move& move); [[nodiscard]] inline std::string moveToUci(const Position& pos, const Move& move);
[[nodiscard]] Move uciToMove(const Position& pos, std::string_view sv); [[nodiscard]] inline Move uciToMove(const Position& pos, std::string_view sv);
[[nodiscard]] std::optional<Move> tryUciToMove(const Position& pos, std::string_view sv); [[nodiscard]] inline std::optional<Move> tryUciToMove(const Position& pos, std::string_view sv);
[[nodiscard]] std::string moveToUci(const Position& pos, const Move& move) [[nodiscard]] inline std::string moveToUci(const Position& pos, const Move& move)
{ {
std::string s; std::string s;
@@ -5814,7 +5716,7 @@ namespace chess
return s; return s;
} }
[[nodiscard]] Move uciToMove(const Position& pos, std::string_view sv) [[nodiscard]] inline Move uciToMove(const Position& pos, std::string_view sv)
{ {
const Square from = parser_bits::parseSquare(sv.data()); const Square from = parser_bits::parseSquare(sv.data());
const Square to = parser_bits::parseSquare(sv.data() + 2); const Square to = parser_bits::parseSquare(sv.data() + 2);
@@ -5850,7 +5752,7 @@ namespace chess
} }
} }
[[nodiscard]] std::optional<Move> tryUciToMove(const Position& pos, std::string_view sv) [[nodiscard]] inline std::optional<Move> tryUciToMove(const Position& pos, std::string_view sv)
{ {
if (sv.size() < 4 || sv.size() > 5) if (sv.size() < 4 || sv.size() > 5)
{ {
@@ -6300,7 +6202,7 @@ namespace binpack
}; };
[[nodiscard]] chess::Position pos_from_packed_sfen(const PackedSfen& sfen) [[nodiscard]] inline chess::Position pos_from_packed_sfen(const PackedSfen& sfen)
{ {
SfenPacker packer; SfenPacker packer;
auto& stream = packer.stream; auto& stream = packer.stream;
@@ -6655,14 +6557,12 @@ namespace binpack
if (!occupied.isSet(sqForward)) if (!occupied.isSet(sqForward))
{ {
destinations |= sqForward; destinations |= sqForward;
const chess::Square sqForward2 = sqForward + forward;
if ( if (
from.rank() == startRank from.rank() == startRank
&& !occupied.isSet(sqForward2) && !occupied.isSet(sqForward + forward)
) )
{ {
destinations |= sqForward2; destinations |= sqForward + forward;
} }
} }
@@ -6845,13 +6745,12 @@ namespace binpack
{ {
destinations |= sqForward; destinations |= sqForward;
const chess::Square sqForward2 = sqForward + forward;
if ( if (
move.from.rank() == startRank move.from.rank() == startRank
&& !occupied.isSet(sqForward2) && !occupied.isSet(sqForward + forward)
) )
{ {
destinations |= sqForward2; destinations |= sqForward + forward;
} }
} }
+4 -1
View File
@@ -133,9 +133,12 @@ namespace Learner
{ {
case SfenOutputType::Bin: case SfenOutputType::Bin:
return std::make_unique<BinSfenOutputStream>(filename); return std::make_unique<BinSfenOutputStream>(filename);
default: case SfenOutputType::Binpack:
return std::make_unique<BinpackSfenOutputStream>(filename); return std::make_unique<BinpackSfenOutputStream>(filename);
} }
assert(false);
return nullptr;
} }
// Helper class for exporting Sfen // Helper class for exporting Sfen
+143 -21
View File
@@ -30,6 +30,8 @@
#include "learn.h" #include "learn.h"
#include "multi_think.h" #include "multi_think.h"
#include "../extra/nnue_data_binpack_format.h"
#include <chrono> #include <chrono>
#include <climits> #include <climits>
#include <cmath> // std::exp(),std::pow(),std::log() #include <cmath> // std::exp(),std::pow(),std::log()
@@ -297,6 +299,115 @@ namespace Learner
return calc_grad((Value)psv.score, shallow, psv); return calc_grad((Value)psv.score, shallow, psv);
} }
struct BasicSfenInputStream
{
virtual std::optional<PackedSfenValue> next() = 0;
virtual bool eof() const = 0;
virtual ~BasicSfenInputStream() {}
};
struct BinSfenInputStream : BasicSfenInputStream
{
static constexpr auto openmode = ios::in | ios::binary;
static inline const std::string extension = "bin";
BinSfenInputStream(std::string filename) :
m_stream(filename, openmode),
m_eof(!m_stream)
{
}
std::optional<PackedSfenValue> next() override
{
PackedSfenValue e;
if(m_stream.read(reinterpret_cast<char*>(&e), sizeof(PackedSfenValue)))
{
return e;
}
else
{
m_eof = true;
return std::nullopt;
}
}
bool eof() const override
{
return m_eof;
}
~BinSfenInputStream() override {}
private:
fstream m_stream;
bool m_eof;
};
struct BinpackSfenInputStream : BasicSfenInputStream
{
static constexpr auto openmode = ios::in | ios::binary;
static inline const std::string extension = "binpack";
BinpackSfenInputStream(std::string filename) :
m_stream(filename, openmode),
m_eof(!m_stream.hasNext())
{
}
std::optional<PackedSfenValue> next() override
{
static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue));
if (!m_stream.hasNext())
{
m_eof = true;
return std::nullopt;
}
auto training_data_entry = m_stream.next();
auto v = binpack::trainingDataEntryToPackedSfenValue(training_data_entry);
PackedSfenValue psv;
// same layout, different types. One is from generic library.
std::memcpy(&psv, &v, sizeof(PackedSfenValue));
return psv;
}
bool eof() const override
{
return m_eof;
}
~BinpackSfenInputStream() override {}
private:
binpack::CompressedTrainingDataEntryReader m_stream;
bool m_eof;
};
static bool ends_with(const std::string& lhs, const std::string& end)
{
if (end.size() > lhs.size()) return false;
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
}
static bool has_extension(const std::string& filename, const std::string& extension)
{
return ends_with(filename, "." + extension);
}
static std::unique_ptr<BasicSfenInputStream> open_sfen_input_file(const std::string& filename)
{
if (has_extension(filename, BinSfenInputStream::extension))
return std::make_unique<BinSfenInputStream>(filename);
else if (has_extension(filename, BinpackSfenInputStream::extension))
return std::make_unique<BinpackSfenInputStream>(filename);
assert(false);
return nullptr;
}
// Sfen reader // Sfen reader
struct SfenReader struct SfenReader
{ {
@@ -369,13 +480,15 @@ namespace Learner
void read_validation_set(const string& file_name, int eval_limit) void read_validation_set(const string& file_name, int eval_limit)
{ {
ifstream input(file_name, ios::binary); auto input = open_sfen_input_file(file_name);
while (input) while(!input->eof())
{ {
PackedSfenValue p; std::optional<PackedSfenValue> p_opt = input->next();
if (input.read(reinterpret_cast<char*>(&p), sizeof(PackedSfenValue))) if (p_opt.has_value())
{ {
auto& p = *p_opt;
if (eval_limit < abs(p.score)) if (eval_limit < abs(p.score))
continue; continue;
@@ -466,25 +579,34 @@ namespace Learner
void file_read_worker() void file_read_worker()
{ {
auto open_next_file = [&]() { auto open_next_file = [&]() {
if (fs.is_open())
fs.close();
// no more // no more
if (filenames.empty()) for(;;)
return false; {
sfen_input_stream.reset();
// Get the next file name. if (filenames.empty())
string filename = filenames.back(); return false;
filenames.pop_back();
fs.open(filename, ios::in | ios::binary); // Get the next file name.
cout << "open filename = " << filename << endl; string filename = filenames.back();
filenames.pop_back();
assert(fs); sfen_input_stream = open_sfen_input_file(filename);
cout << "open filename = " << filename << endl;
return true; // in case the file is empty or was deleted.
if (!sfen_input_stream->eof())
return true;
}
}; };
if (sfen_input_stream == nullptr && !open_next_file())
{
cout << "..end of files." << endl;
end_of_files = true;
return;
}
while (true) while (true)
{ {
// Wait for the buffer to run out. // Wait for the buffer to run out.
@@ -501,10 +623,10 @@ namespace Learner
// Read from the file into the file buffer. // Read from the file into the file buffer.
while (sfens.size() < SFEN_READ_SIZE) while (sfens.size() < SFEN_READ_SIZE)
{ {
PackedSfenValue p; std::optional<PackedSfenValue> p = sfen_input_stream->next();
if (fs.read(reinterpret_cast<char*>(&p), sizeof(PackedSfenValue))) if (p.has_value())
{ {
sfens.push_back(p); sfens.push_back(*p);
} }
else if(!open_next_file()) else if(!open_next_file())
{ {
@@ -545,7 +667,7 @@ namespace Learner
{ {
std::unique_lock<std::mutex> lk(mutex); std::unique_lock<std::mutex> lk(mutex);
// The mutex lock is required because the // The mutex lock is required because the%
// contents of packed_sfens_pool are changed. // contents of packed_sfens_pool are changed.
for (auto& buf : buffers) for (auto& buf : buffers)
@@ -600,7 +722,7 @@ namespace Learner
atomic<bool> end_of_files; atomic<bool> end_of_files;
// handle of sfen file // handle of sfen file
std::fstream fs; std::unique_ptr<BasicSfenInputStream> sfen_input_stream;
// sfen for each thread // sfen for each thread
// (When the thread is used up, the thread should call delete to release it.) // (When the thread is used up, the thread should call delete to release it.)