Indentation consistency in learn folder

This commit is contained in:
Tomasz Sobczyk
2020-10-14 19:53:41 +02:00
committed by nodchip
parent 880d23af1c
commit 904adb9a32
6 changed files with 607 additions and 612 deletions
+1 -2
View File
@@ -3,7 +3,6 @@
#include "packed_sfen.h" #include "packed_sfen.h"
#include "multi_think.h" #include "multi_think.h"
#include "sfen_stream.h" #include "sfen_stream.h"
#include "../syzygy/tbprobe.h"
#include "misc.h" #include "misc.h"
#include "position.h" #include "position.h"
@@ -73,7 +72,7 @@ namespace Learner
file_worker_thread.join(); file_worker_thread.join();
output_file_stream.reset(); output_file_stream.reset();
#if defined(_DEBUG) #if !defined(NDEBUG)
{ {
// All buffers should be empty since file_worker_thread // All buffers should be empty since file_worker_thread
// should have written everything before exiting. // should have written everything before exiting.
+89 -89
View File
@@ -11,122 +11,122 @@
namespace HalfFloat namespace HalfFloat
{ {
// IEEE 754 float 32 format is : // IEEE 754 float 32 format is :
// sign(1bit) + exponent(8bits) + fraction(23bits) = 32bits // sign(1bit) + exponent(8bits) + fraction(23bits) = 32bits
// //
// Our float16 format is : // Our float16 format is :
// sign(1bit) + exponent(5bits) + fraction(10bits) = 16bits // sign(1bit) + exponent(5bits) + fraction(10bits) = 16bits
union float32_converter union float32_converter
{ {
int32_t n; int32_t n;
float f; float f;
}; };
// 16-bit float // 16-bit float
struct float16 struct float16
{ {
// --- constructors // --- constructors
float16() {} float16() {}
float16(int16_t n) { from_float((float)n); } float16(int16_t n) { from_float((float)n); }
float16(int32_t n) { from_float((float)n); } float16(int32_t n) { from_float((float)n); }
float16(float n) { from_float(n); } float16(float n) { from_float(n); }
float16(double n) { from_float((float)n); } float16(double n) { from_float((float)n); }
// build from a float // build from a float
void from_float(float f) { *this = to_float16(f); } void from_float(float f) { *this = to_float16(f); }
// --- implicit converters // --- implicit converters
operator int32_t() const { return (int32_t)to_float(*this); } operator int32_t() const { return (int32_t)to_float(*this); }
operator float() const { return to_float(*this); } operator float() const { return to_float(*this); }
operator double() const { return double(to_float(*this)); } operator double() const { return double(to_float(*this)); }
// --- operators // --- operators
float16 operator += (float16 rhs) { from_float(to_float(*this) + to_float(rhs)); return *this; } float16 operator += (float16 rhs) { from_float(to_float(*this) + to_float(rhs)); return *this; }
float16 operator -= (float16 rhs) { from_float(to_float(*this) - to_float(rhs)); return *this; } float16 operator -= (float16 rhs) { from_float(to_float(*this) - to_float(rhs)); return *this; }
float16 operator *= (float16 rhs) { from_float(to_float(*this) * to_float(rhs)); return *this; } float16 operator *= (float16 rhs) { from_float(to_float(*this) * to_float(rhs)); return *this; }
float16 operator /= (float16 rhs) { from_float(to_float(*this) / to_float(rhs)); return *this; } float16 operator /= (float16 rhs) { from_float(to_float(*this) / to_float(rhs)); return *this; }
float16 operator + (float16 rhs) const { return float16(*this) += rhs; } float16 operator + (float16 rhs) const { return float16(*this) += rhs; }
float16 operator - (float16 rhs) const { return float16(*this) -= rhs; } float16 operator - (float16 rhs) const { return float16(*this) -= rhs; }
float16 operator * (float16 rhs) const { return float16(*this) *= rhs; } float16 operator * (float16 rhs) const { return float16(*this) *= rhs; }
float16 operator / (float16 rhs) const { return float16(*this) /= rhs; } float16 operator / (float16 rhs) const { return float16(*this) /= rhs; }
float16 operator - () const { return float16(-to_float(*this)); } float16 operator - () const { return float16(-to_float(*this)); }
bool operator == (float16 rhs) const { return this->v_ == rhs.v_; } bool operator == (float16 rhs) const { return this->v_ == rhs.v_; }
bool operator != (float16 rhs) const { return !(*this == rhs); } bool operator != (float16 rhs) const { return !(*this == rhs); }
static void UnitTest() { unit_test(); } static void UnitTest() { unit_test(); }
private: private:
// --- entity // --- entity
uint16_t v_; uint16_t v_;
// --- conversion between float and float16 // --- conversion between float and float16
static float16 to_float16(float f) static float16 to_float16(float f)
{ {
float32_converter c; float32_converter c;
c.f = f; c.f = f;
u32 n = c.n; u32 n = c.n;
// The sign bit is MSB in common. // The sign bit is MSB in common.
uint16_t sign_bit = (n >> 16) & 0x8000; uint16_t sign_bit = (n >> 16) & 0x8000;
// The exponent of IEEE 754's float 32 is biased +127 , so we change this bias into +15 and limited to 5-bit. // The exponent of IEEE 754's float 32 is biased +127 , so we change this bias into +15 and limited to 5-bit.
uint16_t exponent = (((n >> 23) - 127 + 15) & 0x1f) << 10; uint16_t exponent = (((n >> 23) - 127 + 15) & 0x1f) << 10;
// The fraction is limited to 10-bit. // The fraction is limited to 10-bit.
uint16_t fraction = (n >> (23-10)) & 0x3ff; uint16_t fraction = (n >> (23-10)) & 0x3ff;
float16 f_; float16 f_;
f_.v_ = sign_bit | exponent | fraction; f_.v_ = sign_bit | exponent | fraction;
return f_; return f_;
} }
static float to_float(float16 v) static float to_float(float16 v)
{ {
u32 sign_bit = (v.v_ & 0x8000) << 16; u32 sign_bit = (v.v_ & 0x8000) << 16;
u32 exponent = ((((v.v_ >> 10) & 0x1f) - 15 + 127) & 0xff) << 23; u32 exponent = ((((v.v_ >> 10) & 0x1f) - 15 + 127) & 0xff) << 23;
u32 fraction = (v.v_ & 0x3ff) << (23 - 10); u32 fraction = (v.v_ & 0x3ff) << (23 - 10);
float32_converter c; float32_converter c;
c.n = sign_bit | exponent | fraction; c.n = sign_bit | exponent | fraction;
return c.f; return c.f;
} }
// It is not a unit test, but I confirmed that it can be calculated. I'll fix the code later (maybe). // It is not a unit test, but I confirmed that it can be calculated. I'll fix the code later (maybe).
static void unit_test() static void unit_test()
{ {
float16 a, b, c, d; float16 a, b, c, d;
a = 1; a = 1;
std::cout << (float)a << std::endl; std::cout << (float)a << std::endl;
b = -118.625; b = -118.625;
std::cout << (float)b << std::endl; std::cout << (float)b << std::endl;
c = 2.5; c = 2.5;
std::cout << (float)c << std::endl; std::cout << (float)c << std::endl;
d = a + c; d = a + c;
std::cout << (float)d << std::endl; std::cout << (float)d << std::endl;
c *= 1.5; c *= 1.5;
std::cout << (float)c << std::endl; std::cout << (float)c << std::endl;
b /= 3; b /= 3;
std::cout << (float)b << std::endl; std::cout << (float)b << std::endl;
float f1 = 1.5; float f1 = 1.5;
a += f1; a += f1;
std::cout << (float)a << std::endl; std::cout << (float)a << std::endl;
a += f1 * (float)a; a += f1 * (float)a;
std::cout << (float)a << std::endl; std::cout << (float)a << std::endl;
} }
}; };
} }
+8 -8
View File
@@ -1066,14 +1066,14 @@ namespace Learner
pos.do_move((Move)ps.move, state[ply++]); pos.do_move((Move)ps.move, state[ply++]);
// There is a possibility that all the pieces are blocked and stuck. // There is a possibility that all the pieces are blocked and stuck.
// Also, the declaration win phase is excluded from // Also, the declaration win phase is excluded from
// learning because you cannot go to leaf with PV moves. // learning because you cannot go to leaf with PV moves.
// (shouldn't write out such teacher aspect itself, // (shouldn't write out such teacher aspect itself,
// but may have written it out with an old generation routine) // but may have written it out with an old generation routine)
// Skip the position if there are no legal moves (=checkmated or stalemate). // Skip the position if there are no legal moves (=checkmated or stalemate).
if (MoveList<LEGAL>(pos).size() == 0) if (MoveList<LEGAL>(pos).size() == 0)
goto RETRY_READ; goto RETRY_READ;
// Evaluation value of shallow search (qsearch) // Evaluation value of shallow search (qsearch)
const auto [_, pv] = Search::qsearch(pos); const auto [_, pv] = Search::qsearch(pos);
+72 -72
View File
@@ -1,103 +1,103 @@
#include "multi_think.h" #include "multi_think.h"
#include "nnue/evaluate_nnue.h"
#include "tt.h" #include "tt.h"
#include "uci.h" #include "uci.h"
#include "types.h" #include "types.h"
#include "search.h" #include "search.h"
#include "nnue/evaluate_nnue.h"
#include <thread> #include <thread>
void MultiThink::go_think() void MultiThink::go_think()
{ {
// Read evaluation function, etc. // Read evaluation function, etc.
// In the case of the learn command, the value of the evaluation function may be corrected after reading the evaluation function, so // In the case of the learn command, the value of the evaluation function may be corrected after reading the evaluation function, so
// Skip memory corruption check. // Skip memory corruption check.
Eval::NNUE::init(); Eval::NNUE::init();
// Call the derived class's init(). // Call the derived class's init().
init(); init();
// The loop upper limit is set with set_loop_max(). // The loop upper limit is set with set_loop_max().
loop_count = 0; loop_count = 0;
done_count = 0; done_count = 0;
// Create threads as many as Options["Threads"] and start thinking. // Create threads as many as Options["Threads"] and start thinking.
std::vector<std::thread> threads; std::vector<std::thread> threads;
auto thread_num = (size_t)Options["Threads"]; auto thread_num = (size_t)Options["Threads"];
// Secure end flag of worker thread // Secure end flag of worker thread
threads_finished=0; threads_finished=0;
// start worker thread // start worker thread
for (size_t i = 0; i < thread_num; ++i) for (size_t i = 0; i < thread_num; ++i)
{ {
threads.push_back(std::thread([i, this] threads.push_back(std::thread([i, this]
{ {
// exhaust all processor threads. // exhaust all processor threads.
WinProcGroup::bindThisThread(i); WinProcGroup::bindThisThread(i);
// execute the overridden process // execute the overridden process
this->thread_worker(i); this->thread_worker(i);
// Set the end flag because the thread has ended // Set the end flag because the thread has ended
this->threads_finished++; this->threads_finished++;
})); }));
} }
// wait for all threads to finish // wait for all threads to finish
// for (auto& th :threads) // for (auto& th :threads)
// th.join(); // th.join();
// If you write like, the thread will rush here while it is still working, // If you write like, the thread will rush here while it is still working,
// During that time, callback_func() cannot be called and you cannot save. // During that time, callback_func() cannot be called and you cannot save.
// Therefore, you need to check the end flag yourself. // Therefore, you need to check the end flag yourself.
// function to determine if all threads have finished // function to determine if all threads have finished
auto threads_done = [&]() auto threads_done = [&]()
{ {
return threads_finished == thread_num; return threads_finished == thread_num;
}; };
// Call back if the callback function is set. // Call back if the callback function is set.
auto do_a_callback = [&]() auto do_a_callback = [&]()
{ {
if (callback_func) if (callback_func)
callback_func(); callback_func();
}; };
for (uint64_t i = 0 ; ; ) for (uint64_t i = 0 ; ; )
{ {
// If all threads have finished, exit the loop. // If all threads have finished, exit the loop.
if (threads_done()) if (threads_done())
break; break;
sleep(1000); sleep(1000);
// callback_func() is called every callback_seconds. // callback_func() is called every callback_seconds.
if (++i == callback_seconds) if (++i == callback_seconds)
{ {
do_a_callback(); do_a_callback();
// Since I am returning from ↑, I reset the counter, so // Since I am returning from ↑, I reset the counter, so
// no matter how long it takes to save() etc. in do_a_callback() // no matter how long it takes to save() etc. in do_a_callback()
// The next call will take a certain amount of time. // The next call will take a certain amount of time.
i = 0; i = 0;
} }
} }
// Last save. // Last save.
std::cout << std::endl << "finalize.."; std::cout << std::endl << "finalize..";
// do_a_callback(); // do_a_callback();
// → It should be saved by the caller, so I feel that it is not necessary here. // → It should be saved by the caller, so I feel that it is not necessary here.
// It is possible that the exit code of the thread is running but the exit code of the thread is running, so // It is possible that the exit code of the thread is running but the exit code of the thread is running, so
// We need to wait for the end with join(). // We need to wait for the end with join().
for (auto& th : threads) for (auto& th : threads)
th.join(); th.join();
// The file writing thread etc. are still running only when all threads are finished // The file writing thread etc. are still running only when all threads are finished
// Since the work itself may not have completed, output only that all threads have finished. // Since the work itself may not have completed, output only that all threads have finished.
std::cout << "all threads are joined." << std::endl; std::cout << "all threads are joined." << std::endl;
} }
+94 -94
View File
@@ -19,84 +19,84 @@
// Derive and use this class. // Derive and use this class.
struct MultiThink struct MultiThink
{ {
static constexpr std::uint64_t LOOP_COUNT_FINISHED = std::numeric_limits<std::uint64_t>::max(); static constexpr std::uint64_t LOOP_COUNT_FINISHED = std::numeric_limits<std::uint64_t>::max();
MultiThink() : prng{}, loop_count(0) { } MultiThink() : prng{}, loop_count(0) { }
MultiThink(std::uint64_t seed) : prng(seed), loop_count(0) { } MultiThink(std::uint64_t seed) : prng(seed), loop_count(0) { }
MultiThink(const std::string& seed) : prng(seed), loop_count(0) { } MultiThink(const std::string& seed) : prng(seed), loop_count(0) { }
// Call this function from the master thread, each thread will think, // Call this function from the master thread, each thread will think,
// Return control when the thought ending condition is satisfied. // Return control when the thought ending condition is satisfied.
// Do something else. // Do something else.
// ・It is safe for each thread to call Learner::search(),qsearch() // ・It is safe for each thread to call Learner::search(),qsearch()
// Separates the substitution table for each thread. (It will be restored after the end.) // Separates the substitution table for each thread. (It will be restored after the end.)
// ・Book is not thread safe when in on the fly mode, so temporarily change this mode. // ・Book is not thread safe when in on the fly mode, so temporarily change this mode.
// Turn it off. // Turn it off.
// [Requirements] // [Requirements]
// 1) Override thread_worker() // 1) Override thread_worker()
// 2) Set the loop count with set_loop_max() // 2) Set the loop count with set_loop_max()
// 3) set a function to be called back periodically (if necessary) // 3) set a function to be called back periodically (if necessary)
// callback_func and callback_interval // callback_func and callback_interval
void go_think(); void go_think();
// If there is something you want to initialize on the derived class side, override this, // If there is something you want to initialize on the derived class side, override this,
// Called when initialization is completed with go_think(). // Called when initialization is completed with go_think().
// It is better to read the fixed trace at that timing. // It is better to read the fixed trace at that timing.
virtual void init() {} virtual void init() {}
// A thread worker that is called by creating a thread when you go_think() // A thread worker that is called by creating a thread when you go_think()
// Override and use this. // Override and use this.
virtual void thread_worker(size_t thread_id) = 0; virtual void thread_worker(size_t thread_id) = 0;
// Called back every callback_seconds [seconds] when go_think(). // Called back every callback_seconds [seconds] when go_think().
std::function<void()> callback_func; std::function<void()> callback_func;
uint64_t callback_seconds = 600; uint64_t callback_seconds = 600;
// Set the number of times worker processes (calls Search::think()). // Set the number of times worker processes (calls Search::think()).
void set_loop_max(uint64_t loop_max_) { loop_max = loop_max_; } void set_loop_max(uint64_t loop_max_) { loop_max = loop_max_; }
// Get the value set by set_loop_max(). // Get the value set by set_loop_max().
uint64_t get_loop_max() const { return loop_max; } uint64_t get_loop_max() const { return loop_max; }
// [ASYNC] Take the value of the loop counter and add the loop counter after taking it out. // [ASYNC] Take the value of the loop counter and add the loop counter after taking it out.
// If the loop counter has reached loop_max, return UINT64_MAX. // If the loop counter has reached loop_max, return UINT64_MAX.
// If you want to generate a phase, you must call this function at the time of generating the phase, // If you want to generate a phase, you must call this function at the time of generating the phase,
// Please note that the number of generated phases and the value of the counter will not match. // Please note that the number of generated phases and the value of the counter will not match.
uint64_t get_next_loop_count() { uint64_t get_next_loop_count() {
std::unique_lock<std::mutex> lk(loop_mutex); std::unique_lock<std::mutex> lk(loop_mutex);
if (loop_count >= loop_max) if (loop_count >= loop_max)
return LOOP_COUNT_FINISHED; return LOOP_COUNT_FINISHED;
return loop_count++; return loop_count++;
} }
// [ASYNC] For returning the processed number. Each time it is called, it returns a counter that is incremented. // [ASYNC] For returning the processed number. Each time it is called, it returns a counter that is incremented.
uint64_t get_done_count() { uint64_t get_done_count() {
std::unique_lock<std::mutex> lk(loop_mutex); std::unique_lock<std::mutex> lk(loop_mutex);
return ++done_count; return ++done_count;
} }
// Mutex when worker thread accesses I/O // Mutex when worker thread accesses I/O
std::mutex io_mutex; std::mutex io_mutex;
protected: protected:
// Random number generator body // Random number generator body
AsyncPRNG prng; AsyncPRNG prng;
private: private:
// number of times worker processes (calls Search::think()) // number of times worker processes (calls Search::think())
std::atomic<uint64_t> loop_max; std::atomic<uint64_t> loop_max;
// number of times the worker has processed (calls Search::think()) // number of times the worker has processed (calls Search::think())
std::atomic<uint64_t> loop_count; std::atomic<uint64_t> loop_count;
// To return the number of times it has been processed. // To return the number of times it has been processed.
std::atomic<uint64_t> done_count; std::atomic<uint64_t> done_count;
// Mutex when changing the variables in ↑ // Mutex when changing the variables in ↑
std::mutex loop_mutex; std::mutex loop_mutex;
// Thread end flag. // Thread end flag.
std::atomic<uint64_t> threads_finished; std::atomic<uint64_t> threads_finished;
}; };
// Mechanism to process task during idle time. // Mechanism to process task during idle time.
@@ -105,48 +105,48 @@ private:
// Convenient to use when you want to write MultiThink thread worker in master-slave method. // Convenient to use when you want to write MultiThink thread worker in master-slave method.
struct TaskDispatcher struct TaskDispatcher
{ {
typedef std::function<void(size_t /* thread_id */)> Task; typedef std::function<void(size_t /* thread_id */)> Task;
// slave calls this function during idle. // slave calls this function during idle.
void on_idle(size_t thread_id) void on_idle(size_t thread_id)
{ {
Task task; Task task;
while ((task = get_task_async()) != nullptr) while ((task = get_task_async()) != nullptr)
task(thread_id); task(thread_id);
sleep(1); sleep(1);
} }
// Stack [ASYNC] task. // Stack [ASYNC] task.
void push_task_async(Task task) void push_task_async(Task task)
{ {
std::unique_lock<std::mutex> lk(task_mutex); std::unique_lock<std::mutex> lk(task_mutex);
tasks.push_back(task); tasks.push_back(task);
} }
// Allocate size array elements for task in advance. // Allocate size array elements for task in advance.
void task_reserve(size_t size) void task_reserve(size_t size)
{ {
tasks.reserve(size); tasks.reserve(size);
} }
protected: protected:
// set of tasks // set of tasks
std::vector<Task> tasks; std::vector<Task> tasks;
// Take out one [ASYNC] task. Called from on_idle(). // Take out one [ASYNC] task. Called from on_idle().
Task get_task_async() Task get_task_async()
{ {
std::unique_lock<std::mutex> lk(task_mutex); std::unique_lock<std::mutex> lk(task_mutex);
if (tasks.size() == 0) if (tasks.size() == 0)
return nullptr; return nullptr;
Task task = *tasks.rbegin(); Task task = *tasks.rbegin();
tasks.pop_back(); tasks.pop_back();
return task; return task;
} }
// a mutex for accessing tasks // a mutex for accessing tasks
std::mutex task_mutex; std::mutex task_mutex;
}; };
#endif #endif
+343 -347
View File
@@ -13,378 +13,374 @@ using namespace std;
namespace Learner { namespace Learner {
// Class that handles bitstream // Class that handles bitstream
// useful when doing aspect encoding // useful when doing aspect encoding
struct BitStream struct BitStream
{
// Set the memory to store the data in advance.
// Assume that memory is cleared to 0.
void set_data(std::uint8_t* data_) { data = data_; reset(); }
// Get the pointer passed in set_data().
uint8_t* get_data() const { return data; }
// Get the cursor.
int get_cursor() const { return bit_cursor; }
// reset the cursor
void reset() { bit_cursor = 0; }
// Write 1bit to the stream.
// If b is non-zero, write out 1. If 0, write 0.
void write_one_bit(int b)
{ {
if (b) // Set the memory to store the data in advance.
data[bit_cursor / 8] |= 1 << (bit_cursor & 7); // Assume that memory is cleared to 0.
void set_data(std::uint8_t* data_) { data = data_; reset(); }
++bit_cursor; // Get the pointer passed in set_data().
} uint8_t* get_data() const { return data; }
// Get 1 bit from the stream. // Get the cursor.
int read_one_bit() int get_cursor() const { return bit_cursor; }
// reset the cursor
void reset() { bit_cursor = 0; }
// Write 1bit to the stream.
// If b is non-zero, write out 1. If 0, write 0.
void write_one_bit(int b)
{
if (b)
data[bit_cursor / 8] |= 1 << (bit_cursor & 7);
++bit_cursor;
}
// Get 1 bit from the stream.
int read_one_bit()
{
int b = (data[bit_cursor / 8] >> (bit_cursor & 7)) & 1;
++bit_cursor;
return b;
}
// write n bits of data
// Data shall be written out from the lower order of d.
void write_n_bit(int d, int n)
{
for (int i = 0; i <n; ++i)
write_one_bit(d & (1 << i));
}
// read n bits of data
// Reverse conversion of write_n_bit().
int read_n_bit(int n)
{
int result = 0;
for (int i = 0; i < n; ++i)
result |= read_one_bit() ? (1 << i) : 0;
return result;
}
private:
// Next bit position to read/write.
int bit_cursor;
// data entity
std::uint8_t* data;
};
// Class for compressing/decompressing sfen
// sfen can be packed to 256bit (32bytes) by Huffman coding.
// This is proven by mini. The above is Huffman coding.
//
// Internal format = 1-bit turn + 7-bit king position *2 + piece on board (Huffman coding) + hand piece (Huffman coding)
// Side to move (White = 0, Black = 1) (1bit)
// White King Position (6 bits)
// Black King Position (6 bits)
// Huffman Encoding of the board
// Castling availability (1 bit x 4)
// En passant square (1 or 1 + 6 bits)
// Rule 50 (6 bits)
// Game play (8 bits)
//
// TODO(someone): Rename SFEN to FEN.
//
struct SfenPacker
{ {
int b = (data[bit_cursor / 8] >> (bit_cursor & 7)) & 1; void pack(const Position& pos);
++bit_cursor;
return b; // sfen packed by pack() (256bit = 32bytes)
} // Or sfen to decode with unpack()
uint8_t *data; // uint8_t[32];
// write n bits of data BitStream stream;
// Data shall be written out from the lower order of d.
void write_n_bit(int d, int n) // Output the board pieces to stream.
void write_board_piece_to_stream(Piece pc);
// Read one board piece from stream
Piece read_board_piece_from_stream();
};
// Huffman coding
// * is simplified from mini encoding to make conversion easier.
//
// Huffman Encoding
//
// Empty xxxxxxx0
// Pawn xxxxx001 + 1 bit (Color)
// Knight xxxxx011 + 1 bit (Color)
// Bishop xxxxx101 + 1 bit (Color)
// Rook xxxxx111 + 1 bit (Color)
// Queen xxxx1001 + 1 bit (Color)
//
// Worst case:
// - 32 empty squares 32 bits
// - 30 pieces 150 bits
// - 2 kings 12 bits
// - castling rights 4 bits
// - ep square 7 bits
// - rule50 7 bits
// - game ply 16 bits
// - TOTAL 228 bits < 256 bits
struct HuffmanedPiece
{ {
for (int i = 0; i <n; ++i) int code; // how it will be coded
write_one_bit(d & (1 << i)); int bits; // How many bits do you have
} };
// read n bits of data constexpr HuffmanedPiece huffman_table[] =
// Reverse conversion of write_n_bit().
int read_n_bit(int n)
{ {
int result = 0; {0b0000,1}, // NO_PIECE
for (int i = 0; i < n; ++i) {0b0001,4}, // PAWN
result |= read_one_bit() ? (1 << i) : 0; {0b0011,4}, // KNIGHT
{0b0101,4}, // BISHOP
{0b0111,4}, // ROOK
{0b1001,4}, // QUEEN
};
return result; // Pack sfen and store in data[32].
void SfenPacker::pack(const Position& pos)
{
memset(data, 0, 32 /* 256bit */);
stream.set_data(data);
// turn
// Side to move.
stream.write_one_bit((int)(pos.side_to_move()));
// 7-bit positions for leading and trailing balls
// White king and black king, 6 bits for each.
for(auto c: Colors)
stream.write_n_bit(pos.king_square(c), 6);
// Write the pieces on the board other than the kings.
for (Rank r = RANK_8; r >= RANK_1; --r)
{
for (File f = FILE_A; f <= FILE_H; ++f)
{
Piece pc = pos.piece_on(make_square(f, r));
if (type_of(pc) == KING)
continue;
write_board_piece_to_stream(pc);
}
}
// TODO(someone): Support chess960.
stream.write_one_bit(pos.can_castle(WHITE_OO));
stream.write_one_bit(pos.can_castle(WHITE_OOO));
stream.write_one_bit(pos.can_castle(BLACK_OO));
stream.write_one_bit(pos.can_castle(BLACK_OOO));
if (pos.ep_square() == SQ_NONE) {
stream.write_one_bit(0);
}
else {
stream.write_one_bit(1);
stream.write_n_bit(static_cast<int>(pos.ep_square()), 6);
}
stream.write_n_bit(pos.state()->rule50, 6);
const int fm = 1 + (pos.game_ply()-(pos.side_to_move() == BLACK)) / 2;
stream.write_n_bit(fm, 8);
// Write high bits of half move. This is a fix for the
// limited range of half move counter.
// This is backwards compatibile.
stream.write_n_bit(fm >> 8, 8);
// Write the highest bit of rule50 at the end. This is a backwards
// compatibile fix for rule50 having only 6 bits stored.
// This bit is just ignored by the old parsers.
stream.write_n_bit(pos.state()->rule50 >> 6, 1);
assert(stream.get_cursor() <= 256);
} }
private:
// Next bit position to read/write.
int bit_cursor;
// data entity
std::uint8_t* data;
};
// Class for compressing/decompressing sfen
// sfen can be packed to 256bit (32bytes) by Huffman coding.
// This is proven by mini. The above is Huffman coding.
//
// Internal format = 1-bit turn + 7-bit king position *2 + piece on board (Huffman coding) + hand piece (Huffman coding)
// Side to move (White = 0, Black = 1) (1bit)
// White King Position (6 bits)
// Black King Position (6 bits)
// Huffman Encoding of the board
// Castling availability (1 bit x 4)
// En passant square (1 or 1 + 6 bits)
// Rule 50 (6 bits)
// Game play (8 bits)
//
// TODO(someone): Rename SFEN to FEN.
//
struct SfenPacker
{
void pack(const Position& pos);
// sfen packed by pack() (256bit = 32bytes)
// Or sfen to decode with unpack()
uint8_t *data; // uint8_t[32];
BitStream stream;
// Output the board pieces to stream. // Output the board pieces to stream.
void write_board_piece_to_stream(Piece pc); void SfenPacker::write_board_piece_to_stream(Piece pc)
{
// piece type
PieceType pr = type_of(pc);
auto c = huffman_table[pr];
stream.write_n_bit(c.code, c.bits);
if (pc == NO_PIECE)
return;
// first and second flag
stream.write_one_bit(color_of(pc));
}
// Read one board piece from stream // Read one board piece from stream
Piece read_board_piece_from_stream(); Piece SfenPacker::read_board_piece_from_stream()
};
// Huffman coding
// * is simplified from mini encoding to make conversion easier.
//
// Huffman Encoding
//
// Empty xxxxxxx0
// Pawn xxxxx001 + 1 bit (Color)
// Knight xxxxx011 + 1 bit (Color)
// Bishop xxxxx101 + 1 bit (Color)
// Rook xxxxx111 + 1 bit (Color)
// Queen xxxx1001 + 1 bit (Color)
//
// Worst case:
// - 32 empty squares 32 bits
// - 30 pieces 150 bits
// - 2 kings 12 bits
// - castling rights 4 bits
// - ep square 7 bits
// - rule50 7 bits
// - game ply 16 bits
// - TOTAL 228 bits < 256 bits
struct HuffmanedPiece
{
int code; // how it will be coded
int bits; // How many bits do you have
};
constexpr HuffmanedPiece huffman_table[] =
{
{0b0000,1}, // NO_PIECE
{0b0001,4}, // PAWN
{0b0011,4}, // KNIGHT
{0b0101,4}, // BISHOP
{0b0111,4}, // ROOK
{0b1001,4}, // QUEEN
};
// Pack sfen and store in data[32].
void SfenPacker::pack(const Position& pos)
{
// cout << pos;
memset(data, 0, 32 /* 256bit */);
stream.set_data(data);
// turn
// Side to move.
stream.write_one_bit((int)(pos.side_to_move()));
// 7-bit positions for leading and trailing balls
// White king and black king, 6 bits for each.
for(auto c: Colors)
stream.write_n_bit(pos.king_square(c), 6);
// Write the pieces on the board other than the kings.
for (Rank r = RANK_8; r >= RANK_1; --r)
{ {
for (File f = FILE_A; f <= FILE_H; ++f) PieceType pr = NO_PIECE_TYPE;
{ int code = 0, bits = 0;
Piece pc = pos.piece_on(make_square(f, r)); while (true)
if (type_of(pc) == KING)
continue;
write_board_piece_to_stream(pc);
}
}
// TODO(someone): Support chess960.
stream.write_one_bit(pos.can_castle(WHITE_OO));
stream.write_one_bit(pos.can_castle(WHITE_OOO));
stream.write_one_bit(pos.can_castle(BLACK_OO));
stream.write_one_bit(pos.can_castle(BLACK_OOO));
if (pos.ep_square() == SQ_NONE) {
stream.write_one_bit(0);
}
else {
stream.write_one_bit(1);
stream.write_n_bit(static_cast<int>(pos.ep_square()), 6);
}
stream.write_n_bit(pos.state()->rule50, 6);
const int fm = 1 + (pos.game_ply()-(pos.side_to_move() == BLACK)) / 2;
stream.write_n_bit(fm, 8);
// Write high bits of half move. This is a fix for the
// limited range of half move counter.
// This is backwards compatibile.
stream.write_n_bit(fm >> 8, 8);
// Write the highest bit of rule50 at the end. This is a backwards
// compatibile fix for rule50 having only 6 bits stored.
// This bit is just ignored by the old parsers.
stream.write_n_bit(pos.state()->rule50 >> 6, 1);
assert(stream.get_cursor() <= 256);
}
// Output the board pieces to stream.
void SfenPacker::write_board_piece_to_stream(Piece pc)
{
// piece type
PieceType pr = type_of(pc);
auto c = huffman_table[pr];
stream.write_n_bit(c.code, c.bits);
if (pc == NO_PIECE)
return;
// first and second flag
stream.write_one_bit(color_of(pc));
}
// Read one board piece from stream
Piece SfenPacker::read_board_piece_from_stream()
{
PieceType pr = NO_PIECE_TYPE;
int code = 0, bits = 0;
while (true)
{
code |= stream.read_one_bit() << bits;
++bits;
assert(bits <= 6);
for (pr = NO_PIECE_TYPE; pr <KING; ++pr)
if (huffman_table[pr].code == code
&& huffman_table[pr].bits == bits)
goto Found;
}
Found:;
if (pr == NO_PIECE_TYPE)
return NO_PIECE;
// first and second flag
Color c = (Color)stream.read_one_bit();
return make_piece(c, pr);
}
int set_from_packed_sfen(Position& pos, const PackedSfen& sfen, StateInfo* si, Thread* th)
{
SfenPacker packer;
auto& stream = packer.stream;
// TODO: separate streams for writing and reading. Here we actually have to
// const_cast which is not safe in the long run.
stream.set_data(const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(&sfen)));
pos.clear();
std::memset(si, 0, sizeof(StateInfo));
std::fill_n(&pos.pieceList[0][0], sizeof(pos.pieceList) / sizeof(Square), SQ_NONE);
pos.st = si;
// Active color
pos.sideToMove = (Color)stream.read_one_bit();
pos.pieceList[W_KING][0] = SQUARE_NB;
pos.pieceList[B_KING][0] = SQUARE_NB;
// First the position of the ball
for (auto c : Colors)
pos.board[stream.read_n_bit(6)] = make_piece(c, KING);
// Piece placement
for (Rank r = RANK_8; r >= RANK_1; --r)
{
for (File f = FILE_A; f <= FILE_H; ++f)
{
auto sq = make_square(f, r);
// it seems there are already balls
Piece pc;
if (type_of(pos.board[sq]) != KING)
{ {
assert(pos.board[sq] == NO_PIECE); code |= stream.read_one_bit() << bits;
pc = packer.read_board_piece_from_stream(); ++bits;
assert(bits <= 6);
for (pr = NO_PIECE_TYPE; pr <KING; ++pr)
if (huffman_table[pr].code == code
&& huffman_table[pr].bits == bits)
goto Found;
} }
else Found:;
if (pr == NO_PIECE_TYPE)
return NO_PIECE;
// first and second flag
Color c = (Color)stream.read_one_bit();
return make_piece(c, pr);
}
int set_from_packed_sfen(Position& pos, const PackedSfen& sfen, StateInfo* si, Thread* th)
{
SfenPacker packer;
auto& stream = packer.stream;
// TODO: separate streams for writing and reading. Here we actually have to
// const_cast which is not safe in the long run.
stream.set_data(const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(&sfen)));
pos.clear();
std::memset(si, 0, sizeof(StateInfo));
std::fill_n(&pos.pieceList[0][0], sizeof(pos.pieceList) / sizeof(Square), SQ_NONE);
pos.st = si;
// Active color
pos.sideToMove = (Color)stream.read_one_bit();
pos.pieceList[W_KING][0] = SQUARE_NB;
pos.pieceList[B_KING][0] = SQUARE_NB;
// First the position of the ball
for (auto c : Colors)
pos.board[stream.read_n_bit(6)] = make_piece(c, KING);
// Piece placement
for (Rank r = RANK_8; r >= RANK_1; --r)
{ {
pc = pos.board[sq]; for (File f = FILE_A; f <= FILE_H; ++f)
// put_piece() will catch ASSERT unless you remove it all. {
pos.board[sq] = NO_PIECE; auto sq = make_square(f, r);
// it seems there are already balls
Piece pc;
if (type_of(pos.board[sq]) != KING)
{
assert(pos.board[sq] == NO_PIECE);
pc = packer.read_board_piece_from_stream();
}
else
{
pc = pos.board[sq];
// put_piece() will catch ASSERT unless you remove it all.
pos.board[sq] = NO_PIECE;
}
// There may be no pieces, so skip in that case.
if (pc == NO_PIECE)
continue;
pos.put_piece(Piece(pc), sq);
if (stream.get_cursor()> 256)
return 1;
}
} }
// There may be no pieces, so skip in that case. // Castling availability.
if (pc == NO_PIECE) // TODO(someone): Support chess960.
continue; pos.st->castlingRights = 0;
if (stream.read_one_bit()) {
Square rsq;
for (rsq = relative_square(WHITE, SQ_H1); pos.piece_on(rsq) != W_ROOK; --rsq) {}
pos.set_castling_right(WHITE, rsq);
}
if (stream.read_one_bit()) {
Square rsq;
for (rsq = relative_square(WHITE, SQ_A1); pos.piece_on(rsq) != W_ROOK; ++rsq) {}
pos.set_castling_right(WHITE, rsq);
}
if (stream.read_one_bit()) {
Square rsq;
for (rsq = relative_square(BLACK, SQ_H1); pos.piece_on(rsq) != B_ROOK; --rsq) {}
pos.set_castling_right(BLACK, rsq);
}
if (stream.read_one_bit()) {
Square rsq;
for (rsq = relative_square(BLACK, SQ_A1); pos.piece_on(rsq) != B_ROOK; ++rsq) {}
pos.set_castling_right(BLACK, rsq);
}
pos.put_piece(Piece(pc), sq); // En passant square. Ignore if no pawn capture is possible
if (stream.read_one_bit()) {
Square ep_square = static_cast<Square>(stream.read_n_bit(6));
pos.st->epSquare = ep_square;
if (stream.get_cursor()> 256) if (!(pos.attackers_to(pos.st->epSquare) & pos.pieces(pos.sideToMove, PAWN))
return 1; || !(pos.pieces(~pos.sideToMove, PAWN) & (pos.st->epSquare + pawn_push(~pos.sideToMove))))
pos.st->epSquare = SQ_NONE;
}
else {
pos.st->epSquare = SQ_NONE;
}
//assert(stream.get_cursor() <= 256); // Halfmove clock
} pos.st->rule50 = stream.read_n_bit(6);
// Fullmove number
pos.gamePly = stream.read_n_bit(8);
// Read the highest bit of rule50. This was added as a fix for rule50
// counter having only 6 bits stored.
// In older entries this will just be a zero bit.
pos.gamePly |= stream.read_n_bit(8) << 8;
// Read the highest bit of rule50. This was added as a fix for rule50
// counter having only 6 bits stored.
// In older entries this will just be a zero bit.
pos.st->rule50 |= stream.read_n_bit(1) << 6;
// Convert from fullmove starting from 1 to gamePly starting from 0,
// handle also common incorrect FEN with fullmove = 0.
pos.gamePly = std::max(2 * (pos.gamePly - 1), 0) + (pos.sideToMove == BLACK);
assert(stream.get_cursor() <= 256);
pos.chess960 = false;
pos.thisThread = th;
pos.set_state(pos.st);
assert(pos.pos_is_ok());
return 0;
} }
// Castling availability. PackedSfen sfen_pack(Position& pos)
// TODO(someone): Support chess960. {
pos.st->castlingRights = 0; PackedSfen sfen;
if (stream.read_one_bit()) {
Square rsq; SfenPacker sp;
for (rsq = relative_square(WHITE, SQ_H1); pos.piece_on(rsq) != W_ROOK; --rsq) {} sp.data = (uint8_t*)&sfen;
pos.set_castling_right(WHITE, rsq); sp.pack(pos);
return sfen;
} }
if (stream.read_one_bit()) {
Square rsq;
for (rsq = relative_square(WHITE, SQ_A1); pos.piece_on(rsq) != W_ROOK; ++rsq) {}
pos.set_castling_right(WHITE, rsq);
}
if (stream.read_one_bit()) {
Square rsq;
for (rsq = relative_square(BLACK, SQ_H1); pos.piece_on(rsq) != B_ROOK; --rsq) {}
pos.set_castling_right(BLACK, rsq);
}
if (stream.read_one_bit()) {
Square rsq;
for (rsq = relative_square(BLACK, SQ_A1); pos.piece_on(rsq) != B_ROOK; ++rsq) {}
pos.set_castling_right(BLACK, rsq);
}
// En passant square. Ignore if no pawn capture is possible
if (stream.read_one_bit()) {
Square ep_square = static_cast<Square>(stream.read_n_bit(6));
pos.st->epSquare = ep_square;
if (!(pos.attackers_to(pos.st->epSquare) & pos.pieces(pos.sideToMove, PAWN))
|| !(pos.pieces(~pos.sideToMove, PAWN) & (pos.st->epSquare + pawn_push(~pos.sideToMove))))
pos.st->epSquare = SQ_NONE;
}
else {
pos.st->epSquare = SQ_NONE;
}
// Halfmove clock
pos.st->rule50 = stream.read_n_bit(6);
// Fullmove number
pos.gamePly = stream.read_n_bit(8);
// Read the highest bit of rule50. This was added as a fix for rule50
// counter having only 6 bits stored.
// In older entries this will just be a zero bit.
pos.gamePly |= stream.read_n_bit(8) << 8;
// Read the highest bit of rule50. This was added as a fix for rule50
// counter having only 6 bits stored.
// In older entries this will just be a zero bit.
pos.st->rule50 |= stream.read_n_bit(1) << 6;
// Convert from fullmove starting from 1 to gamePly starting from 0,
// handle also common incorrect FEN with fullmove = 0.
pos.gamePly = std::max(2 * (pos.gamePly - 1), 0) + (pos.sideToMove == BLACK);
assert(stream.get_cursor() <= 256);
pos.chess960 = false;
pos.thisThread = th;
pos.set_state(pos.st);
assert(pos.pos_is_ok());
return 0;
}
PackedSfen sfen_pack(Position& pos)
{
PackedSfen sfen;
SfenPacker sp;
sp.data = (uint8_t*)&sfen;
sp.pack(pos);
return sfen;
}
} }