mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 08:37:44 +00:00
Add "sfen_format" option in gensfen. Valid values are "bin" and "binpack". It determines the output format of the sfens. Binpack is a highly compressed formats for consecutive sfens. Extension is now determined by the used format, output_file_name should contain just the stem.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
+118
-11
@@ -11,6 +11,8 @@
|
|||||||
#include "learn.h"
|
#include "learn.h"
|
||||||
#include "multi_think.h"
|
#include "multi_think.h"
|
||||||
|
|
||||||
|
#include "../extra/nnue_data_binpack_format.h"
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <climits>
|
#include <climits>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
@@ -32,6 +34,12 @@ using namespace std;
|
|||||||
|
|
||||||
namespace Learner
|
namespace Learner
|
||||||
{
|
{
|
||||||
|
enum struct SfenOutputType
|
||||||
|
{
|
||||||
|
Bin,
|
||||||
|
Binpack
|
||||||
|
};
|
||||||
|
|
||||||
static bool write_out_draw_game_in_training_data_generation = false;
|
static bool write_out_draw_game_in_training_data_generation = false;
|
||||||
static bool detect_draw_by_consecutive_low_score = false;
|
static bool detect_draw_by_consecutive_low_score = false;
|
||||||
static bool detect_draw_by_insufficient_mating_material = false;
|
static bool detect_draw_by_insufficient_mating_material = false;
|
||||||
@@ -42,6 +50,94 @@ namespace Learner
|
|||||||
// https://discordapp.com/channels/435943710472011776/733545871911813221/748524079761326192
|
// https://discordapp.com/channels/435943710472011776/733545871911813221/748524079761326192
|
||||||
extern bool use_raw_nnue_eval;
|
extern bool use_raw_nnue_eval;
|
||||||
|
|
||||||
|
static SfenOutputType sfen_output_type = SfenOutputType::Bin;
|
||||||
|
|
||||||
|
static bool ends_with(const std::string& lhs, const std::string& end)
|
||||||
|
{
|
||||||
|
if (end.size() > lhs.size()) return false;
|
||||||
|
|
||||||
|
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string filename_with_extension(const std::string& filename, const std::string& ext)
|
||||||
|
{
|
||||||
|
if (ends_with(filename, ext))
|
||||||
|
{
|
||||||
|
return filename;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return filename + "." + ext;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct BasicSfenOutputStream
|
||||||
|
{
|
||||||
|
virtual void write(const PSVector& sfens) = 0;
|
||||||
|
virtual ~BasicSfenOutputStream() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BinSfenOutputStream : BasicSfenOutputStream
|
||||||
|
{
|
||||||
|
static constexpr auto openmode = ios::out | ios::binary | ios::app;
|
||||||
|
static inline const std::string extension = "bin";
|
||||||
|
|
||||||
|
BinSfenOutputStream(std::string filename) :
|
||||||
|
m_stream(filename_with_extension(filename, extension), openmode)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void write(const PSVector& sfens) override
|
||||||
|
{
|
||||||
|
m_stream.write(reinterpret_cast<const char*>(sfens.data()), sizeof(PackedSfenValue) * sfens.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
~BinSfenOutputStream() override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
fstream m_stream;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BinpackSfenOutputStream : BasicSfenOutputStream
|
||||||
|
{
|
||||||
|
static constexpr auto openmode = ios::out | ios::binary | ios::app;
|
||||||
|
static inline const std::string extension = "binpack";
|
||||||
|
|
||||||
|
BinpackSfenOutputStream(std::string filename) :
|
||||||
|
m_stream(filename_with_extension(filename, extension), openmode)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
void write(const PSVector& sfens) override
|
||||||
|
{
|
||||||
|
static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue));
|
||||||
|
|
||||||
|
for(auto& sfen : sfens)
|
||||||
|
{
|
||||||
|
// The library uses a type that's different but layout-compatibile.
|
||||||
|
binpack::nodchip::PackedSfenValue e;
|
||||||
|
std::memcpy(&e, &sfen, sizeof(binpack::nodchip::PackedSfenValue));
|
||||||
|
m_stream.addTrainingDataEntry(binpack::packedSfenValueToTrainingDataEntry(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~BinpackSfenOutputStream() override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
binpack::CompressedTrainingDataEntryWriter m_stream;
|
||||||
|
};
|
||||||
|
|
||||||
|
static std::unique_ptr<BasicSfenOutputStream> create_new_sfen_output(const std::string& filename)
|
||||||
|
{
|
||||||
|
switch(sfen_output_type)
|
||||||
|
{
|
||||||
|
case SfenOutputType::Bin:
|
||||||
|
return std::make_unique<BinSfenOutputStream>(filename);
|
||||||
|
default:
|
||||||
|
return std::make_unique<BinpackSfenOutputStream>(filename);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Helper class for exporting Sfen
|
// Helper class for exporting Sfen
|
||||||
struct SfenWriter
|
struct SfenWriter
|
||||||
{
|
{
|
||||||
@@ -58,7 +154,7 @@ namespace Learner
|
|||||||
sfen_buffers_pool.reserve((size_t)thread_num * 10);
|
sfen_buffers_pool.reserve((size_t)thread_num * 10);
|
||||||
sfen_buffers.resize(thread_num);
|
sfen_buffers.resize(thread_num);
|
||||||
|
|
||||||
output_file_stream.open(filename_, ios::out | ios::binary | ios::app);
|
output_file_stream = create_new_sfen_output(filename_);
|
||||||
filename = filename_;
|
filename = filename_;
|
||||||
|
|
||||||
finished = false;
|
finished = false;
|
||||||
@@ -68,7 +164,7 @@ namespace Learner
|
|||||||
{
|
{
|
||||||
finished = true;
|
finished = true;
|
||||||
file_worker_thread.join();
|
file_worker_thread.join();
|
||||||
output_file_stream.close();
|
output_file_stream.reset();
|
||||||
|
|
||||||
#if defined(_DEBUG)
|
#if defined(_DEBUG)
|
||||||
{
|
{
|
||||||
@@ -137,9 +233,6 @@ namespace Learner
|
|||||||
{
|
{
|
||||||
// Also output the current time to console.
|
// Also output the current time to console.
|
||||||
sync_cout << endl << sfen_write_count << " sfens , at " << now_string() << sync_endl;
|
sync_cout << endl << sfen_write_count << " sfens , at " << now_string() << sync_endl;
|
||||||
|
|
||||||
// This is enough for flush().
|
|
||||||
output_file_stream.flush();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
while (!finished || sfen_buffers_pool.size())
|
while (!finished || sfen_buffers_pool.size())
|
||||||
@@ -163,7 +256,7 @@ namespace Learner
|
|||||||
{
|
{
|
||||||
for (auto& buf : buffers)
|
for (auto& buf : buffers)
|
||||||
{
|
{
|
||||||
output_file_stream.write(reinterpret_cast<const char*>(buf->data()), sizeof(PackedSfenValue) * buf->size());
|
output_file_stream->write(*buf);
|
||||||
|
|
||||||
sfen_write_count += buf->size();
|
sfen_write_count += buf->size();
|
||||||
|
|
||||||
@@ -174,8 +267,6 @@ namespace Learner
|
|||||||
{
|
{
|
||||||
sfen_write_count_current_file = 0;
|
sfen_write_count_current_file = 0;
|
||||||
|
|
||||||
output_file_stream.close();
|
|
||||||
|
|
||||||
// Sequential number attached to the file
|
// Sequential number attached to the file
|
||||||
int n = (int)(sfen_write_count / save_every);
|
int n = (int)(sfen_write_count / save_every);
|
||||||
|
|
||||||
@@ -183,7 +274,7 @@ namespace Learner
|
|||||||
// Add ios::app in consideration of overwriting.
|
// Add ios::app in consideration of overwriting.
|
||||||
// (Depending on the operation, it may not be necessary.)
|
// (Depending on the operation, it may not be necessary.)
|
||||||
string new_filename = filename + "_" + std::to_string(n);
|
string new_filename = filename + "_" + std::to_string(n);
|
||||||
output_file_stream.open(new_filename, ios::out | ios::binary | ios::app);
|
output_file_stream = create_new_sfen_output(new_filename);
|
||||||
cout << endl << "output sfen file = " << new_filename << endl;
|
cout << endl << "output sfen file = " << new_filename << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,7 +308,7 @@ namespace Learner
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
fstream output_file_stream;
|
std::unique_ptr<BasicSfenOutputStream> output_file_stream;
|
||||||
|
|
||||||
// A new net is saved after every save_every sfens are processed.
|
// A new net is saved after every save_every sfens are processed.
|
||||||
uint64_t save_every = std::numeric_limits<uint64_t>::max();
|
uint64_t save_every = std::numeric_limits<uint64_t>::max();
|
||||||
@@ -951,7 +1042,7 @@ namespace Learner
|
|||||||
int write_maxply = 400;
|
int write_maxply = 400;
|
||||||
|
|
||||||
// File name to write
|
// File name to write
|
||||||
string output_file_name = "generated_kifu.bin";
|
string output_file_name = "generated_kifu";
|
||||||
|
|
||||||
string token;
|
string token;
|
||||||
|
|
||||||
@@ -962,6 +1053,8 @@ namespace Learner
|
|||||||
// Add a random number to the end of the file name.
|
// Add a random number to the end of the file name.
|
||||||
bool random_file_name = false;
|
bool random_file_name = false;
|
||||||
|
|
||||||
|
std::string sfen_format;
|
||||||
|
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
token = "";
|
token = "";
|
||||||
@@ -1017,10 +1110,24 @@ namespace Learner
|
|||||||
is >> detect_draw_by_insufficient_mating_material;
|
is >> detect_draw_by_insufficient_mating_material;
|
||||||
else if (token == "use_raw_nnue_eval")
|
else if (token == "use_raw_nnue_eval")
|
||||||
is >> use_raw_nnue_eval;
|
is >> use_raw_nnue_eval;
|
||||||
|
else if (token == "sfen_format")
|
||||||
|
is >> sfen_format;
|
||||||
else
|
else
|
||||||
cout << "Error! : Illegal token " << token << endl;
|
cout << "Error! : Illegal token " << token << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!sfen_format.empty())
|
||||||
|
{
|
||||||
|
if (sfen_format == "bin")
|
||||||
|
sfen_output_type = SfenOutputType::Bin;
|
||||||
|
else if (sfen_format == "binpack")
|
||||||
|
sfen_output_type = SfenOutputType::Binpack;
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cout << "Unknown sfen format `" << sfen_format << "`. Using bin\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If search depth2 is not set, leave it the same as search depth.
|
// If search depth2 is not set, leave it the same as search depth.
|
||||||
if (search_depth_max == INT_MIN)
|
if (search_depth_max == INT_MIN)
|
||||||
search_depth_max = search_depth_min;
|
search_depth_max = search_depth_min;
|
||||||
|
|||||||
Reference in New Issue
Block a user