Add 'validation_count' option for 'learn' that specifies how many positions to use for validation

This commit is contained in:
Tomasz Sobczyk
2020-12-19 23:46:18 +01:00
committed by nodchip
parent a7378f3249
commit f56613ebf6
3 changed files with 44 additions and 57 deletions
+2
View File
@@ -80,6 +80,8 @@ Currently the following options are available:
`validation_set_file_name` - path to the file with training data to be used for validation (loss computation and move accuracy) `validation_set_file_name` - path to the file with training data to be used for validation (loss computation and move accuracy)
`validation_count` - the number of positions to use for validation. Default: 2000.
`sfen_read_size` - the number of sfens to always keep in the buffer. Default: 10000000 (10M) `sfen_read_size` - the number of sfens to always keep in the buffer. Default: 10000000 (10M)
`thread_buffer_size` - the number of sfens to copy at once to each thread requesting more sfens for learning. Default: 10000 `thread_buffer_size` - the number of sfens to copy at once to each thread requesting more sfens for learning. Default: 10000
+30 -21
View File
@@ -482,6 +482,12 @@ namespace Learner
// Mini batch size size. Be sure to set it on the side that uses this class. // Mini batch size size. Be sure to set it on the side that uses this class.
uint64_t mini_batch_size = LEARN_MINI_BATCH_SIZE; uint64_t mini_batch_size = LEARN_MINI_BATCH_SIZE;
// Number of phases used for calculation such as mse
// mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time.
// Since search() is performed with depth = 1 in calculation of
// move match rate, simple comparison is not possible...
uint64_t validation_count = 2000;
// Option to exclude early stage from learning // Option to exclude early stage from learning
int reduction_gameply = 1; int reduction_gameply = 1;
@@ -550,16 +556,10 @@ namespace Learner
} }
}; };
// Number of phases used for calculation such as mse
// mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time.
// Since search() is performed with depth = 1 in calculation of
// move match rate, simple comparison is not possible...
static constexpr uint64_t sfen_for_mse_size = 2000;
LearnerThink(const Params& prm) : LearnerThink(const Params& prm) :
params(prm), params(prm),
prng(prm.seed), prng(prm.seed),
sr( train_sr(
prm.filenames, prm.filenames,
prm.shuffle, prm.shuffle,
SfenReaderMode::Cyclic, SfenReaderMode::Cyclic,
@@ -567,6 +567,14 @@ namespace Learner
std::to_string(prng.next_random_seed()), std::to_string(prng.next_random_seed()),
prm.sfen_read_size, prm.sfen_read_size,
prm.thread_buffer_size), prm.thread_buffer_size),
validation_sr(
prm.validation_set_file_name.empty() ? prm.filenames : std::vector<std::string>{ prm.validation_set_file_name },
prm.shuffle,
SfenReaderMode::Cyclic,
1,
std::to_string(prng.next_random_seed()),
prm.sfen_read_size,
prm.thread_buffer_size),
learn_loss_sum{} learn_loss_sum{}
{ {
save_count = 0; save_count = 0;
@@ -612,7 +620,8 @@ namespace Learner
PRNG prng; PRNG prng;
// sfen reader // sfen reader
SfenReader sr; SfenReader train_sr;
SfenReader validation_sr;
uint64_t save_count; uint64_t save_count;
uint64_t loss_output_count; uint64_t loss_output_count;
@@ -666,28 +675,26 @@ namespace Learner
Eval::NNUE::verify_any_net_loaded(); Eval::NNUE::verify_any_net_loaded();
const PSVector sfen_for_mse = const PSVector validation_data =
params.validation_set_file_name.empty() validation_sr.read_some(
? sr.read_for_mse(sfen_for_mse_size) params.validation_count,
: sr.read_validation_set(
params.validation_set_file_name,
params.eval_limit, params.eval_limit,
params.use_draw_games_in_validation); params.use_draw_games_in_validation
);
if (params.validation_set_file_name.empty() if (validation_data.size() != params.validation_count)
&& sfen_for_mse.size() != sfen_for_mse_size)
{ {
auto out = sync_region_cout.new_region(); auto out = sync_region_cout.new_region();
out out
<< "INFO (learn): Error reading sfen_for_mse. Read " << sfen_for_mse.size() << "INFO (learn): Error reading validation data. Read " << validation_data.size()
<< " out of " << sfen_for_mse_size << '\n'; << " out of " << params.validation_count << '\n';
return; return;
} }
if (params.newbob_decay != 1.0) { if (params.newbob_decay != 1.0) {
calc_loss(sfen_for_mse, 0); calc_loss(validation_data, 0);
best_loss = latest_loss_sum / latest_loss_count; best_loss = latest_loss_sum / latest_loss_count;
latest_loss_sum = 0.0; latest_loss_sum = 0.0;
@@ -714,7 +721,7 @@ namespace Learner
if (stop_flag) if (stop_flag)
break; break;
update_weights(sfen_for_mse, epoch); update_weights(validation_data, epoch);
if (stop_flag) if (stop_flag)
break; break;
@@ -742,7 +749,7 @@ namespace Learner
RETRY_READ:; RETRY_READ:;
if (!sr.read_to_thread_buffer(thread_id, ps)) if (!train_sr.read_to_thread_buffer(thread_id, ps))
{ {
// If we ran out of data we stop completely // If we ran out of data we stop completely
// because there's nothing left to do. // because there's nothing left to do.
@@ -1146,6 +1153,7 @@ namespace Learner
is >> filename; is >> filename;
params.filenames.push_back(filename); params.filenames.push_back(filename);
} }
else if (option == "validation_count") is >> params.validation_count;
// Specify the number of loops // Specify the number of loops
else if (option == "epochs") is >> epochs; else if (option == "epochs") is >> epochs;
@@ -1260,6 +1268,7 @@ namespace Learner
out << " - validation set : " << params.validation_set_file_name << endl; out << " - validation set : " << params.validation_set_file_name << endl;
} }
out << " - validation count : " << params.validation_count << endl;
out << " - epochs : " << epochs << endl; out << " - epochs : " << epochs << endl;
out << " - epochs * minibatch size : " << epochs * params.mini_batch_size << endl; out << " - epochs * minibatch size : " << epochs * params.mini_batch_size << endl;
out << " - eval_limit : " << params.eval_limit << endl; out << " - eval_limit : " << params.eval_limit << endl;
+12 -36
View File
@@ -73,10 +73,10 @@ namespace Learner{
} }
// Load the phase for calculation such as mse. // Load the phase for calculation such as mse.
PSVector read_for_mse(uint64_t count) PSVector read_some(uint64_t count, int eval_limit, bool use_draw_games)
{ {
PSVector sfen_for_mse; PSVector psv;
sfen_for_mse.reserve(count); psv.reserve(count);
for (uint64_t i = 0; i < count; ++i) for (uint64_t i = 0; i < count; ++i)
{ {
@@ -84,43 +84,19 @@ namespace Learner{
if (!read_to_thread_buffer(0, ps)) if (!read_to_thread_buffer(0, ps))
{ {
std::cout << "ERROR (sfen_reader): Reading failed." << std::endl; std::cout << "ERROR (sfen_reader): Reading failed." << std::endl;
return sfen_for_mse; return psv;
} }
sfen_for_mse.push_back(ps); if (eval_limit < abs(ps.score))
continue;
if (!use_draw_games && ps.game_result == 0)
continue;
psv.push_back(ps);
} }
return sfen_for_mse; return psv;
}
PSVector read_validation_set(const std::string& file_name, int eval_limit, bool use_draw_games)
{
PSVector sfen_for_mse;
auto input = open_sfen_input_file(file_name);
while(!input->eof())
{
std::optional<PackedSfenValue> p_opt = input->next();
if (p_opt.has_value())
{
auto& p = *p_opt;
if (eval_limit < abs(p.score))
continue;
if (!use_draw_games && p.game_result == 0)
continue;
sfen_for_mse.push_back(p);
}
else
{
break;
}
}
return sfen_for_mse;
} }
// [ASYNC] Thread returns one aspect. Otherwise returns false. // [ASYNC] Thread returns one aspect. Otherwise returns false.