Changed to use TB in the training data generator. #67

This commit is contained in:
nodchip
2020-08-10 12:17:26 +09:00
parent e65c515d6b
commit bac96aa04a
+28 -22
View File
@@ -23,6 +23,7 @@
#include "learn.h" #include "learn.h"
#include "multi_think.h" #include "multi_think.h"
#include "../uci.h" #include "../uci.h"
#include "../syzygy/tbprobe.h"
// evaluate header for learning // evaluate header for learning
#include "../eval/evaluate_common.h" #include "../eval/evaluate_common.h"
@@ -522,13 +523,18 @@ void MultiThinkGenSfen::thread_worker(size_t thread_id)
break; break;
} }
// Isn't all pieces stuck and stuck? // Initialize the Syzygy Ending Tablebase and sort the moves.
if (MoveList<LEGAL>(pos).size() == 0) Search::RootMoves rootMoves;
{ for (const auto& m : MoveList<LEGAL>(pos))
// (write up to the previous phase of this phase) rootMoves.emplace_back(m);
// Write the positions other than this position if checkmated. if (!rootMoves.empty())
if (pos.checkers()) // Mate Tablebases::rank_root_moves(pos, rootMoves);
flush_psv(-1);
// If there is no legal move, terminate the game if position
// is mate or a stalemate.
else {
if (pos.checkers()) // Mate
flush_psv(-1);
else if (use_draw_in_training_data_generation) { else if (use_draw_in_training_data_generation) {
flush_psv(0); // Stalemate flush_psv(0); // Stalemate
} }
@@ -636,10 +642,10 @@ void MultiThinkGenSfen::thread_worker(size_t thread_id)
// cout << pos; // cout << pos;
auto v = Eval::evaluate(pos); auto v = Eval::evaluate(pos);
// evaluate() returns the evaluation value on the turn side, so // evaluate() returns the evaluation value on the turn side, so
// If it's a turn different from root_color, you must invert v and return it. // If it's a turn different from root_color, you must invert v and return it.
if (rootColor != pos.side_to_move()) if (rootColor != pos.side_to_move())
v = -v; v = -v;
// Rewind. // Rewind.
// Is it C++x14, and isn't there even foreach to turn in reverse? // Is it C++x14, and isn't there even foreach to turn in reverse?
@@ -2472,7 +2478,7 @@ void convert_bin(const vector<string>& filenames, const string& output_file_name
std::string value; std::string value;
ss >> token; ss >> token;
if (token == "fen") { if (token == "fen") {
states = StateListPtr(new std::deque<StateInfo>(1)); // Drop old and create a new one states = StateListPtr(new std::deque<StateInfo>(1)); // Drop old and create a new one
std::string input_fen = line.substr(4); std::string input_fen = line.substr(4);
tpos.set(input_fen, false, &states->back(), Threads.main()); tpos.set(input_fen, false, &states->back(), Threads.main());
if (!tpos.pos_is_ok() || tpos.fen() != input_fen) { if (!tpos.pos_is_ok() || tpos.fen() != input_fen) {
@@ -2480,8 +2486,8 @@ void convert_bin(const vector<string>& filenames, const string& output_file_name
filtered_size_fen++; filtered_size_fen++;
} }
else { else {
tpos.sfen_pack(p.sfen); tpos.sfen_pack(p.sfen);
} }
} }
else if (token == "move") { else if (token == "move") {
ss >> value; ss >> value;
@@ -2508,7 +2514,7 @@ void convert_bin(const vector<string>& filenames, const string& output_file_name
} }
p.gamePly = uint16_t(temp); // No cast here? p.gamePly = uint16_t(temp); // No cast here?
if (interpolate_eval != 0){ if (interpolate_eval != 0){
p.score = min(3000, interpolate_eval * temp); p.score = min(3000, interpolate_eval * temp);
} }
} }
else if (token == "result") { else if (token == "result") {
@@ -2516,17 +2522,17 @@ void convert_bin(const vector<string>& filenames, const string& output_file_name
ss >> temp; ss >> temp;
p.game_result = int8_t(temp); // Do you need a cast here? p.game_result = int8_t(temp); // Do you need a cast here?
if (interpolate_eval){ if (interpolate_eval){
p.score = p.score * p.game_result; p.score = p.score * p.game_result;
} }
} }
else if (token == "e") { else if (token == "e") {
if(!(ignore_flag_fen || ignore_flag_move || ignore_flag_ply)){ if(!(ignore_flag_fen || ignore_flag_move || ignore_flag_ply)){
fs.write((char*)&p, sizeof(PackedSfenValue)); fs.write((char*)&p, sizeof(PackedSfenValue));
data_size+=1; data_size+=1;
// debug // debug
// std::cout<<tpos<<std::endl; // std::cout<<tpos<<std::endl;
// std::cout<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl; // std::cout<<p.score<<","<<int(p.gamePly)<<","<<int(p.game_result)<<std::endl;
} }
else { else {
filtered_size++; filtered_size++;
} }