Rewrite gensfen to use stockfish's thread pool.

This commit is contained in:
Tomasz Sobczyk
2020-10-24 21:18:06 +02:00
committed by nodchip
parent 0e528995c2
commit af238fe132
+84 -96
View File
@@ -1,7 +1,6 @@
#include "gensfen.h"
#include "packed_sfen.h"
#include "multi_think.h"
#include "sfen_stream.h"
#include "misc.h"
@@ -261,7 +260,7 @@ namespace Learner
// -----------------------------------
// Class to generate sfen with multiple threads
struct MultiThinkGenSfen : public MultiThink
struct MultiThinkGenSfen
{
// Hash to limit the export of identical sfens
static constexpr uint64_t GENSFEN_HASH_SIZE = 64 * 1024 * 1024;
@@ -269,7 +268,7 @@ namespace Learner
static_assert((GENSFEN_HASH_SIZE& (GENSFEN_HASH_SIZE - 1)) == 0);
MultiThinkGenSfen(int search_depth_min_, int search_depth_max_, SfenWriter& sw_, const std::string& seed) :
MultiThink(seed),
prng(seed),
search_depth_min(search_depth_min_),
search_depth_max(search_depth_max_),
sfen_writer(sw_)
@@ -285,7 +284,9 @@ namespace Learner
sfen_writer.start_file_write_worker();
}
void thread_worker(size_t thread_id) override;
void gensfen(uint64_t limit);
void thread_worker(Thread& th, std::atomic<uint64_t>& counter, uint64_t limit);
optional<int8_t> get_current_game_result(
Position& pos,
@@ -293,7 +294,14 @@ namespace Learner
vector<uint8_t> generate_random_move_flags();
bool commit_psv(PSVector& a_psv, size_t thread_id, int8_t lastTurnIsWin);
bool was_seen_before(const Position& pos);
bool commit_psv(
Thread& th,
PSVector& sfens,
int8_t lastTurnIsWin,
std::atomic<uint64_t>& counter,
uint64_t limit);
optional<Move> choose_random_move(
Position& pos,
@@ -301,6 +309,8 @@ namespace Learner
int ply,
int& random_move_c);
PRNG prng;
// Min and max depths for search during gensfen
int search_depth_min;
int search_depth_max;
@@ -347,6 +357,15 @@ namespace Learner
vector<Key> hash; // 64MB*sizeof(HASH_KEY) = 512MB
};
void MultiThinkGenSfen::gensfen(uint64_t limit)
{
std::atomic<uint64_t> counter{0};
Threads.execute_with_workers([&counter, limit, this](Thread& th) {
thread_worker(th, counter, limit);
});
Threads.wait_for_workers_finished();
}
optional<int8_t> MultiThinkGenSfen::get_current_game_result(
Position& pos,
const vector<int>& move_hist_scores) const
@@ -470,7 +489,12 @@ namespace Learner
// 1 when winning. -1 when losing. Pass 0 for a draw.
// Return value: true if the specified number of
// sfens has already been reached and the process ends.
bool MultiThinkGenSfen::commit_psv(PSVector& sfens, size_t thread_id, int8_t lastTurnIsWin)
bool MultiThinkGenSfen::commit_psv(
Thread& th,
PSVector& sfens,
int8_t lastTurnIsWin,
std::atomic<uint64_t>& counter,
uint64_t limit)
{
if (!write_out_draw_game_in_training_data_generation && lastTurnIsWin == 0)
{
@@ -482,34 +506,26 @@ namespace Learner
// From the final stage (one step before) to the first stage, give information on the outcome of the game for each stage.
// The phases stored in sfens are assumed to be continuous (in order).
bool quit = false;
int num_sfens_to_commit = 0;
for (auto it = sfens.rbegin(); it != sfens.rend(); ++it)
{
// If is_win == 0 (draw), multiply by -1 and it will remain 0 (draw)
is_win = -is_win;
it->game_result = is_win;
// See how many sfens were already written and get the next id.
// Exit if requested number of sfens reached.
auto now_loop_count = get_next_loop_count();
if (now_loop_count == LOOP_COUNT_FINISHED)
{
quit = true;
break;
}
++num_sfens_to_commit;
}
// Write sfens in move order to make potential compression easier
for (auto it = sfens.end() - num_sfens_to_commit; it != sfens.end(); ++it)
for (auto& sfen : sfens)
{
// Return true if there is already enough data generated.
const auto iter = counter.fetch_add(1);
if (iter >= limit)
return true;
// Write out one sfen.
sfen_writer.write(thread_id, *it);
sfen_writer.write(th.thread_idx(), sfen);
}
return quit;
return false;
}
optional<Move> MultiThinkGenSfen::choose_random_move(
@@ -640,8 +656,29 @@ namespace Learner
return random_move_flag;
}
bool MultiThinkGenSfen::was_seen_before(const Position& pos)
{
// Look into the position hashtable to see if the same
// position was seen before.
// This is a good heuristic to exlude already seen
// positions without many false positives.
auto key = pos.key();
auto hash_index = (size_t)(key & (GENSFEN_HASH_SIZE - 1));
auto old_key = hash[hash_index];
if (key == old_key)
{
return true;
}
else
{
// Replace with the current key.
hash[hash_index] = key;
return false;
}
}
// thread_id = 0..Threads.size()-1
void MultiThinkGenSfen::thread_worker(size_t thread_id)
void MultiThinkGenSfen::thread_worker(Thread& th, std::atomic<uint64_t>& counter, uint64_t limit)
{
// For the time being, it will be treated as a draw
// at the maximum number of steps to write.
@@ -660,10 +697,8 @@ namespace Learner
// It is necessary to set a dependent thread for Position.
// When parallelizing, Threads (since this is a vector<Thread*>,
// Do the same for up to Threads[0]...Threads[thread_num-1].
auto th = Threads[thread_id];
auto& pos = th->rootPos;
pos.set(StartFEN, false, &si, th);
auto& pos = th.rootPos;
pos.set(StartFEN, false, &si, &th);
int resign_counter = 0;
bool should_resign = prng.rand(10) > 1;
@@ -684,13 +719,11 @@ namespace Learner
vector<int> move_hist_scores;
auto flush_psv = [&](int8_t result) {
quit = commit_psv(a_psv, thread_id, result);
quit = commit_psv(th, a_psv, result, counter, limit);
};
for (int ply = 0; ; ++ply)
{
Move next_move = MOVE_NONE;
// Current search depth
const int depth = search_depth_min + (int)prng.rand(search_depth_max - search_depth_min + 1);
@@ -715,18 +748,17 @@ namespace Learner
flush_psv((search_value >= eval_limit) ? 1 : -1);
break;
}
} else {
}
else
{
resign_counter = 0;
}
// Verification of a strange move
if (search_pv.size() > 0
&& (search_pv[0] == MOVE_NONE || search_pv[0] == MOVE_NULL))
// In case there is no PV and the game was not ended here
// there is nothing we can do, we can't continue the game,
// we don't know the result, so discard this game.
if (search_pv.empty())
{
// (???)
// MOVE_WIN is checking if it is the declaration victory stage before this
// The declarative winning move should never come back here.
// Also, when MOVE_RESIGN, search_value is a one-stop score, which should be the minimum value of eval_limit (-31998)...
cout << "Error! : " << pos.fen() << next_move << search_value << endl;
break;
}
@@ -736,34 +768,10 @@ namespace Learner
// Discard stuff before write_minply is reached
// because it can harm training due to overfitting.
// Initial positions would be too common.
if (ply < write_minply - 1)
{
a_psv.clear();
goto SKIP_SAVE;
}
// Look into the position hashtable to see if the same
// position was seen before.
// This is a good heuristic to exlude already seen
// positions without many false positives.
{
auto key = pos.key();
auto hash_index = (size_t)(key & (GENSFEN_HASH_SIZE - 1));
auto old_key = hash[hash_index];
if (key == old_key)
{
goto SKIP_SAVE;
}
else
{
// Replace with the current key.
hash[hash_index] = key;
}
}
// Pack the current position into a packed sfen and save it into the buffer.
if (ply >= write_minply && !was_seen_before(pos))
{
a_psv.emplace_back(PackedSfenValue());
auto& psv = a_psv.back();
// Here we only write the position data.
@@ -771,48 +779,29 @@ namespace Learner
pos.sfen_pack(psv.sfen);
psv.score = search_value;
psv.gamePly = ply;
// Take out the first PV move. This should be present unless depth 0.
assert(search_pv.size() >= 1);
psv.move = search_pv[0];
}
SKIP_SAVE:;
// Update the next move according to best search result or random move.
auto random_move = choose_random_move(pos, random_move_flag, ply, actual_random_move_count);
const Move next_move = random_move.has_value() ? *random_move : search_pv[0];
// For some reason, We could not get PV (hit the substitution table etc. and got stuck?)
// so go to the next game. It's a rare case, so you can ignore it.
if (search_pv.size() == 0)
// We don't have the whole game yet, but it ended,
// so the writing process ends and the next game starts.
// This shouldn't really happen.
if (!is_ok(next_move))
{
break;
}
// Update the next move according to best search result.
next_move = search_pv[0];
// Random move.
auto random_move = choose_random_move(pos, random_move_flag, ply, actual_random_move_count);
if (random_move.has_value())
{
next_move = random_move.value();
// We don't have the whole game yet, but it ended,
// so the writing process ends and the next game starts.
if (!is_ok(next_move))
{
break;
}
}
// Do move.
pos.do_move(next_move, states[ply]);
} // for (int ply = 0; ; ++ply)
}
}
} // while(!quit)
sfen_writer.finalize(thread_id);
sfen_writer.finalize(th.thread_idx());
}
// -----------------------------------
@@ -1029,7 +1018,6 @@ namespace Learner
MultiThinkGenSfen multi_think(search_depth_min, search_depth_max, sfen_writer, seed);
multi_think.nodes = nodes;
multi_think.set_loop_max(loop_max);
multi_think.eval_limit = eval_limit;
multi_think.random_move_minply = random_move_minply;
multi_think.random_move_maxply = random_move_maxply;
@@ -1041,7 +1029,7 @@ namespace Learner
multi_think.write_minply = write_minply;
multi_think.write_maxply = write_maxply;
multi_think.start_file_write_worker();
multi_think.go_think();
multi_think.gensfen(loop_max);
// Since we are joining with the destructor of SfenWriter, please give a message that it has finished after the join
// Enclose this in a block because it should be displayed.