Simple filtering for validation data.

This commit is contained in:
Tomasz Sobczyk
2020-12-21 20:57:51 +01:00
committed by nodchip
parent 50df3a7389
commit 6853b4aac2
3 changed files with 61 additions and 16 deletions
+3 -1
View File
@@ -66,7 +66,9 @@ Currently the following options are available:
`assume_quiet` - this is a flag option. When specified learn will not perform qsearch to reach a quiet position.
`smart_fen_skipping` - this is a flag option. When specified some position that are not good candidates for teaching are skipped. This includes positions where the best move is a capture or promotion, and position where a king is in check.
`smart_fen_skipping` - this is a flag option. When specified some position that are not good candidates for teaching are skipped. This includes positions where the best move is a capture or promotion, and position where a king is in check. Default: 1.
`smart_fen_skipping_for_validation` - same as `smart_fen_skipping` but applies to validation data set. Default: 0.
`newbob_num_trials` - determines after how many subsequent rejected nets the training process will be terminated. Default: 4.
+49 -7
View File
@@ -521,6 +521,7 @@ namespace Learner
bool assume_quiet = false;
bool smart_fen_skipping = false;
bool smart_fen_skipping_for_validation = false;
double learning_rate = 1.0;
double max_grad = 1.0;
@@ -593,6 +594,8 @@ namespace Learner
private:
static void set_learning_search_limits();
PSVector fetch_next_validation_set();
void learn_worker(Thread& th, std::atomic<uint64_t>& counter, uint64_t limit);
void update_weights(const PSVector& psv, uint64_t epoch);
@@ -665,6 +668,44 @@ namespace Learner
limits.depth = 0;
}
PSVector LearnerThink::fetch_next_validation_set()
{
PSVector validation_data;
auto mainThread = Threads.main();
mainThread->execute_with_worker([&validation_data, this](auto& th){
auto do_include_predicate = [&th, this](const PackedSfenValue& ps) -> bool {
if (params.eval_limit < abs(ps.score))
return false;
if (!params.use_draw_games_in_validation && ps.game_result == 0)
return false;
if (params.smart_fen_skipping_for_validation)
{
StateInfo si;
auto& pos = th.rootPos;
if (pos.set_from_packed_sfen(ps.sfen, &si, &th) != 0)
return false;
if (pos.capture_or_promotion((Move)ps.move) || pos.checkers())
return false;
}
return true;
};
validation_data = validation_sr.read_some(
params.validation_count,
params.validation_count * 100, // to have a reasonable bound on the running time.
do_include_predicate
);
});
mainThread->wait_for_worker_finished();
return validation_data;
}
void LearnerThink::learn(uint64_t epochs)
{
#if defined(_OPENMP)
@@ -675,19 +716,16 @@ namespace Learner
Eval::NNUE::verify_any_net_loaded();
const PSVector validation_data =
validation_sr.read_some(
params.validation_count,
params.eval_limit,
params.use_draw_games_in_validation
);
const PSVector validation_data = fetch_next_validation_set();
if (validation_data.size() != params.validation_count)
{
auto out = sync_region_cout.new_region();
out
<< "INFO (learn): Error reading validation data. Read " << validation_data.size()
<< " out of " << params.validation_count << '\n';
<< " out of " << params.validation_count << '\n'
<< "INFO (learn): This either means that less than 1% of the validation data passed the filter"
<< " or the file is empty\n";
return;
}
@@ -1235,6 +1273,7 @@ namespace Learner
else if (option == "verbose") params.verbose = true;
else if (option == "assume_quiet") params.assume_quiet = true;
else if (option == "smart_fen_skipping") params.smart_fen_skipping = true;
else if (option == "smart_fen_skipping_for_validation") params.smart_fen_skipping_for_validation = true;
else
{
out << "INFO: Unknown option: " << option << ". Ignoring.\n";
@@ -1306,6 +1345,9 @@ namespace Learner
out << " - sfen_read_size : " << params.sfen_read_size << endl;
out << " - thread_buffer_size : " << params.thread_buffer_size << endl;
out << " - smart_fen_skipping : " << params.smart_fen_skipping << endl;
out << " - smart_fen_skipping_val : " << params.smart_fen_skipping_for_validation << endl;
out << " - seed : " << params.seed << endl;
out << " - verbose : " << (params.verbose ? "true" : "false") << endl;
+9 -8
View File
@@ -15,6 +15,7 @@
#include <iostream>
#include <cstdint>
#include <thread>
#include <functional>
namespace Learner{
@@ -73,12 +74,12 @@ namespace Learner{
}
// Load the phase for calculation such as mse.
PSVector read_some(uint64_t count, int eval_limit, bool use_draw_games)
PSVector read_some(uint64_t count, uint64_t count_tries, std::function<bool(const PackedSfenValue&)> do_take)
{
PSVector psv;
psv.reserve(count);
for (uint64_t i = 0; i < count; ++i)
for (uint64_t i = 0; i < count_tries; ++i)
{
PackedSfenValue ps;
if (!read_to_thread_buffer(0, ps))
@@ -87,13 +88,13 @@ namespace Learner{
return psv;
}
if (eval_limit < abs(ps.score))
continue;
if (!use_draw_games && ps.game_result == 0)
continue;
if (do_take(ps))
{
psv.push_back(ps);
if (psv.size() >= count)
break;
}
}
return psv;