mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 02:47:45 +00:00
Tools transform option for filtering data for training nn-335a9b2d8a80.nnue (#4324)
Append hash of first master net trained with filter method Hardcode depth 6 and remove option to set depth Underscores for consistency Filter out standard startpos positions too
This commit is contained in:
+225
-1
@@ -63,6 +63,13 @@ namespace Stockfish::Tools
|
||||
}
|
||||
};
|
||||
|
||||
struct FilterParams
|
||||
{
|
||||
std::string input_filename = "in.binpack";
|
||||
std::string output_filename = "out.binpack";
|
||||
bool debug_print = false;
|
||||
};
|
||||
|
||||
[[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16)
|
||||
{
|
||||
auto saturate_i32_to_i16 = [](int v) {
|
||||
@@ -502,11 +509,228 @@ namespace Stockfish::Tools
|
||||
do_rescore(params);
|
||||
}
|
||||
|
||||
void do_filter_data_335a9b2d8a80(FilterParams& params)
|
||||
{
|
||||
// TODO: Use SfenReader once it works correctly in sequential mode. See issue #271
|
||||
auto in = Tools::open_sfen_input_file(params.input_filename);
|
||||
auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> PSVector {
|
||||
|
||||
PSVector psv;
|
||||
psv.reserve(n);
|
||||
|
||||
std::unique_lock lock(mutex);
|
||||
|
||||
for (int i = 0; i < n; ++i)
|
||||
{
|
||||
auto ps_opt = in->next();
|
||||
if (ps_opt.has_value())
|
||||
{
|
||||
psv.emplace_back(*ps_opt);
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return psv;
|
||||
};
|
||||
|
||||
auto sfen_format = SfenOutputType::Binpack;
|
||||
|
||||
auto out = SfenWriter(
|
||||
params.output_filename,
|
||||
Threads.size(),
|
||||
std::numeric_limits<std::uint64_t>::max(),
|
||||
sfen_format);
|
||||
|
||||
// About Search::Limits
|
||||
// Be careful because this member variable is global and affects other threads.
|
||||
auto& limits = Search::Limits;
|
||||
|
||||
// Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done)
|
||||
limits.infinite = true;
|
||||
|
||||
// Since PV is an obstacle when displayed, erase it.
|
||||
limits.silent = true;
|
||||
|
||||
// If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it.
|
||||
limits.nodes = 0;
|
||||
|
||||
// depth is also processed by the one passed as an argument of Tools::search().
|
||||
limits.depth = 0;
|
||||
|
||||
std::atomic<std::uint64_t> num_processed = 0;
|
||||
std::atomic<std::uint64_t> num_standard_startpos = 0;
|
||||
std::atomic<std::uint64_t> num_position_in_check = 0;
|
||||
std::atomic<std::uint64_t> num_move_already_is_capture = 0;
|
||||
std::atomic<std::uint64_t> num_capture_or_promo_skipped_multipv_cap0 = 0;
|
||||
std::atomic<std::uint64_t> num_capture_or_promo_skipped_multipv_cap1 = 0;
|
||||
|
||||
Threads.execute_with_workers([&](auto& th){
|
||||
Position& pos = th.rootPos;
|
||||
StateInfo si;
|
||||
const bool frc = Options["UCI_Chess960"];
|
||||
|
||||
const bool debug_print = params.debug_print;
|
||||
for (;;)
|
||||
{
|
||||
PSVector psv = readsome(5000);
|
||||
if (psv.empty())
|
||||
break;
|
||||
|
||||
for(auto& ps : psv)
|
||||
{
|
||||
pos.set_from_packed_sfen(ps.sfen, &si, &th, frc);
|
||||
bool should_skip_position = false;
|
||||
if (pos.checkers()) {
|
||||
// Skip if in check
|
||||
if (debug_print) {
|
||||
sync_cout << "[debug] " << pos.fen() << sync_endl
|
||||
<< "[debug] Position is in check" << sync_endl
|
||||
<< "[debug]" << sync_endl;
|
||||
}
|
||||
num_position_in_check.fetch_add(1);
|
||||
should_skip_position = true;
|
||||
} else if (pos.capture_or_promotion((Stockfish::Move)ps.move)) {
|
||||
// Skip if the provided move is already a capture or promotion
|
||||
if (debug_print) {
|
||||
sync_cout << "[debug] " << pos.fen() << sync_endl
|
||||
<< "[debug] Provided move is capture or promo: "
|
||||
<< UCI::move((Stockfish::Move)ps.move, false)
|
||||
<< sync_endl
|
||||
<< "[debug]" << sync_endl;
|
||||
}
|
||||
num_move_already_is_capture.fetch_add(1);
|
||||
should_skip_position = true;
|
||||
} else if (pos.fen() == "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1") {
|
||||
num_standard_startpos.fetch_add(1);
|
||||
should_skip_position = true;
|
||||
} else {
|
||||
auto [search_val, pvs] = Search::search(pos, 6, 2);
|
||||
if (!pvs.empty() && th.rootMoves.size() > 0) {
|
||||
auto best_move = th.rootMoves[0].pv[0];
|
||||
bool more_than_one_valid_move = th.rootMoves.size() > 1;
|
||||
if (debug_print) {
|
||||
sync_cout << "[debug] " << pos.fen() << sync_endl;
|
||||
sync_cout << "[debug] Main PV move: "
|
||||
<< UCI::move(best_move, false) << " "
|
||||
<< th.rootMoves[0].score << " " << sync_endl;
|
||||
if (more_than_one_valid_move) {
|
||||
sync_cout << "[debug] 2nd PV move: "
|
||||
<< UCI::move(th.rootMoves[1].pv[0], false) << " "
|
||||
<< th.rootMoves[1].score << " " << sync_endl;
|
||||
} else {
|
||||
sync_cout << "[debug] The only valid move" << sync_endl;
|
||||
}
|
||||
}
|
||||
if (pos.capture_or_promotion(best_move)) {
|
||||
// skip if multipv 1st line bestmove is a capture or promo
|
||||
if (debug_print) {
|
||||
sync_cout << "[debug] Move is capture or promo: " << UCI::move(best_move, false)
|
||||
<< sync_endl
|
||||
<< "[debug] 1st best move at depth 6 multipv 2" << sync_endl
|
||||
<< "[debug]" << sync_endl;
|
||||
}
|
||||
num_capture_or_promo_skipped_multipv_cap0.fetch_add(1);
|
||||
should_skip_position = true;
|
||||
} else if (more_than_one_valid_move && pos.capture_or_promotion(th.rootMoves[1].pv[0])) {
|
||||
// skip if multipv 2nd line bestmove is a capture or promo
|
||||
if (debug_print) {
|
||||
sync_cout << "[debug] Move is capture or promo: " << UCI::move(best_move, false)
|
||||
<< sync_endl
|
||||
<< "[debug] 2nd best move at depth 6 multipv 2" << sync_endl
|
||||
<< "[debug]" << sync_endl;
|
||||
}
|
||||
num_capture_or_promo_skipped_multipv_cap1.fetch_add(1);
|
||||
should_skip_position = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
pos.sfen_pack(ps.sfen, false);
|
||||
// nnue-pytorch training data loader skips positions with score VALUE_NONE
|
||||
if (should_skip_position)
|
||||
ps.score = 32002; // VALUE_NONE
|
||||
ps.padding = 0;
|
||||
|
||||
out.write(th.id(), ps);
|
||||
|
||||
auto p = num_processed.fetch_add(1) + 1;
|
||||
if (p % 10000 == 0) {
|
||||
auto c = num_position_in_check.load();
|
||||
auto a = num_move_already_is_capture.load();
|
||||
auto s = num_standard_startpos.load();
|
||||
auto multipv_cap0 = num_capture_or_promo_skipped_multipv_cap0.load();
|
||||
auto multipv_cap1 = num_capture_or_promo_skipped_multipv_cap1.load();
|
||||
sync_cout << "Processed " << p << " positions. Skipped " << (c + a + s + multipv_cap0 + multipv_cap1) << " positions."
|
||||
<< sync_endl
|
||||
<< " Static filter: " << (a + c + s)
|
||||
<< " (capture or promo: " << a << ", in check: " << c << ", startpos: " << s << ")"
|
||||
<< sync_endl
|
||||
<< " MultiPV filter: " << (multipv_cap0 + multipv_cap1)
|
||||
<< " (cap0: " << multipv_cap0 << ", cap1: " << multipv_cap1 << ")"
|
||||
<< " depth 6 multipv 2" << sync_endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
Threads.wait_for_workers_finished();
|
||||
|
||||
std::cout << "Finished.\n";
|
||||
}
|
||||
|
||||
void do_filter_335a9b2d8a80(FilterParams& params)
|
||||
{
|
||||
if (ends_with(params.input_filename, ".binpack"))
|
||||
{
|
||||
do_filter_data_335a9b2d8a80(params);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Invalid input file type.\n";
|
||||
}
|
||||
}
|
||||
|
||||
void filter_335a9b2d8a80(std::istringstream& is)
|
||||
{
|
||||
FilterParams params{};
|
||||
|
||||
while(true)
|
||||
{
|
||||
std::string token;
|
||||
is >> token;
|
||||
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
else if (token == "input_file")
|
||||
is >> params.input_filename;
|
||||
else if (token == "output_file")
|
||||
is >> params.output_filename;
|
||||
else if (token == "debug_print")
|
||||
is >> params.debug_print;
|
||||
else
|
||||
{
|
||||
std::cout << "ERROR: Unknown option " << token << ". Exiting...\n";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Performing transform filter_335a9b2d8a80 with parameters:\n";
|
||||
std::cout << "input_file : " << params.input_filename << '\n';
|
||||
std::cout << "output_file : " << params.output_filename << '\n';
|
||||
std::cout << "debug_print : " << params.debug_print << '\n';
|
||||
std::cout << '\n';
|
||||
|
||||
do_filter_335a9b2d8a80(params);
|
||||
}
|
||||
|
||||
void transform(std::istringstream& is)
|
||||
{
|
||||
const std::map<std::string, CommandFunc> subcommands = {
|
||||
{ "nudged_static", &nudged_static },
|
||||
{ "rescore", &rescore }
|
||||
{ "rescore", &rescore },
|
||||
{ "filter_335a9b2d8a80", &filter_335a9b2d8a80 }
|
||||
};
|
||||
|
||||
Eval::NNUE::init();
|
||||
|
||||
Reference in New Issue
Block a user