mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 03:57:45 +00:00
[cluster] keep track of node counts cluster-wide.
This generalizes exchange of signals between the ranks using a non-blocking all-reduce. It is now used for the stop signal and the node count, but should be easily generalizable (TB hits, and ponder still missing). It avoids having long-lived outstanding non-blocking collectives (removes an early posted Ibarrier). A bit too short a test, but not worse than before: Score of new-r4-1t vs old-r4-1t: 459 - 401 - 1505 [0.512] 2365 Elo difference: 8.52 +/- 8.43
This commit is contained in:
committed by
Stéphane Nicolet
parent
2f882309d5
commit
87f0fa55a0
+96
-40
@@ -37,13 +37,21 @@ namespace Cluster {
|
||||
|
||||
static int world_rank = MPI_PROC_NULL;
|
||||
static int world_size = 0;
|
||||
static bool stop_signal = false;
|
||||
static MPI_Request reqStop = MPI_REQUEST_NULL;
|
||||
|
||||
static MPI_Request reqSignals = MPI_REQUEST_NULL;
|
||||
static uint64_t signalsCallCounter = 0;
|
||||
|
||||
enum Signals : int { SIG_NODES = 0, SIG_STOP = 1, SIG_NB = 2};
|
||||
static uint64_t signalsSend[SIG_NB] = {};
|
||||
static uint64_t signalsRecv[SIG_NB] = {};
|
||||
|
||||
static uint64_t nodesSearchedOthers = 0;
|
||||
static uint64_t stopSignalsPosted = 0;
|
||||
|
||||
static MPI_Comm InputComm = MPI_COMM_NULL;
|
||||
static MPI_Comm TTComm = MPI_COMM_NULL;
|
||||
static MPI_Comm MoveComm = MPI_COMM_NULL;
|
||||
static MPI_Comm StopComm = MPI_COMM_NULL;
|
||||
static MPI_Comm signalsComm = MPI_COMM_NULL;
|
||||
|
||||
static std::vector<KeyedTTEntry> TTBuff;
|
||||
|
||||
@@ -75,14 +83,34 @@ void init() {
|
||||
MPI_Comm_dup(MPI_COMM_WORLD, &InputComm);
|
||||
MPI_Comm_dup(MPI_COMM_WORLD, &TTComm);
|
||||
MPI_Comm_dup(MPI_COMM_WORLD, &MoveComm);
|
||||
MPI_Comm_dup(MPI_COMM_WORLD, &StopComm);
|
||||
MPI_Comm_dup(MPI_COMM_WORLD, &signalsComm);
|
||||
}
|
||||
|
||||
void finalize() {
|
||||
|
||||
|
||||
// free data tyes and communicators
|
||||
MPI_Type_free(&MIDatatype);
|
||||
|
||||
MPI_Comm_free(&InputComm);
|
||||
MPI_Comm_free(&TTComm);
|
||||
MPI_Comm_free(&MoveComm);
|
||||
MPI_Comm_free(&signalsComm);
|
||||
|
||||
MPI_Finalize();
|
||||
}
|
||||
|
||||
int size() {
|
||||
|
||||
return world_size;
|
||||
}
|
||||
|
||||
int rank() {
|
||||
|
||||
return world_rank;
|
||||
}
|
||||
|
||||
|
||||
bool getline(std::istream& input, std::string& str) {
|
||||
|
||||
int size;
|
||||
@@ -124,45 +152,60 @@ bool getline(std::istream& input, std::string& str) {
|
||||
return state;
|
||||
}
|
||||
|
||||
void sync_start() {
|
||||
void signals_send() {
|
||||
|
||||
stop_signal = false;
|
||||
|
||||
// Start listening to stop signal
|
||||
if (!is_root())
|
||||
MPI_Ibarrier(StopComm, &reqStop);
|
||||
signalsSend[SIG_NODES] = Threads.nodes_searched();
|
||||
signalsSend[SIG_STOP] = Threads.stop;
|
||||
MPI_Iallreduce(signalsSend, signalsRecv, SIG_NB, MPI_UINT64_T,
|
||||
MPI_SUM, signalsComm, &reqSignals);
|
||||
++signalsCallCounter;
|
||||
}
|
||||
|
||||
void sync_stop() {
|
||||
void signals_process() {
|
||||
|
||||
if (is_root())
|
||||
nodesSearchedOthers = signalsRecv[SIG_NODES] - signalsSend[SIG_NODES];
|
||||
stopSignalsPosted = signalsRecv[SIG_STOP];
|
||||
if (signalsRecv[SIG_STOP] > 0)
|
||||
Threads.stop = true;
|
||||
}
|
||||
|
||||
void signals_sync() {
|
||||
|
||||
while(stopSignalsPosted < uint64_t(size()))
|
||||
signals_poll();
|
||||
|
||||
// finalize outstanding messages of the signal loops. We might have issued one call less than needed on some ranks.
|
||||
uint64_t globalCounter;
|
||||
MPI_Allreduce(&signalsCallCounter, &globalCounter, 1, MPI_UINT64_T, MPI_MAX, MoveComm); // MoveComm needed
|
||||
if (signalsCallCounter < globalCounter)
|
||||
signals_send();
|
||||
|
||||
assert(signalsCallCounter == globalCounter);
|
||||
|
||||
MPI_Wait(&reqSignals, MPI_STATUS_IGNORE);
|
||||
|
||||
signals_process();
|
||||
|
||||
}
|
||||
|
||||
void signals_init() {
|
||||
|
||||
stopSignalsPosted = nodesSearchedOthers = 0;
|
||||
|
||||
signalsSend[SIG_NODES] = signalsRecv[SIG_NODES] = 0;
|
||||
signalsSend[SIG_STOP] = signalsRecv[SIG_STOP] = 0;
|
||||
|
||||
}
|
||||
|
||||
void signals_poll() {
|
||||
|
||||
int flag;
|
||||
MPI_Test(&reqSignals, &flag, MPI_STATUS_IGNORE);
|
||||
if (flag)
|
||||
{
|
||||
if (!stop_signal && Threads.stop)
|
||||
{
|
||||
// Signal the cluster about stopping
|
||||
stop_signal = true;
|
||||
MPI_Ibarrier(StopComm, &reqStop);
|
||||
MPI_Wait(&reqStop, MPI_STATUS_IGNORE);
|
||||
}
|
||||
signals_process();
|
||||
signals_send();
|
||||
}
|
||||
else
|
||||
{
|
||||
int flagStop;
|
||||
// Check if we've received any stop signal
|
||||
MPI_Test(&reqStop, &flagStop, MPI_STATUS_IGNORE);
|
||||
if (flagStop)
|
||||
Threads.stop = true;
|
||||
}
|
||||
}
|
||||
|
||||
int size() {
|
||||
|
||||
return world_size;
|
||||
}
|
||||
|
||||
int rank() {
|
||||
|
||||
return world_rank;
|
||||
}
|
||||
|
||||
void save(Thread* thread, TTEntry* tte,
|
||||
@@ -270,10 +313,23 @@ void pick_moves(MoveInfo& mi) {
|
||||
MPI_Bcast(&mi, 1, MIDatatype, 0, MoveComm);
|
||||
}
|
||||
|
||||
void sum(uint64_t& val) {
|
||||
uint64_t nodes_searched() {
|
||||
|
||||
const uint64_t send = val;
|
||||
MPI_Reduce(&send, &val, 1, MPI_UINT64_T, MPI_SUM, 0, MoveComm);
|
||||
return nodesSearchedOthers + Threads.nodes_searched();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#include "cluster.h"
|
||||
#include "thread.h"
|
||||
|
||||
namespace Cluster {
|
||||
|
||||
uint64_t nodes_searched() {
|
||||
|
||||
return Threads.nodes_searched();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user