Tools transform option for filtering data for training nn-335a9b2d8a80.nnue (#4324)

Append hash of first master net trained with filter method

Hardcode depth 6 and remove option to set depth

Underscores for consistency

Filter out standard startpos positions too
This commit is contained in:
Linmiao Xu
2023-02-05 08:21:48 -05:00
committed by GitHub
parent 073d71a36b
commit 8e16592430
+225 -1
View File
@@ -63,6 +63,13 @@ namespace Stockfish::Tools
}
};
struct FilterParams
{
std::string input_filename = "in.binpack";
std::string output_filename = "out.binpack";
bool debug_print = false;
};
[[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16)
{
auto saturate_i32_to_i16 = [](int v) {
@@ -502,11 +509,228 @@ namespace Stockfish::Tools
do_rescore(params);
}
void do_filter_data_335a9b2d8a80(FilterParams& params)
{
// TODO: Use SfenReader once it works correctly in sequential mode. See issue #271
auto in = Tools::open_sfen_input_file(params.input_filename);
auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> PSVector {
PSVector psv;
psv.reserve(n);
std::unique_lock lock(mutex);
for (int i = 0; i < n; ++i)
{
auto ps_opt = in->next();
if (ps_opt.has_value())
{
psv.emplace_back(*ps_opt);
}
else
{
break;
}
}
return psv;
};
auto sfen_format = SfenOutputType::Binpack;
auto out = SfenWriter(
params.output_filename,
Threads.size(),
std::numeric_limits<std::uint64_t>::max(),
sfen_format);
// About Search::Limits
// Be careful because this member variable is global and affects other threads.
auto& limits = Search::Limits;
// Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done)
limits.infinite = true;
// Since PV is an obstacle when displayed, erase it.
limits.silent = true;
// If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it.
limits.nodes = 0;
// depth is also processed by the one passed as an argument of Tools::search().
limits.depth = 0;
std::atomic<std::uint64_t> num_processed = 0;
std::atomic<std::uint64_t> num_standard_startpos = 0;
std::atomic<std::uint64_t> num_position_in_check = 0;
std::atomic<std::uint64_t> num_move_already_is_capture = 0;
std::atomic<std::uint64_t> num_capture_or_promo_skipped_multipv_cap0 = 0;
std::atomic<std::uint64_t> num_capture_or_promo_skipped_multipv_cap1 = 0;
Threads.execute_with_workers([&](auto& th){
Position& pos = th.rootPos;
StateInfo si;
const bool frc = Options["UCI_Chess960"];
const bool debug_print = params.debug_print;
for (;;)
{
PSVector psv = readsome(5000);
if (psv.empty())
break;
for(auto& ps : psv)
{
pos.set_from_packed_sfen(ps.sfen, &si, &th, frc);
bool should_skip_position = false;
if (pos.checkers()) {
// Skip if in check
if (debug_print) {
sync_cout << "[debug] " << pos.fen() << sync_endl
<< "[debug] Position is in check" << sync_endl
<< "[debug]" << sync_endl;
}
num_position_in_check.fetch_add(1);
should_skip_position = true;
} else if (pos.capture_or_promotion((Stockfish::Move)ps.move)) {
// Skip if the provided move is already a capture or promotion
if (debug_print) {
sync_cout << "[debug] " << pos.fen() << sync_endl
<< "[debug] Provided move is capture or promo: "
<< UCI::move((Stockfish::Move)ps.move, false)
<< sync_endl
<< "[debug]" << sync_endl;
}
num_move_already_is_capture.fetch_add(1);
should_skip_position = true;
} else if (pos.fen() == "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1") {
num_standard_startpos.fetch_add(1);
should_skip_position = true;
} else {
auto [search_val, pvs] = Search::search(pos, 6, 2);
if (!pvs.empty() && th.rootMoves.size() > 0) {
auto best_move = th.rootMoves[0].pv[0];
bool more_than_one_valid_move = th.rootMoves.size() > 1;
if (debug_print) {
sync_cout << "[debug] " << pos.fen() << sync_endl;
sync_cout << "[debug] Main PV move: "
<< UCI::move(best_move, false) << " "
<< th.rootMoves[0].score << " " << sync_endl;
if (more_than_one_valid_move) {
sync_cout << "[debug] 2nd PV move: "
<< UCI::move(th.rootMoves[1].pv[0], false) << " "
<< th.rootMoves[1].score << " " << sync_endl;
} else {
sync_cout << "[debug] The only valid move" << sync_endl;
}
}
if (pos.capture_or_promotion(best_move)) {
// skip if multipv 1st line bestmove is a capture or promo
if (debug_print) {
sync_cout << "[debug] Move is capture or promo: " << UCI::move(best_move, false)
<< sync_endl
<< "[debug] 1st best move at depth 6 multipv 2" << sync_endl
<< "[debug]" << sync_endl;
}
num_capture_or_promo_skipped_multipv_cap0.fetch_add(1);
should_skip_position = true;
} else if (more_than_one_valid_move && pos.capture_or_promotion(th.rootMoves[1].pv[0])) {
// skip if multipv 2nd line bestmove is a capture or promo
if (debug_print) {
sync_cout << "[debug] Move is capture or promo: " << UCI::move(best_move, false)
<< sync_endl
<< "[debug] 2nd best move at depth 6 multipv 2" << sync_endl
<< "[debug]" << sync_endl;
}
num_capture_or_promo_skipped_multipv_cap1.fetch_add(1);
should_skip_position = true;
}
}
}
pos.sfen_pack(ps.sfen, false);
// nnue-pytorch training data loader skips positions with score VALUE_NONE
if (should_skip_position)
ps.score = 32002; // VALUE_NONE
ps.padding = 0;
out.write(th.id(), ps);
auto p = num_processed.fetch_add(1) + 1;
if (p % 10000 == 0) {
auto c = num_position_in_check.load();
auto a = num_move_already_is_capture.load();
auto s = num_standard_startpos.load();
auto multipv_cap0 = num_capture_or_promo_skipped_multipv_cap0.load();
auto multipv_cap1 = num_capture_or_promo_skipped_multipv_cap1.load();
sync_cout << "Processed " << p << " positions. Skipped " << (c + a + s + multipv_cap0 + multipv_cap1) << " positions."
<< sync_endl
<< " Static filter: " << (a + c + s)
<< " (capture or promo: " << a << ", in check: " << c << ", startpos: " << s << ")"
<< sync_endl
<< " MultiPV filter: " << (multipv_cap0 + multipv_cap1)
<< " (cap0: " << multipv_cap0 << ", cap1: " << multipv_cap1 << ")"
<< " depth 6 multipv 2" << sync_endl;
}
}
}
});
Threads.wait_for_workers_finished();
std::cout << "Finished.\n";
}
void do_filter_335a9b2d8a80(FilterParams& params)
{
if (ends_with(params.input_filename, ".binpack"))
{
do_filter_data_335a9b2d8a80(params);
}
else
{
std::cerr << "Invalid input file type.\n";
}
}
void filter_335a9b2d8a80(std::istringstream& is)
{
FilterParams params{};
while(true)
{
std::string token;
is >> token;
if (token == "")
break;
else if (token == "input_file")
is >> params.input_filename;
else if (token == "output_file")
is >> params.output_filename;
else if (token == "debug_print")
is >> params.debug_print;
else
{
std::cout << "ERROR: Unknown option " << token << ". Exiting...\n";
return;
}
}
std::cout << "Performing transform filter_335a9b2d8a80 with parameters:\n";
std::cout << "input_file : " << params.input_filename << '\n';
std::cout << "output_file : " << params.output_filename << '\n';
std::cout << "debug_print : " << params.debug_print << '\n';
std::cout << '\n';
do_filter_335a9b2d8a80(params);
}
void transform(std::istringstream& is)
{
const std::map<std::string, CommandFunc> subcommands = {
{ "nudged_static", &nudged_static },
{ "rescore", &rescore }
{ "rescore", &rescore },
{ "filter_335a9b2d8a80", &filter_335a9b2d8a80 }
};
Eval::NNUE::init();