diff --git a/src/tools/transform.cpp b/src/tools/transform.cpp index 0712658b..5abf24b8 100644 --- a/src/tools/transform.cpp +++ b/src/tools/transform.cpp @@ -63,6 +63,13 @@ namespace Stockfish::Tools } }; + struct FilterParams + { + std::string input_filename = "in.binpack"; + std::string output_filename = "out.binpack"; + bool debug_print = false; + }; + [[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16) { auto saturate_i32_to_i16 = [](int v) { @@ -502,11 +509,228 @@ namespace Stockfish::Tools do_rescore(params); } + void do_filter_data_335a9b2d8a80(FilterParams& params) + { + // TODO: Use SfenReader once it works correctly in sequential mode. See issue #271 + auto in = Tools::open_sfen_input_file(params.input_filename); + auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> PSVector { + + PSVector psv; + psv.reserve(n); + + std::unique_lock lock(mutex); + + for (int i = 0; i < n; ++i) + { + auto ps_opt = in->next(); + if (ps_opt.has_value()) + { + psv.emplace_back(*ps_opt); + } + else + { + break; + } + } + + return psv; + }; + + auto sfen_format = SfenOutputType::Binpack; + + auto out = SfenWriter( + params.output_filename, + Threads.size(), + std::numeric_limits::max(), + sfen_format); + + // About Search::Limits + // Be careful because this member variable is global and affects other threads. + auto& limits = Search::Limits; + + // Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done) + limits.infinite = true; + + // Since PV is an obstacle when displayed, erase it. + limits.silent = true; + + // If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it. + limits.nodes = 0; + + // depth is also processed by the one passed as an argument of Tools::search(). + limits.depth = 0; + + std::atomic num_processed = 0; + std::atomic num_standard_startpos = 0; + std::atomic num_position_in_check = 0; + std::atomic num_move_already_is_capture = 0; + std::atomic num_capture_or_promo_skipped_multipv_cap0 = 0; + std::atomic num_capture_or_promo_skipped_multipv_cap1 = 0; + + Threads.execute_with_workers([&](auto& th){ + Position& pos = th.rootPos; + StateInfo si; + const bool frc = Options["UCI_Chess960"]; + + const bool debug_print = params.debug_print; + for (;;) + { + PSVector psv = readsome(5000); + if (psv.empty()) + break; + + for(auto& ps : psv) + { + pos.set_from_packed_sfen(ps.sfen, &si, &th, frc); + bool should_skip_position = false; + if (pos.checkers()) { + // Skip if in check + if (debug_print) { + sync_cout << "[debug] " << pos.fen() << sync_endl + << "[debug] Position is in check" << sync_endl + << "[debug]" << sync_endl; + } + num_position_in_check.fetch_add(1); + should_skip_position = true; + } else if (pos.capture_or_promotion((Stockfish::Move)ps.move)) { + // Skip if the provided move is already a capture or promotion + if (debug_print) { + sync_cout << "[debug] " << pos.fen() << sync_endl + << "[debug] Provided move is capture or promo: " + << UCI::move((Stockfish::Move)ps.move, false) + << sync_endl + << "[debug]" << sync_endl; + } + num_move_already_is_capture.fetch_add(1); + should_skip_position = true; + } else if (pos.fen() == "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1") { + num_standard_startpos.fetch_add(1); + should_skip_position = true; + } else { + auto [search_val, pvs] = Search::search(pos, 6, 2); + if (!pvs.empty() && th.rootMoves.size() > 0) { + auto best_move = th.rootMoves[0].pv[0]; + bool more_than_one_valid_move = th.rootMoves.size() > 1; + if (debug_print) { + sync_cout << "[debug] " << pos.fen() << sync_endl; + sync_cout << "[debug] Main PV move: " + << UCI::move(best_move, false) << " " + << th.rootMoves[0].score << " " << sync_endl; + if (more_than_one_valid_move) { + sync_cout << "[debug] 2nd PV move: " + << UCI::move(th.rootMoves[1].pv[0], false) << " " + << th.rootMoves[1].score << " " << sync_endl; + } else { + sync_cout << "[debug] The only valid move" << sync_endl; + } + } + if (pos.capture_or_promotion(best_move)) { + // skip if multipv 1st line bestmove is a capture or promo + if (debug_print) { + sync_cout << "[debug] Move is capture or promo: " << UCI::move(best_move, false) + << sync_endl + << "[debug] 1st best move at depth 6 multipv 2" << sync_endl + << "[debug]" << sync_endl; + } + num_capture_or_promo_skipped_multipv_cap0.fetch_add(1); + should_skip_position = true; + } else if (more_than_one_valid_move && pos.capture_or_promotion(th.rootMoves[1].pv[0])) { + // skip if multipv 2nd line bestmove is a capture or promo + if (debug_print) { + sync_cout << "[debug] Move is capture or promo: " << UCI::move(best_move, false) + << sync_endl + << "[debug] 2nd best move at depth 6 multipv 2" << sync_endl + << "[debug]" << sync_endl; + } + num_capture_or_promo_skipped_multipv_cap1.fetch_add(1); + should_skip_position = true; + } + } + } + pos.sfen_pack(ps.sfen, false); + // nnue-pytorch training data loader skips positions with score VALUE_NONE + if (should_skip_position) + ps.score = 32002; // VALUE_NONE + ps.padding = 0; + + out.write(th.id(), ps); + + auto p = num_processed.fetch_add(1) + 1; + if (p % 10000 == 0) { + auto c = num_position_in_check.load(); + auto a = num_move_already_is_capture.load(); + auto s = num_standard_startpos.load(); + auto multipv_cap0 = num_capture_or_promo_skipped_multipv_cap0.load(); + auto multipv_cap1 = num_capture_or_promo_skipped_multipv_cap1.load(); + sync_cout << "Processed " << p << " positions. Skipped " << (c + a + s + multipv_cap0 + multipv_cap1) << " positions." + << sync_endl + << " Static filter: " << (a + c + s) + << " (capture or promo: " << a << ", in check: " << c << ", startpos: " << s << ")" + << sync_endl + << " MultiPV filter: " << (multipv_cap0 + multipv_cap1) + << " (cap0: " << multipv_cap0 << ", cap1: " << multipv_cap1 << ")" + << " depth 6 multipv 2" << sync_endl; + } + } + } + }); + Threads.wait_for_workers_finished(); + + std::cout << "Finished.\n"; + } + + void do_filter_335a9b2d8a80(FilterParams& params) + { + if (ends_with(params.input_filename, ".binpack")) + { + do_filter_data_335a9b2d8a80(params); + } + else + { + std::cerr << "Invalid input file type.\n"; + } + } + + void filter_335a9b2d8a80(std::istringstream& is) + { + FilterParams params{}; + + while(true) + { + std::string token; + is >> token; + + if (token == "") + break; + + else if (token == "input_file") + is >> params.input_filename; + else if (token == "output_file") + is >> params.output_filename; + else if (token == "debug_print") + is >> params.debug_print; + else + { + std::cout << "ERROR: Unknown option " << token << ". Exiting...\n"; + return; + } + } + + std::cout << "Performing transform filter_335a9b2d8a80 with parameters:\n"; + std::cout << "input_file : " << params.input_filename << '\n'; + std::cout << "output_file : " << params.output_filename << '\n'; + std::cout << "debug_print : " << params.debug_print << '\n'; + std::cout << '\n'; + + do_filter_335a9b2d8a80(params); + } + void transform(std::istringstream& is) { const std::map subcommands = { { "nudged_static", &nudged_static }, - { "rescore", &rescore } + { "rescore", &rescore }, + { "filter_335a9b2d8a80", &filter_335a9b2d8a80 } }; Eval::NNUE::init();