Add dedicated command for training data validation.

This commit is contained in:
Tomasz Sobczyk
2021-05-24 19:43:07 +02:00
parent 5676a50807
commit eac1d430b4
6 changed files with 299 additions and 0 deletions
+150
View File
@@ -7831,4 +7831,154 @@ namespace binpack
std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n";
}
inline void validatePlain(std::string inputPath)
{
constexpr std::size_t reportSize = 1000000;
std::cout << "Validating " << inputPath << '\n';
TrainingDataEntry e;
std::string key;
std::string value;
std::string move;
std::ifstream inputFile(inputPath);
const auto base = inputFile.tellg();
std::size_t numProcessedPositions = 0;
std::size_t numProcessedPositionsBatch = 0;
for(;;)
{
inputFile >> key;
if (!inputFile)
{
break;
}
if (key == "e"sv)
{
e.move = chess::uci::uciToMove(e.pos, move);
if (!e.isValid())
{
std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n';
return;
}
++numProcessedPositions;
++numProcessedPositionsBatch;
if (numProcessedPositionsBatch >= reportSize)
{
numProcessedPositionsBatch -= reportSize;
const auto cur = inputFile.tellg();
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
}
continue;
}
inputFile >> std::ws;
std::getline(inputFile, value, '\n');
if (key == "fen"sv) e.pos = chess::Position::fromFen(value.c_str());
if (key == "move"sv) move = value;
if (key == "score"sv) e.score = std::stoi(value);
if (key == "ply"sv) e.ply = std::stoi(value);
if (key == "result"sv) e.result = std::stoi(value);
}
if (numProcessedPositionsBatch)
{
const auto cur = inputFile.tellg();
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
}
std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n";
}
inline void validateBin(std::string inputPath)
{
constexpr std::size_t reportSize = 1000000;
std::cout << "Validating " << inputPath << '\n';
std::ifstream inputFile(inputPath, std::ios_base::binary);
const auto base = inputFile.tellg();
std::size_t numProcessedPositions = 0;
std::size_t numProcessedPositionsBatch = 0;
nodchip::PackedSfenValue psv;
for(;;)
{
inputFile.read(reinterpret_cast<char*>(&psv), sizeof(psv));
if (inputFile.gcount() != 40)
{
break;
}
auto e = packedSfenValueToTrainingDataEntry(psv);
if (!e.isValid())
{
std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n';
return;
}
++numProcessedPositions;
++numProcessedPositionsBatch;
if (numProcessedPositionsBatch >= reportSize)
{
numProcessedPositionsBatch -= reportSize;
const auto cur = inputFile.tellg();
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
}
}
if (numProcessedPositionsBatch)
{
const auto cur = inputFile.tellg();
std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n";
}
std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n";
}
inline void validateBinpack(std::string inputPath)
{
constexpr std::size_t reportSize = 1000000;
std::cout << "Validating " << inputPath << '\n';
CompressedTrainingDataEntryReader reader(inputPath);
std::size_t numProcessedPositions = 0;
std::size_t numProcessedPositionsBatch = 0;
while(reader.hasNext())
{
auto e = reader.next();
if (!e.isValid())
{
std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n';
return;
}
++numProcessedPositions;
++numProcessedPositionsBatch;
if (numProcessedPositionsBatch >= reportSize)
{
numProcessedPositionsBatch -= reportSize;
std::cout << "Processed " << numProcessedPositions << " positions.\n";
}
}
if (numProcessedPositionsBatch)
{
std::cout << "Processed " << numProcessedPositions << " positions.\n";
}
std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n";
}
}