mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 10:57:43 +00:00
Rewrite gensfen to use stockfish's thread pool.
This commit is contained in:
+84
-96
@@ -1,7 +1,6 @@
|
||||
#include "gensfen.h"
|
||||
|
||||
#include "packed_sfen.h"
|
||||
#include "multi_think.h"
|
||||
#include "sfen_stream.h"
|
||||
|
||||
#include "misc.h"
|
||||
@@ -261,7 +260,7 @@ namespace Learner
|
||||
// -----------------------------------
|
||||
|
||||
// Class to generate sfen with multiple threads
|
||||
struct MultiThinkGenSfen : public MultiThink
|
||||
struct MultiThinkGenSfen
|
||||
{
|
||||
// Hash to limit the export of identical sfens
|
||||
static constexpr uint64_t GENSFEN_HASH_SIZE = 64 * 1024 * 1024;
|
||||
@@ -269,7 +268,7 @@ namespace Learner
|
||||
static_assert((GENSFEN_HASH_SIZE& (GENSFEN_HASH_SIZE - 1)) == 0);
|
||||
|
||||
MultiThinkGenSfen(int search_depth_min_, int search_depth_max_, SfenWriter& sw_, const std::string& seed) :
|
||||
MultiThink(seed),
|
||||
prng(seed),
|
||||
search_depth_min(search_depth_min_),
|
||||
search_depth_max(search_depth_max_),
|
||||
sfen_writer(sw_)
|
||||
@@ -285,7 +284,9 @@ namespace Learner
|
||||
sfen_writer.start_file_write_worker();
|
||||
}
|
||||
|
||||
void thread_worker(size_t thread_id) override;
|
||||
void gensfen(uint64_t limit);
|
||||
|
||||
void thread_worker(Thread& th, std::atomic<uint64_t>& counter, uint64_t limit);
|
||||
|
||||
optional<int8_t> get_current_game_result(
|
||||
Position& pos,
|
||||
@@ -293,7 +294,14 @@ namespace Learner
|
||||
|
||||
vector<uint8_t> generate_random_move_flags();
|
||||
|
||||
bool commit_psv(PSVector& a_psv, size_t thread_id, int8_t lastTurnIsWin);
|
||||
bool was_seen_before(const Position& pos);
|
||||
|
||||
bool commit_psv(
|
||||
Thread& th,
|
||||
PSVector& sfens,
|
||||
int8_t lastTurnIsWin,
|
||||
std::atomic<uint64_t>& counter,
|
||||
uint64_t limit);
|
||||
|
||||
optional<Move> choose_random_move(
|
||||
Position& pos,
|
||||
@@ -301,6 +309,8 @@ namespace Learner
|
||||
int ply,
|
||||
int& random_move_c);
|
||||
|
||||
PRNG prng;
|
||||
|
||||
// Min and max depths for search during gensfen
|
||||
int search_depth_min;
|
||||
int search_depth_max;
|
||||
@@ -347,6 +357,15 @@ namespace Learner
|
||||
vector<Key> hash; // 64MB*sizeof(HASH_KEY) = 512MB
|
||||
};
|
||||
|
||||
void MultiThinkGenSfen::gensfen(uint64_t limit)
|
||||
{
|
||||
std::atomic<uint64_t> counter{0};
|
||||
Threads.execute_with_workers([&counter, limit, this](Thread& th) {
|
||||
thread_worker(th, counter, limit);
|
||||
});
|
||||
Threads.wait_for_workers_finished();
|
||||
}
|
||||
|
||||
optional<int8_t> MultiThinkGenSfen::get_current_game_result(
|
||||
Position& pos,
|
||||
const vector<int>& move_hist_scores) const
|
||||
@@ -470,7 +489,12 @@ namespace Learner
|
||||
// 1 when winning. -1 when losing. Pass 0 for a draw.
|
||||
// Return value: true if the specified number of
|
||||
// sfens has already been reached and the process ends.
|
||||
bool MultiThinkGenSfen::commit_psv(PSVector& sfens, size_t thread_id, int8_t lastTurnIsWin)
|
||||
bool MultiThinkGenSfen::commit_psv(
|
||||
Thread& th,
|
||||
PSVector& sfens,
|
||||
int8_t lastTurnIsWin,
|
||||
std::atomic<uint64_t>& counter,
|
||||
uint64_t limit)
|
||||
{
|
||||
if (!write_out_draw_game_in_training_data_generation && lastTurnIsWin == 0)
|
||||
{
|
||||
@@ -482,34 +506,26 @@ namespace Learner
|
||||
|
||||
// From the final stage (one step before) to the first stage, give information on the outcome of the game for each stage.
|
||||
// The phases stored in sfens are assumed to be continuous (in order).
|
||||
bool quit = false;
|
||||
int num_sfens_to_commit = 0;
|
||||
for (auto it = sfens.rbegin(); it != sfens.rend(); ++it)
|
||||
{
|
||||
// If is_win == 0 (draw), multiply by -1 and it will remain 0 (draw)
|
||||
is_win = -is_win;
|
||||
it->game_result = is_win;
|
||||
|
||||
// See how many sfens were already written and get the next id.
|
||||
// Exit if requested number of sfens reached.
|
||||
auto now_loop_count = get_next_loop_count();
|
||||
if (now_loop_count == LOOP_COUNT_FINISHED)
|
||||
{
|
||||
quit = true;
|
||||
break;
|
||||
}
|
||||
|
||||
++num_sfens_to_commit;
|
||||
}
|
||||
|
||||
// Write sfens in move order to make potential compression easier
|
||||
for (auto it = sfens.end() - num_sfens_to_commit; it != sfens.end(); ++it)
|
||||
for (auto& sfen : sfens)
|
||||
{
|
||||
// Return true if there is already enough data generated.
|
||||
const auto iter = counter.fetch_add(1);
|
||||
if (iter >= limit)
|
||||
return true;
|
||||
|
||||
// Write out one sfen.
|
||||
sfen_writer.write(thread_id, *it);
|
||||
sfen_writer.write(th.thread_idx(), sfen);
|
||||
}
|
||||
|
||||
return quit;
|
||||
return false;
|
||||
}
|
||||
|
||||
optional<Move> MultiThinkGenSfen::choose_random_move(
|
||||
@@ -640,8 +656,29 @@ namespace Learner
|
||||
return random_move_flag;
|
||||
}
|
||||
|
||||
bool MultiThinkGenSfen::was_seen_before(const Position& pos)
|
||||
{
|
||||
// Look into the position hashtable to see if the same
|
||||
// position was seen before.
|
||||
// This is a good heuristic to exlude already seen
|
||||
// positions without many false positives.
|
||||
auto key = pos.key();
|
||||
auto hash_index = (size_t)(key & (GENSFEN_HASH_SIZE - 1));
|
||||
auto old_key = hash[hash_index];
|
||||
if (key == old_key)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Replace with the current key.
|
||||
hash[hash_index] = key;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// thread_id = 0..Threads.size()-1
|
||||
void MultiThinkGenSfen::thread_worker(size_t thread_id)
|
||||
void MultiThinkGenSfen::thread_worker(Thread& th, std::atomic<uint64_t>& counter, uint64_t limit)
|
||||
{
|
||||
// For the time being, it will be treated as a draw
|
||||
// at the maximum number of steps to write.
|
||||
@@ -660,10 +697,8 @@ namespace Learner
|
||||
// It is necessary to set a dependent thread for Position.
|
||||
// When parallelizing, Threads (since this is a vector<Thread*>,
|
||||
// Do the same for up to Threads[0]...Threads[thread_num-1].
|
||||
auto th = Threads[thread_id];
|
||||
|
||||
auto& pos = th->rootPos;
|
||||
pos.set(StartFEN, false, &si, th);
|
||||
auto& pos = th.rootPos;
|
||||
pos.set(StartFEN, false, &si, &th);
|
||||
|
||||
int resign_counter = 0;
|
||||
bool should_resign = prng.rand(10) > 1;
|
||||
@@ -684,13 +719,11 @@ namespace Learner
|
||||
vector<int> move_hist_scores;
|
||||
|
||||
auto flush_psv = [&](int8_t result) {
|
||||
quit = commit_psv(a_psv, thread_id, result);
|
||||
quit = commit_psv(th, a_psv, result, counter, limit);
|
||||
};
|
||||
|
||||
for (int ply = 0; ; ++ply)
|
||||
{
|
||||
Move next_move = MOVE_NONE;
|
||||
|
||||
// Current search depth
|
||||
const int depth = search_depth_min + (int)prng.rand(search_depth_max - search_depth_min + 1);
|
||||
|
||||
@@ -715,18 +748,17 @@ namespace Learner
|
||||
flush_psv((search_value >= eval_limit) ? 1 : -1);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
resign_counter = 0;
|
||||
}
|
||||
// Verification of a strange move
|
||||
if (search_pv.size() > 0
|
||||
&& (search_pv[0] == MOVE_NONE || search_pv[0] == MOVE_NULL))
|
||||
|
||||
// In case there is no PV and the game was not ended here
|
||||
// there is nothing we can do, we can't continue the game,
|
||||
// we don't know the result, so discard this game.
|
||||
if (search_pv.empty())
|
||||
{
|
||||
// (???)
|
||||
// MOVE_WIN is checking if it is the declaration victory stage before this
|
||||
// The declarative winning move should never come back here.
|
||||
// Also, when MOVE_RESIGN, search_value is a one-stop score, which should be the minimum value of eval_limit (-31998)...
|
||||
cout << "Error! : " << pos.fen() << next_move << search_value << endl;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -736,34 +768,10 @@ namespace Learner
|
||||
// Discard stuff before write_minply is reached
|
||||
// because it can harm training due to overfitting.
|
||||
// Initial positions would be too common.
|
||||
if (ply < write_minply - 1)
|
||||
{
|
||||
a_psv.clear();
|
||||
goto SKIP_SAVE;
|
||||
}
|
||||
|
||||
// Look into the position hashtable to see if the same
|
||||
// position was seen before.
|
||||
// This is a good heuristic to exlude already seen
|
||||
// positions without many false positives.
|
||||
{
|
||||
auto key = pos.key();
|
||||
auto hash_index = (size_t)(key & (GENSFEN_HASH_SIZE - 1));
|
||||
auto old_key = hash[hash_index];
|
||||
if (key == old_key)
|
||||
{
|
||||
goto SKIP_SAVE;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Replace with the current key.
|
||||
hash[hash_index] = key;
|
||||
}
|
||||
}
|
||||
|
||||
// Pack the current position into a packed sfen and save it into the buffer.
|
||||
if (ply >= write_minply && !was_seen_before(pos))
|
||||
{
|
||||
a_psv.emplace_back(PackedSfenValue());
|
||||
|
||||
auto& psv = a_psv.back();
|
||||
|
||||
// Here we only write the position data.
|
||||
@@ -771,48 +779,29 @@ namespace Learner
|
||||
pos.sfen_pack(psv.sfen);
|
||||
|
||||
psv.score = search_value;
|
||||
|
||||
psv.gamePly = ply;
|
||||
|
||||
// Take out the first PV move. This should be present unless depth 0.
|
||||
assert(search_pv.size() >= 1);
|
||||
psv.move = search_pv[0];
|
||||
}
|
||||
|
||||
SKIP_SAVE:;
|
||||
// Update the next move according to best search result or random move.
|
||||
auto random_move = choose_random_move(pos, random_move_flag, ply, actual_random_move_count);
|
||||
const Move next_move = random_move.has_value() ? *random_move : search_pv[0];
|
||||
|
||||
// For some reason, We could not get PV (hit the substitution table etc. and got stuck?)
|
||||
// so go to the next game. It's a rare case, so you can ignore it.
|
||||
if (search_pv.size() == 0)
|
||||
// We don't have the whole game yet, but it ended,
|
||||
// so the writing process ends and the next game starts.
|
||||
// This shouldn't really happen.
|
||||
if (!is_ok(next_move))
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
// Update the next move according to best search result.
|
||||
next_move = search_pv[0];
|
||||
|
||||
// Random move.
|
||||
auto random_move = choose_random_move(pos, random_move_flag, ply, actual_random_move_count);
|
||||
if (random_move.has_value())
|
||||
{
|
||||
next_move = random_move.value();
|
||||
|
||||
// We don't have the whole game yet, but it ended,
|
||||
// so the writing process ends and the next game starts.
|
||||
if (!is_ok(next_move))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Do move.
|
||||
pos.do_move(next_move, states[ply]);
|
||||
|
||||
} // for (int ply = 0; ; ++ply)
|
||||
}
|
||||
}
|
||||
|
||||
} // while(!quit)
|
||||
|
||||
sfen_writer.finalize(thread_id);
|
||||
sfen_writer.finalize(th.thread_idx());
|
||||
}
|
||||
|
||||
// -----------------------------------
|
||||
@@ -1029,7 +1018,6 @@ namespace Learner
|
||||
|
||||
MultiThinkGenSfen multi_think(search_depth_min, search_depth_max, sfen_writer, seed);
|
||||
multi_think.nodes = nodes;
|
||||
multi_think.set_loop_max(loop_max);
|
||||
multi_think.eval_limit = eval_limit;
|
||||
multi_think.random_move_minply = random_move_minply;
|
||||
multi_think.random_move_maxply = random_move_maxply;
|
||||
@@ -1041,7 +1029,7 @@ namespace Learner
|
||||
multi_think.write_minply = write_minply;
|
||||
multi_think.write_maxply = write_maxply;
|
||||
multi_think.start_file_write_worker();
|
||||
multi_think.go_think();
|
||||
multi_think.gensfen(loop_max);
|
||||
|
||||
// Since we are joining with the destructor of SfenWriter, please give a message that it has finished after the join
|
||||
// Enclose this in a block because it should be displayed.
|
||||
|
||||
Reference in New Issue
Block a user