Add "max_time_*" options to "generate_training_data" tool that allow limiting the runtime by time instead of count.

This commit is contained in:
Tomasz Sobczyk
2021-09-14 14:47:24 +02:00
parent a02a6bf13e
commit 79abe1e662
2 changed files with 53 additions and 12 deletions
+50 -11
View File
@@ -149,7 +149,7 @@ namespace Stockfish::Tools
std::cout << prngs[0] << std::endl;
}
void generate(uint64_t limit);
void generate(uint64_t limit, uint64_t limit_seconds);
private:
Params params;
@@ -173,7 +173,8 @@ namespace Stockfish::Tools
void generate_worker(
Thread& th,
std::atomic<uint64_t>& counter,
uint64_t limit);
uint64_t limit,
uint64_t limit_seconds);
bool was_seen_before(const Position& pos);
@@ -198,6 +199,7 @@ namespace Stockfish::Tools
uint64_t limit,
Color result_color);
void initial_report();
void report(uint64_t done, uint64_t new_done);
void maybe_report(uint64_t done);
@@ -222,24 +224,23 @@ namespace Stockfish::Tools
limits.depth = 0;
}
void TrainingDataGenerator::generate(uint64_t limit)
void TrainingDataGenerator::generate(uint64_t limit, uint64_t limit_seconds)
{
last_stats_report_time = 0;
initial_report();
set_gensfen_search_limits();
std::atomic<uint64_t> counter{0};
Threads.execute_with_workers([&counter, limit, this](Thread& th) {
generate_worker(th, counter, limit);
Threads.execute_with_workers([&counter, limit, limit_seconds, this](Thread& th) {
generate_worker(th, counter, limit, limit_seconds);
});
Threads.wait_for_workers_finished();
sfen_writer.flush();
if (limit % REPORT_STATS_EVERY != 0)
{
report(limit, limit % REPORT_STATS_EVERY);
}
report(counter.load(), counter.load() % REPORT_STATS_EVERY + 1);
std::cout << std::endl;
}
@@ -247,7 +248,8 @@ namespace Stockfish::Tools
void TrainingDataGenerator::generate_worker(
Thread& th,
std::atomic<uint64_t>& counter,
uint64_t limit)
uint64_t limit,
uint64_t limit_seconds)
{
// For the time being, it will be treated as a draw
// at the maximum number of steps to write.
@@ -262,9 +264,16 @@ namespace Stockfish::Tools
// end flag
bool quit = false;
const auto start_time = now();
// repeat until the specified number of times
while (!quit)
{
if (static_cast<uint64_t>(now() - start_time) / 1000 >= limit_seconds)
{
break;
}
// It is necessary to set a dependent thread for Position.
// When parallelizing, Threads (since this is a vector<Thread*>,
// Do the same for up to Threads[0]...Threads[thread_num-1].
@@ -699,6 +708,22 @@ namespace Stockfish::Tools
return false;
}
void TrainingDataGenerator::initial_report()
{
out = sync_region_cout.new_region();
const auto now_time = now();
out
<< endl
<< 0 << " sfens, "
<< "at " << now_string() << sync_endl;
last_stats_report_time = now_time;
out = sync_region_cout.new_region();
}
void TrainingDataGenerator::report(uint64_t done, uint64_t new_done)
{
const auto now_time = now();
@@ -744,6 +769,7 @@ namespace Stockfish::Tools
{
// Number of generated game records default = 8 billion phases (Ponanza specification)
uint64_t loop_max = 8000000000UL;
uint64_t time_max = 8000000000UL;
TrainingDataGenerator::Params params;
@@ -772,6 +798,18 @@ namespace Stockfish::Tools
is >> params.nodes;
else if (token == "count")
is >> loop_max;
else if (token == "max_time_seconds")
is >> time_max;
else if (token == "max_time_minutes")
{
is >> time_max;
time_max *= 60;
}
else if (token == "max_time_hours")
{
is >> time_max;
time_max *= 3600;
}
else if (token == "output_file_name")
is >> params.output_file_name;
else if (token == "eval_limit")
@@ -865,6 +903,7 @@ namespace Stockfish::Tools
<< " - search_depth_max = " << params.search_depth_max << endl
<< " - nodes = " << params.nodes << endl
<< " - count = " << loop_max << endl
<< " - max_time_seconds = " << time_max << endl
<< " - eval_limit = " << params.eval_limit << endl
<< " - num threads (UCI) = " << params.num_threads << endl
<< " - random_move_min_ply = " << params.random_move_minply << endl
@@ -890,7 +929,7 @@ namespace Stockfish::Tools
Threads.main()->ponder = false;
TrainingDataGenerator gensfen(params);
gensfen.generate(loop_max);
gensfen.generate(loop_max, time_max);
std::cout << "INFO: generate_training_data finished." << endl;
}