mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 09:47:46 +00:00
Merge remote-tracking branch 'remotes/nodchip/master' into trainer
This commit is contained in:
+117
-12
@@ -1,16 +1,19 @@
|
||||
#if defined(EVAL_LEARN)
|
||||
#include "convert.h"
|
||||
|
||||
#include "multi_think.h"
|
||||
|
||||
#include "uci.h"
|
||||
#include "misc.h"
|
||||
#include "thread.h"
|
||||
#include "position.h"
|
||||
#include "tt.h"
|
||||
|
||||
// evaluate header for learning
|
||||
#include "../eval/evaluate_common.h"
|
||||
#include "eval/evaluate_common.h"
|
||||
|
||||
#include "learn.h"
|
||||
#include "multi_think.h"
|
||||
#include "../uci.h"
|
||||
#include "../syzygy/tbprobe.h"
|
||||
#include "../misc.h"
|
||||
#include "../thread.h"
|
||||
#include "../position.h"
|
||||
#include "../tt.h"
|
||||
#include "extra/nnue_data_binpack_format.h"
|
||||
|
||||
#include "syzygy/tbprobe.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
@@ -119,7 +122,7 @@ namespace Learner
|
||||
else if (token == "score") {
|
||||
double score;
|
||||
ss >> score;
|
||||
// Training Formula · Issue #71 · nodchip/Stockfish https://github.com/nodchip/Stockfish/issues/71
|
||||
// Training Formula ?Issue #71 ?nodchip/Stockfish https://github.com/nodchip/Stockfish/issues/71
|
||||
// Normalize to [0.0, 1.0].
|
||||
score = (score - src_score_min_value) / (src_score_max_value - src_score_min_value);
|
||||
// Scale to [dest_score_min_value, dest_score_max_value].
|
||||
@@ -497,5 +500,107 @@ namespace Learner
|
||||
ofs.close();
|
||||
std::cout << "all done" << std::endl;
|
||||
}
|
||||
|
||||
static inline const std::string plain_extension = ".plain";
|
||||
static inline const std::string bin_extension = ".bin";
|
||||
static inline const std::string binpack_extension = ".binpack";
|
||||
|
||||
static bool file_exists(const std::string& name)
|
||||
{
|
||||
std::ifstream f(name);
|
||||
return f.good();
|
||||
}
|
||||
|
||||
static bool ends_with(const std::string& lhs, const std::string& end)
|
||||
{
|
||||
if (end.size() > lhs.size()) return false;
|
||||
|
||||
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
|
||||
}
|
||||
|
||||
static bool is_convert_of_type(
|
||||
const std::string& input_path,
|
||||
const std::string& output_path,
|
||||
const std::string& expected_input_extension,
|
||||
const std::string& expected_output_extension)
|
||||
{
|
||||
return ends_with(input_path, expected_input_extension)
|
||||
&& ends_with(output_path, expected_output_extension);
|
||||
}
|
||||
|
||||
using ConvertFunctionType = void(std::string inputPath, std::string outputPath, std::ios_base::openmode om);
|
||||
|
||||
static ConvertFunctionType* get_convert_function(const std::string& input_path, const std::string& output_path)
|
||||
{
|
||||
if (is_convert_of_type(input_path, output_path, plain_extension, bin_extension))
|
||||
return binpack::convertPlainToBin;
|
||||
if (is_convert_of_type(input_path, output_path, plain_extension, binpack_extension))
|
||||
return binpack::convertPlainToBinpack;
|
||||
|
||||
if (is_convert_of_type(input_path, output_path, bin_extension, plain_extension))
|
||||
return binpack::convertBinToPlain;
|
||||
if (is_convert_of_type(input_path, output_path, bin_extension, binpack_extension))
|
||||
return binpack::convertBinToBinpack;
|
||||
|
||||
if (is_convert_of_type(input_path, output_path, binpack_extension, plain_extension))
|
||||
return binpack::convertBinpackToPlain;
|
||||
if (is_convert_of_type(input_path, output_path, binpack_extension, bin_extension))
|
||||
return binpack::convertBinpackToBin;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void convert(const std::string& input_path, const std::string& output_path, std::ios_base::openmode om)
|
||||
{
|
||||
if(!file_exists(input_path))
|
||||
{
|
||||
std::cerr << "Input file does not exist.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
auto func = get_convert_function(input_path, output_path);
|
||||
if (func != nullptr)
|
||||
{
|
||||
func(input_path, output_path, om);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Conversion between files of these types is not supported.\n";
|
||||
}
|
||||
}
|
||||
|
||||
static void convert(const std::vector<std::string>& args)
|
||||
{
|
||||
if (args.size() < 2 || args.size() > 3)
|
||||
{
|
||||
std::cerr << "Invalid arguments.\n";
|
||||
std::cerr << "Usage: convert from_path to_path [append]\n";
|
||||
return;
|
||||
}
|
||||
|
||||
const bool append = (args.size() == 3) && (args[2] == "append");
|
||||
const std::ios_base::openmode openmode =
|
||||
append
|
||||
? std::ios_base::app
|
||||
: std::ios_base::trunc;
|
||||
|
||||
convert(args[0], args[1], openmode);
|
||||
}
|
||||
|
||||
void convert(istringstream& is)
|
||||
{
|
||||
std::vector<std::string> args;
|
||||
|
||||
while (true)
|
||||
{
|
||||
std::string token = "";
|
||||
is >> token;
|
||||
if (token == "")
|
||||
break;
|
||||
|
||||
args.push_back(token);
|
||||
}
|
||||
|
||||
convert(args);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
#ifndef _CONVERT_H_
|
||||
#define _CONVERT_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
namespace Learner {
|
||||
void convert_bin_from_pgn_extract(
|
||||
const std::vector<std::string>& filenames,
|
||||
const std::string& output_file_name,
|
||||
const bool pgn_eval_side_to_move,
|
||||
const bool convert_no_eval_fens_as_score_zero);
|
||||
|
||||
void convert_bin(
|
||||
const std::vector<std::string>& filenames,
|
||||
const std::string& output_file_name,
|
||||
const int ply_minimum,
|
||||
const int ply_maximum,
|
||||
const int interpolate_eval,
|
||||
const int src_score_min_value,
|
||||
const int src_score_max_value,
|
||||
const int dest_score_min_value,
|
||||
const int dest_score_max_value,
|
||||
const bool check_invalid_fen,
|
||||
const bool check_illegal_move);
|
||||
|
||||
void convert_plain(
|
||||
const std::vector<std::string>& filenames,
|
||||
const std::string& output_file_name);
|
||||
|
||||
void convert(std::istringstream& is);
|
||||
}
|
||||
|
||||
#endif
|
||||
+180
-40
@@ -1,17 +1,23 @@
|
||||
#if defined(EVAL_LEARN)
|
||||
#include "gensfen.h"
|
||||
|
||||
#include "../eval/evaluate_common.h"
|
||||
#include "../misc.h"
|
||||
#include "../nnue/evaluate_nnue_learner.h"
|
||||
#include "../position.h"
|
||||
#include "../syzygy/tbprobe.h"
|
||||
#include "../thread.h"
|
||||
#include "../tt.h"
|
||||
#include "../uci.h"
|
||||
#include "learn.h"
|
||||
#include "packed_sfen.h"
|
||||
#include "multi_think.h"
|
||||
#include "../syzygy/tbprobe.h"
|
||||
|
||||
#include "misc.h"
|
||||
#include "position.h"
|
||||
#include "thread.h"
|
||||
#include "tt.h"
|
||||
#include "uci.h"
|
||||
|
||||
#include "eval/evaluate_common.h"
|
||||
|
||||
#include "extra/nnue_data_binpack_format.h"
|
||||
|
||||
#include "nnue/evaluate_nnue_learner.h"
|
||||
|
||||
#include "syzygy/tbprobe.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
@@ -33,11 +39,107 @@ using namespace std;
|
||||
|
||||
namespace Learner
|
||||
{
|
||||
enum struct SfenOutputType
|
||||
{
|
||||
Bin,
|
||||
Binpack
|
||||
};
|
||||
|
||||
static bool write_out_draw_game_in_training_data_generation = false;
|
||||
static bool detect_draw_by_consecutive_low_score = false;
|
||||
static bool detect_draw_by_insufficient_mating_material = false;
|
||||
|
||||
static std::vector<std::string> bookStart;
|
||||
static SfenOutputType sfen_output_type = SfenOutputType::Bin;
|
||||
|
||||
static bool ends_with(const std::string& lhs, const std::string& end)
|
||||
{
|
||||
if (end.size() > lhs.size()) return false;
|
||||
|
||||
return std::equal(end.rbegin(), end.rend(), lhs.rbegin());
|
||||
}
|
||||
|
||||
static std::string filename_with_extension(const std::string& filename, const std::string& ext)
|
||||
{
|
||||
if (ends_with(filename, ext))
|
||||
{
|
||||
return filename;
|
||||
}
|
||||
else
|
||||
{
|
||||
return filename + "." + ext;
|
||||
}
|
||||
}
|
||||
|
||||
struct BasicSfenOutputStream
|
||||
{
|
||||
virtual void write(const PSVector& sfens) = 0;
|
||||
virtual ~BasicSfenOutputStream() {}
|
||||
};
|
||||
|
||||
struct BinSfenOutputStream : BasicSfenOutputStream
|
||||
{
|
||||
static constexpr auto openmode = ios::out | ios::binary | ios::app;
|
||||
static inline const std::string extension = "bin";
|
||||
|
||||
BinSfenOutputStream(std::string filename) :
|
||||
m_stream(filename_with_extension(filename, extension), openmode)
|
||||
{
|
||||
}
|
||||
|
||||
void write(const PSVector& sfens) override
|
||||
{
|
||||
m_stream.write(reinterpret_cast<const char*>(sfens.data()), sizeof(PackedSfenValue) * sfens.size());
|
||||
}
|
||||
|
||||
~BinSfenOutputStream() override {}
|
||||
|
||||
private:
|
||||
fstream m_stream;
|
||||
};
|
||||
|
||||
struct BinpackSfenOutputStream : BasicSfenOutputStream
|
||||
{
|
||||
static constexpr auto openmode = ios::out | ios::binary | ios::app;
|
||||
static inline const std::string extension = "binpack";
|
||||
|
||||
BinpackSfenOutputStream(std::string filename) :
|
||||
m_stream(filename_with_extension(filename, extension), openmode)
|
||||
{
|
||||
}
|
||||
|
||||
void write(const PSVector& sfens) override
|
||||
{
|
||||
static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue));
|
||||
|
||||
for(auto& sfen : sfens)
|
||||
{
|
||||
// The library uses a type that's different but layout-compatibile.
|
||||
binpack::nodchip::PackedSfenValue e;
|
||||
std::memcpy(&e, &sfen, sizeof(binpack::nodchip::PackedSfenValue));
|
||||
m_stream.addTrainingDataEntry(binpack::packedSfenValueToTrainingDataEntry(e));
|
||||
}
|
||||
}
|
||||
|
||||
~BinpackSfenOutputStream() override {}
|
||||
|
||||
private:
|
||||
binpack::CompressedTrainingDataEntryWriter m_stream;
|
||||
};
|
||||
|
||||
static std::unique_ptr<BasicSfenOutputStream> create_new_sfen_output(const std::string& filename)
|
||||
{
|
||||
switch(sfen_output_type)
|
||||
{
|
||||
case SfenOutputType::Bin:
|
||||
return std::make_unique<BinSfenOutputStream>(filename);
|
||||
case SfenOutputType::Binpack:
|
||||
return std::make_unique<BinpackSfenOutputStream>(filename);
|
||||
}
|
||||
|
||||
assert(false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Helper class for exporting Sfen
|
||||
struct SfenWriter
|
||||
@@ -55,7 +157,7 @@ namespace Learner
|
||||
sfen_buffers_pool.reserve((size_t)thread_num * 10);
|
||||
sfen_buffers.resize(thread_num);
|
||||
|
||||
output_file_stream.open(filename_, ios::out | ios::binary | ios::app);
|
||||
output_file_stream = create_new_sfen_output(filename_);
|
||||
filename = filename_;
|
||||
|
||||
finished = false;
|
||||
@@ -65,7 +167,7 @@ namespace Learner
|
||||
{
|
||||
finished = true;
|
||||
file_worker_thread.join();
|
||||
output_file_stream.close();
|
||||
output_file_stream.reset();
|
||||
|
||||
#if defined(_DEBUG)
|
||||
{
|
||||
@@ -134,9 +236,6 @@ namespace Learner
|
||||
{
|
||||
// Also output the current time to console.
|
||||
sync_cout << endl << sfen_write_count << " sfens , at " << now_string() << sync_endl;
|
||||
|
||||
// This is enough for flush().
|
||||
output_file_stream.flush();
|
||||
};
|
||||
|
||||
while (!finished || sfen_buffers_pool.size())
|
||||
@@ -160,7 +259,7 @@ namespace Learner
|
||||
{
|
||||
for (auto& buf : buffers)
|
||||
{
|
||||
output_file_stream.write(reinterpret_cast<const char*>(buf->data()), sizeof(PackedSfenValue) * buf->size());
|
||||
output_file_stream->write(*buf);
|
||||
|
||||
sfen_write_count += buf->size();
|
||||
|
||||
@@ -171,8 +270,6 @@ namespace Learner
|
||||
{
|
||||
sfen_write_count_current_file = 0;
|
||||
|
||||
output_file_stream.close();
|
||||
|
||||
// Sequential number attached to the file
|
||||
int n = (int)(sfen_write_count / save_every);
|
||||
|
||||
@@ -180,7 +277,7 @@ namespace Learner
|
||||
// Add ios::app in consideration of overwriting.
|
||||
// (Depending on the operation, it may not be necessary.)
|
||||
string new_filename = filename + "_" + std::to_string(n);
|
||||
output_file_stream.open(new_filename, ios::out | ios::binary | ios::app);
|
||||
output_file_stream = create_new_sfen_output(new_filename);
|
||||
cout << endl << "output sfen file = " << new_filename << endl;
|
||||
}
|
||||
|
||||
@@ -214,7 +311,7 @@ namespace Learner
|
||||
|
||||
private:
|
||||
|
||||
fstream output_file_stream;
|
||||
std::unique_ptr<BasicSfenOutputStream> output_file_stream;
|
||||
|
||||
// A new net is saved after every save_every sfens are processed.
|
||||
uint64_t save_every = std::numeric_limits<uint64_t>::max();
|
||||
@@ -260,7 +357,8 @@ namespace Learner
|
||||
// It must be 2**N because it will be used as the mask to calculate hash_index.
|
||||
static_assert((GENSFEN_HASH_SIZE& (GENSFEN_HASH_SIZE - 1)) == 0);
|
||||
|
||||
MultiThinkGenSfen(int search_depth_min_, int search_depth_max_, SfenWriter& sw_) :
|
||||
MultiThinkGenSfen(int search_depth_min_, int search_depth_max_, SfenWriter& sw_, const std::string& seed) :
|
||||
MultiThink(seed),
|
||||
search_depth_min(search_depth_min_),
|
||||
search_depth_max(search_depth_max_),
|
||||
sfen_writer(sw_)
|
||||
@@ -759,20 +857,6 @@ namespace Learner
|
||||
break;
|
||||
}
|
||||
|
||||
if (pos.count<ALL_PIECES>() <= 6) {
|
||||
Tablebases::ProbeState probe_state;
|
||||
Tablebases::WDLScore wdl = Tablebases::probe_wdl(pos, &probe_state);
|
||||
assert(wdl != Tablebases::WDLScore::WDLScoreNone);
|
||||
if (wdl == Tablebases::WDLScore::WDLWin) {
|
||||
flush_psv(1);
|
||||
} else if (wdl == Tablebases::WDLScore::WDLLoss) {
|
||||
flush_psv(-1);
|
||||
} else {
|
||||
flush_psv(0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
{
|
||||
auto [search_value, search_pv] = search(pos, depth, 1, nodes);
|
||||
|
||||
@@ -819,6 +903,25 @@ namespace Learner
|
||||
goto SKIP_SAVE;
|
||||
}
|
||||
|
||||
// Look into the position hashtable to see if the same
|
||||
// position was seen before.
|
||||
// This is a good heuristic to exlude already seen
|
||||
// positions without many false positives.
|
||||
{
|
||||
auto key = pos.key();
|
||||
auto hash_index = (size_t)(key & (GENSFEN_HASH_SIZE - 1));
|
||||
auto old_key = hash[hash_index];
|
||||
if (key == old_key)
|
||||
{
|
||||
goto SKIP_SAVE;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Replace with the current key.
|
||||
hash[hash_index] = key;
|
||||
}
|
||||
}
|
||||
|
||||
// Pack the current position into a packed sfen and save it into the buffer.
|
||||
{
|
||||
a_psv.emplace_back(PackedSfenValue());
|
||||
@@ -916,7 +1019,7 @@ namespace Learner
|
||||
int write_maxply = 400;
|
||||
|
||||
// File name to write
|
||||
string output_file_name = "generated_kifu.bin";
|
||||
string output_file_name = "generated_kifu";
|
||||
|
||||
string token;
|
||||
|
||||
@@ -927,6 +1030,9 @@ namespace Learner
|
||||
// Add a random number to the end of the file name.
|
||||
bool random_file_name = false;
|
||||
|
||||
std::string sfen_format;
|
||||
std::string seed;
|
||||
|
||||
while (true)
|
||||
{
|
||||
token = "";
|
||||
@@ -980,10 +1086,26 @@ namespace Learner
|
||||
is >> detect_draw_by_consecutive_low_score;
|
||||
else if (token == "detect_draw_by_insufficient_mating_material")
|
||||
is >> detect_draw_by_insufficient_mating_material;
|
||||
else if (token == "sfen_format")
|
||||
is >> sfen_format;
|
||||
else if (token == "seed")
|
||||
is >> seed;
|
||||
else
|
||||
cout << "Error! : Illegal token " << token << endl;
|
||||
}
|
||||
|
||||
if (!sfen_format.empty())
|
||||
{
|
||||
if (sfen_format == "bin")
|
||||
sfen_output_type = SfenOutputType::Bin;
|
||||
else if (sfen_format == "binpack")
|
||||
sfen_output_type = SfenOutputType::Binpack;
|
||||
else
|
||||
{
|
||||
cout << "Unknown sfen format `" << sfen_format << "`. Using bin\n";
|
||||
}
|
||||
}
|
||||
|
||||
// If search depth2 is not set, leave it the same as search depth.
|
||||
if (search_depth_max == INT_MIN)
|
||||
search_depth_max = search_depth_min;
|
||||
@@ -994,7 +1116,7 @@ namespace Learner
|
||||
{
|
||||
// Give a random number to output_file_name at this point.
|
||||
// Do not use std::random_device(). Because it always the same integers on MinGW.
|
||||
PRNG r(std::chrono::system_clock::now().time_since_epoch().count());
|
||||
PRNG r(seed);
|
||||
// Just in case, reassign the random numbers.
|
||||
for (int i = 0; i < 10; ++i)
|
||||
r.rand(1);
|
||||
@@ -1018,6 +1140,8 @@ namespace Learner
|
||||
bookStart.push_back(line);
|
||||
}
|
||||
myfile.close();
|
||||
} else {
|
||||
bookStart.push_back(StartFEN);
|
||||
}
|
||||
}
|
||||
std::cout << "gensfen : " << endl
|
||||
@@ -1048,12 +1172,30 @@ namespace Learner
|
||||
|
||||
Threads.main()->ponder = false;
|
||||
|
||||
// About Search::Limits
|
||||
// Be careful because this member variable is global and affects other threads.
|
||||
{
|
||||
auto& limits = Search::Limits;
|
||||
|
||||
// Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done)
|
||||
limits.infinite = true;
|
||||
|
||||
// Since PV is an obstacle when displayed, erase it.
|
||||
limits.silent = true;
|
||||
|
||||
// If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it.
|
||||
limits.nodes = 0;
|
||||
|
||||
// depth is also processed by the one passed as an argument of Learner::search().
|
||||
limits.depth = 0;
|
||||
}
|
||||
|
||||
// Create and execute threads as many as Options["Threads"].
|
||||
{
|
||||
SfenWriter sfen_writer(output_file_name, thread_num);
|
||||
sfen_writer.set_save_interval(save_every);
|
||||
|
||||
MultiThinkGenSfen multi_think(search_depth_min, search_depth_max, sfen_writer);
|
||||
MultiThinkGenSfen multi_think(search_depth_min, search_depth_max, sfen_writer, seed);
|
||||
multi_think.nodes = nodes;
|
||||
multi_think.set_loop_max(loop_max);
|
||||
multi_think.eval_limit = eval_limit;
|
||||
@@ -1074,7 +1216,5 @@ namespace Learner
|
||||
}
|
||||
|
||||
std::cout << "gensfen finished." << endl;
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
#ifndef _GENSFEN_H_
|
||||
#define _GENSFEN_H_
|
||||
|
||||
#include "position.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace Learner {
|
||||
|
||||
// Automatic generation of teacher position
|
||||
void gen_sfen(Position& pos, std::istringstream& is);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -7,7 +7,7 @@
|
||||
// Floating point operation by 16bit type
|
||||
// Assume that the float type code generated by the compiler is in IEEE 754 format and use it.
|
||||
|
||||
#include "../types.h"
|
||||
#include "types.h"
|
||||
|
||||
namespace HalfFloat
|
||||
{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+32
-95
@@ -1,10 +1,6 @@
|
||||
#ifndef _LEARN_H_
|
||||
#define _LEARN_H_
|
||||
|
||||
#if defined(EVAL_LEARN)
|
||||
|
||||
#include <vector>
|
||||
|
||||
// ----------------------
|
||||
// Floating point for learning
|
||||
// ----------------------
|
||||
@@ -14,7 +10,7 @@
|
||||
// Even if it is a double type, there is almost no difference in the way of convergence, so fix it to float.
|
||||
|
||||
// when using float
|
||||
typedef float LearnFloatType;
|
||||
using LearnFloatType = float;
|
||||
|
||||
// when using double
|
||||
//typedef double LearnFloatType;
|
||||
@@ -36,107 +32,48 @@ typedef float LearnFloatType;
|
||||
// ----------------------
|
||||
// Definition of struct used in Learner
|
||||
// ----------------------
|
||||
#include "../position.h"
|
||||
|
||||
#include "packed_sfen.h"
|
||||
|
||||
#include "position.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
namespace Learner
|
||||
{
|
||||
// ----------------------
|
||||
// Settings for learning
|
||||
// ----------------------
|
||||
// ----------------------
|
||||
// Settings for learning
|
||||
// ----------------------
|
||||
|
||||
// mini-batch size.
|
||||
// Calculate the gradient by combining this number of phases.
|
||||
// If you make it smaller, the number of update_weights() will increase and the convergence will be faster. The gradient is incorrect.
|
||||
// If you increase it, the number of update_weights() decreases, so the convergence will be slow. The slope will come out accurately.
|
||||
// I don't think you need to change this value in most cases.
|
||||
// mini-batch size.
|
||||
// Calculate the gradient by combining this number of phases.
|
||||
// If you make it smaller, the number of update_weights() will increase and the convergence will be faster. The gradient is incorrect.
|
||||
// If you increase it, the number of update_weights() decreases, so the convergence will be slow. The slope will come out accurately.
|
||||
// I don't think you need to change this value in most cases.
|
||||
|
||||
constexpr std::size_t LEARN_MINI_BATCH_SIZE = 1000 * 1000 * 1;
|
||||
constexpr std::size_t LEARN_MINI_BATCH_SIZE = 1000 * 1000 * 1;
|
||||
|
||||
// The number of phases to read from the file at one time. After reading this much, shuffle.
|
||||
// It is better to have a certain size, but this number x 40 bytes x 3 times as much memory is consumed. 400MB*3 is consumed in the 10M phase.
|
||||
// Must be a multiple of THREAD_BUFFER_SIZE(=10000).
|
||||
// The number of phases to read from the file at one time. After reading this much, shuffle.
|
||||
// It is better to have a certain size, but this number x 40 bytes x 3 times as much memory is consumed. 400MB*3 is consumed in the 10M phase.
|
||||
// Must be a multiple of THREAD_BUFFER_SIZE(=10000).
|
||||
|
||||
constexpr std::size_t LEARN_SFEN_READ_SIZE = 1000 * 1000 * 10;
|
||||
constexpr std::size_t LEARN_SFEN_READ_SIZE = 1000 * 1000 * 10;
|
||||
|
||||
// Saving interval of evaluation function at learning. Save each time you learn this number of phases.
|
||||
// Needless to say, the longer the saving interval, the shorter the learning time.
|
||||
// Folder name is incremented for each save like 0/, 1/, 2/...
|
||||
// By default, once every 1 billion phases.
|
||||
constexpr std::size_t LEARN_EVAL_SAVE_INTERVAL = 1000000000ULL;
|
||||
// Saving interval of evaluation function at learning. Save each time you learn this number of phases.
|
||||
// Needless to say, the longer the saving interval, the shorter the learning time.
|
||||
// Folder name is incremented for each save like 0/, 1/, 2/...
|
||||
// By default, once every 1 billion phases.
|
||||
constexpr std::size_t LEARN_EVAL_SAVE_INTERVAL = 1000000000ULL;
|
||||
|
||||
// Reduce the output of rmse during learning to 1 for this number of times.
|
||||
// rmse calculation is done in one thread, so it takes some time, so reducing the output is effective.
|
||||
constexpr std::size_t LEARN_RMSE_OUTPUT_INTERVAL = 1;
|
||||
// Reduce the output of rmse during learning to 1 for this number of times.
|
||||
// rmse calculation is done in one thread, so it takes some time, so reducing the output is effective.
|
||||
constexpr std::size_t LEARN_RMSE_OUTPUT_INTERVAL = 1;
|
||||
|
||||
//Structure in which PackedSfen and evaluation value are integrated
|
||||
// If you write different contents for each option, it will be a problem when reusing the teacher game
|
||||
// For the time being, write all the following members regardless of the options.
|
||||
struct PackedSfenValue
|
||||
{
|
||||
// phase
|
||||
PackedSfen sfen;
|
||||
double calc_grad(Value shallow, const PackedSfenValue& psv);
|
||||
|
||||
// Evaluation value returned from Learner::search()
|
||||
int16_t score;
|
||||
|
||||
// PV first move
|
||||
// Used when finding the match rate with the teacher
|
||||
uint16_t move;
|
||||
|
||||
// Trouble of the phase from the initial phase.
|
||||
uint16_t gamePly;
|
||||
|
||||
// 1 if the player on this side ultimately wins the game. -1 if you are losing.
|
||||
// 0 if a draw is reached.
|
||||
// The draw is in the teacher position generation command gensfen,
|
||||
// Only write if LEARN_GENSFEN_DRAW_RESULT is enabled.
|
||||
int8_t game_result;
|
||||
|
||||
// When exchanging the file that wrote the teacher aspect with other people
|
||||
//Because this structure size is not fixed, pad it so that it is 40 bytes in any environment.
|
||||
uint8_t padding;
|
||||
|
||||
// 32 + 2 + 2 + 2 + 1 + 1 = 40bytes
|
||||
};
|
||||
|
||||
// Type that returns the reading line and the evaluation value at that time
|
||||
// Used in Learner::search(), Learner::qsearch().
|
||||
typedef std::pair<Value, std::vector<Move> > ValueAndPV;
|
||||
|
||||
// Phase array: PSVector stands for packed sfen vector.
|
||||
typedef std::vector<PackedSfenValue> PSVector;
|
||||
|
||||
// So far, only Yaneura King 2018 Otafuku has this stub
|
||||
// This stub is required if EVAL_LEARN is defined.
|
||||
extern Learner::ValueAndPV search(Position& pos, int depth , size_t multiPV = 1 , uint64_t NodesLimit = 0);
|
||||
extern Learner::ValueAndPV qsearch(Position& pos);
|
||||
|
||||
double calc_grad(Value shallow, const PackedSfenValue& psv);
|
||||
|
||||
void convert_bin_from_pgn_extract(
|
||||
const std::vector<std::string>& filenames,
|
||||
const std::string& output_file_name,
|
||||
const bool pgn_eval_side_to_move,
|
||||
const bool convert_no_eval_fens_as_score_zero);
|
||||
|
||||
void convert_bin(
|
||||
const std::vector<std::string>& filenames,
|
||||
const std::string& output_file_name,
|
||||
const int ply_minimum,
|
||||
const int ply_maximum,
|
||||
const int interpolate_eval,
|
||||
const int src_score_min_value,
|
||||
const int src_score_max_value,
|
||||
const int dest_score_min_value,
|
||||
const int dest_score_max_value,
|
||||
const bool check_invalid_fen,
|
||||
const bool check_illegal_move);
|
||||
|
||||
void convert_plain(
|
||||
const std::vector<std::string>& filenames,
|
||||
const std::string& output_file_name);
|
||||
// Learning from the generated game record
|
||||
void learn(Position& pos, std::istringstream& is);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif // ifndef _LEARN_H_
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
#include "learning_tools.h"
|
||||
|
||||
#if defined (EVAL_LEARN)
|
||||
|
||||
#include "../misc.h"
|
||||
#include "misc.h"
|
||||
|
||||
using namespace Eval;
|
||||
|
||||
@@ -18,5 +16,3 @@ namespace EvalLearningTools
|
||||
uint64_t Weight::eta1_epoch;
|
||||
uint64_t Weight::eta2_epoch;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -5,9 +5,7 @@
|
||||
|
||||
#include "learn.h"
|
||||
|
||||
#if defined (EVAL_LEARN)
|
||||
|
||||
#include "../misc.h" // PRNG , my_insertion_sort
|
||||
#include "misc.h" // PRNG , my_insertion_sort
|
||||
|
||||
#include <array>
|
||||
#include <cmath> // std::sqrt()
|
||||
@@ -98,5 +96,4 @@ namespace EvalLearningTools
|
||||
};
|
||||
}
|
||||
|
||||
#endif // defined (EVAL_LEARN)
|
||||
#endif
|
||||
|
||||
+10
-19
@@ -1,10 +1,9 @@
|
||||
#include "../types.h"
|
||||
#include "multi_think.h"
|
||||
|
||||
#if defined(EVAL_LEARN)
|
||||
|
||||
#include "multi_think.h"
|
||||
#include "../tt.h"
|
||||
#include "../uci.h"
|
||||
#include "tt.h"
|
||||
#include "uci.h"
|
||||
#include "types.h"
|
||||
#include "search.h"
|
||||
|
||||
#include <thread>
|
||||
|
||||
@@ -27,14 +26,13 @@ void MultiThink::go_think()
|
||||
auto thread_num = (size_t)Options["Threads"];
|
||||
|
||||
// Secure end flag of worker thread
|
||||
thread_finished.resize(thread_num);
|
||||
|
||||
threads_finished=0;
|
||||
|
||||
// start worker thread
|
||||
for (size_t i = 0; i < thread_num; ++i)
|
||||
{
|
||||
thread_finished[i] = 0;
|
||||
threads.push_back(std::thread([i, this]
|
||||
{
|
||||
{
|
||||
// exhaust all processor threads.
|
||||
WinProcGroup::bindThisThread(i);
|
||||
|
||||
@@ -42,7 +40,7 @@ void MultiThink::go_think()
|
||||
this->thread_worker(i);
|
||||
|
||||
// Set the end flag because the thread has ended
|
||||
this->thread_finished[i] = 1;
|
||||
this->threads_finished++;
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -56,11 +54,7 @@ void MultiThink::go_think()
|
||||
// function to determine if all threads have finished
|
||||
auto threads_done = [&]()
|
||||
{
|
||||
// returns false if no one is finished
|
||||
for (auto& f : thread_finished)
|
||||
if (!f)
|
||||
return false;
|
||||
return true;
|
||||
return threads_finished == thread_num;
|
||||
};
|
||||
|
||||
// Call back if the callback function is set.
|
||||
@@ -105,6 +99,3 @@ void MultiThink::go_think()
|
||||
// Since the work itself may not have completed, output only that all threads have finished.
|
||||
std::cout << "all threads are joined." << std::endl;
|
||||
}
|
||||
|
||||
|
||||
#endif // defined(EVAL_LEARN)
|
||||
|
||||
+14
-17
@@ -1,17 +1,18 @@
|
||||
#ifndef _MULTI_THINK_
|
||||
#define _MULTI_THINK_
|
||||
|
||||
#if defined(EVAL_LEARN)
|
||||
#include "learn.h"
|
||||
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
|
||||
#include "../misc.h"
|
||||
#include "../learn/learn.h"
|
||||
#include "../thread_win32_osx.h"
|
||||
#include "misc.h"
|
||||
#include "thread_win32_osx.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <limits>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <cstdint>
|
||||
|
||||
|
||||
// Learning from a game record, when making yourself think and generating a fixed track, etc.
|
||||
// Helper class used when multiple threads want to call Search::think() individually.
|
||||
@@ -20,10 +21,11 @@ struct MultiThink
|
||||
{
|
||||
static constexpr std::uint64_t LOOP_COUNT_FINISHED = std::numeric_limits<std::uint64_t>::max();
|
||||
|
||||
MultiThink() : prng(std::chrono::system_clock::now().time_since_epoch().count())
|
||||
{
|
||||
loop_count = 0;
|
||||
}
|
||||
MultiThink() : prng{}, loop_count(0) { }
|
||||
|
||||
MultiThink(std::uint64_t seed) : prng(seed), loop_count(0) { }
|
||||
|
||||
MultiThink(const std::string& seed) : prng(seed), loop_count(0) { }
|
||||
|
||||
// Call this function from the master thread, each thread will think,
|
||||
// Return control when the thought ending condition is satisfied.
|
||||
@@ -94,10 +96,7 @@ private:
|
||||
std::mutex loop_mutex;
|
||||
|
||||
// Thread end flag.
|
||||
// vector<bool> may not be reflected properly when trying to rewrite from multiple threads...
|
||||
typedef uint8_t Flag;
|
||||
std::vector<Flag> thread_finished;
|
||||
|
||||
std::atomic<uint64_t> threads_finished;
|
||||
};
|
||||
|
||||
// Mechanism to process task during idle time.
|
||||
@@ -150,6 +149,4 @@ protected:
|
||||
std::mutex task_mutex;
|
||||
};
|
||||
|
||||
#endif // defined(EVAL_LEARN) && defined(YANEURAOU_2018_OTAFUKU_ENGINE)
|
||||
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
#ifndef _PACKED_SFEN_H_
|
||||
#define _PACKED_SFEN_H_
|
||||
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
namespace Learner {
|
||||
|
||||
// packed sfen
|
||||
struct PackedSfen { std::uint8_t data[32]; };
|
||||
|
||||
// Structure in which PackedSfen and evaluation value are integrated
|
||||
// If you write different contents for each option, it will be a problem when reusing the teacher game
|
||||
// For the time being, write all the following members regardless of the options.
|
||||
struct PackedSfenValue
|
||||
{
|
||||
// phase
|
||||
PackedSfen sfen;
|
||||
|
||||
// Evaluation value returned from Learner::search()
|
||||
std::int16_t score;
|
||||
|
||||
// PV first move
|
||||
// Used when finding the match rate with the teacher
|
||||
std::uint16_t move;
|
||||
|
||||
// Trouble of the phase from the initial phase.
|
||||
std::uint16_t gamePly;
|
||||
|
||||
// 1 if the player on this side ultimately wins the game. -1 if you are losing.
|
||||
// 0 if a draw is reached.
|
||||
// The draw is in the teacher position generation command gensfen,
|
||||
// Only write if LEARN_GENSFEN_DRAW_RESULT is enabled.
|
||||
std::int8_t game_result;
|
||||
|
||||
// When exchanging the file that wrote the teacher aspect with other people
|
||||
//Because this structure size is not fixed, pad it so that it is 40 bytes in any environment.
|
||||
std::uint8_t padding;
|
||||
|
||||
// 32 + 2 + 2 + 2 + 1 + 1 = 40bytes
|
||||
};
|
||||
|
||||
// Phase array: PSVector stands for packed sfen vector.
|
||||
using PSVector = std::vector<PackedSfenValue>;
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,402 @@
|
||||
#include "sfen_packer.h"
|
||||
|
||||
#include "packed_sfen.h"
|
||||
|
||||
#include "misc.h"
|
||||
#include "position.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <cstring> // std::memset()
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace Learner {
|
||||
|
||||
// Class that handles bitstream
|
||||
// useful when doing aspect encoding
|
||||
struct BitStream
|
||||
{
|
||||
// Set the memory to store the data in advance.
|
||||
// Assume that memory is cleared to 0.
|
||||
void set_data(std::uint8_t* data_) { data = data_; reset(); }
|
||||
|
||||
// Get the pointer passed in set_data().
|
||||
uint8_t* get_data() const { return data; }
|
||||
|
||||
// Get the cursor.
|
||||
int get_cursor() const { return bit_cursor; }
|
||||
|
||||
// reset the cursor
|
||||
void reset() { bit_cursor = 0; }
|
||||
|
||||
// Write 1bit to the stream.
|
||||
// If b is non-zero, write out 1. If 0, write 0.
|
||||
void write_one_bit(int b)
|
||||
{
|
||||
if (b)
|
||||
data[bit_cursor / 8] |= 1 << (bit_cursor & 7);
|
||||
|
||||
++bit_cursor;
|
||||
}
|
||||
|
||||
// Get 1 bit from the stream.
|
||||
int read_one_bit()
|
||||
{
|
||||
int b = (data[bit_cursor / 8] >> (bit_cursor & 7)) & 1;
|
||||
++bit_cursor;
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
// write n bits of data
|
||||
// Data shall be written out from the lower order of d.
|
||||
void write_n_bit(int d, int n)
|
||||
{
|
||||
for (int i = 0; i <n; ++i)
|
||||
write_one_bit(d & (1 << i));
|
||||
}
|
||||
|
||||
// read n bits of data
|
||||
// Reverse conversion of write_n_bit().
|
||||
int read_n_bit(int n)
|
||||
{
|
||||
int result = 0;
|
||||
for (int i = 0; i < n; ++i)
|
||||
result |= read_one_bit() ? (1 << i) : 0;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
// Next bit position to read/write.
|
||||
int bit_cursor;
|
||||
|
||||
// data entity
|
||||
std::uint8_t* data;
|
||||
};
|
||||
|
||||
// Class for compressing/decompressing sfen
|
||||
// sfen can be packed to 256bit (32bytes) by Huffman coding.
|
||||
// This is proven by mini. The above is Huffman coding.
|
||||
//
|
||||
// Internal format = 1-bit turn + 7-bit king position *2 + piece on board (Huffman coding) + hand piece (Huffman coding)
|
||||
// Side to move (White = 0, Black = 1) (1bit)
|
||||
// White King Position (6 bits)
|
||||
// Black King Position (6 bits)
|
||||
// Huffman Encoding of the board
|
||||
// Castling availability (1 bit x 4)
|
||||
// En passant square (1 or 1 + 6 bits)
|
||||
// Rule 50 (6 bits)
|
||||
// Game play (8 bits)
|
||||
//
|
||||
// TODO(someone): Rename SFEN to FEN.
|
||||
//
|
||||
struct SfenPacker
|
||||
{
|
||||
void pack(const Position& pos);
|
||||
|
||||
// sfen packed by pack() (256bit = 32bytes)
|
||||
// Or sfen to decode with unpack()
|
||||
uint8_t *data; // uint8_t[32];
|
||||
|
||||
BitStream stream;
|
||||
|
||||
// Output the board pieces to stream.
|
||||
void write_board_piece_to_stream(Piece pc);
|
||||
|
||||
// Read one board piece from stream
|
||||
Piece read_board_piece_from_stream();
|
||||
};
|
||||
|
||||
|
||||
// Huffman coding
|
||||
// * is simplified from mini encoding to make conversion easier.
|
||||
//
|
||||
// 1 box on the board (other than NO_PIECE) = 2 to 6 bits (+ 1-bit flag + 1-bit forward and backward)
|
||||
// 1 piece of hand piece = 1-5bit (+ 1-bit flag + 1bit ahead and behind)
|
||||
//
|
||||
// empty xxxxx0 + 0 (none)
|
||||
// step xxxx01 + 2 xxxx0 + 2
|
||||
// incense xx0011 + 2 xx001 + 2
|
||||
// Katsura xx1011 + 2 xx101 + 2
|
||||
// silver xx0111 + 2 xx011 + 2
|
||||
// Gold x01111 + 1 x0111 + 1 // Gold is valid and has no flags.
|
||||
// corner 011111 + 2 01111 + 2
|
||||
// Fly 111111 + 2 11111 + 2
|
||||
//
|
||||
// Assuming all pieces are on the board,
|
||||
// Sky 81-40 pieces = 41 boxes = 41bit
|
||||
// Walk 4bit*18 pieces = 72bit
|
||||
// Incense 6bit*4 pieces = 24bit
|
||||
// Katsura 6bit*4 pieces = 24bit
|
||||
// Silver 6bit*4 pieces = 24bit
|
||||
// Gold 6bit* 4 pieces = 24bit
|
||||
// corner 8bit* 2 pieces = 16bit
|
||||
// Fly 8bit* 2 pieces = 16bit
|
||||
// -------
|
||||
// 241bit + 1bit (turn) + 7bit × 2 (King's position after) = 256bit
|
||||
//
|
||||
// When the piece on the board moves to the hand piece, the piece on the board becomes empty, so the box on the board can be expressed with 1 bit,
|
||||
// Since the hand piece can be expressed by 1 bit less than the piece on the board, the total number of bits does not change in the end.
|
||||
// Therefore, in this expression, any aspect can be expressed by this bit number.
|
||||
// It is a hand piece and no flag is required, but if you include this, the bit number of the piece on the board will be -1
|
||||
// Since the total number of bits can be fixed, we will include this as well.
|
||||
|
||||
// Huffman Encoding
|
||||
//
|
||||
// Empty xxxxxxx0
|
||||
// Pawn xxxxx001 + 1 bit (Side to move)
|
||||
// Knight xxxxx011 + 1 bit (Side to move)
|
||||
// Bishop xxxxx101 + 1 bit (Side to move)
|
||||
// Rook xxxxx111 + 1 bit (Side to move)
|
||||
|
||||
struct HuffmanedPiece
|
||||
{
|
||||
int code; // how it will be coded
|
||||
int bits; // How many bits do you have
|
||||
};
|
||||
|
||||
constexpr HuffmanedPiece huffman_table[] =
|
||||
{
|
||||
{0b0000,1}, // NO_PIECE
|
||||
{0b0001,4}, // PAWN
|
||||
{0b0011,4}, // KNIGHT
|
||||
{0b0101,4}, // BISHOP
|
||||
{0b0111,4}, // ROOK
|
||||
{0b1001,4}, // QUEEN
|
||||
};
|
||||
|
||||
// Pack sfen and store in data[32].
|
||||
void SfenPacker::pack(const Position& pos)
|
||||
{
|
||||
// cout << pos;
|
||||
|
||||
memset(data, 0, 32 /* 256bit */);
|
||||
stream.set_data(data);
|
||||
|
||||
// turn
|
||||
// Side to move.
|
||||
stream.write_one_bit((int)(pos.side_to_move()));
|
||||
|
||||
// 7-bit positions for leading and trailing balls
|
||||
// White king and black king, 6 bits for each.
|
||||
for(auto c: Colors)
|
||||
stream.write_n_bit(pos.king_square(c), 6);
|
||||
|
||||
// Write the pieces on the board other than the kings.
|
||||
for (Rank r = RANK_8; r >= RANK_1; --r)
|
||||
{
|
||||
for (File f = FILE_A; f <= FILE_H; ++f)
|
||||
{
|
||||
Piece pc = pos.piece_on(make_square(f, r));
|
||||
if (type_of(pc) == KING)
|
||||
continue;
|
||||
write_board_piece_to_stream(pc);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(someone): Support chess960.
|
||||
stream.write_one_bit(pos.can_castle(WHITE_OO));
|
||||
stream.write_one_bit(pos.can_castle(WHITE_OOO));
|
||||
stream.write_one_bit(pos.can_castle(BLACK_OO));
|
||||
stream.write_one_bit(pos.can_castle(BLACK_OOO));
|
||||
|
||||
if (pos.ep_square() == SQ_NONE) {
|
||||
stream.write_one_bit(0);
|
||||
}
|
||||
else {
|
||||
stream.write_one_bit(1);
|
||||
stream.write_n_bit(static_cast<int>(pos.ep_square()), 6);
|
||||
}
|
||||
|
||||
stream.write_n_bit(pos.state()->rule50, 6);
|
||||
|
||||
stream.write_n_bit(1 + (pos.game_ply()-(pos.side_to_move() == BLACK)) / 2, 8);
|
||||
|
||||
assert(stream.get_cursor() <= 256);
|
||||
}
|
||||
|
||||
// Output the board pieces to stream.
|
||||
void SfenPacker::write_board_piece_to_stream(Piece pc)
|
||||
{
|
||||
// piece type
|
||||
PieceType pr = type_of(pc);
|
||||
auto c = huffman_table[pr];
|
||||
stream.write_n_bit(c.code, c.bits);
|
||||
|
||||
if (pc == NO_PIECE)
|
||||
return;
|
||||
|
||||
// first and second flag
|
||||
stream.write_one_bit(color_of(pc));
|
||||
}
|
||||
|
||||
// Read one board piece from stream
|
||||
Piece SfenPacker::read_board_piece_from_stream()
|
||||
{
|
||||
PieceType pr = NO_PIECE_TYPE;
|
||||
int code = 0, bits = 0;
|
||||
while (true)
|
||||
{
|
||||
code |= stream.read_one_bit() << bits;
|
||||
++bits;
|
||||
|
||||
assert(bits <= 6);
|
||||
|
||||
for (pr = NO_PIECE_TYPE; pr <KING; ++pr)
|
||||
if (huffman_table[pr].code == code
|
||||
&& huffman_table[pr].bits == bits)
|
||||
goto Found;
|
||||
}
|
||||
Found:;
|
||||
if (pr == NO_PIECE_TYPE)
|
||||
return NO_PIECE;
|
||||
|
||||
// first and second flag
|
||||
Color c = (Color)stream.read_one_bit();
|
||||
|
||||
return make_piece(c, pr);
|
||||
}
|
||||
|
||||
int set_from_packed_sfen(Position& pos, const PackedSfen& sfen, StateInfo* si, Thread* th, bool mirror)
|
||||
{
|
||||
SfenPacker packer;
|
||||
auto& stream = packer.stream;
|
||||
|
||||
// TODO: separate streams for writing and reading. Here we actually have to
|
||||
// const_cast which is not safe in the long run.
|
||||
stream.set_data(const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(&sfen)));
|
||||
|
||||
pos.clear();
|
||||
std::memset(si, 0, sizeof(StateInfo));
|
||||
std::fill_n(&pos.pieceList[0][0], sizeof(pos.pieceList) / sizeof(Square), SQ_NONE);
|
||||
pos.st = si;
|
||||
|
||||
// Active color
|
||||
pos.sideToMove = (Color)stream.read_one_bit();
|
||||
|
||||
pos.pieceList[W_KING][0] = SQUARE_NB;
|
||||
pos.pieceList[B_KING][0] = SQUARE_NB;
|
||||
|
||||
// First the position of the ball
|
||||
if (mirror)
|
||||
{
|
||||
for (auto c : Colors)
|
||||
pos.board[flip_file((Square)stream.read_n_bit(6))] = make_piece(c, KING);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto c : Colors)
|
||||
pos.board[stream.read_n_bit(6)] = make_piece(c, KING);
|
||||
}
|
||||
|
||||
// Piece placement
|
||||
for (Rank r = RANK_8; r >= RANK_1; --r)
|
||||
{
|
||||
for (File f = FILE_A; f <= FILE_H; ++f)
|
||||
{
|
||||
auto sq = make_square(f, r);
|
||||
if (mirror) {
|
||||
sq = flip_file(sq);
|
||||
}
|
||||
|
||||
// it seems there are already balls
|
||||
Piece pc;
|
||||
if (type_of(pos.board[sq]) != KING)
|
||||
{
|
||||
assert(pos.board[sq] == NO_PIECE);
|
||||
pc = packer.read_board_piece_from_stream();
|
||||
}
|
||||
else
|
||||
{
|
||||
pc = pos.board[sq];
|
||||
// put_piece() will catch ASSERT unless you remove it all.
|
||||
pos.board[sq] = NO_PIECE;
|
||||
}
|
||||
|
||||
// There may be no pieces, so skip in that case.
|
||||
if (pc == NO_PIECE)
|
||||
continue;
|
||||
|
||||
pos.put_piece(Piece(pc), sq);
|
||||
|
||||
if (stream.get_cursor()> 256)
|
||||
return 1;
|
||||
|
||||
//assert(stream.get_cursor() <= 256);
|
||||
}
|
||||
}
|
||||
|
||||
// Castling availability.
|
||||
// TODO(someone): Support chess960.
|
||||
pos.st->castlingRights = 0;
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(WHITE, SQ_H1); pos.piece_on(rsq) != W_ROOK; --rsq) {}
|
||||
pos.set_castling_right(WHITE, rsq);
|
||||
}
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(WHITE, SQ_A1); pos.piece_on(rsq) != W_ROOK; ++rsq) {}
|
||||
pos.set_castling_right(WHITE, rsq);
|
||||
}
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(BLACK, SQ_H1); pos.piece_on(rsq) != B_ROOK; --rsq) {}
|
||||
pos.set_castling_right(BLACK, rsq);
|
||||
}
|
||||
if (stream.read_one_bit()) {
|
||||
Square rsq;
|
||||
for (rsq = relative_square(BLACK, SQ_A1); pos.piece_on(rsq) != B_ROOK; ++rsq) {}
|
||||
pos.set_castling_right(BLACK, rsq);
|
||||
}
|
||||
|
||||
// En passant square. Ignore if no pawn capture is possible
|
||||
if (stream.read_one_bit()) {
|
||||
Square ep_square = static_cast<Square>(stream.read_n_bit(6));
|
||||
if (mirror) {
|
||||
ep_square = flip_file(ep_square);
|
||||
}
|
||||
pos.st->epSquare = ep_square;
|
||||
|
||||
if (!(pos.attackers_to(pos.st->epSquare) & pos.pieces(pos.sideToMove, PAWN))
|
||||
|| !(pos.pieces(~pos.sideToMove, PAWN) & (pos.st->epSquare + pawn_push(~pos.sideToMove))))
|
||||
pos.st->epSquare = SQ_NONE;
|
||||
}
|
||||
else {
|
||||
pos.st->epSquare = SQ_NONE;
|
||||
}
|
||||
|
||||
// Halfmove clock
|
||||
pos.st->rule50 = static_cast<Square>(stream.read_n_bit(6));
|
||||
|
||||
// Fullmove number
|
||||
pos.gamePly = static_cast<Square>(stream.read_n_bit(8));
|
||||
|
||||
// Convert from fullmove starting from 1 to gamePly starting from 0,
|
||||
// handle also common incorrect FEN with fullmove = 0.
|
||||
pos.gamePly = std::max(2 * (pos.gamePly - 1), 0) + (pos.sideToMove == BLACK);
|
||||
|
||||
assert(stream.get_cursor() <= 256);
|
||||
|
||||
pos.chess960 = false;
|
||||
pos.thisThread = th;
|
||||
pos.set_state(pos.st);
|
||||
|
||||
assert(pos.pos_is_ok());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
PackedSfen sfen_pack(Position& pos)
|
||||
{
|
||||
PackedSfen sfen;
|
||||
|
||||
SfenPacker sp;
|
||||
sp.data = (uint8_t*)&sfen;
|
||||
sp.pack(pos);
|
||||
|
||||
return sfen;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
#ifndef _SFEN_PACKER_H_
|
||||
#define _SFEN_PACKER_H_
|
||||
|
||||
#include "types.h"
|
||||
|
||||
#include "learn/packed_sfen.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
class Position;
|
||||
struct StateInfo;
|
||||
class Thread;
|
||||
|
||||
namespace Learner {
|
||||
|
||||
int set_from_packed_sfen(Position& pos, const PackedSfen& sfen, StateInfo* si, Thread* th, bool mirror);
|
||||
PackedSfen sfen_pack(Position& pos);
|
||||
}
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user