mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 01:37:46 +00:00
Tools transform option for filtering data for training nn-335a9b2d8a80.nnue (#4324)
Append hash of first master net trained with filter method Hardcode depth 6 and remove option to set depth Underscores for consistency Filter out standard startpos positions too
This commit is contained in:
+225
-1
@@ -63,6 +63,13 @@ namespace Stockfish::Tools
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct FilterParams
|
||||||
|
{
|
||||||
|
std::string input_filename = "in.binpack";
|
||||||
|
std::string output_filename = "out.binpack";
|
||||||
|
bool debug_print = false;
|
||||||
|
};
|
||||||
|
|
||||||
[[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16)
|
[[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16)
|
||||||
{
|
{
|
||||||
auto saturate_i32_to_i16 = [](int v) {
|
auto saturate_i32_to_i16 = [](int v) {
|
||||||
@@ -502,11 +509,228 @@ namespace Stockfish::Tools
|
|||||||
do_rescore(params);
|
do_rescore(params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void do_filter_data_335a9b2d8a80(FilterParams& params)
|
||||||
|
{
|
||||||
|
// TODO: Use SfenReader once it works correctly in sequential mode. See issue #271
|
||||||
|
auto in = Tools::open_sfen_input_file(params.input_filename);
|
||||||
|
auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> PSVector {
|
||||||
|
|
||||||
|
PSVector psv;
|
||||||
|
psv.reserve(n);
|
||||||
|
|
||||||
|
std::unique_lock lock(mutex);
|
||||||
|
|
||||||
|
for (int i = 0; i < n; ++i)
|
||||||
|
{
|
||||||
|
auto ps_opt = in->next();
|
||||||
|
if (ps_opt.has_value())
|
||||||
|
{
|
||||||
|
psv.emplace_back(*ps_opt);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return psv;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto sfen_format = SfenOutputType::Binpack;
|
||||||
|
|
||||||
|
auto out = SfenWriter(
|
||||||
|
params.output_filename,
|
||||||
|
Threads.size(),
|
||||||
|
std::numeric_limits<std::uint64_t>::max(),
|
||||||
|
sfen_format);
|
||||||
|
|
||||||
|
// About Search::Limits
|
||||||
|
// Be careful because this member variable is global and affects other threads.
|
||||||
|
auto& limits = Search::Limits;
|
||||||
|
|
||||||
|
// Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done)
|
||||||
|
limits.infinite = true;
|
||||||
|
|
||||||
|
// Since PV is an obstacle when displayed, erase it.
|
||||||
|
limits.silent = true;
|
||||||
|
|
||||||
|
// If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it.
|
||||||
|
limits.nodes = 0;
|
||||||
|
|
||||||
|
// depth is also processed by the one passed as an argument of Tools::search().
|
||||||
|
limits.depth = 0;
|
||||||
|
|
||||||
|
std::atomic<std::uint64_t> num_processed = 0;
|
||||||
|
std::atomic<std::uint64_t> num_standard_startpos = 0;
|
||||||
|
std::atomic<std::uint64_t> num_position_in_check = 0;
|
||||||
|
std::atomic<std::uint64_t> num_move_already_is_capture = 0;
|
||||||
|
std::atomic<std::uint64_t> num_capture_or_promo_skipped_multipv_cap0 = 0;
|
||||||
|
std::atomic<std::uint64_t> num_capture_or_promo_skipped_multipv_cap1 = 0;
|
||||||
|
|
||||||
|
Threads.execute_with_workers([&](auto& th){
|
||||||
|
Position& pos = th.rootPos;
|
||||||
|
StateInfo si;
|
||||||
|
const bool frc = Options["UCI_Chess960"];
|
||||||
|
|
||||||
|
const bool debug_print = params.debug_print;
|
||||||
|
for (;;)
|
||||||
|
{
|
||||||
|
PSVector psv = readsome(5000);
|
||||||
|
if (psv.empty())
|
||||||
|
break;
|
||||||
|
|
||||||
|
for(auto& ps : psv)
|
||||||
|
{
|
||||||
|
pos.set_from_packed_sfen(ps.sfen, &si, &th, frc);
|
||||||
|
bool should_skip_position = false;
|
||||||
|
if (pos.checkers()) {
|
||||||
|
// Skip if in check
|
||||||
|
if (debug_print) {
|
||||||
|
sync_cout << "[debug] " << pos.fen() << sync_endl
|
||||||
|
<< "[debug] Position is in check" << sync_endl
|
||||||
|
<< "[debug]" << sync_endl;
|
||||||
|
}
|
||||||
|
num_position_in_check.fetch_add(1);
|
||||||
|
should_skip_position = true;
|
||||||
|
} else if (pos.capture_or_promotion((Stockfish::Move)ps.move)) {
|
||||||
|
// Skip if the provided move is already a capture or promotion
|
||||||
|
if (debug_print) {
|
||||||
|
sync_cout << "[debug] " << pos.fen() << sync_endl
|
||||||
|
<< "[debug] Provided move is capture or promo: "
|
||||||
|
<< UCI::move((Stockfish::Move)ps.move, false)
|
||||||
|
<< sync_endl
|
||||||
|
<< "[debug]" << sync_endl;
|
||||||
|
}
|
||||||
|
num_move_already_is_capture.fetch_add(1);
|
||||||
|
should_skip_position = true;
|
||||||
|
} else if (pos.fen() == "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1") {
|
||||||
|
num_standard_startpos.fetch_add(1);
|
||||||
|
should_skip_position = true;
|
||||||
|
} else {
|
||||||
|
auto [search_val, pvs] = Search::search(pos, 6, 2);
|
||||||
|
if (!pvs.empty() && th.rootMoves.size() > 0) {
|
||||||
|
auto best_move = th.rootMoves[0].pv[0];
|
||||||
|
bool more_than_one_valid_move = th.rootMoves.size() > 1;
|
||||||
|
if (debug_print) {
|
||||||
|
sync_cout << "[debug] " << pos.fen() << sync_endl;
|
||||||
|
sync_cout << "[debug] Main PV move: "
|
||||||
|
<< UCI::move(best_move, false) << " "
|
||||||
|
<< th.rootMoves[0].score << " " << sync_endl;
|
||||||
|
if (more_than_one_valid_move) {
|
||||||
|
sync_cout << "[debug] 2nd PV move: "
|
||||||
|
<< UCI::move(th.rootMoves[1].pv[0], false) << " "
|
||||||
|
<< th.rootMoves[1].score << " " << sync_endl;
|
||||||
|
} else {
|
||||||
|
sync_cout << "[debug] The only valid move" << sync_endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pos.capture_or_promotion(best_move)) {
|
||||||
|
// skip if multipv 1st line bestmove is a capture or promo
|
||||||
|
if (debug_print) {
|
||||||
|
sync_cout << "[debug] Move is capture or promo: " << UCI::move(best_move, false)
|
||||||
|
<< sync_endl
|
||||||
|
<< "[debug] 1st best move at depth 6 multipv 2" << sync_endl
|
||||||
|
<< "[debug]" << sync_endl;
|
||||||
|
}
|
||||||
|
num_capture_or_promo_skipped_multipv_cap0.fetch_add(1);
|
||||||
|
should_skip_position = true;
|
||||||
|
} else if (more_than_one_valid_move && pos.capture_or_promotion(th.rootMoves[1].pv[0])) {
|
||||||
|
// skip if multipv 2nd line bestmove is a capture or promo
|
||||||
|
if (debug_print) {
|
||||||
|
sync_cout << "[debug] Move is capture or promo: " << UCI::move(best_move, false)
|
||||||
|
<< sync_endl
|
||||||
|
<< "[debug] 2nd best move at depth 6 multipv 2" << sync_endl
|
||||||
|
<< "[debug]" << sync_endl;
|
||||||
|
}
|
||||||
|
num_capture_or_promo_skipped_multipv_cap1.fetch_add(1);
|
||||||
|
should_skip_position = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos.sfen_pack(ps.sfen, false);
|
||||||
|
// nnue-pytorch training data loader skips positions with score VALUE_NONE
|
||||||
|
if (should_skip_position)
|
||||||
|
ps.score = 32002; // VALUE_NONE
|
||||||
|
ps.padding = 0;
|
||||||
|
|
||||||
|
out.write(th.id(), ps);
|
||||||
|
|
||||||
|
auto p = num_processed.fetch_add(1) + 1;
|
||||||
|
if (p % 10000 == 0) {
|
||||||
|
auto c = num_position_in_check.load();
|
||||||
|
auto a = num_move_already_is_capture.load();
|
||||||
|
auto s = num_standard_startpos.load();
|
||||||
|
auto multipv_cap0 = num_capture_or_promo_skipped_multipv_cap0.load();
|
||||||
|
auto multipv_cap1 = num_capture_or_promo_skipped_multipv_cap1.load();
|
||||||
|
sync_cout << "Processed " << p << " positions. Skipped " << (c + a + s + multipv_cap0 + multipv_cap1) << " positions."
|
||||||
|
<< sync_endl
|
||||||
|
<< " Static filter: " << (a + c + s)
|
||||||
|
<< " (capture or promo: " << a << ", in check: " << c << ", startpos: " << s << ")"
|
||||||
|
<< sync_endl
|
||||||
|
<< " MultiPV filter: " << (multipv_cap0 + multipv_cap1)
|
||||||
|
<< " (cap0: " << multipv_cap0 << ", cap1: " << multipv_cap1 << ")"
|
||||||
|
<< " depth 6 multipv 2" << sync_endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
Threads.wait_for_workers_finished();
|
||||||
|
|
||||||
|
std::cout << "Finished.\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
void do_filter_335a9b2d8a80(FilterParams& params)
|
||||||
|
{
|
||||||
|
if (ends_with(params.input_filename, ".binpack"))
|
||||||
|
{
|
||||||
|
do_filter_data_335a9b2d8a80(params);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::cerr << "Invalid input file type.\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void filter_335a9b2d8a80(std::istringstream& is)
|
||||||
|
{
|
||||||
|
FilterParams params{};
|
||||||
|
|
||||||
|
while(true)
|
||||||
|
{
|
||||||
|
std::string token;
|
||||||
|
is >> token;
|
||||||
|
|
||||||
|
if (token == "")
|
||||||
|
break;
|
||||||
|
|
||||||
|
else if (token == "input_file")
|
||||||
|
is >> params.input_filename;
|
||||||
|
else if (token == "output_file")
|
||||||
|
is >> params.output_filename;
|
||||||
|
else if (token == "debug_print")
|
||||||
|
is >> params.debug_print;
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::cout << "ERROR: Unknown option " << token << ". Exiting...\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "Performing transform filter_335a9b2d8a80 with parameters:\n";
|
||||||
|
std::cout << "input_file : " << params.input_filename << '\n';
|
||||||
|
std::cout << "output_file : " << params.output_filename << '\n';
|
||||||
|
std::cout << "debug_print : " << params.debug_print << '\n';
|
||||||
|
std::cout << '\n';
|
||||||
|
|
||||||
|
do_filter_335a9b2d8a80(params);
|
||||||
|
}
|
||||||
|
|
||||||
void transform(std::istringstream& is)
|
void transform(std::istringstream& is)
|
||||||
{
|
{
|
||||||
const std::map<std::string, CommandFunc> subcommands = {
|
const std::map<std::string, CommandFunc> subcommands = {
|
||||||
{ "nudged_static", &nudged_static },
|
{ "nudged_static", &nudged_static },
|
||||||
{ "rescore", &rescore }
|
{ "rescore", &rescore },
|
||||||
|
{ "filter_335a9b2d8a80", &filter_335a9b2d8a80 }
|
||||||
};
|
};
|
||||||
|
|
||||||
Eval::NNUE::init();
|
Eval::NNUE::init();
|
||||||
|
|||||||
Reference in New Issue
Block a user