mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 16:47:37 +00:00
Add 'validation_count' option for 'learn' that specifies how many positions to use for validation
This commit is contained in:
@@ -80,6 +80,8 @@ Currently the following options are available:
|
|||||||
|
|
||||||
`validation_set_file_name` - path to the file with training data to be used for validation (loss computation and move accuracy)
|
`validation_set_file_name` - path to the file with training data to be used for validation (loss computation and move accuracy)
|
||||||
|
|
||||||
|
`validation_count` - the number of positions to use for validation. Default: 2000.
|
||||||
|
|
||||||
`sfen_read_size` - the number of sfens to always keep in the buffer. Default: 10000000 (10M)
|
`sfen_read_size` - the number of sfens to always keep in the buffer. Default: 10000000 (10M)
|
||||||
|
|
||||||
`thread_buffer_size` - the number of sfens to copy at once to each thread requesting more sfens for learning. Default: 10000
|
`thread_buffer_size` - the number of sfens to copy at once to each thread requesting more sfens for learning. Default: 10000
|
||||||
|
|||||||
+30
-21
@@ -482,6 +482,12 @@ namespace Learner
|
|||||||
// Mini batch size size. Be sure to set it on the side that uses this class.
|
// Mini batch size size. Be sure to set it on the side that uses this class.
|
||||||
uint64_t mini_batch_size = LEARN_MINI_BATCH_SIZE;
|
uint64_t mini_batch_size = LEARN_MINI_BATCH_SIZE;
|
||||||
|
|
||||||
|
// Number of phases used for calculation such as mse
|
||||||
|
// mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time.
|
||||||
|
// Since search() is performed with depth = 1 in calculation of
|
||||||
|
// move match rate, simple comparison is not possible...
|
||||||
|
uint64_t validation_count = 2000;
|
||||||
|
|
||||||
// Option to exclude early stage from learning
|
// Option to exclude early stage from learning
|
||||||
int reduction_gameply = 1;
|
int reduction_gameply = 1;
|
||||||
|
|
||||||
@@ -550,16 +556,10 @@ namespace Learner
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Number of phases used for calculation such as mse
|
|
||||||
// mini-batch size = 1M is standard, so 0.2% of that should be negligible in terms of time.
|
|
||||||
// Since search() is performed with depth = 1 in calculation of
|
|
||||||
// move match rate, simple comparison is not possible...
|
|
||||||
static constexpr uint64_t sfen_for_mse_size = 2000;
|
|
||||||
|
|
||||||
LearnerThink(const Params& prm) :
|
LearnerThink(const Params& prm) :
|
||||||
params(prm),
|
params(prm),
|
||||||
prng(prm.seed),
|
prng(prm.seed),
|
||||||
sr(
|
train_sr(
|
||||||
prm.filenames,
|
prm.filenames,
|
||||||
prm.shuffle,
|
prm.shuffle,
|
||||||
SfenReaderMode::Cyclic,
|
SfenReaderMode::Cyclic,
|
||||||
@@ -567,6 +567,14 @@ namespace Learner
|
|||||||
std::to_string(prng.next_random_seed()),
|
std::to_string(prng.next_random_seed()),
|
||||||
prm.sfen_read_size,
|
prm.sfen_read_size,
|
||||||
prm.thread_buffer_size),
|
prm.thread_buffer_size),
|
||||||
|
validation_sr(
|
||||||
|
prm.validation_set_file_name.empty() ? prm.filenames : std::vector<std::string>{ prm.validation_set_file_name },
|
||||||
|
prm.shuffle,
|
||||||
|
SfenReaderMode::Cyclic,
|
||||||
|
1,
|
||||||
|
std::to_string(prng.next_random_seed()),
|
||||||
|
prm.sfen_read_size,
|
||||||
|
prm.thread_buffer_size),
|
||||||
learn_loss_sum{}
|
learn_loss_sum{}
|
||||||
{
|
{
|
||||||
save_count = 0;
|
save_count = 0;
|
||||||
@@ -612,7 +620,8 @@ namespace Learner
|
|||||||
PRNG prng;
|
PRNG prng;
|
||||||
|
|
||||||
// sfen reader
|
// sfen reader
|
||||||
SfenReader sr;
|
SfenReader train_sr;
|
||||||
|
SfenReader validation_sr;
|
||||||
|
|
||||||
uint64_t save_count;
|
uint64_t save_count;
|
||||||
uint64_t loss_output_count;
|
uint64_t loss_output_count;
|
||||||
@@ -666,28 +675,26 @@ namespace Learner
|
|||||||
|
|
||||||
Eval::NNUE::verify_any_net_loaded();
|
Eval::NNUE::verify_any_net_loaded();
|
||||||
|
|
||||||
const PSVector sfen_for_mse =
|
const PSVector validation_data =
|
||||||
params.validation_set_file_name.empty()
|
validation_sr.read_some(
|
||||||
? sr.read_for_mse(sfen_for_mse_size)
|
params.validation_count,
|
||||||
: sr.read_validation_set(
|
|
||||||
params.validation_set_file_name,
|
|
||||||
params.eval_limit,
|
params.eval_limit,
|
||||||
params.use_draw_games_in_validation);
|
params.use_draw_games_in_validation
|
||||||
|
);
|
||||||
|
|
||||||
if (params.validation_set_file_name.empty()
|
if (validation_data.size() != params.validation_count)
|
||||||
&& sfen_for_mse.size() != sfen_for_mse_size)
|
|
||||||
{
|
{
|
||||||
auto out = sync_region_cout.new_region();
|
auto out = sync_region_cout.new_region();
|
||||||
out
|
out
|
||||||
<< "INFO (learn): Error reading sfen_for_mse. Read " << sfen_for_mse.size()
|
<< "INFO (learn): Error reading validation data. Read " << validation_data.size()
|
||||||
<< " out of " << sfen_for_mse_size << '\n';
|
<< " out of " << params.validation_count << '\n';
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.newbob_decay != 1.0) {
|
if (params.newbob_decay != 1.0) {
|
||||||
|
|
||||||
calc_loss(sfen_for_mse, 0);
|
calc_loss(validation_data, 0);
|
||||||
|
|
||||||
best_loss = latest_loss_sum / latest_loss_count;
|
best_loss = latest_loss_sum / latest_loss_count;
|
||||||
latest_loss_sum = 0.0;
|
latest_loss_sum = 0.0;
|
||||||
@@ -714,7 +721,7 @@ namespace Learner
|
|||||||
if (stop_flag)
|
if (stop_flag)
|
||||||
break;
|
break;
|
||||||
|
|
||||||
update_weights(sfen_for_mse, epoch);
|
update_weights(validation_data, epoch);
|
||||||
|
|
||||||
if (stop_flag)
|
if (stop_flag)
|
||||||
break;
|
break;
|
||||||
@@ -742,7 +749,7 @@ namespace Learner
|
|||||||
|
|
||||||
RETRY_READ:;
|
RETRY_READ:;
|
||||||
|
|
||||||
if (!sr.read_to_thread_buffer(thread_id, ps))
|
if (!train_sr.read_to_thread_buffer(thread_id, ps))
|
||||||
{
|
{
|
||||||
// If we ran out of data we stop completely
|
// If we ran out of data we stop completely
|
||||||
// because there's nothing left to do.
|
// because there's nothing left to do.
|
||||||
@@ -1146,6 +1153,7 @@ namespace Learner
|
|||||||
is >> filename;
|
is >> filename;
|
||||||
params.filenames.push_back(filename);
|
params.filenames.push_back(filename);
|
||||||
}
|
}
|
||||||
|
else if (option == "validation_count") is >> params.validation_count;
|
||||||
|
|
||||||
// Specify the number of loops
|
// Specify the number of loops
|
||||||
else if (option == "epochs") is >> epochs;
|
else if (option == "epochs") is >> epochs;
|
||||||
@@ -1260,6 +1268,7 @@ namespace Learner
|
|||||||
out << " - validation set : " << params.validation_set_file_name << endl;
|
out << " - validation set : " << params.validation_set_file_name << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
out << " - validation count : " << params.validation_count << endl;
|
||||||
out << " - epochs : " << epochs << endl;
|
out << " - epochs : " << epochs << endl;
|
||||||
out << " - epochs * minibatch size : " << epochs * params.mini_batch_size << endl;
|
out << " - epochs * minibatch size : " << epochs * params.mini_batch_size << endl;
|
||||||
out << " - eval_limit : " << params.eval_limit << endl;
|
out << " - eval_limit : " << params.eval_limit << endl;
|
||||||
|
|||||||
+8
-32
@@ -73,10 +73,10 @@ namespace Learner{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Load the phase for calculation such as mse.
|
// Load the phase for calculation such as mse.
|
||||||
PSVector read_for_mse(uint64_t count)
|
PSVector read_some(uint64_t count, int eval_limit, bool use_draw_games)
|
||||||
{
|
{
|
||||||
PSVector sfen_for_mse;
|
PSVector psv;
|
||||||
sfen_for_mse.reserve(count);
|
psv.reserve(count);
|
||||||
|
|
||||||
for (uint64_t i = 0; i < count; ++i)
|
for (uint64_t i = 0; i < count; ++i)
|
||||||
{
|
{
|
||||||
@@ -84,43 +84,19 @@ namespace Learner{
|
|||||||
if (!read_to_thread_buffer(0, ps))
|
if (!read_to_thread_buffer(0, ps))
|
||||||
{
|
{
|
||||||
std::cout << "ERROR (sfen_reader): Reading failed." << std::endl;
|
std::cout << "ERROR (sfen_reader): Reading failed." << std::endl;
|
||||||
return sfen_for_mse;
|
return psv;
|
||||||
}
|
}
|
||||||
|
|
||||||
sfen_for_mse.push_back(ps);
|
if (eval_limit < abs(ps.score))
|
||||||
}
|
|
||||||
|
|
||||||
return sfen_for_mse;
|
|
||||||
}
|
|
||||||
|
|
||||||
PSVector read_validation_set(const std::string& file_name, int eval_limit, bool use_draw_games)
|
|
||||||
{
|
|
||||||
PSVector sfen_for_mse;
|
|
||||||
|
|
||||||
auto input = open_sfen_input_file(file_name);
|
|
||||||
|
|
||||||
while(!input->eof())
|
|
||||||
{
|
|
||||||
std::optional<PackedSfenValue> p_opt = input->next();
|
|
||||||
if (p_opt.has_value())
|
|
||||||
{
|
|
||||||
auto& p = *p_opt;
|
|
||||||
|
|
||||||
if (eval_limit < abs(p.score))
|
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (!use_draw_games && p.game_result == 0)
|
if (!use_draw_games && ps.game_result == 0)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
sfen_for_mse.push_back(p);
|
psv.push_back(ps);
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sfen_for_mse;
|
return psv;
|
||||||
}
|
}
|
||||||
|
|
||||||
// [ASYNC] Thread returns one aspect. Otherwise returns false.
|
// [ASYNC] Thread returns one aspect. Otherwise returns false.
|
||||||
|
|||||||
Reference in New Issue
Block a user