Add "max_time_*" options to "generate_training_data" tool that allow limiting the runtime by time instead of count.

This commit is contained in:
Tomasz Sobczyk
2021-09-14 14:47:24 +02:00
parent a02a6bf13e
commit 79abe1e662
2 changed files with 53 additions and 12 deletions
+3 -1
View File
@@ -22,7 +22,9 @@ Currently the following options are available:
`nodes` - the number of nodes to use for evaluation of each position. This number is multiplied by the number of PVs of the current search. This does NOT override the `depth` and `depth2` options. If specified then whichever of depth or nodes limit is reached first applies. `nodes` - the number of nodes to use for evaluation of each position. This number is multiplied by the number of PVs of the current search. This does NOT override the `depth` and `depth2` options. If specified then whichever of depth or nodes limit is reached first applies.
`count` - the number of training data entries to generate. 1 entry == 1 position. Default: 8000000000 (8B). `count` - the number of training data entries to generate. 1 entry == 1 position. If both `count` and `max_time_*` are specified the data generation process will end when any of conditions is fullfilled. Default: 8000000000 (8B).
`max_time_seconds`, `max_time_minutes`, `max_time_hours` - specifies the maximum runtime for the data generation. The data generation will NOT be interrupted while a self-play game is in progress. If both `count` and `max_time_*` are specified the data generation process will end when any of conditions is fullfilled. Default: \~250 years.
`output_file_name` - the name of the file to output to. If the extension is not present or doesn't match the selected training data format the right extension will be appened. Default: generated_kifu `output_file_name` - the name of the file to output to. If the extension is not present or doesn't match the selected training data format the right extension will be appened. Default: generated_kifu
+50 -11
View File
@@ -149,7 +149,7 @@ namespace Stockfish::Tools
std::cout << prngs[0] << std::endl; std::cout << prngs[0] << std::endl;
} }
void generate(uint64_t limit); void generate(uint64_t limit, uint64_t limit_seconds);
private: private:
Params params; Params params;
@@ -173,7 +173,8 @@ namespace Stockfish::Tools
void generate_worker( void generate_worker(
Thread& th, Thread& th,
std::atomic<uint64_t>& counter, std::atomic<uint64_t>& counter,
uint64_t limit); uint64_t limit,
uint64_t limit_seconds);
bool was_seen_before(const Position& pos); bool was_seen_before(const Position& pos);
@@ -198,6 +199,7 @@ namespace Stockfish::Tools
uint64_t limit, uint64_t limit,
Color result_color); Color result_color);
void initial_report();
void report(uint64_t done, uint64_t new_done); void report(uint64_t done, uint64_t new_done);
void maybe_report(uint64_t done); void maybe_report(uint64_t done);
@@ -222,24 +224,23 @@ namespace Stockfish::Tools
limits.depth = 0; limits.depth = 0;
} }
void TrainingDataGenerator::generate(uint64_t limit) void TrainingDataGenerator::generate(uint64_t limit, uint64_t limit_seconds)
{ {
last_stats_report_time = 0; last_stats_report_time = 0;
initial_report();
set_gensfen_search_limits(); set_gensfen_search_limits();
std::atomic<uint64_t> counter{0}; std::atomic<uint64_t> counter{0};
Threads.execute_with_workers([&counter, limit, this](Thread& th) { Threads.execute_with_workers([&counter, limit, limit_seconds, this](Thread& th) {
generate_worker(th, counter, limit); generate_worker(th, counter, limit, limit_seconds);
}); });
Threads.wait_for_workers_finished(); Threads.wait_for_workers_finished();
sfen_writer.flush(); sfen_writer.flush();
if (limit % REPORT_STATS_EVERY != 0) report(counter.load(), counter.load() % REPORT_STATS_EVERY + 1);
{
report(limit, limit % REPORT_STATS_EVERY);
}
std::cout << std::endl; std::cout << std::endl;
} }
@@ -247,7 +248,8 @@ namespace Stockfish::Tools
void TrainingDataGenerator::generate_worker( void TrainingDataGenerator::generate_worker(
Thread& th, Thread& th,
std::atomic<uint64_t>& counter, std::atomic<uint64_t>& counter,
uint64_t limit) uint64_t limit,
uint64_t limit_seconds)
{ {
// For the time being, it will be treated as a draw // For the time being, it will be treated as a draw
// at the maximum number of steps to write. // at the maximum number of steps to write.
@@ -262,9 +264,16 @@ namespace Stockfish::Tools
// end flag // end flag
bool quit = false; bool quit = false;
const auto start_time = now();
// repeat until the specified number of times // repeat until the specified number of times
while (!quit) while (!quit)
{ {
if (static_cast<uint64_t>(now() - start_time) / 1000 >= limit_seconds)
{
break;
}
// It is necessary to set a dependent thread for Position. // It is necessary to set a dependent thread for Position.
// When parallelizing, Threads (since this is a vector<Thread*>, // When parallelizing, Threads (since this is a vector<Thread*>,
// Do the same for up to Threads[0]...Threads[thread_num-1]. // Do the same for up to Threads[0]...Threads[thread_num-1].
@@ -699,6 +708,22 @@ namespace Stockfish::Tools
return false; return false;
} }
void TrainingDataGenerator::initial_report()
{
out = sync_region_cout.new_region();
const auto now_time = now();
out
<< endl
<< 0 << " sfens, "
<< "at " << now_string() << sync_endl;
last_stats_report_time = now_time;
out = sync_region_cout.new_region();
}
void TrainingDataGenerator::report(uint64_t done, uint64_t new_done) void TrainingDataGenerator::report(uint64_t done, uint64_t new_done)
{ {
const auto now_time = now(); const auto now_time = now();
@@ -744,6 +769,7 @@ namespace Stockfish::Tools
{ {
// Number of generated game records default = 8 billion phases (Ponanza specification) // Number of generated game records default = 8 billion phases (Ponanza specification)
uint64_t loop_max = 8000000000UL; uint64_t loop_max = 8000000000UL;
uint64_t time_max = 8000000000UL;
TrainingDataGenerator::Params params; TrainingDataGenerator::Params params;
@@ -772,6 +798,18 @@ namespace Stockfish::Tools
is >> params.nodes; is >> params.nodes;
else if (token == "count") else if (token == "count")
is >> loop_max; is >> loop_max;
else if (token == "max_time_seconds")
is >> time_max;
else if (token == "max_time_minutes")
{
is >> time_max;
time_max *= 60;
}
else if (token == "max_time_hours")
{
is >> time_max;
time_max *= 3600;
}
else if (token == "output_file_name") else if (token == "output_file_name")
is >> params.output_file_name; is >> params.output_file_name;
else if (token == "eval_limit") else if (token == "eval_limit")
@@ -865,6 +903,7 @@ namespace Stockfish::Tools
<< " - search_depth_max = " << params.search_depth_max << endl << " - search_depth_max = " << params.search_depth_max << endl
<< " - nodes = " << params.nodes << endl << " - nodes = " << params.nodes << endl
<< " - count = " << loop_max << endl << " - count = " << loop_max << endl
<< " - max_time_seconds = " << time_max << endl
<< " - eval_limit = " << params.eval_limit << endl << " - eval_limit = " << params.eval_limit << endl
<< " - num threads (UCI) = " << params.num_threads << endl << " - num threads (UCI) = " << params.num_threads << endl
<< " - random_move_min_ply = " << params.random_move_minply << endl << " - random_move_min_ply = " << params.random_move_minply << endl
@@ -890,7 +929,7 @@ namespace Stockfish::Tools
Threads.main()->ponder = false; Threads.main()->ponder = false;
TrainingDataGenerator gensfen(params); TrainingDataGenerator gensfen(params);
gensfen.generate(loop_max); gensfen.generate(loop_max, time_max);
std::cout << "INFO: generate_training_data finished." << endl; std::cout << "INFO: generate_training_data finished." << endl;
} }