diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..377796f7 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,75 @@ +language: cpp +dist: focal + +matrix: + include: + - os: linux + compiler: gcc + addons: + apt: + packages: ['g++-multilib', 'valgrind', 'expect', 'curl'] + env: + - COMPILER=g++ + - COMP=gcc + +branches: + only: + - master + +before_script: + - cd src + +script: + # Download net + - make net + + # Obtain bench reference from git log + - git log HEAD | grep "\b[Bb]ench[ :]\+[0-9]\{7\}" | head -n 1 | sed "s/[^0-9]*\([0-9]*\).*/\1/g" > git_sig + - export benchref=$(cat git_sig) + - echo "Reference bench:" $benchref + + # Compiler version string + - $COMPILER -v + + # test help target + - make help + + # Verify bench number against various builds + - export CXXFLAGS="-Werror -D_GLIBCXX_DEBUG" + - make clean && make -j2 ARCH=x86-64-modern optimize=no debug=yes build && ../tests/signature.sh $benchref + - export CXXFLAGS="-Werror" + - make clean && make -j2 ARCH=x86-64-modern build && ../tests/signature.sh $benchref + - make clean && make -j2 ARCH=x86-64-ssse3 build && ../tests/signature.sh $benchref + - make clean && make -j2 ARCH=x86-64-sse3-popcnt build && ../tests/signature.sh $benchref + - make clean && make -j2 ARCH=x86-64 build && ../tests/signature.sh $benchref + # TODO avoid _mm_malloc + # - if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then make clean && make -j2 ARCH=general-64 build && ../tests/signature.sh $benchref; fi + - make clean && make -j2 ARCH=x86-64-modern profile-build && ../tests/signature.sh $benchref + + # compile only for some more advanced architectures (might not run in travis) + - make clean && make -j2 ARCH=x86-64-avx2 blas=yes build + + - make clean && make -j2 ARCH=x86-64-avx2 build + - make clean && make -j2 ARCH=x86-64-bmi2 build + - make clean && make -j2 ARCH=x86-64-avx512 build + - make clean && make -j2 ARCH=x86-64-vnni512 build + - make clean && make -j2 ARCH=x86-64-vnni256 build + + # + # Check perft and reproducible search + - make clean && make -j2 ARCH=x86-64-modern build + - ../tests/perft.sh + - ../tests/reprosearch.sh + + # + # Valgrind + # + - export CXXFLAGS="-O1 -fno-inline" + - make clean && make -j2 ARCH=x86-64-modern debug=yes optimize=no build > /dev/null && ../tests/instrumented.sh --valgrind + - ../tests/instrumented.sh --valgrind-thread + + # + # Sanitizer + # + - make clean && make -j2 ARCH=x86-64-modern sanitize=undefined optimize=no debug=yes build > /dev/null && ../tests/instrumented.sh --sanitizer-undefined + - make clean && make -j2 ARCH=x86-64-modern sanitize=thread optimize=no debug=yes build > /dev/null && ../tests/instrumented.sh --sanitizer-thread diff --git a/README.md b/README.md index 79db8170..467dd3c3 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,41 @@ For developers the following non-standard commands might be of interest, mainly * #### flip Flips the side to move. +### Generating Training Data + +To generate training data from the classic eval, use the generate_training_data command with the setting "Use NNUE" set to "false". The given example is generation in its simplest form. There are more commands. + +``` +uci +setoption name PruneAtShallowDepth value false +setoption name Use NNUE value false +setoption name Threads value X +setoption name Hash value Y +setoption name SyzygyPath value path +isready +generate_training_data depth A count B keep_draws 1 eval_limit 32000 +``` + +- `A` is the searched depth per move, or how far the engine looks forward. This value is an integer. +- `B` is the amount of positions generated. This value is also an integer. + +Specify how many threads and how much memory you would like to use with the `x` and `y` values. The option SyzygyPath is not necessary, but if you would like to use it, you must first have Syzygy endgame tablebases on your computer, which you can find [here](http://oics.olympuschess.com/tracker/index.php). You will need to have a torrent client to download these tablebases, as that is probably the fastest way to obtain them. The `path` is the path to the folder containing those tablebases. It does not have to be surrounded in quotes. + +This will create a file named "training_data.binpack" in the same folder as the binary containing the generated training data. Once generation is done, you can rename the file to something like "1billiondepth12.binpack" to remember the depth and quantity of the positions and move it to a folder named "trainingdata" in the same directory as the binaries. + +You will also need validation data that is used for loss calculation and accuracy computation. Validation data is generated in the same way as training data, but generally at most 1 million positions should be used as there's no need for more and it would just slow the learning process down. It may also be better to slightly increase the depth for validation data. After generation you can rename the validation data file to "val.binpack" and drop it in a folder named "validationdata" in the same directory to make it easier. + +## Training data formats. + +Currently there are 3 training data formats. Two of them are supported directly. + +- `.bin` - the original training data format. Uses 40 bytes per entry. Is supported directly by the `generate_training_data` command. +- `.plain` - a human readable training data format. This one is not supported directly by the `generate_training_data` command. It should not be used for data exchange because it's less compact than other formats. It is mostly useful for inspection of the data. +- `.binpack` - a compact binary training data format that exploits positions chains to further reduce size. It uses on average between 2 to 3 bytes per entry when generating data with `generate_training_data`. It is supported directly by `generate_training_data` command. It is currently the default for the `generate_training_data` command. A more in depth description can be found [here](docs/binpack.md) + +### Conversion between formats. + +There is a builting converted that support all 3 formats described above. Any of them can be converted to any other. For more information and usage guide see [here](docs/convert.md). ## A note on classical evaluation versus NNUE evaluation diff --git a/docs/binpack.md b/docs/binpack.md new file mode 100644 index 00000000..1940a5dc --- /dev/null +++ b/docs/binpack.md @@ -0,0 +1,42 @@ +# Binpack + +Binpack is a binary training data storage format designed to take advantage of position chains differing by a single move. Therefore it is very good at compactly storing data generated from real games (as opposed to random positions for example sourced from an opening book). + +It is currently implemented through a single header library in `extra/nnue_data_binpack_format.h`. + +Below follows a rough description of the format in a BNF-like notation. + +``` +[[nodiscard]] std::uint16_t signedToUnsigned(std::int16_t a) { + std::uint16_t r; + std::memcpy(&r, &a, sizeof(std::uint16_t)); + if (r & 0x8000) r ^= 0x7FFF; // flip value bits if negative + r = (r << 1) | (r >> 15); // store sign bit at bit 0 + return r; +} + +file := * +block := BINP* +chain := +stem := (32 bytes) +pos := https://github.com/Sopel97/nnue_data_compress/blob/master/src/chess/Position.h#L1166 (24 bytes) +move := https://github.com/Sopel97/nnue_data_compress/blob/master/src/chess/Chess.h#L1044 (2 bytes) +score := signedToUnsigned(score) (2 bytes, big endian) +ply_and_result := ply bitwise_or (signedToUnsigned(result) << 14) (2 bytes, big endian) +rule50 := rule_50_counter (2 bytes, big endian) + // this is a small defect from old version, + I didn't want to break backwards compatibility. Effectively means that there's + one byte left for something else in the future because rule50 always fits in one byte. + +movetext := * +count := number of plies in the movetext (2 bytes, big endian). Can be 0. +move_and_score := (~2 bytes) +encoded_move := oof this one is complicated to explain. + https://github.com/Sopel97/nnue_data_compress/blob/master/src/compress_file.cpp#L827. + https://github.com/Sopel97/chess_pos_db/blob/master/docs/bcgn/variable_length.md + +encoded_score := https://en.wikipedia.org/wiki/Variable-width_encoding + with block size of 4 bits + 1 bit for extension bit. + Encoded value is signedToUnsigned(-prev_score - current_score) + (scores are always seen from the perspective of side to move in , that's why the '-' before prev_score) +``` \ No newline at end of file diff --git a/docs/convert.md b/docs/convert.md new file mode 100644 index 00000000..132f66e0 --- /dev/null +++ b/docs/convert.md @@ -0,0 +1,18 @@ +# Convert + +`convert` allows conversion of training data between any of `.plain`, `.bin`, and `.binpack`. + +As all commands in stockfish `convert` can be invoked either from command line (as `stockfish.exe convert ...`) or in the interactive prompt. + +The syntax of this command is as follows: +``` +convert from_path to_path [append] [validate] +``` + +`from_path` is the path to the file to convert from. The type of the data is deduced based on its extension (one of `.plain`, `.bin`, `.binpack`). +`to_path` is the path to an output file. The type of the data is deduced from its extension. If the file does not exist it is created. + +`append` and `validate` can come in any order and are optional. +If `append` not specified then the output file will be truncated prior to any writes. If `append` is specified then the converted training data will be appended to the end of the output file. + +If `validate` is specified then the conversion will stop on the first illegal move found and a diagnostic will be shown. \ No newline at end of file diff --git a/docs/generate_training_data.md b/docs/generate_training_data.md new file mode 100644 index 00000000..734f7e81 --- /dev/null +++ b/docs/generate_training_data.md @@ -0,0 +1,63 @@ +# generate_training_data + +`generate_training_data` command allows generation of training data from self-play in a manner that suits training better than traditional games. It introduces random moves to diversify openings, and fixed depth evaluation. + +As all commands in stockfish `generate_training_data` can be invoked either from command line (as `stockfish.exe generate_training_data ...`, but this is not recommended because it's not possible to specify UCI options before `generate_training_data` executes) or in the interactive prompt. + +It is recommended to set the `PruneAtShallowDepth` UCI option to `false` as it will increase the quality of fixed depth searches. + +It is recommended to keep the `EnableTranspositionTable` UCI option at the default `true` value as it will make the generation process faster without noticably harming the uniformity of the data. + +`generate_training_data` takes named parameters in the form of `generate_training_data param_1_name param_1_value param_2_name param_2_value ...`. + +Currently the following options are available: + +`set_recommended_uci_options` - this is a modifier not a parameter, no value follows it. If specified then some UCI options are set to recommended values. + +`depth` - sets minimum and maximum depth of evaluation of each position. Default: 3. + +`mindepth` - minimum depth of evaluation of each position. If not specified then the same as `depth`. + +`maxdepth` - minimum depth of evaluation of each position. If not specified then the same as `depth`. + +`nodes` - the number of nodes to use for evaluation of each position. This number is multiplied by the number of PVs of the current search. This does NOT override the `depth` and `depth2` options. If specified then whichever of depth or nodes limit is reached first applies. + +`count` - the number of training data entries to generate. 1 entry == 1 position. Default: 8000000000 (8B). + +`output_file_name` - the name of the file to output to. If the extension is not present or doesn't match the selected training data format the right extension will be appened. Default: generated_kifu + +`eval_limit` - evaluations with higher absolute value than this will not be written and will terminate a self-play game. Should not exceed 10000 which is VALUE_KNOWN_WIN, but is only hardcapped at mate in 2 (\~30000). Default: 3000 + +`random_move_min_ply` - the minimal ply at which a random move may be executed instead of a move chosen by search. Default: 1. + +`random_move_max_ply` - the maximal ply at which a random move may be executed instead of a move chosen by search. Default: 24. + +`random_move_count` - maximum number of random moves in a single self-play game. Default: 5. + +`random_move_like_apery` - either 0 or 1. If 1 then random king moves will be followed by a random king move from the opponent whenever possible with 50% probability. Default: 0. + +`random_multi_pv` - the number of PVs used for determining the random move. If not specified then a truly random move will be chosen. If specified then a multiPV search will be performed the random move will be one of the moves chosen by the search. + +`random_multi_pv_diff` - Makes the multiPV random move selection consider only moves that are at most `random_multi_pv_diff` worse than the next best move. Default: 30000 (all multiPV moves). + +`random_multi_pv_depth` - the depth to use for multiPV search for random move. Default: `depth2`. + +`write_min_ply` - minimum ply for which the training data entry will be emitted. Default: 16. + +`write_max_ply` - maximum ply for which the training data entry will be emitted. Default: 400. + +`book` - a path to an opening book to use for the starting positions. Currently only .epd format is supported. If not specified then the starting position is always the standard chess starting position. + +`save_every` - the number of training data entries per file. If not specified then there will be always one file. If specified there may be more than one file generated (each having at most `save_every` training data entries) and each file will have a unique number attached. + +`random_file_name` - if specified then the output filename will be chosen randomly. Overrides `output_file_name`. + +`keep_draws` - either 0 or 1. If 1 then training data from drawn games will be emitted too. Default: 1. + +`adjudicate_draws_by_score` - either 0 or 1. If 1 then drawn games will be adjudicated when the score remains 0 for at least 8 plies after ply 80. Default: 1. + +`adjudicate_draws_by_insufficient_mating_material` - either 0 or 1. If 1 then position with insufficient material will be adjudicated as draws. Default: 1. + +`data_format` - format of the training data to use. Either `bin` or `binpack`. Default: `binpack`. + +`seed` - seed for the PRNG. Can be either a number or a string. If it's a string then its hash will be used. If not specified then the current time will be used. diff --git a/docs/generate_training_data_nonpv.md b/docs/generate_training_data_nonpv.md new file mode 100644 index 00000000..5e4e0ee5 --- /dev/null +++ b/docs/generate_training_data_nonpv.md @@ -0,0 +1,41 @@ +# generate_training_data_nonpv + +`generate_training_data_nonpv` command allows generation of training data from self-play in a manner that suits training better than traditional games. It plays fixed nodes self play games for exploration and records [some of] the evaluated positions. Then rescores them with fixed depth search. + +As all commands in stockfish `generate_training_data_nonpv` can be invoked either from command line (as `stockfish.exe generate_training_data_nonpv ...`, but this is not recommended because it's not possible to specify UCI options before `generate_training_data_nonpv` executes) or in the interactive prompt. + +It is recommended to set the `PruneAtShallowDepth` UCI option to `false` as it will increase the quality of fixed depth searches. + +It is recommended to keep the `EnableTranspositionTable` UCI option at the default `true` value as it will make the generation process faster without noticably harming the uniformity of the data. + +`generate_training_data_nonpv` takes named parameters in the form of `generate_training_data_nonpv param_1_name param_1_value param_2_name param_2_value ...`. + +Currently the following options are available: + +`depth` - the search depth to use for rescoring. Default: 3. + +`count` - the number of training data entries to generate. 1 entry == 1 position. Default: 1000000 (1M). + +`exploration_min_nodes` - the min number of nodes to use for exploraton during selfplay. Default: 5000. + +`exploration_max_nodes` - the max number of nodes to use for exploraton during selfplay. The number of nodes is chosen from a uniform distribution between min and max. Default: 15000. + +`exploration_save_rate` - the ratio of positions seen during exploration self play games that are saved for later rescoring. Default: 0.01 (meaning 1 in 100 positions seen during search get saved for rescoring). + +`output_file` - the name of the file to output to. If the extension is not present or doesn't match the selected training data format the right extension will be appened. Default: generated_gensfen_nonpv + +`eval_limit` - evaluations with higher absolute value than this will not be written and will terminate a self-play game. Should not exceed 10000 which is VALUE_KNOWN_WIN, but is only hardcapped at mate in 2 (\~30000). Default: 4000 + +`exploration_eval_limit` - same as `eval_limit` but used during exploration with a value from fixed depth search. + +`exploration_min_pieces` - the min number of pieces in the self play games to start the fixed depth search. Note that even if there's N pieces on the board the fixed nodes search usually reaches positions with less pieces and they are saved too. Default: 8. + +`exploration_max_ply` the max ply for the exploration self play. Default: 200. + +`smart_fen_skipping` - this is a flag option. When specified some position that are not good candidates for teaching are removed from the output. This includes positions where the best move is a capture or promotion, and position where a king is in check. + +`book` - a path to an opening book to use for the starting positions. Currently only .epd format is supported. If not specified then the starting position is always the standard chess starting position. + +`data_format` - format of the training data to use. Either `bin` or `binpack`. Default: `binpack`. + +`seed` - seed for the PRNG. Can be either a number or a string. If it's a string then its hash will be used. If not specified then the current time will be used. diff --git a/docs/stats.md b/docs/stats.md new file mode 100644 index 00000000..8fe50ee5 --- /dev/null +++ b/docs/stats.md @@ -0,0 +1,41 @@ +# Stats + +`gather_statistics` command allows gathering various statistics from a .bin or a .binpack file. The syntax is `gather_statistics (GROUP)* input_file FILENAME`. There can be many groups specified. Any statistic gatherer that belongs to at least one of the specified groups will be used. + +Simplest usage: `stockfish.exe gather_statistics all input_file a.binpack` + +Any name that doesn't designate an argument name or is not an argument will be interpreted as a group name. + +## Parameters + +`input_file` - the path to the .bin or .binpack input file to read + +`output_file` - optional path to the output file to write the results too. Results are always written on the console, so if this is specified the results will be written in both places. + +`max_count` - the maximum number of positions to process. Default: no limit. + +## Groups + +`all` - a special group designating all statistics gatherers available. + +`position_count` - the total number of positions in the file. + +`king`, `king_square_count` - the number of times a king was on each square. Output is layed out as a chessboard, with the 8th rank being the topmost. Separate values for white and black kings. + +`move`, `move_from_count` - same as `king_square_count` but for from_sq(move) + +`move`, `move_to_count` - same as `king_square_count` but for to_sq(move) + +`move`, `move_type` - the number of moves with each type. Includes normal, captures, castling, promotions, enpassant. The groups are not disjoint. + +`move`, `moved_piece_type` - the number of times a piece of each type was moved + +`piece_count` - the histogram of the number of pieces on the board + +`ply_discontinuities` - the number of times the ply jumped by a value different than 1 between two consecutive positions. Usually the number of games. + +`material_imbalance` - the histogram of imbalances, with values computed using "simple eval", that is pawn=1, bishop=knight=3, rook=5, queen=9 + +`results` - the distribution of game results + +`endgames_6man` - distribution of endgame configurations for <=6 pieces (including kings) diff --git a/docs/transform.md b/docs/transform.md new file mode 100644 index 00000000..486fd34e --- /dev/null +++ b/docs/transform.md @@ -0,0 +1,42 @@ +# Transform + +`transform` command exposes subcommands that perform some specific transformation over data. The call syntax is `transform `. Currently implemented subcommands are listed and described below. + +## `nudged_static` + +`transform nudged_static` takes named parameters in the form of `transform nudged_static param_1_name param_1_value param_2_name param_2_value ...` and flag parameters which don't require values. + +This command goes through positions in the input files and replaces the scores with new ones - generated from static eval - but slightly adjusted based on the scores in the original input file. + +Currently the following options are available: + +`input_file` - path to the input file. Supports bin and binpack formats. Default: in.binpack. + +`output_file` - path to the output file. Supports bin and binpack formats. Default: out.binpack. + +`absolute` - states that the adjustment should be bounded by an absolute value. After this token follows the maximum absolute adjustment. Values are always adjusted towards scores in the input file. This is the default mode. Default maximum adjustement: 5. + +`relative` - states that the adjustment should be bounded by a value relative in magnitude to the static eval value. After this token follows the maximum relative change - a floating point value greater than 0. For example a value of 0.1 only allows changing the static eval by at most 10% towards the score from the input file. + +`interpolate` states that the output score should be a value interpolated between static eval and the score from the input file. After this token follows the interpolation constant `t`. `t` of 0 means that only static eval is used. `t` of 1 means that only score from the input file is used. `t` of 0.5 means that the static eval and input score are averaged. It accepts values outside of range `<0, 1>`, but the usefulness is questionable. + +## `rescore` + +`transform rescore` takes named parameters in the form of `transform rescore param_1_name param_1_value param_2_name param_2_value ...` and flag parameters which don't require values. + +This tool respects the UCI option `Threads` and uses all available threads. + +This command takes a path to the input file that is either a .epd file which contains one FEN per line or a .bin or .binpack file and outputs a .bin or .binpack file with these positions rescored with specified depth search. + +Currently the following options are available: + +`input_file` - path to the input file. Default: in.binpack. + +`output_file` - path to the output .bin or .binpack file. The file is opened in append mode. Default: out.binpack. + +`depth` - the search depth to use for rescoring. Default: 3. + +`keep_moves` - whether to keep moves from the input file if available. Allows to keep compression in .binpack. Default: 1. + +`research_count` - number of additional searches of depth N done on the same position before using the eval. Default: 0. + diff --git a/docs/validate_training_data.md b/docs/validate_training_data.md new file mode 100644 index 00000000..e2bfc30c --- /dev/null +++ b/docs/validate_training_data.md @@ -0,0 +1,12 @@ +# validate_training_data + +`validate_training_data` allows validation of training data of types `.plain`, `.bin`, and `.binpack`. + +As all commands in stockfish `validate_training_data` can be invoked either from command line (as `stockfish.exe validate_training_data ...`) or in the interactive prompt. + +The syntax of this command is as follows: +``` +validate_training_data in_path +``` + +`in_path` is the path to the file to validate. The type of the data is deduced based on its extension (one of `.plain`, `.bin`, `.binpack`). \ No newline at end of file diff --git a/script/README.md b/script/README.md new file mode 100644 index 00000000..feb57ca2 --- /dev/null +++ b/script/README.md @@ -0,0 +1,52 @@ +# `pgn_to_plain` +This script converts pgn files into text file to apply `learn convert_bin` command. You need to import [python-chess](https://pypi.org/project/python-chess/) to use this script. + + + pip install python-chess + + +# Example of Qhapaq's finetune using `pgn_to_plain` + +## Download data +You can download data from [here](http://rebel13.nl/index.html) + +## Convert pgn files + +**Important : convert text will be superheavy (approx 200 byte / position)** + + python pgn_to_plain.py --pgn "pgn/*.pgn" --start_ply 1 --output converted_pgn.txt + + +`--pgn` option supports wildcard. When you use pgn files with elo >= 3300, You will get 1.7 GB text file. + + +## Convert into training data + + +### Example build command + + make nnue-learn ARCH=x86-64 + +See `src/Makefile` for detail. + + +### Convert + + ./stockfish + learn convert_bin converted_pgn.txt output_file_name pgn_bin.bin + learn shuffle pgn_bin.bin + +You also need to prepare validation data for training like following. + + python pgn_to_plain.py --pgn "pgn/ccrl-40-15-3400.pgn" --start_ply 1 --output ccrl-40-15-3400.txt + ./stockfish + learn convert_bin ccrl-40-15-3400.txt ccrl-40-15-3400_plain.bin + + +### Learn + + ./stockfish + setoption name Threads value 8 + learn shuffled_sfen.bin newbob_decay 0.5 validation_set_file_name ccrl-40-15-3400_plain.bin nn_batch_size 50000 batchsize 1000000 eval_save_interval 8000000 eta 0.05 lambda 0.0 eval_limit 3000 mirror_percentage 0 use_draw_in_training 1 + + diff --git a/script/extract_bin.py b/script/extract_bin.py new file mode 100644 index 00000000..9574aa17 --- /dev/null +++ b/script/extract_bin.py @@ -0,0 +1,42 @@ +import sys + +ENTRY_SIZE = 40 +NUM_ENTRIES_IN_CHUNK = 1024*1024 + +def copy(infile, outfile, count, times): + if times > 1: + outfile.write(infile.read(count*ENTRY_SIZE)*times) + else: + offset = 0 + while offset < count: + to_read = NUM_ENTRIES_IN_CHUNK if offset + NUM_ENTRIES_IN_CHUNK <= count else count - offset + + outfile.write(infile.read(to_read*ENTRY_SIZE)) + + offset += NUM_ENTRIES_IN_CHUNK + +def work(): + filename = sys.argv[1] + offset = int(sys.argv[2]) + count = int(sys.argv[3]) + times = int(sys.argv[4]) if len(sys.argv) >= 5 else 1 + + with open(filename, 'rb') as infile: + infile.seek(offset * ENTRY_SIZE) + filename_parts = filename.split('.') + out_path = '.'.join(filename_parts[:-1]) + '_' + str(offset) + '_' + str(count) + '_' + str(times) + '.' + filename_parts[-1] + with open(out_path, 'wb') as outfile: + copy(infile, outfile, count, times) + +def show_help(): + print('Usage: python extract_bin.py filename offset count [times]') + print('filename - the path to the .bin file to process') + print('offset - the number of sfens to skip') + print('count - the number of sfens to extract') + print('times - the number of times to repeat the extracted sfens. Default = 1') + print('The result is saved in a new file named `filename.stem`_`offset`_`count`_`times`.bin') + +if len(sys.argv) < 4: + show_help() +else: + work() diff --git a/script/interleave_binpacks.py b/script/interleave_binpacks.py new file mode 100644 index 00000000..02888b33 --- /dev/null +++ b/script/interleave_binpacks.py @@ -0,0 +1,86 @@ +import struct +import sys +import os +import random +from pathlib import Path + + +def copy_next_chunk(in_file, out_file): + chunk_header = in_file.read(8) + assert chunk_header[0:4] == b"BINP" + size = struct.unpack(" 0: + where = random.randrange(total_remaining) + i = 0 + while where >= in_files_remaining[i]: + where -= in_files_remaining[i] + i += 1 + size = copy_next_chunk(in_files[i], out_file) + in_files_remaining[i] -= size + total_remaining -= size + total_size += size + mib = total_size // 1024 // 1024 + if mib // 100 != prev_mib // 100: + print("Copied {} MiB".format(mib)) + prev_mib = mib + + out_file.close() + for in_file in in_files: + in_file.close() + + print("Merged {} bytes".format(total_size)) + + +main() diff --git a/script/pgn_to_plain.py b/script/pgn_to_plain.py new file mode 100644 index 00000000..c551c136 --- /dev/null +++ b/script/pgn_to_plain.py @@ -0,0 +1,110 @@ +import chess.pgn +import argparse +import glob +import re +from typing import List + +# todo close in c++ tools using pgn-extract +# https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/help.html#-w + +commentRe = re.compile("([+-]*M*[0-9.]*)/([0-9]*)") +mateRe = re.compile("([+-])M([0-9]*)") +flip_black = False + +def parse_result(result_str:str, board:chess.Board) -> int: + if result_str == "1/2-1/2": + return 0 + if result_str == "0-1": + if board.turn == chess.WHITE: + return -1 + else: + return 1 + elif result_str == "1-0": + if board.turn == chess.WHITE: + return 1 + else: + return -1 + else: + print("illegal result", result_str) + raise ValueError + +def game_sanity_check(game: chess.pgn.Game) -> bool: + if not game.headers["Result"] in ["1/2-1/2", "0-1", "1-0"]: + print("invalid result", game.headers["Result"]) + return False + return True + +def parse_comment_for_score(comment_str: str, board: chess.Board) -> int: + global commentRe + global mateRe + global flip_black + + try: + m = commentRe.search(comment_str) + if m: + score = m.group(1) + # depth = int(m.group(2)) + m = mateRe.search(score) + if m: + if m.group(1) == "+": + score = 32000 - int(m.group(2)) + else: + score = -32000 + int(m.group(2)) + else: + score = int(float(score) * 208) # pawn to SF PawnValueEg + + if flip_black and board.turn == chess.BLACK: + score = -score + else: + score = 0 + except: + score = 0 + + return score + +def parse_game(game: chess.pgn.Game, writer, start_play: int=1)->None: + board: chess.Board = game.board() + if not game_sanity_check(game): + return + + result: str = game.headers["Result"] + ply = 0 + for node in game.mainline(): + move = node.move + if ply >= start_play: + comment: str = node.comment + writer.write("fen " + board.fen() + "\n") + writer.write("move " + str(move) + "\n") + writer.write("score " + str(parse_comment_for_score(comment, board)) + "\n") + writer.write("ply " + str(ply)+"\n") + writer.write("result " + str(parse_result(result, board)) +"\n") + writer.write("e\n") + ply += 1 + board.push(move) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--pgn", type=str, required=True) + parser.add_argument("--start_ply", type=int, default=1) + parser.add_argument("--output", type=str, default="plain.txt") + parser.add_argument("--flip_black_score", action='store_true', dest='flip_black_score', help="Flip black score. Default: False") + args = parser.parse_args() + + global flip_black + flip_black = args.flip_black_score + + pgn_files: List[str] = glob.glob(args.pgn) + pgn_files = sorted(pgn_files, key=lambda x:float(re.findall("-(\d+).pgn",x)[0] if re.findall("-(\d+).pgn",x) else 0.0)) + f = open(args.output, 'w') + for pgn_file in pgn_files: + print("parse", pgn_file) + pgn_loader = open(pgn_file) + while True: + game = chess.pgn.read_game(pgn_loader) + if game is None: + break + parse_game(game, f, args.start_ply) + f.close() + +if __name__=="__main__": + main() diff --git a/script/shuffle_binpack.py b/script/shuffle_binpack.py new file mode 100644 index 00000000..ca3b0b8e --- /dev/null +++ b/script/shuffle_binpack.py @@ -0,0 +1,89 @@ +import struct +import sys +import os +import random +from pathlib import Path + +def index_binpack(file): + print('Indexing...') + index = [] + offset = 0 + report_every = 100 + prev_mib = -report_every + while file.peek(): + chunk_header = file.read(8) + assert chunk_header[0:4] == b'BINP' + size = struct.unpack(' 3: + # split the infile in split_count pieces, creating new outfile names based on the provided name + basefile = sys.argv[2] + split_count = int(sys.argv[3]) + base=os.path.splitext(basefile)[0] + ext=os.path.splitext(basefile)[1] + out_filenames = [] + for i in range(split_count): + out_filenames.append(base+"_{}".format(i)+ext) + else: + out_filenames = [sys.argv[2]] + + for out_filename in out_filenames: + if (Path(out_filename).exists()): + print('Output path {} already exists. Please specify a path to a file that does not exist.'.format(out_filename)) + return + + print(out_filenames) + + in_file = open(in_filename, 'rb') + index = index_binpack(in_file) + + print('Shuffling...') + random.shuffle(index) + + out_files = [] + for out_filename in out_filenames: + out_files.append(open(out_filename, 'wb')) + + copy_binpack_indexed(in_file, index, out_files) + + in_file.close() + for out_file in out_files: + out_file.close() + +main() diff --git a/src/Makefile b/src/Makefile index b021ba12..860c391d 100644 --- a/src/Makefile +++ b/src/Makefile @@ -26,6 +26,12 @@ else EXE = stockfish endif +### Establish the operating system name +KERNEL = $(shell uname -s) +ifeq ($(KERNEL),Linux) + OS = $(shell uname -o) +endif + ### Installation dir definitions PREFIX = /usr/local BINDIR = $(PREFIX)/bin @@ -33,25 +39,30 @@ BINDIR = $(PREFIX)/bin ### Built-in benchmark for pgo-builds ifeq ($(SDE_PATH),) PGOBENCH = ./$(EXE) bench + PGOGENSFEN = ./$(EXE) gensfen depth 3 loop 1000 sfen_format bin output_file_name $(PGO_TRAINING_DATA_FILE) else PGOBENCH = $(SDE_PATH) -- ./$(EXE) bench + PGOGENSFEN = $(SDE_PATH) -- ./$(EXE) gensfen depth 3 loop 1000 sfen_format bin output_file_name $(PGO_TRAINING_DATA_FILE) endif ### Source and object files SRCS = benchmark.cpp bitbase.cpp bitboard.cpp endgame.cpp evaluate.cpp main.cpp \ material.cpp misc.cpp movegen.cpp movepick.cpp pawns.cpp position.cpp psqt.cpp \ search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \ - nnue/evaluate_nnue.cpp nnue/features/half_ka_v2.cpp + nnue/evaluate_nnue.cpp \ + nnue/features/half_ka_v2.cpp \ + tools/validate_training_data.cpp \ + tools/sfen_packer.cpp \ + tools/training_data_generator.cpp \ + tools/training_data_generator_nonpv.cpp \ + tools/opening_book.cpp \ + tools/convert.cpp \ + tools/transform.cpp \ + tools/stats.cpp OBJS = $(notdir $(SRCS:.cpp=.o)) -VPATH = syzygy:nnue:nnue/features - -### Establish the operating system name -KERNEL = $(shell uname -s) -ifeq ($(KERNEL),Linux) - OS = $(shell uname -o) -endif +VPATH = syzygy:nnue:nnue/features:eval:extra:tools ### ========================================================================== ### Section 2. High-level Configuration @@ -315,8 +326,9 @@ endif ### ========================================================================== ### 3.1 Selecting compiler (default = gcc) -CXXFLAGS += -Wall -Wcast-qual -fno-exceptions -std=c++17 $(EXTRACXXFLAGS) -DEPENDFLAGS += -std=c++17 +ADDITIONAL_INCLUDE_DIRECTORIES = -I. +CXXFLAGS += -Wall -Wcast-qual -fno-exceptions -std=c++17 $(ADDITIONAL_INCLUDE_DIRECTORIES) $(EXTRACXXFLAGS) +DEPENDFLAGS += -std=c++17 $(ADDITIONAL_INCLUDE_DIRECTORIES) LDFLAGS += $(EXTRALDFLAGS) ifeq ($(COMP),) @@ -326,7 +338,7 @@ endif ifeq ($(COMP),gcc) comp=gcc CXX=g++ - CXXFLAGS += -pedantic -Wextra -Wshadow + CXXFLAGS += -pedantic -Wextra -Wshadow -lstdc++fs ifeq ($(arch),$(filter $(arch),armv7 armv8)) ifeq ($(OS),Android) @@ -471,13 +483,13 @@ ifneq ($(comp),mingw) # Haiku has pthreads in its libroot, so only link it in on other platforms ifneq ($(KERNEL),Haiku) ifneq ($(COMP),ndk) - LDFLAGS += -lpthread - endif + LDFLAGS += -lpthread endif endif endif +endif -### 3.2.1 Debugging +### 3.2.2 Debugging ifeq ($(debug),no) CXXFLAGS += -DNDEBUG else @@ -601,7 +613,7 @@ ifeq ($(neon),yes) CXXFLAGS += -mfpu=neon endif endif - endif +endif endif ### 3.7 pext @@ -752,6 +764,7 @@ profile-build: net config-sanity objclean profileclean @echo "" @echo "Step 2/4. Running benchmark for pgo-build ..." $(PGOBENCH) > /dev/null + $(PGOGENSFEN) > /dev/null @echo "" @echo "Step 3/4. Building optimized executable ..." $(MAKE) ARCH=$(ARCH) COMP=$(COMP) objclean @@ -798,12 +811,12 @@ net: # clean binaries and objects objclean: - @rm -f $(EXE) *.o ./syzygy/*.o ./nnue/*.o ./nnue/features/*.o + @rm -f $(EXE) *.o ./syzygy/*.o ./nnue/*.o ./nnue/features/*.o ./tools/*.o ./extra/*.o ./eval/*.o # clean auxiliary profiling files profileclean: @rm -rf profdir - @rm -f bench.txt *.gcda *.gcno ./syzygy/*.gcda ./nnue/*.gcda ./nnue/features/*.gcda *.s + @rm -f bench.txt *.gcda *.gcno ./syzygy/*.gcda ./nnue/*.gcda ./nnue/features/*.gcda *.s ./tools/*.gcda ./extra/*.gcda ./eval/*.gcda @rm -f stockfish.profdata *.profraw @rm -f stockfish.exe.lto_wrapper_args @rm -f stockfish.exe.ltrans.out @@ -913,6 +926,6 @@ icc-profile-use: all .depend: - -@$(CXX) $(DEPENDFLAGS) -MM $(SRCS) > $@ 2> /dev/null + -@$(CXX) $(DEPENDFLAGS) -MM $(SRCS) > $@ -include .depend diff --git a/src/evaluate.cpp b/src/evaluate.cpp index 64f91725..d44fb225 100644 --- a/src/evaluate.cpp +++ b/src/evaluate.cpp @@ -27,6 +27,8 @@ #include #include +#include "nnue/evaluate_nnue.h" + #include "bitboard.h" #include "evaluate.h" #include "material.h" @@ -37,7 +39,6 @@ #include "uci.h" #include "incbin/incbin.h" - // Macro to embed the default efficiently updatable neural network (NNUE) file // data in the engine binary (using incbin.h, by Dale Weiler). // This macro invocation will declare the following three variables @@ -53,15 +54,28 @@ const unsigned int gEmbeddedNNUESize = 1; #endif - using namespace std; namespace Stockfish { namespace Eval { - bool useNNUE; - string eval_file_loaded = "None"; + namespace NNUE { + string eval_file_loaded = "None"; + UseNNUEMode useNNUE; + + static UseNNUEMode nnue_mode_from_option(const UCI::Option& mode) + { + if (mode == "false") + return UseNNUEMode::False; + else if (mode == "true") + return UseNNUEMode::True; + else if (mode == "pure") + return UseNNUEMode::Pure; + + return UseNNUEMode::False; + } + } /// NNUE::init() tries to load a NNUE network at startup time, or when the engine /// receives a UCI command "setoption name EvalFile value nn-[a-z0-9]{12}.nnue" @@ -73,8 +87,8 @@ namespace Eval { void NNUE::init() { - useNNUE = Options["Use NNUE"]; - if (!useNNUE) + useNNUE = nnue_mode_from_option(Options["Use NNUE"]); + if (useNNUE == UseNNUEMode::False) return; string eval_file = string(Options["EvalFile"]); @@ -119,7 +133,7 @@ namespace Eval { string eval_file = string(Options["EvalFile"]); - if (useNNUE && eval_file_loaded != eval_file) + if (useNNUE != UseNNUEMode::False && eval_file_loaded != eval_file) { UCI::OptionsMap defaults; UCI::init(defaults); @@ -139,7 +153,7 @@ namespace Eval { exit(EXIT_FAILURE); } - if (useNNUE) + if (useNNUE != UseNNUEMode::False) sync_cout << "info string NNUE evaluation using " << eval_file << " enabled" << sync_endl; else sync_cout << "info string classical evaluation enabled" << sync_endl; @@ -1081,9 +1095,19 @@ make_v: Value Eval::evaluate(const Position& pos) { + pos.this_thread()->on_eval(); + Value v; - if (!Eval::useNNUE) + if (NNUE::useNNUE == NNUE::UseNNUEMode::Pure) { + v = NNUE::evaluate(pos); + + // Guarantee evaluation does not hit the tablebase range + v = std::clamp(v, VALUE_TB_LOSS_IN_MAX_PLY + 1, VALUE_TB_WIN_IN_MAX_PLY - 1); + + return v; + } + else if (NNUE::useNNUE == NNUE::UseNNUEMode::False) v = Evaluation(pos).value(); else { @@ -1172,7 +1196,7 @@ std::string Eval::trace(Position& pos) { v = pos.side_to_move() == WHITE ? v : -v; ss << "\nClassical evaluation " << to_cp(v) << " (white side)\n"; - if (Eval::useNNUE) + if (NNUE::useNNUE != NNUE::UseNNUEMode::False) { v = NNUE::evaluate(pos, false); v = pos.side_to_move() == WHITE ? v : -v; diff --git a/src/evaluate.h b/src/evaluate.h index 1ac224b6..ecf1e9a2 100644 --- a/src/evaluate.h +++ b/src/evaluate.h @@ -33,15 +33,21 @@ namespace Eval { std::string trace(Position& pos); Value evaluate(const Position& pos); - extern bool useNNUE; - extern std::string eval_file_loaded; - // The default net name MUST follow the format nn-[SHA256 first 12 digits].nnue // for the build process (profile-build and fishtest) to work. Do not change the // name of the macro, as it is used in the Makefile. #define EvalFileDefaultName "nn-46832cfbead3.nnue" namespace NNUE { + enum struct UseNNUEMode + { + False, + True, + Pure + }; + + extern UseNNUEMode useNNUE; + extern std::string eval_file_loaded; std::string trace(Position& pos); Value evaluate(const Position& pos, bool adjusted = false); diff --git a/src/extra/nnue_data_binpack_format.h b/src/extra/nnue_data_binpack_format.h new file mode 100644 index 00000000..0b1ac0aa --- /dev/null +++ b/src/extra/nnue_data_binpack_format.h @@ -0,0 +1,8000 @@ +/* + Stockfish, a UCI chess playing engine derived from Glaurung 2.1 + Copyright (C) 2004-2021 The Stockfish developers (see AUTHORS file) + + Stockfish is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Stockfish is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if (defined(_MSC_VER) || defined(__INTEL_COMPILER)) && !defined(__clang__) +#include +#endif + +namespace chess +{ + #if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) + + #define FORCEINLINE __attribute__((always_inline)) + + #elif defined(_MSC_VER) + + // NOTE: for some reason it breaks the profiler a little + // keep it on only when not profiling. + //#define FORCEINLINE __forceinline + #define FORCEINLINE + + #else + + #define FORCEINLINE inline + + #endif + + #if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) + + #define NOINLINE __attribute__((noinline)) + + #elif defined(_MSC_VER) + + #define NOINLINE __declspec(noinline) + + #else + + #define NOINLINE + + #endif + + namespace intrin + { + [[nodiscard]] constexpr int popcount_constexpr(std::uint64_t value) + { + int r = 0; + while (value) + { + value &= value - 1; + ++r; + } + return r; + } + + [[nodiscard]] constexpr int lsb_constexpr(std::uint64_t value) + { + int c = 0; + value &= ~value + 1; // leave only the lsb + if ((value & 0x00000000FFFFFFFFull) == 0) c += 32; + if ((value & 0x0000FFFF0000FFFFull) == 0) c += 16; + if ((value & 0x00FF00FF00FF00FFull) == 0) c += 8; + if ((value & 0x0F0F0F0F0F0F0F0Full) == 0) c += 4; + if ((value & 0x3333333333333333ull) == 0) c += 2; + if ((value & 0x5555555555555555ull) == 0) c += 1; + return c; + } + + [[nodiscard]] constexpr int msb_constexpr(std::uint64_t value) + { + int c = 63; + if ((value & 0xFFFFFFFF00000000ull) == 0) { c -= 32; value <<= 32; } + if ((value & 0xFFFF000000000000ull) == 0) { c -= 16; value <<= 16; } + if ((value & 0xFF00000000000000ull) == 0) { c -= 8; value <<= 8; } + if ((value & 0xF000000000000000ull) == 0) { c -= 4; value <<= 4; } + if ((value & 0xC000000000000000ull) == 0) { c -= 2; value <<= 2; } + if ((value & 0x8000000000000000ull) == 0) { c -= 1; } + return c; + } + } + + namespace intrin + { + [[nodiscard]] inline int popcount(std::uint64_t b) + { + #if (defined(_MSC_VER) || defined(__INTEL_COMPILER)) && !defined(__clang__) + + return static_cast(_mm_popcnt_u64(b)); + + #else + + return static_cast(__builtin_popcountll(b)); + + #endif + } + + #if defined(_MSC_VER) && !defined(__clang__) + + [[nodiscard]] inline int lsb(std::uint64_t value) + { + assert(value != 0); + + unsigned long idx; + _BitScanForward64(&idx, value); + return static_cast(idx); + } + + [[nodiscard]] inline int msb(std::uint64_t value) + { + assert(value != 0); + + unsigned long idx; + _BitScanReverse64(&idx, value); + return static_cast(idx); + } + + #else + + [[nodiscard]] inline int lsb(std::uint64_t value) + { + assert(value != 0); + + return __builtin_ctzll(value); + } + + [[nodiscard]] inline int msb(std::uint64_t value) + { + assert(value != 0); + + return 63 ^ __builtin_clzll(value); + } + + #endif + } + + template + [[nodiscard]] constexpr IntT floorLog2(IntT value) + { + return intrin::msb_constexpr(value); + } + + template + constexpr auto computeMasks() + { + static_assert(std::is_unsigned_v); + + constexpr std::size_t numBits = sizeof(IntT) * CHAR_BIT; + std::array nbitmasks{}; + + for (std::size_t i = 0; i < numBits; ++i) + { + nbitmasks[i] = (static_cast(1u) << i) - 1u; + } + nbitmasks[numBits] = ~static_cast(0u); + + return nbitmasks; + } + + template + constexpr auto nbitmask = computeMasks(); + + template > + inline ToT signExtend(FromT value) + { + static_assert(std::is_signed_v); + static_assert(std::is_unsigned_v); + static_assert(sizeof(ToT) == sizeof(FromT)); + + constexpr std::size_t totalBits = sizeof(FromT) * CHAR_BIT; + + static_assert(N > 0 && N <= totalBits); + + constexpr std::size_t unusedBits = totalBits - N; + if constexpr (ToT(~FromT(0)) >> 1 == ToT(~FromT(0))) + { + return ToT(value << unusedBits) >> ToT(unusedBits); + } + else + { + constexpr FromT mask = (~FromT(0)) >> unusedBits; + value &= mask; + if (value & (FromT(1) << (N - 1))) + { + value |= ~mask; + } + return static_cast(value); + } + } + + namespace lookup + { + constexpr int nthSetBitIndexNaive(std::uint64_t value, int n) + { + for (int i = 0; i < n; ++i) + { + value &= value - 1; + } + return intrin::lsb_constexpr(value); + } + + constexpr std::array, 256> nthSetBitIndex = []() + { + std::array, 256> t{}; + + for (int i = 0; i < 256; ++i) + { + for (int j = 0; j < 8; ++j) + { + t[i][j] = nthSetBitIndexNaive(i, j); + } + } + + return t; + }(); + } + + inline int nthSetBitIndex(std::uint64_t v, std::uint64_t n) + { + std::uint64_t shift = 0; + + std::uint64_t p = intrin::popcount(v & 0xFFFFFFFFull); + std::uint64_t pmask = static_cast(p > n) - 1ull; + v >>= 32 & pmask; + shift += 32 & pmask; + n -= p & pmask; + + p = intrin::popcount(v & 0xFFFFull); + pmask = static_cast(p > n) - 1ull; + v >>= 16 & pmask; + shift += 16 & pmask; + n -= p & pmask; + + p = intrin::popcount(v & 0xFFull); + pmask = static_cast(p > n) - 1ull; + shift += 8 & pmask; + v >>= 8 & pmask; + n -= p & pmask; + + return static_cast(lookup::nthSetBitIndex[v & 0xFFull][n] + shift); + } + + namespace util + { + inline std::size_t usedBits(std::size_t value) + { + if (value == 0) return 0; + return intrin::msb(value) + 1; + } + } + + template + struct EnumTraits; + + template + [[nodiscard]] constexpr auto hasEnumTraits() -> decltype(EnumTraits::cardinaliy, bool{}) + { + return true; + } + + template + [[nodiscard]] constexpr bool hasEnumTraits(...) + { + return false; + } + + template + [[nodiscard]] constexpr bool isNaturalIndex() noexcept + { + return EnumTraits::isNaturalIndex; + } + + template + [[nodiscard]] constexpr int cardinality() noexcept + { + return EnumTraits::cardinality; + } + + template + [[nodiscard]] constexpr const std::array()>& values() noexcept + { + return EnumTraits::values; + } + + template + [[nodiscard]] constexpr EnumT fromOrdinal(int id) noexcept + { + assert(!EnumTraits::isNaturalIndex || (id >= 0 && id < EnumTraits::cardinality)); + + return EnumTraits::fromOrdinal(id); + } + + template + [[nodiscard]] constexpr typename EnumTraits::IdType ordinal(EnumT v) noexcept + { + return EnumTraits::ordinal(v); + } + + template ()>> + [[nodiscard]] constexpr decltype(auto) toString(EnumT v, ArgsTs&&... args) + { + return EnumTraits::toString(v, std::forward(args)...); + } + + template + [[nodiscard]] constexpr decltype(auto) toString(EnumT v) + { + return EnumTraits::toString(v); + } + + template ()>> + [[nodiscard]] constexpr decltype(auto) toString(FormatT&& f, EnumT v) + { + return EnumTraits::toString(std::forward(f), v); + } + + template + [[nodiscard]] constexpr decltype(auto) toChar(EnumT v) + { + return EnumTraits::toChar(v); + } + + template + [[nodiscard]] constexpr decltype(auto) toChar(FormatT&& f, EnumT v) + { + return EnumTraits::toChar(std::forward(f), v); + } + + template + [[nodiscard]] constexpr decltype(auto) fromString(ArgsTs&& ... args) + { + return EnumTraits::fromString(std::forward(args)...); + } + + template + [[nodiscard]] constexpr decltype(auto) fromChar(ArgsTs&& ... args) + { + return EnumTraits::fromChar(std::forward(args)...); + } + + template <> + struct EnumTraits + { + using IdType = int; + using EnumType = bool; + + static constexpr int cardinality = 2; + static constexpr bool isNaturalIndex = true; + + static constexpr std::array values{ + false, + true + }; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + { + return static_cast(c); + } + + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + { + return static_cast(id); + } + }; + + template ()> + struct EnumArray + { + static_assert(isNaturalIndex(), "Enum must start with 0 and end with cardinality-1."); + + using value_type = ValueT; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = ValueT *; + using const_pointer = const ValueT*; + using reference = ValueT &; + using const_reference = const ValueT &; + + using iterator = pointer; + using const_iterator = const_pointer; + + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + using KeyType = EnumT; + using ValueType = ValueT; + + constexpr void fill(const ValueType& init) + { + for (auto& v : elements) + { + v = init; + } + } + + [[nodiscard]] constexpr ValueType& operator[](const KeyType& dir) + { + assert(static_cast(ordinal(dir)) < static_cast(SizeV)); + + return elements[ordinal(dir)]; + } + + [[nodiscard]] constexpr const ValueType& operator[](const KeyType& dir) const + { + assert(static_cast(ordinal(dir)) < static_cast(SizeV)); + + return elements[ordinal(dir)]; + } + + [[nodiscard]] constexpr ValueType& front() + { + return elements[0]; + } + + [[nodiscard]] constexpr const ValueType& front() const + { + return elements[0]; + } + + [[nodiscard]] constexpr ValueType& back() + { + return elements[SizeV - 1]; + } + + [[nodiscard]] constexpr const ValueType& back() const + { + return elements[SizeV - 1]; + } + + [[nodiscard]] constexpr pointer data() + { + return elements; + } + + [[nodiscard]] constexpr const_pointer data() const + { + return elements; + } + + [[nodiscard]] constexpr iterator begin() noexcept + { + return elements; + } + + [[nodiscard]] constexpr const_iterator begin() const noexcept + { + return elements; + } + + [[nodiscard]] constexpr iterator end() noexcept + { + return elements + SizeV; + } + + [[nodiscard]] constexpr const_iterator end() const noexcept + { + return elements + SizeV; + } + + [[nodiscard]] constexpr reverse_iterator rbegin() noexcept + { + return reverse_iterator(end()); + } + + [[nodiscard]] constexpr const_reverse_iterator rbegin() const noexcept + { + return const_reverse_iterator(end()); + } + + [[nodiscard]] constexpr reverse_iterator rend() noexcept + { + return reverse_iterator(begin()); + } + + [[nodiscard]] constexpr const_reverse_iterator rend() const noexcept + { + return const_reverse_iterator(begin()); + } + + [[nodiscard]] constexpr const_iterator cbegin() const noexcept + { + return begin(); + } + + [[nodiscard]] constexpr const_iterator cend() const noexcept + { + return end(); + } + + [[nodiscard]] constexpr const_reverse_iterator crbegin() const noexcept + { + return rbegin(); + } + + [[nodiscard]] constexpr const_reverse_iterator crend() const noexcept + { + return rend(); + } + + [[nodiscard]] constexpr size_type size() const noexcept + { + return SizeV; + } + + ValueT elements[SizeV]; + }; + + template (), std::size_t Size2V = cardinality()> + using EnumArray2 = EnumArray, Size1V>; + + enum struct Color : std::uint8_t + { + White, + Black + }; + + template <> + struct EnumTraits + { + using IdType = int; + using EnumType = Color; + + static constexpr int cardinality = 2; + static constexpr bool isNaturalIndex = true; + + static constexpr std::array values{ + Color::White, + Color::Black + }; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + { + return static_cast(c); + } + + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + { + assert(id >= 0 && id < cardinality); + + return static_cast(id); + } + + [[nodiscard]] static constexpr std::string_view toString(EnumType c) noexcept + { + return std::string_view("wb" + ordinal(c), 1); + } + + [[nodiscard]] static constexpr char toChar(EnumType c) noexcept + { + return "wb"[ordinal(c)]; + } + + [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept + { + if (c == 'w') return Color::White; + if (c == 'b') return Color::Black; + + return {}; + } + + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept + { + if (sv.size() != 1) return {}; + + return fromChar(sv[0]); + } + }; + + constexpr Color operator!(Color c) + { + return fromOrdinal(ordinal(c) ^ 1); + } + + enum struct PieceType : std::uint8_t + { + Pawn, + Knight, + Bishop, + Rook, + Queen, + King, + + None + }; + + template <> + struct EnumTraits + { + using IdType = int; + using EnumType = PieceType; + + static constexpr int cardinality = 7; + static constexpr bool isNaturalIndex = true; + + static constexpr std::array values{ + PieceType::Pawn, + PieceType::Knight, + PieceType::Bishop, + PieceType::Rook, + PieceType::Queen, + PieceType::King, + PieceType::None + }; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + { + return static_cast(c); + } + + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + { + assert(id >= 0 && id < cardinality); + + return static_cast(id); + } + + [[nodiscard]] static constexpr std::string_view toString(EnumType p, Color c) noexcept + { + return std::string_view("PpNnBbRrQqKk " + (chess::ordinal(p) * 2 + chess::ordinal(c)), 1); + } + + [[nodiscard]] static constexpr char toChar(EnumType p, Color c) noexcept + { + return "PpNnBbRrQqKk "[chess::ordinal(p) * 2 + chess::ordinal(c)]; + } + + [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept + { + auto it = std::string_view("PpNnBbRrQqKk ").find(c); + if (it == std::string::npos) return {}; + else return static_cast(it/2); + } + + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept + { + if (sv.size() != 1) return {}; + + return fromChar(sv[0]); + } + }; + + struct Piece + { + [[nodiscard]] static constexpr Piece fromId(int id) + { + return Piece(id); + } + + [[nodiscard]] static constexpr Piece none() + { + return Piece(PieceType::None, Color::White); + } + + constexpr Piece() noexcept : + Piece(PieceType::None, Color::White) + { + + } + + constexpr Piece(PieceType type, Color color) noexcept : + m_id((ordinal(type) << 1) | ordinal(color)) + { + assert(type != PieceType::None || color == Color::White); + } + + constexpr Piece& operator=(const Piece& other) = default; + + [[nodiscard]] constexpr friend bool operator==(Piece lhs, Piece rhs) noexcept + { + return lhs.m_id == rhs.m_id; + } + + [[nodiscard]] constexpr friend bool operator!=(Piece lhs, Piece rhs) noexcept + { + return !(lhs == rhs); + } + + [[nodiscard]] constexpr PieceType type() const + { + return fromOrdinal(m_id >> 1); + } + + [[nodiscard]] constexpr Color color() const + { + return fromOrdinal(m_id & 1); + } + + [[nodiscard]] constexpr std::pair parts() const + { + return std::make_pair(type(), color()); + } + + [[nodiscard]] constexpr explicit operator int() const + { + return static_cast(m_id); + } + + private: + constexpr Piece(int id) : + m_id(id) + { + } + + std::uint8_t m_id; // lowest bit is a color, 7 highest bits are a piece type + }; + + [[nodiscard]] constexpr Piece operator|(PieceType type, Color color) noexcept + { + return Piece(type, color); + } + + [[nodiscard]] constexpr Piece operator|(Color color, PieceType type) noexcept + { + return Piece(type, color); + } + + constexpr Piece whitePawn = Piece(PieceType::Pawn, Color::White); + constexpr Piece whiteKnight = Piece(PieceType::Knight, Color::White); + constexpr Piece whiteBishop = Piece(PieceType::Bishop, Color::White); + constexpr Piece whiteRook = Piece(PieceType::Rook, Color::White); + constexpr Piece whiteQueen = Piece(PieceType::Queen, Color::White); + constexpr Piece whiteKing = Piece(PieceType::King, Color::White); + + constexpr Piece blackPawn = Piece(PieceType::Pawn, Color::Black); + constexpr Piece blackKnight = Piece(PieceType::Knight, Color::Black); + constexpr Piece blackBishop = Piece(PieceType::Bishop, Color::Black); + constexpr Piece blackRook = Piece(PieceType::Rook, Color::Black); + constexpr Piece blackQueen = Piece(PieceType::Queen, Color::Black); + constexpr Piece blackKing = Piece(PieceType::King, Color::Black); + + static_assert(Piece::none().type() == PieceType::None); + + template <> + struct EnumTraits + { + using IdType = int; + using EnumType = Piece; + + static constexpr int cardinality = 13; + static constexpr bool isNaturalIndex = true; + + static constexpr std::array values{ + whitePawn, + blackPawn, + whiteKnight, + blackKnight, + whiteBishop, + blackBishop, + whiteRook, + blackRook, + whiteQueen, + blackQueen, + whiteKing, + blackKing, + Piece::none() + }; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + { + return static_cast(c); + } + + [[nodiscard]] static constexpr EnumType fromOrdinal(int id) noexcept + { + assert(id >= 0 && id < cardinality); + + return Piece::fromId(id); + } + + [[nodiscard]] static constexpr std::string_view toString(EnumType p) noexcept + { + return std::string_view("PpNnBbRrQqKk " + ordinal(p), 1); + } + + [[nodiscard]] static constexpr char toChar(EnumType p) noexcept + { + return "PpNnBbRrQqKk "[ordinal(p)]; + } + + [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept + { + auto it = std::string_view("PpNnBbRrQqKk ").find(c); + if (it == std::string::npos) return {}; + else return Piece::fromId(static_cast(it)); + } + + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept + { + if (sv.size() != 1) return {}; + + return fromChar(sv[0]); + } + }; + + template + struct Coord + { + constexpr Coord() noexcept : + m_i(0) + { + } + + constexpr explicit Coord(int i) noexcept : + m_i(i) + { + } + + [[nodiscard]] constexpr explicit operator int() const + { + return static_cast(m_i); + } + + constexpr friend Coord& operator++(Coord& c) + { + ++c.m_i; + return c; + } + + constexpr friend Coord& operator--(Coord& c) + { + --c.m_i; + return c; + } + + constexpr friend Coord& operator+=(Coord& c, int d) + { + c.m_i += d; + return c; + } + + constexpr friend Coord& operator-=(Coord& c, int d) + { + c.m_i -= d; + return c; + } + + constexpr friend Coord operator+(const Coord& c, int d) + { + Coord cpy(c); + cpy += d; + return cpy; + } + + constexpr friend Coord operator-(const Coord& c, int d) + { + Coord cpy(c); + cpy -= d; + return cpy; + } + + constexpr friend int operator-(const Coord& c1, const Coord& c2) + { + return c1.m_i - c2.m_i; + } + + [[nodiscard]] constexpr friend bool operator==(const Coord& c1, const Coord& c2) noexcept + { + return c1.m_i == c2.m_i; + } + + [[nodiscard]] constexpr friend bool operator!=(const Coord& c1, const Coord& c2) noexcept + { + return c1.m_i != c2.m_i; + } + + [[nodiscard]] constexpr friend bool operator<(const Coord& c1, const Coord& c2) noexcept + { + return c1.m_i < c2.m_i; + } + + [[nodiscard]] constexpr friend bool operator<=(const Coord& c1, const Coord& c2) noexcept + { + return c1.m_i <= c2.m_i; + } + + [[nodiscard]] constexpr friend bool operator>(const Coord& c1, const Coord& c2) noexcept + { + return c1.m_i > c2.m_i; + } + + [[nodiscard]] constexpr friend bool operator>=(const Coord& c1, const Coord& c2) noexcept + { + return c1.m_i >= c2.m_i; + } + + private: + std::int8_t m_i; + }; + + struct FileTag; + struct RankTag; + using File = Coord; + using Rank = Coord; + + constexpr File fileA = File(0); + constexpr File fileB = File(1); + constexpr File fileC = File(2); + constexpr File fileD = File(3); + constexpr File fileE = File(4); + constexpr File fileF = File(5); + constexpr File fileG = File(6); + constexpr File fileH = File(7); + + constexpr Rank rank1 = Rank(0); + constexpr Rank rank2 = Rank(1); + constexpr Rank rank3 = Rank(2); + constexpr Rank rank4 = Rank(3); + constexpr Rank rank5 = Rank(4); + constexpr Rank rank6 = Rank(5); + constexpr Rank rank7 = Rank(6); + constexpr Rank rank8 = Rank(7); + + template <> + struct EnumTraits + { + using IdType = int; + using EnumType = File; + + static constexpr int cardinality = 8; + static constexpr bool isNaturalIndex = true; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + { + return static_cast(c); + } + + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + { + assert(id >= 0 && id < cardinality); + + return static_cast(id); + } + + [[nodiscard]] static constexpr std::string_view toString(EnumType c) noexcept + { + assert(ordinal(c) >= 0 && ordinal(c) < 8); + + return std::string_view("abcdefgh" + ordinal(c), 1); + } + + [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept + { + if (c < 'a' || c > 'h') return {}; + return static_cast(c - 'a'); + } + + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept + { + if (sv.size() != 1) return {}; + + return fromChar(sv[0]); + } + }; + + template <> + struct EnumTraits + { + using IdType = int; + using EnumType = Rank; + + static constexpr int cardinality = 8; + static constexpr bool isNaturalIndex = true; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + { + return static_cast(c); + } + + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + { + assert(id >= 0 && id < cardinality); + + return static_cast(id); + } + + [[nodiscard]] static constexpr std::string_view toString(EnumType c) noexcept + { + assert(ordinal(c) >= 0 && ordinal(c) < 8); + + return std::string_view("12345678" + ordinal(c), 1); + } + + [[nodiscard]] static constexpr std::optional fromChar(char c) noexcept + { + if (c < '1' || c > '8') return {}; + return static_cast(c - '1'); + } + + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept + { + if (sv.size() != 1) return {}; + + return fromChar(sv[0]); + } + }; + + // files east + // ranks north + struct FlatSquareOffset + { + std::int8_t value; + + constexpr FlatSquareOffset() noexcept : + value(0) + { + } + + constexpr FlatSquareOffset(int files, int ranks) noexcept : + value(files + ranks * cardinality()) + { + assert(files + ranks * cardinality() >= std::numeric_limits::min()); + assert(files + ranks * cardinality() <= std::numeric_limits::max()); + } + + constexpr FlatSquareOffset operator-() const noexcept + { + return FlatSquareOffset(-value); + } + + private: + constexpr FlatSquareOffset(int v) noexcept : + value(v) + { + } + }; + + struct Offset + { + std::int8_t files; + std::int8_t ranks; + + constexpr Offset() : + files(0), + ranks(0) + { + } + + constexpr Offset(int files_, int ranks_) : + files(files_), + ranks(ranks_) + { + } + + [[nodiscard]] constexpr FlatSquareOffset flat() const + { + return { files, ranks }; + } + + [[nodiscard]] constexpr Offset operator-() const + { + return { -files, -ranks }; + } + }; + + struct SquareCoords + { + File file; + Rank rank; + + constexpr SquareCoords() noexcept : + file{}, + rank{} + { + } + + constexpr SquareCoords(File f, Rank r) noexcept : + file(f), + rank(r) + { + } + + constexpr friend SquareCoords& operator+=(SquareCoords& c, Offset offset) + { + c.file += offset.files; + c.rank += offset.ranks; + return c; + } + + [[nodiscard]] constexpr friend SquareCoords operator+(const SquareCoords& c, Offset offset) + { + SquareCoords cpy(c); + cpy.file += offset.files; + cpy.rank += offset.ranks; + return cpy; + } + + [[nodiscard]] constexpr bool isOk() const + { + return file >= fileA && file <= fileH && rank >= rank1 && rank <= rank8; + } + }; + + struct Square + { + private: + static constexpr std::int8_t m_noneId = cardinality() * cardinality(); + + static constexpr std::uint8_t fileMask = 0b111; + static constexpr std::uint8_t rankMask = 0b111000; + static constexpr std::uint8_t rankShift = 3; + + public: + [[nodiscard]] static constexpr Square none() + { + return Square(m_noneId); + } + + constexpr Square() noexcept : + m_id(0) + { + } + + constexpr explicit Square(int idx) noexcept : + m_id(idx) + { + assert(isOk() || m_id == m_noneId); + } + + constexpr Square(File file, Rank rank) noexcept : + m_id(ordinal(file) + ordinal(rank) * cardinality()) + { + assert(isOk()); + } + + constexpr explicit Square(SquareCoords coords) noexcept : + Square(coords.file, coords.rank) + { + } + + [[nodiscard]] constexpr friend bool operator<(Square lhs, Square rhs) noexcept + { + return lhs.m_id < rhs.m_id; + } + + [[nodiscard]] constexpr friend bool operator>(Square lhs, Square rhs) noexcept + { + return lhs.m_id > rhs.m_id; + } + + [[nodiscard]] constexpr friend bool operator<=(Square lhs, Square rhs) noexcept + { + return lhs.m_id <= rhs.m_id; + } + + [[nodiscard]] constexpr friend bool operator>=(Square lhs, Square rhs) noexcept + { + return lhs.m_id >= rhs.m_id; + } + + [[nodiscard]] constexpr friend bool operator==(Square lhs, Square rhs) noexcept + { + return lhs.m_id == rhs.m_id; + } + + [[nodiscard]] constexpr friend bool operator!=(Square lhs, Square rhs) noexcept + { + return !(lhs == rhs); + } + + constexpr friend Square& operator++(Square& sq) + { + ++sq.m_id; + return sq; + } + + constexpr friend Square& operator--(Square& sq) + { + --sq.m_id; + return sq; + } + + [[nodiscard]] constexpr friend Square operator+(Square sq, FlatSquareOffset offset) + { + Square sqCpy = sq; + sqCpy += offset; + return sqCpy; + } + + constexpr friend Square& operator+=(Square& sq, FlatSquareOffset offset) + { + assert(sq.m_id + offset.value >= 0 && sq.m_id + offset.value < Square::m_noneId); + sq.m_id += offset.value; + return sq; + } + + [[nodiscard]] constexpr friend Square operator+(Square sq, Offset offset) + { + assert(sq.file() + offset.files >= fileA); + assert(sq.file() + offset.files <= fileH); + assert(sq.rank() + offset.ranks >= rank1); + assert(sq.rank() + offset.ranks <= rank8); + return operator+(sq, offset.flat()); + } + + constexpr friend Square& operator+=(Square& sq, Offset offset) + { + return operator+=(sq, offset.flat()); + } + + [[nodiscard]] constexpr explicit operator int() const + { + return m_id; + } + + [[nodiscard]] constexpr File file() const + { + assert(isOk()); + return File(static_cast(m_id) & fileMask); + } + + [[nodiscard]] constexpr Rank rank() const + { + assert(isOk()); + return Rank(static_cast(m_id) >> rankShift); + } + + [[nodiscard]] constexpr SquareCoords coords() const + { + return { file(), rank() }; + } + + [[nodiscard]] constexpr Color color() const + { + assert(isOk()); + return !fromOrdinal((ordinal(rank()) + ordinal(file())) & 1); + } + + constexpr void flipVertically() + { + m_id ^= rankMask; + } + + constexpr void flipHorizontally() + { + m_id ^= fileMask; + } + + constexpr Square flippedVertically() const + { + return Square(m_id ^ rankMask); + } + + constexpr Square flippedHorizontally() const + { + return Square(m_id ^ fileMask); + } + + [[nodiscard]] constexpr bool isOk() const + { + return m_id >= 0 && m_id < m_noneId; + } + + private: + std::int8_t m_id; + }; + + constexpr Square a1(fileA, rank1); + constexpr Square a2(fileA, rank2); + constexpr Square a3(fileA, rank3); + constexpr Square a4(fileA, rank4); + constexpr Square a5(fileA, rank5); + constexpr Square a6(fileA, rank6); + constexpr Square a7(fileA, rank7); + constexpr Square a8(fileA, rank8); + + constexpr Square b1(fileB, rank1); + constexpr Square b2(fileB, rank2); + constexpr Square b3(fileB, rank3); + constexpr Square b4(fileB, rank4); + constexpr Square b5(fileB, rank5); + constexpr Square b6(fileB, rank6); + constexpr Square b7(fileB, rank7); + constexpr Square b8(fileB, rank8); + + constexpr Square c1(fileC, rank1); + constexpr Square c2(fileC, rank2); + constexpr Square c3(fileC, rank3); + constexpr Square c4(fileC, rank4); + constexpr Square c5(fileC, rank5); + constexpr Square c6(fileC, rank6); + constexpr Square c7(fileC, rank7); + constexpr Square c8(fileC, rank8); + + constexpr Square d1(fileD, rank1); + constexpr Square d2(fileD, rank2); + constexpr Square d3(fileD, rank3); + constexpr Square d4(fileD, rank4); + constexpr Square d5(fileD, rank5); + constexpr Square d6(fileD, rank6); + constexpr Square d7(fileD, rank7); + constexpr Square d8(fileD, rank8); + + constexpr Square e1(fileE, rank1); + constexpr Square e2(fileE, rank2); + constexpr Square e3(fileE, rank3); + constexpr Square e4(fileE, rank4); + constexpr Square e5(fileE, rank5); + constexpr Square e6(fileE, rank6); + constexpr Square e7(fileE, rank7); + constexpr Square e8(fileE, rank8); + + constexpr Square f1(fileF, rank1); + constexpr Square f2(fileF, rank2); + constexpr Square f3(fileF, rank3); + constexpr Square f4(fileF, rank4); + constexpr Square f5(fileF, rank5); + constexpr Square f6(fileF, rank6); + constexpr Square f7(fileF, rank7); + constexpr Square f8(fileF, rank8); + + constexpr Square g1(fileG, rank1); + constexpr Square g2(fileG, rank2); + constexpr Square g3(fileG, rank3); + constexpr Square g4(fileG, rank4); + constexpr Square g5(fileG, rank5); + constexpr Square g6(fileG, rank6); + constexpr Square g7(fileG, rank7); + constexpr Square g8(fileG, rank8); + + constexpr Square h1(fileH, rank1); + constexpr Square h2(fileH, rank2); + constexpr Square h3(fileH, rank3); + constexpr Square h4(fileH, rank4); + constexpr Square h5(fileH, rank5); + constexpr Square h6(fileH, rank6); + constexpr Square h7(fileH, rank7); + constexpr Square h8(fileH, rank8); + + static_assert(e1.color() == Color::Black); + static_assert(e8.color() == Color::White); + + static_assert(e1.file() == fileE); + static_assert(e1.rank() == rank1); + + static_assert(e1.flippedHorizontally() == d1); + static_assert(e1.flippedVertically() == e8); + + template <> + struct EnumTraits + { + using IdType = int; + using EnumType = Square; + + static constexpr int cardinality = chess::cardinality() * chess::cardinality(); + static constexpr bool isNaturalIndex = true; + + static constexpr std::array values{ + a1, b1, c1, d1, e1, f1, g1, h1, + a2, b2, c2, d2, e2, f2, g2, h2, + a3, b3, c3, d3, e3, f3, g3, h3, + a4, b4, c4, d4, e4, f4, g4, h4, + a5, b5, c5, d5, e5, f5, g5, h5, + a6, b6, c6, d6, e6, f6, g6, h6, + a7, b7, c7, d7, e7, f7, g7, h7, + a8, b8, c8, d8, e8, f8, g8, h8 + }; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + { + return static_cast(c); + } + + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + { + assert(id >= 0 && id < cardinality + 1); + + return static_cast(id); + } + + [[nodiscard]] static constexpr std::string_view toString(Square sq) + { + assert(sq.isOk()); + + return + std::string_view( + "a1b1c1d1e1f1g1h1" + "a2b2c2d2e2f2g2h2" + "a3b3c3d3e3f3g3h3" + "a4b4c4d4e4f4g4h4" + "a5b5c5d5e5f5g5h5" + "a6b6c6d6e6f6g6h6" + "a7b7c7d7e7f7g7h7" + "a8b8c8d8e8f8g8h8" + + (ordinal(sq) * 2), + 2 + ); + } + + [[nodiscard]] static constexpr std::optional fromString(std::string_view sv) noexcept + { + if (sv.size() != 2) return {}; + + const char f = sv[0]; + const char r = sv[1]; + if (f < 'a' || f > 'h') return {}; + if (r < '1' || r > '8') return {}; + + return Square(static_cast(f - 'a'), static_cast(r - '1')); + } + }; + + static_assert(toString(d1) == std::string_view("d1")); + static_assert(values()[29] == f4); + + enum struct MoveType : std::uint8_t + { + Normal, + Promotion, + Castle, + EnPassant + }; + + template <> + struct EnumTraits + { + using IdType = int; + using EnumType = MoveType; + + static constexpr int cardinality = 4; + static constexpr bool isNaturalIndex = true; + + static constexpr std::array values{ + MoveType::Normal, + MoveType::Promotion, + MoveType::Castle, + MoveType::EnPassant + }; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + { + return static_cast(c); + } + + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + { + assert(id >= 0 && id < cardinality); + + return static_cast(id); + } + }; + + enum struct CastleType : std::uint8_t + { + Short, + Long + }; + + [[nodiscard]] constexpr CastleType operator!(CastleType ct) + { + return static_cast(static_cast(ct) ^ 1); + } + + template <> + struct EnumTraits + { + using IdType = int; + using EnumType = CastleType; + + static constexpr int cardinality = 2; + static constexpr bool isNaturalIndex = true; + + static constexpr std::array values{ + CastleType::Short, + CastleType::Long + }; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + { + return static_cast(c); + } + + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + { + assert(id >= 0 && id < cardinality); + + return static_cast(id); + } + }; + + struct CompressedMove; + + // castling is encoded as a king capturing rook + // ep is encoded as a normal pawn capture (move.to is empty on the board) + struct Move + { + Square from; + Square to; + MoveType type = MoveType::Normal; + Piece promotedPiece = Piece::none(); + + [[nodiscard]] constexpr friend bool operator==(const Move& lhs, const Move& rhs) noexcept + { + return lhs.from == rhs.from + && lhs.to == rhs.to + && lhs.type == rhs.type + && lhs.promotedPiece == rhs.promotedPiece; + } + + [[nodiscard]] constexpr friend bool operator!=(const Move& lhs, const Move& rhs) noexcept + { + return !(lhs == rhs); + } + + [[nodiscard]] constexpr CompressedMove compress() const noexcept; + + [[nodiscard]] constexpr static Move null() + { + return Move{ Square::none(), Square::none() }; + } + + [[nodiscard]] constexpr static Move castle(CastleType ct, Color c); + + [[nodiscard]] constexpr static Move normal(Square from, Square to) + { + return Move{ from, to, MoveType::Normal, Piece::none() }; + } + + [[nodiscard]] constexpr static Move enPassant(Square from, Square to) + { + return Move{ from, to, MoveType::EnPassant, Piece::none() }; + } + + [[nodiscard]] constexpr static Move promotion(Square from, Square to, Piece piece) + { + return Move{ from, to, MoveType::Promotion, piece }; + } + }; + + namespace detail::castle + { + constexpr EnumArray2 moves = { { + {{ { e1, h1, MoveType::Castle }, { e8, h8, MoveType::Castle } }}, + {{ { e1, a1, MoveType::Castle }, { e8, a8, MoveType::Castle } }} + } }; + } + + [[nodiscard]] constexpr Move Move::castle(CastleType ct, Color c) + { + return detail::castle::moves[ct][c]; + } + + static_assert(sizeof(Move) == 4); + + struct CompressedMove + { + private: + // from most significant bits + // 2 bits for move type + // 6 bits for from square + // 6 bits for to square + // 2 bits for promoted piece type + // 0 if not a promotion + static constexpr std::uint16_t squareMask = 0b111111u; + static constexpr std::uint16_t promotedPieceTypeMask = 0b11u; + static constexpr std::uint16_t moveTypeMask = 0b11u; + + public: + [[nodiscard]] constexpr static CompressedMove readFromBigEndian(const unsigned char* data) + { + CompressedMove move{}; + move.m_packed = (data[0] << 8) | data[1]; + return move; + } + + constexpr CompressedMove() noexcept : + m_packed(0) + { + } + + // move must be either valid or a null move + constexpr CompressedMove(Move move) noexcept : + m_packed(0) + { + // else null move + if (move.from != move.to) + { + assert(move.from != Square::none()); + assert(move.to != Square::none()); + + m_packed = + (static_cast(ordinal(move.type)) << (16 - 2)) + | (static_cast(ordinal(move.from)) << (16 - 2 - 6)) + | (static_cast(ordinal(move.to)) << (16 - 2 - 6 - 6)); + + if (move.type == MoveType::Promotion) + { + assert(move.promotedPiece != Piece::none()); + + m_packed |= ordinal(move.promotedPiece.type()) - ordinal(PieceType::Knight); + } + else + { + assert(move.promotedPiece == Piece::none()); + } + } + } + + void writeToBigEndian(unsigned char* data) const + { + *data++ = m_packed >> 8; + *data++ = m_packed & 0xFF; + } + + [[nodiscard]] constexpr std::uint16_t packed() const + { + return m_packed; + } + + [[nodiscard]] constexpr MoveType type() const + { + return fromOrdinal(m_packed >> (16 - 2)); + } + + [[nodiscard]] constexpr Square from() const + { + return fromOrdinal((m_packed >> (16 - 2 - 6)) & squareMask); + } + + [[nodiscard]] constexpr Square to() const + { + return fromOrdinal((m_packed >> (16 - 2 - 6 - 6)) & squareMask); + } + + [[nodiscard]] constexpr Piece promotedPiece() const + { + if (type() == MoveType::Promotion) + { + const Color color = + (to().rank() == rank1) + ? Color::Black + : Color::White; + + const PieceType pt = fromOrdinal((m_packed & promotedPieceTypeMask) + ordinal(PieceType::Knight)); + return color | pt; + } + else + { + return Piece::none(); + } + } + + [[nodiscard]] constexpr Move decompress() const noexcept + { + if (m_packed == 0) + { + return Move::null(); + } + else + { + const MoveType type = fromOrdinal(m_packed >> (16 - 2)); + const Square from = fromOrdinal((m_packed >> (16 - 2 - 6)) & squareMask); + const Square to = fromOrdinal((m_packed >> (16 - 2 - 6 - 6)) & squareMask); + const Piece promotedPiece = [&]() { + if (type == MoveType::Promotion) + { + const Color color = + (to.rank() == rank1) + ? Color::Black + : Color::White; + + const PieceType pt = fromOrdinal((m_packed & promotedPieceTypeMask) + ordinal(PieceType::Knight)); + return color | pt; + } + else + { + return Piece::none(); + } + }(); + + return Move{ from, to, type, promotedPiece }; + } + } + + private: + std::uint16_t m_packed; + }; + + static_assert(sizeof(CompressedMove) == 2); + + [[nodiscard]] constexpr CompressedMove Move::compress() const noexcept + { + return CompressedMove(*this); + } + + static_assert(a4 + Offset{ 0, 1 } == a5); + static_assert(a4 + Offset{ 0, 2 } == a6); + static_assert(a4 + Offset{ 0, -2 } == a2); + static_assert(a4 + Offset{ 0, -1 } == a3); + + static_assert(e4 + Offset{ 1, 0 } == f4); + static_assert(e4 + Offset{ 2, 0 } == g4); + static_assert(e4 + Offset{ -1, 0 } == d4); + static_assert(e4 + Offset{ -2, 0 } == c4); + + enum struct CastlingRights : std::uint8_t + { + None = 0x0, + WhiteKingSide = 0x1, + WhiteQueenSide = 0x2, + BlackKingSide = 0x4, + BlackQueenSide = 0x8, + White = WhiteKingSide | WhiteQueenSide, + Black = BlackKingSide | BlackQueenSide, + All = WhiteKingSide | WhiteQueenSide | BlackKingSide | BlackQueenSide + }; + + [[nodiscard]] constexpr CastlingRights operator|(CastlingRights lhs, CastlingRights rhs) + { + return static_cast(static_cast(lhs) | static_cast(rhs)); + } + + [[nodiscard]] constexpr CastlingRights operator&(CastlingRights lhs, CastlingRights rhs) + { + return static_cast(static_cast(lhs) & static_cast(rhs)); + } + + [[nodiscard]] constexpr CastlingRights operator~(CastlingRights lhs) + { + return static_cast(~static_cast(lhs) & static_cast(CastlingRights::All)); + } + + constexpr CastlingRights& operator|=(CastlingRights& lhs, CastlingRights rhs) + { + lhs = static_cast(static_cast(lhs) | static_cast(rhs)); + return lhs; + } + + constexpr CastlingRights& operator&=(CastlingRights& lhs, CastlingRights rhs) + { + lhs = static_cast(static_cast(lhs) & static_cast(rhs)); + return lhs; + } + // checks whether lhs contains rhs + [[nodiscard]] constexpr bool contains(CastlingRights lhs, CastlingRights rhs) + { + return (lhs & rhs) == rhs; + } + + template <> + struct EnumTraits + { + using IdType = int; + using EnumType = CastlingRights; + + static constexpr int cardinality = 4; + static constexpr bool isNaturalIndex = false; + + static constexpr std::array values{ + CastlingRights::WhiteKingSide, + CastlingRights::WhiteQueenSide, + CastlingRights::BlackKingSide, + CastlingRights::BlackQueenSide + }; + + [[nodiscard]] static constexpr int ordinal(EnumType c) noexcept + { + return static_cast(c); + } + + [[nodiscard]] static constexpr EnumType fromOrdinal(IdType id) noexcept + { + return static_cast(id); + } + }; + + struct CompressedReverseMove; + + struct ReverseMove + { + Move move; + Piece capturedPiece; + Square oldEpSquare; + CastlingRights oldCastlingRights; + + // We need a well defined case for the starting position. + constexpr ReverseMove() : + move(Move::null()), + capturedPiece(Piece::none()), + oldEpSquare(Square::none()), + oldCastlingRights(CastlingRights::All) + { + } + + constexpr ReverseMove(const Move& move_, Piece capturedPiece_, Square oldEpSquare_, CastlingRights oldCastlingRights_) : + move(move_), + capturedPiece(capturedPiece_), + oldEpSquare(oldEpSquare_), + oldCastlingRights(oldCastlingRights_) + { + } + + constexpr bool isNull() const + { + return move.from == move.to; + } + + [[nodiscard]] constexpr CompressedReverseMove compress() const noexcept; + + [[nodiscard]] constexpr friend bool operator==(const ReverseMove& lhs, const ReverseMove& rhs) noexcept + { + return lhs.move == rhs.move + && lhs.capturedPiece == rhs.capturedPiece + && lhs.oldEpSquare == rhs.oldEpSquare + && lhs.oldCastlingRights == rhs.oldCastlingRights; + } + + [[nodiscard]] constexpr friend bool operator!=(const ReverseMove& lhs, const ReverseMove& rhs) noexcept + { + return !(lhs == rhs); + } + }; + + static_assert(sizeof(ReverseMove) == 7); + + struct CompressedReverseMove + { + private: + // we use 7 bits because it can be Square::none() + static constexpr std::uint32_t squareMask = 0b1111111u; + static constexpr std::uint32_t pieceMask = 0b1111u; + static constexpr std::uint32_t castlingRightsMask = 0b1111; + public: + + constexpr CompressedReverseMove() noexcept : + m_move{}, + m_oldState{} + { + } + + constexpr CompressedReverseMove(const ReverseMove& rm) noexcept : + m_move(rm.move.compress()), + m_oldState{ static_cast( + ((ordinal(rm.capturedPiece) & pieceMask) << 11) + | ((ordinal(rm.oldCastlingRights) & castlingRightsMask) << 7) + | (ordinal(rm.oldEpSquare) & squareMask) + ) + } + { + } + + [[nodiscard]] constexpr Move move() const + { + return m_move.decompress(); + } + + [[nodiscard]] const CompressedMove& compressedMove() const + { + return m_move; + } + + [[nodiscard]] constexpr Piece capturedPiece() const + { + return fromOrdinal(m_oldState >> 11); + } + + [[nodiscard]] constexpr CastlingRights oldCastlingRights() const + { + return fromOrdinal((m_oldState >> 7) & castlingRightsMask); + } + + [[nodiscard]] constexpr Square oldEpSquare() const + { + return fromOrdinal(m_oldState & squareMask); + } + + [[nodiscard]] constexpr ReverseMove decompress() const noexcept + { + const Piece capturedPiece = fromOrdinal(m_oldState >> 11); + const CastlingRights castlingRights = fromOrdinal((m_oldState >> 7) & castlingRightsMask); + // We could pack the ep square more, but don't have to, because + // can't save another byte anyway. + const Square epSquare = fromOrdinal(m_oldState & squareMask); + + return ReverseMove(m_move.decompress(), capturedPiece, epSquare, castlingRights); + } + + private: + CompressedMove m_move; + std::uint16_t m_oldState; + }; + + static_assert(sizeof(CompressedReverseMove) == 4); + + [[nodiscard]] constexpr CompressedReverseMove ReverseMove::compress() const noexcept + { + return CompressedReverseMove(*this); + } + + // This can be regarded as a perfect hash. Going back is hard. + struct PackedReverseMove + { + static constexpr std::uint32_t mask = 0x7FFFFFFu; + static constexpr std::size_t numBits = 27; + + private: + static constexpr std::uint32_t squareMask = 0b111111u; + static constexpr std::uint32_t pieceMask = 0b1111u; + static constexpr std::uint32_t pieceTypeMask = 0b111u; + static constexpr std::uint32_t castlingRightsMask = 0b1111; + static constexpr std::uint32_t fileMask = 0b111; + + public: + constexpr PackedReverseMove(const std::uint32_t packed) : + m_packed(packed) + { + + } + + constexpr PackedReverseMove(const ReverseMove& reverseMove) : + m_packed( + 0u + // The only move when square is none() is null move and + // then both squares are none(). No other move is like that + // so we don't lose any information by storing only + // the 6 bits of each square. + | ((ordinal(reverseMove.move.from) & squareMask) << 21) + | ((ordinal(reverseMove.move.to) & squareMask) << 15) + // Other masks are just for code clarity, they should + // never change the values. + | ((ordinal(reverseMove.capturedPiece) & pieceMask) << 11) + | ((ordinal(reverseMove.oldCastlingRights) & castlingRightsMask) << 7) + | ((ordinal(reverseMove.move.promotedPiece.type()) & pieceTypeMask) << 4) + | (((reverseMove.oldEpSquare != Square::none()) & 1) << 3) + // We probably could omit the squareMask here but for clarity it's left. + | (ordinal(Square(ordinal(reverseMove.oldEpSquare) & squareMask).file()) & fileMask) + ) + { + } + + constexpr std::uint32_t packed() const + { + return m_packed; + } + + constexpr ReverseMove unpack(Color sideThatMoved) const + { + ReverseMove rmove{}; + + rmove.move.from = fromOrdinal((m_packed >> 21) & squareMask); + rmove.move.to = fromOrdinal((m_packed >> 15) & squareMask); + rmove.capturedPiece = fromOrdinal((m_packed >> 11) & pieceMask); + rmove.oldCastlingRights = fromOrdinal((m_packed >> 7) & castlingRightsMask); + const PieceType promotedPieceType = fromOrdinal((m_packed >> 4) & pieceTypeMask); + if (promotedPieceType != PieceType::None) + { + rmove.move.promotedPiece = Piece(promotedPieceType, sideThatMoved); + rmove.move.type = MoveType::Promotion; + } + const bool hasEpSquare = static_cast((m_packed >> 3) & 1); + if (hasEpSquare) + { + // ep square is always where the opponent moved + const Rank rank = + sideThatMoved == Color::White + ? rank6 + : rank3; + const File file = fromOrdinal(m_packed & fileMask); + rmove.oldEpSquare = Square(file, rank); + if (rmove.oldEpSquare == rmove.move.to) + { + rmove.move.type = MoveType::EnPassant; + } + } + else + { + rmove.oldEpSquare = Square::none(); + } + + if (rmove.move.type == MoveType::Normal && rmove.oldCastlingRights != CastlingRights::None) + { + // If castling was possible then we know it was the king that moved from e1/e8. + if (rmove.move.from == e1) + { + if (rmove.move.to == h1 || rmove.move.to == a1) + { + rmove.move.type = MoveType::Castle; + } + } + else if (rmove.move.from == e8) + { + if (rmove.move.to == h8 || rmove.move.to == a8) + { + rmove.move.type = MoveType::Castle; + } + } + } + + return rmove; + } + + private: + // Uses only 27 lowest bits. + // Bit meaning from highest to lowest. + // - 6 bits from + // - 6 bits to + // - 4 bits for the captured piece + // - 4 bits for prev castling rights + // - 3 bits promoted piece type + // - 1 bit to specify if the ep square was valid (false if none()) + // - 3 bits for prev ep square file + std::uint32_t m_packed; + }; + + struct MoveCompareLess + { + [[nodiscard]] bool operator()(const Move& lhs, const Move& rhs) const noexcept + { + if (ordinal(lhs.from) < ordinal(rhs.from)) return true; + if (ordinal(lhs.from) > ordinal(rhs.from)) return false; + + if (ordinal(lhs.to) < ordinal(rhs.to)) return true; + if (ordinal(lhs.to) > ordinal(rhs.to)) return false; + + if (ordinal(lhs.type) < ordinal(rhs.type)) return true; + if (ordinal(lhs.type) > ordinal(rhs.type)) return false; + + if (ordinal(lhs.promotedPiece) < ordinal(rhs.promotedPiece)) return true; + + return false; + } + }; + + struct ReverseMoveCompareLess + { + [[nodiscard]] bool operator()(const ReverseMove& lhs, const ReverseMove& rhs) const noexcept + { + if (MoveCompareLess{}(lhs.move, rhs.move)) return true; + if (MoveCompareLess{}(rhs.move, lhs.move)) return false; + + if (ordinal(lhs.capturedPiece) < ordinal(rhs.capturedPiece)) return true; + if (ordinal(lhs.capturedPiece) > ordinal(rhs.capturedPiece)) return false; + + if (static_cast(lhs.oldCastlingRights) < static_cast(rhs.oldCastlingRights)) return true; + if (static_cast(lhs.oldCastlingRights) > static_cast(rhs.oldCastlingRights)) return false; + + if (ordinal(lhs.oldEpSquare) < ordinal(rhs.oldEpSquare)) return true; + if (ordinal(lhs.oldEpSquare) > ordinal(rhs.oldEpSquare)) return false; + + return false; + } + }; + + struct BitboardIterator + { + using value_type = Square; + using difference_type = std::ptrdiff_t; + using reference = Square; + using iterator_category = std::input_iterator_tag; + using pointer = const Square*; + + constexpr BitboardIterator() noexcept : + m_squares(0) + { + } + + constexpr BitboardIterator(std::uint64_t v) noexcept : + m_squares(v) + { + } + + constexpr BitboardIterator(const BitboardIterator&) = default; + constexpr BitboardIterator(BitboardIterator&&) = default; + constexpr BitboardIterator& operator=(const BitboardIterator&) = default; + constexpr BitboardIterator& operator=(BitboardIterator&&) = default; + + [[nodiscard]] constexpr bool friend operator==(BitboardIterator lhs, BitboardIterator rhs) noexcept + { + return lhs.m_squares == rhs.m_squares; + } + + [[nodiscard]] constexpr bool friend operator!=(BitboardIterator lhs, BitboardIterator rhs) noexcept + { + return lhs.m_squares != rhs.m_squares; + } + + [[nodiscard]] inline Square operator*() const + { + return first(); + } + + constexpr BitboardIterator& operator++() noexcept + { + popFirst(); + return *this; + } + + private: + std::uint64_t m_squares; + + constexpr void popFirst() noexcept + { + m_squares &= m_squares - 1; + } + + [[nodiscard]] inline Square first() const + { + assert(m_squares != 0); + + return fromOrdinal(intrin::lsb(m_squares)); + } + }; + + struct Bitboard + { + // bits counted from the LSB + // order is A1 B2 ... G8 H8 + // just like in Square + + public: + constexpr Bitboard() noexcept : + m_squares(0) + { + } + + private: + constexpr explicit Bitboard(Square sq) noexcept : + m_squares(static_cast(1ULL) << ordinal(sq)) + { + assert(sq.isOk()); + } + + constexpr explicit Bitboard(Rank r) noexcept : + m_squares(static_cast(0xFFULL) << (ordinal(r) * 8)) + { + } + + constexpr explicit Bitboard(File f) noexcept : + m_squares(static_cast(0x0101010101010101ULL) << ordinal(f)) + { + } + + constexpr explicit Bitboard(Color c) noexcept : + m_squares(c == Color::White ? 0xAA55AA55AA55AA55ULL : ~0xAA55AA55AA55AA55ULL) + { + } + + constexpr explicit Bitboard(std::uint64_t bb) noexcept : + m_squares(bb) + { + } + + // files A..file inclusive + static constexpr EnumArray m_filesUpToBB{ + 0x0101010101010101ULL, + 0x0303030303030303ULL, + 0x0707070707070707ULL, + 0x0F0F0F0F0F0F0F0FULL, + 0x1F1F1F1F1F1F1F1FULL, + 0x3F3F3F3F3F3F3F3FULL, + 0x7F7F7F7F7F7F7F7FULL, + 0xFFFFFFFFFFFFFFFFULL + }; + + public: + + [[nodiscard]] static constexpr Bitboard none() + { + return Bitboard{}; + } + + [[nodiscard]] static constexpr Bitboard all() + { + return ~none(); + } + + [[nodiscard]] static constexpr Bitboard square(Square sq) + { + return Bitboard(sq); + } + + [[nodiscard]] static constexpr Bitboard file(File f) + { + return Bitboard(f); + } + + [[nodiscard]] static constexpr Bitboard rank(Rank r) + { + return Bitboard(r); + } + + [[nodiscard]] static constexpr Bitboard color(Color c) + { + return Bitboard(c); + } + + [[nodiscard]] static constexpr Bitboard fromBits(std::uint64_t bits) + { + return Bitboard(bits); + } + + // inclusive + [[nodiscard]] static constexpr Bitboard betweenFiles(File left, File right) + { + assert(left <= right); + + if (left == fileA) + { + return Bitboard::fromBits(m_filesUpToBB[right]); + } + else + { + return Bitboard::fromBits(m_filesUpToBB[right] ^ m_filesUpToBB[left - 1]); + } + } + + [[nodiscard]] constexpr bool isEmpty() const + { + return m_squares == 0; + } + + [[nodiscard]] constexpr bool isSet(Square sq) const + { + return !!((m_squares >> ordinal(sq)) & 1ull); + } + + constexpr void set(Square sq) + { + *this |= Bitboard(sq); + } + + constexpr void unset(Square sq) + { + *this &= ~(Bitboard(sq)); + } + + constexpr void toggle(Square sq) + { + *this ^= Bitboard(sq); + } + + [[nodiscard]] constexpr BitboardIterator begin() const + { + return BitboardIterator(m_squares); + } + + [[nodiscard]] constexpr BitboardIterator end() const + { + return BitboardIterator{}; + } + + [[nodiscard]] constexpr BitboardIterator cbegin() const + { + return BitboardIterator(m_squares); + } + + [[nodiscard]] constexpr BitboardIterator cend() const + { + return BitboardIterator{}; + } + + [[nodiscard]] constexpr bool friend operator==(Bitboard lhs, Bitboard rhs) noexcept + { + return lhs.m_squares == rhs.m_squares; + } + + [[nodiscard]] constexpr bool friend operator!=(Bitboard lhs, Bitboard rhs) noexcept + { + return lhs.m_squares != rhs.m_squares; + } + + constexpr Bitboard shiftedVertically(int ranks) const + { + if (ranks >= 0) + { + return fromBits(m_squares << 8 * ranks); + } + else + { + return fromBits(m_squares >> -8 * ranks); + } + } + + template + constexpr void shift() + { + static_assert(files >= -7); + static_assert(ranks >= -7); + static_assert(files <= 7); + static_assert(ranks <= 7); + + if constexpr (files != 0) + { + constexpr Bitboard mask = + files > 0 + ? Bitboard::betweenFiles(fileA, fileH - files) + : Bitboard::betweenFiles(fileA - files, fileH); + + m_squares &= mask.m_squares; + } + + constexpr int shift = files + ranks * 8; + if constexpr (shift == 0) + { + return; + } + + if constexpr (shift < 0) + { + m_squares >>= -shift; + } + else + { + m_squares <<= shift; + } + } + + template + constexpr Bitboard shifted() const + { + Bitboard bbCpy(*this); + bbCpy.shift(); + return bbCpy; + } + + constexpr void shift(Offset offset) + { + assert(offset.files >= -7); + assert(offset.ranks >= -7); + assert(offset.files <= 7); + assert(offset.ranks <= 7); + + if (offset.files != 0) + { + const Bitboard mask = + offset.files > 0 + ? Bitboard::betweenFiles(fileA, fileH - offset.files) + : Bitboard::betweenFiles(fileA - offset.files, fileH); + + m_squares &= mask.m_squares; + } + + const int shift = offset.files + offset.ranks * 8; + if (shift < 0) + { + m_squares >>= -shift; + } + else + { + m_squares <<= shift; + } + } + + [[nodiscard]] constexpr Bitboard shifted(Offset offset) const + { + Bitboard bbCpy(*this); + bbCpy.shift(offset); + return bbCpy; + } + + [[nodiscard]] constexpr Bitboard operator~() const + { + Bitboard bb = *this; + bb.m_squares = ~m_squares; + return bb; + } + + constexpr Bitboard& operator^=(Color c) + { + m_squares ^= Bitboard(c).m_squares; + return *this; + } + + constexpr Bitboard& operator&=(Color c) + { + m_squares &= Bitboard(c).m_squares; + return *this; + } + + constexpr Bitboard& operator|=(Color c) + { + m_squares |= Bitboard(c).m_squares; + return *this; + } + + [[nodiscard]] constexpr Bitboard operator^(Color c) const + { + Bitboard bb = *this; + bb ^= c; + return bb; + } + + [[nodiscard]] constexpr Bitboard operator&(Color c) const + { + Bitboard bb = *this; + bb &= c; + return bb; + } + + [[nodiscard]] constexpr Bitboard operator|(Color c) const + { + Bitboard bb = *this; + bb |= c; + return bb; + } + + constexpr Bitboard& operator^=(Square sq) + { + m_squares ^= Bitboard(sq).m_squares; + return *this; + } + + constexpr Bitboard& operator&=(Square sq) + { + m_squares &= Bitboard(sq).m_squares; + return *this; + } + + constexpr Bitboard& operator|=(Square sq) + { + m_squares |= Bitboard(sq).m_squares; + return *this; + } + + [[nodiscard]] constexpr Bitboard operator^(Square sq) const + { + Bitboard bb = *this; + bb ^= sq; + return bb; + } + + [[nodiscard]] constexpr Bitboard operator&(Square sq) const + { + Bitboard bb = *this; + bb &= sq; + return bb; + } + + [[nodiscard]] constexpr Bitboard operator|(Square sq) const + { + Bitboard bb = *this; + bb |= sq; + return bb; + } + + [[nodiscard]] constexpr friend Bitboard operator^(Square sq, Bitboard bb) + { + return bb ^ sq; + } + + [[nodiscard]] constexpr friend Bitboard operator&(Square sq, Bitboard bb) + { + return bb & sq; + } + + [[nodiscard]] constexpr friend Bitboard operator|(Square sq, Bitboard bb) + { + return bb | sq; + } + + constexpr Bitboard& operator^=(Bitboard rhs) + { + m_squares ^= rhs.m_squares; + return *this; + } + + constexpr Bitboard& operator&=(Bitboard rhs) + { + m_squares &= rhs.m_squares; + return *this; + } + + constexpr Bitboard& operator|=(Bitboard rhs) + { + m_squares |= rhs.m_squares; + return *this; + } + + [[nodiscard]] constexpr Bitboard operator^(Bitboard sq) const + { + Bitboard bb = *this; + bb ^= sq; + return bb; + } + + [[nodiscard]] constexpr Bitboard operator&(Bitboard sq) const + { + Bitboard bb = *this; + bb &= sq; + return bb; + } + + [[nodiscard]] constexpr Bitboard operator|(Bitboard sq) const + { + Bitboard bb = *this; + bb |= sq; + return bb; + } + + [[nodiscard]] inline int count() const + { + return static_cast(intrin::popcount(m_squares)); + } + + [[nodiscard]] constexpr bool moreThanOne() const + { + return !!(m_squares & (m_squares - 1)); + } + + [[nodiscard]] constexpr bool exactlyOne() const + { + return m_squares != 0 && !moreThanOne(); + } + + [[nodiscard]] constexpr bool any() const + { + return !!m_squares; + } + + [[nodiscard]] inline Square first() const + { + assert(m_squares != 0); + + return fromOrdinal(intrin::lsb(m_squares)); + } + + [[nodiscard]] inline Square nth(int n) const + { + assert(count() > n); + + Bitboard cpy = *this; + while (n--) cpy.popFirst(); + return cpy.first(); + } + + [[nodiscard]] inline Square last() const + { + assert(m_squares != 0); + + return fromOrdinal(intrin::msb(m_squares)); + } + + [[nodiscard]] constexpr std::uint64_t bits() const + { + return m_squares; + } + + constexpr void popFirst() + { + assert(m_squares != 0); + + m_squares &= m_squares - 1; + } + + constexpr Bitboard& operator=(const Bitboard& other) = default; + + private: + std::uint64_t m_squares; + }; + + [[nodiscard]] constexpr Bitboard operator^(Square sq0, Square sq1) + { + return Bitboard::square(sq0) ^ sq1; + } + + [[nodiscard]] constexpr Bitboard operator&(Square sq0, Square sq1) + { + return Bitboard::square(sq0) & sq1; + } + + [[nodiscard]] constexpr Bitboard operator|(Square sq0, Square sq1) + { + return Bitboard::square(sq0) | sq1; + } + + [[nodiscard]] constexpr Bitboard operator""_bb(unsigned long long bits) + { + return Bitboard::fromBits(bits); + } + + namespace bb + { + namespace fancy_magics + { + // Implementation based on https://github.com/syzygy1/Cfish + + alignas(64) constexpr EnumArray g_rookMagics{ { + 0x0A80004000801220ull, + 0x8040004010002008ull, + 0x2080200010008008ull, + 0x1100100008210004ull, + 0xC200209084020008ull, + 0x2100010004000208ull, + 0x0400081000822421ull, + 0x0200010422048844ull, + 0x0800800080400024ull, + 0x0001402000401000ull, + 0x3000801000802001ull, + 0x4400800800100083ull, + 0x0904802402480080ull, + 0x4040800400020080ull, + 0x0018808042000100ull, + 0x4040800080004100ull, + 0x0040048001458024ull, + 0x00A0004000205000ull, + 0x3100808010002000ull, + 0x4825010010000820ull, + 0x5004808008000401ull, + 0x2024818004000A00ull, + 0x0005808002000100ull, + 0x2100060004806104ull, + 0x0080400880008421ull, + 0x4062220600410280ull, + 0x010A004A00108022ull, + 0x0000100080080080ull, + 0x0021000500080010ull, + 0x0044000202001008ull, + 0x0000100400080102ull, + 0xC020128200040545ull, + 0x0080002000400040ull, + 0x0000804000802004ull, + 0x0000120022004080ull, + 0x010A386103001001ull, + 0x9010080080800400ull, + 0x8440020080800400ull, + 0x0004228824001001ull, + 0x000000490A000084ull, + 0x0080002000504000ull, + 0x200020005000C000ull, + 0x0012088020420010ull, + 0x0010010080080800ull, + 0x0085001008010004ull, + 0x0002000204008080ull, + 0x0040413002040008ull, + 0x0000304081020004ull, + 0x0080204000800080ull, + 0x3008804000290100ull, + 0x1010100080200080ull, + 0x2008100208028080ull, + 0x5000850800910100ull, + 0x8402019004680200ull, + 0x0120911028020400ull, + 0x0000008044010200ull, + 0x0020850200244012ull, + 0x0020850200244012ull, + 0x0000102001040841ull, + 0x140900040A100021ull, + 0x000200282410A102ull, + 0x000200282410A102ull, + 0x000200282410A102ull, + 0x4048240043802106ull + } }; + + alignas(64) constexpr EnumArray g_bishopMagics{ { + 0x40106000A1160020ull, + 0x0020010250810120ull, + 0x2010010220280081ull, + 0x002806004050C040ull, + 0x0002021018000000ull, + 0x2001112010000400ull, + 0x0881010120218080ull, + 0x1030820110010500ull, + 0x0000120222042400ull, + 0x2000020404040044ull, + 0x8000480094208000ull, + 0x0003422A02000001ull, + 0x000A220210100040ull, + 0x8004820202226000ull, + 0x0018234854100800ull, + 0x0100004042101040ull, + 0x0004001004082820ull, + 0x0010000810010048ull, + 0x1014004208081300ull, + 0x2080818802044202ull, + 0x0040880C00A00100ull, + 0x0080400200522010ull, + 0x0001000188180B04ull, + 0x0080249202020204ull, + 0x1004400004100410ull, + 0x00013100A0022206ull, + 0x2148500001040080ull, + 0x4241080011004300ull, + 0x4020848004002000ull, + 0x10101380D1004100ull, + 0x0008004422020284ull, + 0x01010A1041008080ull, + 0x0808080400082121ull, + 0x0808080400082121ull, + 0x0091128200100C00ull, + 0x0202200802010104ull, + 0x8C0A020200440085ull, + 0x01A0008080B10040ull, + 0x0889520080122800ull, + 0x100902022202010Aull, + 0x04081A0816002000ull, + 0x0000681208005000ull, + 0x8170840041008802ull, + 0x0A00004200810805ull, + 0x0830404408210100ull, + 0x2602208106006102ull, + 0x1048300680802628ull, + 0x2602208106006102ull, + 0x0602010120110040ull, + 0x0941010801043000ull, + 0x000040440A210428ull, + 0x0008240020880021ull, + 0x0400002012048200ull, + 0x00AC102001210220ull, + 0x0220021002009900ull, + 0x84440C080A013080ull, + 0x0001008044200440ull, + 0x0004C04410841000ull, + 0x2000500104011130ull, + 0x1A0C010011C20229ull, + 0x0044800112202200ull, + 0x0434804908100424ull, + 0x0300404822C08200ull, + 0x48081010008A2A80ull + } }; + + alignas(64) static EnumArray g_rookMasks; + alignas(64) static EnumArray g_rookShifts; + alignas(64) static EnumArray g_rookAttacks; + + alignas(64) static EnumArray g_bishopMasks; + alignas(64) static EnumArray g_bishopShifts; + alignas(64) static EnumArray g_bishopAttacks; + + alignas(64) static std::array g_allRookAttacks; + alignas(64) static std::array g_allBishopAttacks; + + inline Bitboard bishopAttacks(Square s, Bitboard occupied) + { + const std::size_t idx = + (occupied & fancy_magics::g_bishopMasks[s]).bits() + * fancy_magics::g_bishopMagics[s] + >> fancy_magics::g_bishopShifts[s]; + + return fancy_magics::g_bishopAttacks[s][idx]; + } + + inline Bitboard rookAttacks(Square s, Bitboard occupied) + { + const std::size_t idx = + (occupied & fancy_magics::g_rookMasks[s]).bits() + * fancy_magics::g_rookMagics[s] + >> fancy_magics::g_rookShifts[s]; + + return fancy_magics::g_rookAttacks[s][idx]; + } + } + + [[nodiscard]] constexpr Bitboard square(Square sq) + { + return Bitboard::square(sq); + } + + [[nodiscard]] constexpr Bitboard rank(Rank rank) + { + return Bitboard::rank(rank); + } + + [[nodiscard]] constexpr Bitboard file(File file) + { + return Bitboard::file(file); + } + + [[nodiscard]] constexpr Bitboard color(Color c) + { + return Bitboard::color(c); + } + + [[nodiscard]] constexpr Bitboard before(Square sq) + { + return Bitboard::fromBits(nbitmask[ordinal(sq)]); + } + + constexpr Bitboard lightSquares = bb::color(Color::White); + constexpr Bitboard darkSquares = bb::color(Color::Black); + + constexpr Bitboard fileA = bb::file(chess::fileA); + constexpr Bitboard fileB = bb::file(chess::fileB); + constexpr Bitboard fileC = bb::file(chess::fileC); + constexpr Bitboard fileD = bb::file(chess::fileD); + constexpr Bitboard fileE = bb::file(chess::fileE); + constexpr Bitboard fileF = bb::file(chess::fileF); + constexpr Bitboard fileG = bb::file(chess::fileG); + constexpr Bitboard fileH = bb::file(chess::fileH); + + constexpr Bitboard rank1 = bb::rank(chess::rank1); + constexpr Bitboard rank2 = bb::rank(chess::rank2); + constexpr Bitboard rank3 = bb::rank(chess::rank3); + constexpr Bitboard rank4 = bb::rank(chess::rank4); + constexpr Bitboard rank5 = bb::rank(chess::rank5); + constexpr Bitboard rank6 = bb::rank(chess::rank6); + constexpr Bitboard rank7 = bb::rank(chess::rank7); + constexpr Bitboard rank8 = bb::rank(chess::rank8); + + constexpr Bitboard a1 = bb::square(chess::a1); + constexpr Bitboard a2 = bb::square(chess::a2); + constexpr Bitboard a3 = bb::square(chess::a3); + constexpr Bitboard a4 = bb::square(chess::a4); + constexpr Bitboard a5 = bb::square(chess::a5); + constexpr Bitboard a6 = bb::square(chess::a6); + constexpr Bitboard a7 = bb::square(chess::a7); + constexpr Bitboard a8 = bb::square(chess::a8); + + constexpr Bitboard b1 = bb::square(chess::b1); + constexpr Bitboard b2 = bb::square(chess::b2); + constexpr Bitboard b3 = bb::square(chess::b3); + constexpr Bitboard b4 = bb::square(chess::b4); + constexpr Bitboard b5 = bb::square(chess::b5); + constexpr Bitboard b6 = bb::square(chess::b6); + constexpr Bitboard b7 = bb::square(chess::b7); + constexpr Bitboard b8 = bb::square(chess::b8); + + constexpr Bitboard c1 = bb::square(chess::c1); + constexpr Bitboard c2 = bb::square(chess::c2); + constexpr Bitboard c3 = bb::square(chess::c3); + constexpr Bitboard c4 = bb::square(chess::c4); + constexpr Bitboard c5 = bb::square(chess::c5); + constexpr Bitboard c6 = bb::square(chess::c6); + constexpr Bitboard c7 = bb::square(chess::c7); + constexpr Bitboard c8 = bb::square(chess::c8); + + constexpr Bitboard d1 = bb::square(chess::d1); + constexpr Bitboard d2 = bb::square(chess::d2); + constexpr Bitboard d3 = bb::square(chess::d3); + constexpr Bitboard d4 = bb::square(chess::d4); + constexpr Bitboard d5 = bb::square(chess::d5); + constexpr Bitboard d6 = bb::square(chess::d6); + constexpr Bitboard d7 = bb::square(chess::d7); + constexpr Bitboard d8 = bb::square(chess::d8); + + constexpr Bitboard e1 = bb::square(chess::e1); + constexpr Bitboard e2 = bb::square(chess::e2); + constexpr Bitboard e3 = bb::square(chess::e3); + constexpr Bitboard e4 = bb::square(chess::e4); + constexpr Bitboard e5 = bb::square(chess::e5); + constexpr Bitboard e6 = bb::square(chess::e6); + constexpr Bitboard e7 = bb::square(chess::e7); + constexpr Bitboard e8 = bb::square(chess::e8); + + constexpr Bitboard f1 = bb::square(chess::f1); + constexpr Bitboard f2 = bb::square(chess::f2); + constexpr Bitboard f3 = bb::square(chess::f3); + constexpr Bitboard f4 = bb::square(chess::f4); + constexpr Bitboard f5 = bb::square(chess::f5); + constexpr Bitboard f6 = bb::square(chess::f6); + constexpr Bitboard f7 = bb::square(chess::f7); + constexpr Bitboard f8 = bb::square(chess::f8); + + constexpr Bitboard g1 = bb::square(chess::g1); + constexpr Bitboard g2 = bb::square(chess::g2); + constexpr Bitboard g3 = bb::square(chess::g3); + constexpr Bitboard g4 = bb::square(chess::g4); + constexpr Bitboard g5 = bb::square(chess::g5); + constexpr Bitboard g6 = bb::square(chess::g6); + constexpr Bitboard g7 = bb::square(chess::g7); + constexpr Bitboard g8 = bb::square(chess::g8); + + constexpr Bitboard h1 = bb::square(chess::h1); + constexpr Bitboard h2 = bb::square(chess::h2); + constexpr Bitboard h3 = bb::square(chess::h3); + constexpr Bitboard h4 = bb::square(chess::h4); + constexpr Bitboard h5 = bb::square(chess::h5); + constexpr Bitboard h6 = bb::square(chess::h6); + constexpr Bitboard h7 = bb::square(chess::h7); + constexpr Bitboard h8 = bb::square(chess::h8); + + [[nodiscard]] Bitboard between(Square s1, Square s2); + + [[nodiscard]] Bitboard line(Square s1, Square s2); + + template + [[nodiscard]] Bitboard pseudoAttacks(Square sq); + + [[nodiscard]] Bitboard pseudoAttacks(PieceType pt, Square sq); + + template + Bitboard attacks(Square sq, Bitboard occupied) + { + static_assert(PieceTypeV != PieceType::None && PieceTypeV != PieceType::Pawn); + + assert(sq.isOk()); + + if constexpr (PieceTypeV == PieceType::Bishop) + { + return fancy_magics::bishopAttacks(sq, occupied); + } + else if constexpr (PieceTypeV == PieceType::Rook) + { + return fancy_magics::rookAttacks(sq, occupied); + } + else if constexpr (PieceTypeV == PieceType::Queen) + { + return + fancy_magics::bishopAttacks(sq, occupied) + | fancy_magics::rookAttacks(sq, occupied); + } + else + { + return pseudoAttacks(sq); + } + } + + [[nodiscard]] inline Bitboard attacks(PieceType pt, Square sq, Bitboard occupied) + { + assert(sq.isOk()); + + switch (pt) + { + case PieceType::Bishop: + return attacks(sq, occupied); + case PieceType::Rook: + return attacks(sq, occupied); + case PieceType::Queen: + return attacks(sq, occupied); + default: + return pseudoAttacks(pt, sq); + } + } + + [[nodiscard]] inline Bitboard pawnAttacks(Bitboard pawns, Color color); + + [[nodiscard]] inline Bitboard westPawnAttacks(Bitboard pawns, Color color); + + [[nodiscard]] inline Bitboard eastPawnAttacks(Bitboard pawns, Color color); + + [[nodiscard]] inline bool isAttackedBySlider( + Square sq, + Bitboard bishops, + Bitboard rooks, + Bitboard queens, + Bitboard occupied + ); + + namespace detail + { + static constexpr std::array knightOffsets{ { {-1, -2}, {-1, 2}, {1, -2}, {1, 2}, {-2, -1}, {-2, 1}, {2, -1}, {2, 1} } }; + static constexpr std::array kingOffsets{ { {-1, -1}, {-1, 0}, {-1, 1}, {0, -1}, {0, 1}, {1, -1}, {1, 0}, {1, 1} } }; + + enum Direction + { + North = 0, + NorthEast, + East, + SouthEast, + South, + SouthWest, + West, + NorthWest + }; + + constexpr std::array offsets = { { + { 0, 1 }, + { 1, 1 }, + { 1, 0 }, + { 1, -1 }, + { 0, -1 }, + { -1, -1 }, + { -1, 0 }, + { -1, 1 } + } }; + + static constexpr std::array bishopOffsets{ + offsets[NorthEast], + offsets[SouthEast], + offsets[SouthWest], + offsets[NorthWest] + }; + static constexpr std::array rookOffsets{ + offsets[North], + offsets[East], + offsets[South], + offsets[West] + }; + + [[nodiscard]] static EnumArray generatePseudoAttacks_Pawn() + { + // pseudo attacks don't make sense for pawns + return {}; + } + + [[nodiscard]] static EnumArray generatePseudoAttacks_Knight() + { + EnumArray bbs{}; + + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + Bitboard bb{}; + + for (auto&& offset : knightOffsets) + { + const SquareCoords toSq = fromSq.coords() + offset; + if (toSq.isOk()) + { + bb |= Square(toSq); + } + } + + bbs[fromSq] = bb; + } + + return bbs; + } + + [[nodiscard]] static Bitboard generateSliderPseudoAttacks(const std::array & offsets_, Square fromSq) + { + assert(fromSq.isOk()); + + Bitboard bb{}; + + for (auto&& offset : offsets_) + { + SquareCoords fromSqC = fromSq.coords(); + + for (;;) + { + fromSqC += offset; + + if (!fromSqC.isOk()) + { + break; + } + + bb |= Square(fromSqC); + } + } + + return bb; + } + + [[nodiscard]] static EnumArray generatePseudoAttacks_Bishop() + { + EnumArray bbs{}; + + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + bbs[fromSq] = generateSliderPseudoAttacks(bishopOffsets, fromSq); + } + + return bbs; + } + + [[nodiscard]] static EnumArray generatePseudoAttacks_Rook() + { + EnumArray bbs{}; + + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + bbs[fromSq] = generateSliderPseudoAttacks(rookOffsets, fromSq); + } + + return bbs; + } + + [[nodiscard]] static EnumArray generatePseudoAttacks_Queen() + { + EnumArray bbs{}; + + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + bbs[fromSq] = + generateSliderPseudoAttacks(bishopOffsets, fromSq) + | generateSliderPseudoAttacks(rookOffsets, fromSq); + } + + return bbs; + } + + [[nodiscard]] static EnumArray generatePseudoAttacks_King() + { + EnumArray bbs{}; + + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + Bitboard bb{}; + + for (auto&& offset : kingOffsets) + { + const SquareCoords toSq = fromSq.coords() + offset; + if (toSq.isOk()) + { + bb |= Square(toSq); + } + } + + bbs[fromSq] = bb; + } + + return bbs; + } + + [[nodiscard]] static EnumArray2 generatePseudoAttacks() + { + return EnumArray2{ + generatePseudoAttacks_Pawn(), + generatePseudoAttacks_Knight(), + generatePseudoAttacks_Bishop(), + generatePseudoAttacks_Rook(), + generatePseudoAttacks_Queen(), + generatePseudoAttacks_King() + }; + } + + static const EnumArray2& pseudoAttacks() + { + static const EnumArray2 s_pseudoAttacks = generatePseudoAttacks(); + return s_pseudoAttacks; + } + + [[nodiscard]] static Bitboard generatePositiveRayAttacks(Direction dir, Square fromSq) + { + assert(fromSq.isOk()); + + Bitboard bb{}; + + const auto offset = offsets[dir]; + SquareCoords fromSqC = fromSq.coords(); + for (;;) + { + fromSqC += offset; + + if (!fromSqC.isOk()) + { + break; + } + + bb |= Square(fromSqC); + } + + return bb; + } + + // classical slider move generation approach https://www.chessprogramming.org/Classical_Approach + + [[nodiscard]] static EnumArray generatePositiveRayAttacks(Direction dir) + { + EnumArray bbs{}; + + for (Square fromSq = chess::a1; fromSq != Square::none(); ++fromSq) + { + bbs[fromSq] = generatePositiveRayAttacks(dir, fromSq); + } + + return bbs; + } + + [[nodiscard]] static std::array, 8> generatePositiveRayAttacks() + { + std::array, 8> bbs{}; + + bbs[North] = generatePositiveRayAttacks(North); + bbs[NorthEast] = generatePositiveRayAttacks(NorthEast); + bbs[East] = generatePositiveRayAttacks(East); + bbs[SouthEast] = generatePositiveRayAttacks(SouthEast); + bbs[South] = generatePositiveRayAttacks(South); + bbs[SouthWest] = generatePositiveRayAttacks(SouthWest); + bbs[West] = generatePositiveRayAttacks(West); + bbs[NorthWest] = generatePositiveRayAttacks(NorthWest); + + return bbs; + } + + + static const std::array, 8>& positiveRayAttacks() + { + static const std::array, 8> s_positiveRayAttacks = generatePositiveRayAttacks(); + return s_positiveRayAttacks; + } + + template + [[nodiscard]] static Bitboard slidingAttacks(Square sq, Bitboard occupied) + { + assert(sq.isOk()); + + Bitboard attacks = positiveRayAttacks()[DirV][sq]; + + if constexpr (DirV == NorthWest || DirV == North || DirV == NorthEast || DirV == East) + { + Bitboard blocker = (attacks & occupied) | h8; // set highest bit (H8) so msb never fails + return attacks ^ positiveRayAttacks()[DirV][blocker.first()]; + } + else + { + Bitboard blocker = (attacks & occupied) | a1; + return attacks ^ positiveRayAttacks()[DirV][blocker.last()]; + } + } + + template Bitboard slidingAttacks(Square, Bitboard); + template Bitboard slidingAttacks(Square, Bitboard); + template Bitboard slidingAttacks(Square, Bitboard); + template Bitboard slidingAttacks(Square, Bitboard); + template Bitboard slidingAttacks(Square, Bitboard); + template Bitboard slidingAttacks(Square, Bitboard); + template Bitboard slidingAttacks(Square, Bitboard); + template Bitboard slidingAttacks(Square, Bitboard); + + template + [[nodiscard]] inline Bitboard pieceSlidingAttacks(Square sq, Bitboard occupied) + { + static_assert( + PieceTypeV == PieceType::Rook + || PieceTypeV == PieceType::Bishop + || PieceTypeV == PieceType::Queen); + + assert(sq.isOk()); + + if constexpr (PieceTypeV == PieceType::Bishop) + { + return + detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied); + } + else if constexpr (PieceTypeV == PieceType::Rook) + { + return + detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied); + } + else // if constexpr (PieceTypeV == PieceType::Queen) + { + return + detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied) + | detail::slidingAttacks(sq, occupied); + } + } + + static Bitboard generateBetween(Square s1, Square s2) + { + Bitboard bb = Bitboard::none(); + + if (s1 == s2) + { + return bb; + } + + const int fd = s2.file() - s1.file(); + const int rd = s2.rank() - s1.rank(); + + if (fd == 0 || rd == 0 || fd == rd || fd == -rd) + { + // s1 and s2 lie on a line. + const int fileStep = (fd > 0) - (fd < 0); + const int rankStep = (rd > 0) - (rd < 0); + const auto step = FlatSquareOffset(fileStep, rankStep); + s1 += step; // omit s1 + while(s1 != s2) // omit s2 + { + bb |= s1; + s1 += step; + } + } + + return bb; + } + + static Bitboard generateLine(Square s1, Square s2) + { + for (PieceType pt : { PieceType::Bishop, PieceType::Rook }) + { + const Bitboard s1Attacks = pseudoAttacks()[pt][s1]; + if (s1Attacks.isSet(s2)) + { + const Bitboard s2Attacks = pseudoAttacks()[pt][s2]; + return (s1Attacks & s2Attacks) | s1 | s2; + } + } + + return Bitboard::none(); + } + + static const EnumArray2 between = []() + { + EnumArray2 between_; + + for (Square s1 : values()) + { + for (Square s2 : values()) + { + between_[s1][s2] = generateBetween(s1, s2); + } + } + + return between_; + }(); + + static const EnumArray2 line = []() + { + EnumArray2 line_; + + for (Square s1 : values()) + { + for (Square s2 : values()) + { + line_[s1][s2] = generateLine(s1, s2); + } + } + + return line_; + }(); + } + + namespace fancy_magics + { + enum struct MagicsType + { + Rook, + Bishop + }; + + template + [[nodiscard]] inline Bitboard slidingAttacks(Square sq, Bitboard occupied) + { + if (TypeV == MagicsType::Rook) + { + return chess::bb::detail::pieceSlidingAttacks(sq, occupied); + } + + if (TypeV == MagicsType::Bishop) + { + return chess::bb::detail::pieceSlidingAttacks(sq, occupied); + } + + return Bitboard::none(); + } + + template + [[nodiscard]] inline bool initMagics( + const EnumArray& magics, + std::array& table, + EnumArray& masks, + EnumArray& shifts, + EnumArray& attacks + ) + { + std::size_t size = 0; + for (Square sq : values()) + { + const Bitboard edges = + ((bb::rank1 | bb::rank8) & ~Bitboard::rank(sq.rank())) + | ((bb::fileA | bb::fileH) & ~Bitboard::file(sq.file())); + + Bitboard* currentAttacks = table.data() + size; + + attacks[sq] = currentAttacks; + masks[sq] = slidingAttacks(sq, Bitboard::none()) & ~edges; + shifts[sq] = 64 - masks[sq].count(); + + Bitboard occupied = Bitboard::none(); + do + { + const std::size_t idx = + (occupied & masks[sq]).bits() + * magics[sq] + >> shifts[sq]; + + currentAttacks[idx] = slidingAttacks(sq, occupied); + + ++size; + occupied = Bitboard::fromBits(occupied.bits() - masks[sq].bits()) & masks[sq]; + } while (occupied.any()); + } + + return true; + } + + static bool g_isRookMagicsInitialized = + initMagics(g_rookMagics, g_allRookAttacks, g_rookMasks, g_rookShifts, g_rookAttacks); + + static bool g_isBishopMagicsInitialized = + initMagics(g_bishopMagics, g_allBishopAttacks, g_bishopMasks, g_bishopShifts, g_bishopAttacks); + } + + [[nodiscard]] inline Bitboard between(Square s1, Square s2) + { + return detail::between[s1][s2]; + } + + [[nodiscard]] inline Bitboard line(Square s1, Square s2) + { + return detail::line[s1][s2]; + } + + template + [[nodiscard]] inline Bitboard pseudoAttacks(Square sq) + { + static_assert(PieceTypeV != PieceType::None && PieceTypeV != PieceType::Pawn); + + assert(sq.isOk()); + + return detail::pseudoAttacks()[PieceTypeV][sq]; + } + + [[nodiscard]] inline Bitboard pseudoAttacks(PieceType pt, Square sq) + { + assert(sq.isOk()); + + return detail::pseudoAttacks()[pt][sq]; + } + + [[nodiscard]] inline Bitboard pawnAttacks(Bitboard pawns, Color color) + { + if (color == Color::White) + { + return pawns.shifted<1, 1>() | pawns.shifted<-1, 1>(); + } + else + { + return pawns.shifted<1, -1>() | pawns.shifted<-1, -1>(); + } + } + + [[nodiscard]] inline Bitboard westPawnAttacks(Bitboard pawns, Color color) + { + if (color == Color::White) + { + return pawns.shifted<-1, 1>(); + } + else + { + return pawns.shifted<-1, -1>(); + } + } + + [[nodiscard]] inline Bitboard eastPawnAttacks(Bitboard pawns, Color color) + { + if (color == Color::White) + { + return pawns.shifted<1, 1>(); + } + else + { + return pawns.shifted<1, -1>(); + } + } + + [[nodiscard]] inline bool isAttackedBySlider( + Square sq, + Bitboard bishops, + Bitboard rooks, + Bitboard queens, + Bitboard occupied + ) + { + const Bitboard opponentBishopLikePieces = (bishops | queens); + const Bitboard bishopAttacks = bb::attacks(sq, occupied); + if ((bishopAttacks & opponentBishopLikePieces).any()) + { + return true; + } + + const Bitboard opponentRookLikePieces = (rooks | queens); + const Bitboard rookAttacks = bb::attacks(sq, occupied); + return (rookAttacks & opponentRookLikePieces).any(); + } + } + + struct CastlingTraits + { + static constexpr EnumArray2 rookDestination = { { {{ f1, d1 }}, {{ f8, d8 }} } }; + static constexpr EnumArray2 kingDestination = { { {{ g1, c1 }}, {{ g8, c8 }} } }; + + static constexpr EnumArray2 rookStart = { { {{ h1, a1 }}, {{ h8, a8 }} } }; + + static constexpr EnumArray kingStart = { { e1, e8 } }; + + static constexpr EnumArray2 castlingPath = { + { + {{ Bitboard::square(f1) | g1, Bitboard::square(b1) | c1 | d1 }}, + {{ Bitboard::square(f8) | g8, Bitboard::square(b8) | c8 | d8 }} + } + }; + + static constexpr EnumArray2 squarePassedByKing = { + { + {{ f1, d1 }}, + {{ f8, d8 }} + } + }; + + static constexpr EnumArray2 castlingRights = { + { + {{ CastlingRights::WhiteKingSide, CastlingRights::WhiteQueenSide }}, + {{ CastlingRights::BlackKingSide, CastlingRights::BlackQueenSide }} + } + }; + + // Move has to be a legal castling move. + static constexpr CastleType moveCastlingType(const Move& move) + { + return (move.to.file() == fileH) ? CastleType::Short : CastleType::Long; + } + + // Move must be a legal castling move. + static constexpr CastlingRights moveCastlingRight(Move move) + { + if (move.to == h1) return CastlingRights::WhiteKingSide; + if (move.to == a1) return CastlingRights::WhiteQueenSide; + if (move.to == h8) return CastlingRights::WhiteKingSide; + if (move.to == a8) return CastlingRights::WhiteQueenSide; + return CastlingRights::None; + } + }; + + namespace parser_bits + { + [[nodiscard]] constexpr bool isFile(char c) + { + return c >= 'a' && c <= 'h'; + } + + [[nodiscard]] constexpr bool isRank(char c) + { + return c >= '1' && c <= '8'; + } + + [[nodiscard]] constexpr Rank parseRank(char c) + { + assert(isRank(c)); + + return fromOrdinal(c - '1'); + } + + [[nodiscard]] constexpr File parseFile(char c) + { + assert(isFile(c)); + + return fromOrdinal(c - 'a'); + } + + [[nodiscard]] constexpr bool isSquare(const char* s) + { + return isFile(s[0]) && isRank(s[1]); + } + + [[nodiscard]] constexpr Square parseSquare(const char* s) + { + const File file = parseFile(s[0]); + const Rank rank = parseRank(s[1]); + return Square(file, rank); + } + + [[nodiscard]] constexpr std::optional tryParseSquare(std::string_view s) + { + if (s.size() != 2) return {}; + if (!isSquare(s.data())) return {}; + return parseSquare(s.data()); + } + + [[nodiscard]] constexpr std::optional tryParseEpSquare(std::string_view s) + { + if (s == std::string_view("-")) return Square::none(); + return tryParseSquare(s); + } + + [[nodiscard]] constexpr std::optional tryParseCastlingRights(std::string_view s) + { + if (s == std::string_view("-")) return CastlingRights::None; + + CastlingRights rights = CastlingRights::None; + + for (auto& c : s) + { + CastlingRights toAdd = CastlingRights::None; + switch (c) + { + case 'K': + toAdd = CastlingRights::WhiteKingSide; + break; + case 'Q': + toAdd = CastlingRights::WhiteQueenSide; + break; + case 'k': + toAdd = CastlingRights::BlackKingSide; + break; + case 'q': + toAdd = CastlingRights::BlackQueenSide; + break; + } + + // If there are duplicated castling rights specification we bail. + // If there is an invalid character we bail. + // (It always contains None) + if (contains(rights, toAdd)) return {}; + else rights |= toAdd; + } + + return rights; + } + + [[nodiscard]] constexpr CastlingRights readCastlingRights(const char*& s) + { + CastlingRights rights = CastlingRights::None; + + while (*s != ' ') + { + switch (*s) + { + case 'K': + rights |= CastlingRights::WhiteKingSide; + break; + case 'Q': + rights |= CastlingRights::WhiteQueenSide; + break; + case 'k': + rights |= CastlingRights::BlackKingSide; + break; + case 'q': + rights |= CastlingRights::BlackQueenSide; + break; + } + + ++s; + } + + return rights; + } + + FORCEINLINE inline void appendCastlingRightsToString(CastlingRights rights, std::string& str) + { + if (rights == CastlingRights::None) + { + str += '-'; + } + else + { + if (contains(rights, CastlingRights::WhiteKingSide)) str += 'K'; + if (contains(rights, CastlingRights::WhiteQueenSide)) str += 'Q'; + if (contains(rights, CastlingRights::BlackKingSide)) str += 'k'; + if (contains(rights, CastlingRights::BlackQueenSide)) str += 'q'; + } + } + + FORCEINLINE inline void appendSquareToString(Square sq, std::string& str) + { + str += static_cast('a' + ordinal(sq.file())); + str += static_cast('1' + ordinal(sq.rank())); + } + + FORCEINLINE inline void appendEpSquareToString(Square sq, std::string& str) + { + if (sq == Square::none()) + { + str += '-'; + } + else + { + appendSquareToString(sq, str); + } + } + + FORCEINLINE inline void appendRankToString(Rank r, std::string& str) + { + str += static_cast('1' + ordinal(r)); + } + + FORCEINLINE inline void appendFileToString(File f, std::string& str) + { + str += static_cast('a' + ordinal(f)); + } + + [[nodiscard]] FORCEINLINE inline bool isDigit(char c) + { + return c >= '0' && c <= '9'; + } + + [[nodiscard]] inline std::uint16_t parseUInt16(std::string_view sv) + { + assert(sv.size() > 0); + assert(sv.size() <= 5); + + std::uint16_t v = 0; + + std::size_t idx = 0; + switch (sv.size()) + { + case 5: + v += (sv[idx++] - '0') * 10000; + case 4: + v += (sv[idx++] - '0') * 1000; + case 3: + v += (sv[idx++] - '0') * 100; + case 2: + v += (sv[idx++] - '0') * 10; + case 1: + v += sv[idx] - '0'; + break; + + default: + assert(false); + } + + return v; + } + + [[nodiscard]] inline std::optional tryParseUInt16(std::string_view sv) + { + if (sv.size() == 0 || sv.size() > 5) return std::nullopt; + + std::uint32_t v = 0; + + std::size_t idx = 0; + switch (sv.size()) + { + case 5: + v += (sv[idx++] - '0') * 10000; + case 4: + v += (sv[idx++] - '0') * 1000; + case 3: + v += (sv[idx++] - '0') * 100; + case 2: + v += (sv[idx++] - '0') * 10; + case 1: + v += sv[idx] - '0'; + break; + + default: + assert(false); + } + + if (v > std::numeric_limits::max()) + { + return std::nullopt; + } + + return static_cast(v); + } + } + + + struct Board + { + constexpr Board() noexcept : + m_pieces{}, + m_pieceBB{}, + m_piecesByColorBB{}, + m_pieceCount{} + { + m_pieces.fill(Piece::none()); + m_pieceBB.fill(Bitboard::none()); + m_pieceBB[Piece::none()] = Bitboard::all(); + m_piecesByColorBB.fill(Bitboard::none()); + m_pieceCount.fill(0); + m_pieceCount[Piece::none()] = 64; + } + + [[nodiscard]] inline bool isValid() const + { + if (piecesBB(whiteKing).count() != 1) return false; + if (piecesBB(blackKing).count() != 1) return false; + if (((piecesBB(whitePawn) | piecesBB(blackPawn)) & (bb::rank(rank1) | bb::rank(rank8))).any()) return false; + return true; + } + + [[nodiscard]] inline std::string fen() const; + + [[nodiscard]] inline bool trySet(std::string_view boardState) + { + File f = fileA; + Rank r = rank8; + bool lastWasSkip = false; + for (auto c : boardState) + { + Piece piece = Piece::none(); + switch (c) + { + case 'r': + piece = Piece(PieceType::Rook, Color::Black); + break; + case 'n': + piece = Piece(PieceType::Knight, Color::Black); + break; + case 'b': + piece = Piece(PieceType::Bishop, Color::Black); + break; + case 'q': + piece = Piece(PieceType::Queen, Color::Black); + break; + case 'k': + piece = Piece(PieceType::King, Color::Black); + break; + case 'p': + piece = Piece(PieceType::Pawn, Color::Black); + break; + + case 'R': + piece = Piece(PieceType::Rook, Color::White); + break; + case 'N': + piece = Piece(PieceType::Knight, Color::White); + break; + case 'B': + piece = Piece(PieceType::Bishop, Color::White); + break; + case 'Q': + piece = Piece(PieceType::Queen, Color::White); + break; + case 'K': + piece = Piece(PieceType::King, Color::White); + break; + case 'P': + piece = Piece(PieceType::Pawn, Color::White); + break; + + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + { + if (lastWasSkip) return false; + lastWasSkip = true; + + const int skip = c - '0'; + f += skip; + if (f > fileH + 1) return false; + break; + } + + case '/': + lastWasSkip = false; + if (f != fileH + 1) return false; + f = fileA; + --r; + break; + + default: + return false; + } + + if (piece != Piece::none()) + { + lastWasSkip = false; + + const Square sq(f, r); + if (!sq.isOk()) return false; + + place(piece, sq); + ++f; + } + } + + if (f != fileH + 1) return false; + if (r != rank1) return false; + + return isValid(); + } + + // returns side to move + [[nodiscard]] constexpr const char* set(const char* fen) + { + assert(fen != nullptr); + + File f = fileA; + Rank r = rank8; + auto current = fen; + bool done = false; + while (*current != '\0') + { + Piece piece = Piece::none(); + switch (*current) + { + case 'r': + piece = Piece(PieceType::Rook, Color::Black); + break; + case 'n': + piece = Piece(PieceType::Knight, Color::Black); + break; + case 'b': + piece = Piece(PieceType::Bishop, Color::Black); + break; + case 'q': + piece = Piece(PieceType::Queen, Color::Black); + break; + case 'k': + piece = Piece(PieceType::King, Color::Black); + break; + case 'p': + piece = Piece(PieceType::Pawn, Color::Black); + break; + + case 'R': + piece = Piece(PieceType::Rook, Color::White); + break; + case 'N': + piece = Piece(PieceType::Knight, Color::White); + break; + case 'B': + piece = Piece(PieceType::Bishop, Color::White); + break; + case 'Q': + piece = Piece(PieceType::Queen, Color::White); + break; + case 'K': + piece = Piece(PieceType::King, Color::White); + break; + case 'P': + piece = Piece(PieceType::Pawn, Color::White); + break; + + case ' ': + done = true; + break; + + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + { + const int skip = (*current) - '0'; + f += skip; + break; + } + + case '/': + f = fileA; + --r; + break; + + default: + break; + } + + if (done) + { + break; + } + + if (piece != Piece::none()) + { + place(piece, Square(f, r)); + ++f; + } + + ++current; + } + + return current; + } + + static constexpr Board fromFen(const char* fen) + { + Board board; + (void)board.set(fen); + return board; + } + + [[nodiscard]] constexpr friend bool operator==(const Board& lhs, const Board& rhs) noexcept + { + bool equal = true; + for (Square sq = a1; sq <= h8; ++sq) + { + if (lhs.m_pieces[sq] != rhs.m_pieces[sq]) + { + equal = false; + break; + } + } + + assert(bbsEqual(lhs, rhs) == equal); + + return equal; + } + + constexpr void place(Piece piece, Square sq) + { + assert(sq.isOk()); + + auto oldPiece = m_pieces[sq]; + m_pieceBB[oldPiece] ^= sq; + if (oldPiece != Piece::none()) + { + m_piecesByColorBB[oldPiece.color()] ^= sq; + } + m_pieces[sq] = piece; + m_pieceBB[piece] |= sq; + m_piecesByColorBB[piece.color()] |= sq; + --m_pieceCount[oldPiece]; + ++m_pieceCount[piece]; + } + + // returns captured piece + // doesn't check validity + inline constexpr Piece doMove(Move move) + { + if (move.type == MoveType::Normal) + { + const Piece capturedPiece = m_pieces[move.to]; + const Piece piece = m_pieces[move.from]; + + const Bitboard frombb = Bitboard::square(move.from); + const Bitboard tobb = Bitboard::square(move.to); + const Bitboard xormove = frombb ^ tobb; + + m_pieces[move.to] = piece; + m_pieces[move.from] = Piece::none(); + + m_pieceBB[piece] ^= xormove; + + m_piecesByColorBB[piece.color()] ^= xormove; + + if (capturedPiece == Piece::none()) + { + m_pieceBB[Piece::none()] ^= xormove; + } + else + { + m_pieceBB[capturedPiece] ^= tobb; + m_pieceBB[Piece::none()] ^= frombb; + + m_piecesByColorBB[capturedPiece.color()] ^= tobb; + + --m_pieceCount[capturedPiece]; + ++m_pieceCount[Piece::none()]; + } + + return capturedPiece; + } + + return doMoveColdPath(move); + } + + inline constexpr Piece doMoveColdPath(Move move) + { + if (move.type == MoveType::Promotion) + { + // We split it even though it's similar just because + // the normal case is much more common. + const Piece capturedPiece = m_pieces[move.to]; + const Piece fromPiece = m_pieces[move.from]; + const Piece toPiece = move.promotedPiece; + + m_pieces[move.to] = toPiece; + m_pieces[move.from] = Piece::none(); + + m_pieceBB[fromPiece] ^= move.from; + m_pieceBB[toPiece] ^= move.to; + + m_pieceBB[capturedPiece] ^= move.to; + m_pieceBB[Piece::none()] ^= move.from; + + m_piecesByColorBB[fromPiece.color()] ^= move.to; + m_piecesByColorBB[fromPiece.color()] ^= move.from; + if (capturedPiece != Piece::none()) + { + m_piecesByColorBB[capturedPiece.color()] ^= move.to; + --m_pieceCount[capturedPiece]; + ++m_pieceCount[Piece::none()]; + } + + --m_pieceCount[fromPiece]; + ++m_pieceCount[toPiece]; + + return capturedPiece; + } + else if (move.type == MoveType::EnPassant) + { + const Piece movedPiece = m_pieces[move.from]; + const Piece capturedPiece(PieceType::Pawn, !movedPiece.color()); + const Square capturedPieceSq(move.to.file(), move.from.rank()); + + // on ep move there are 3 squares involved + m_pieces[move.to] = movedPiece; + m_pieces[move.from] = Piece::none(); + m_pieces[capturedPieceSq] = Piece::none(); + + m_pieceBB[movedPiece] ^= move.from; + m_pieceBB[movedPiece] ^= move.to; + + m_pieceBB[Piece::none()] ^= move.from; + m_pieceBB[Piece::none()] ^= move.to; + + m_pieceBB[capturedPiece] ^= capturedPieceSq; + m_pieceBB[Piece::none()] ^= capturedPieceSq; + + m_piecesByColorBB[movedPiece.color()] ^= move.to; + m_piecesByColorBB[movedPiece.color()] ^= move.from; + m_piecesByColorBB[capturedPiece.color()] ^= capturedPieceSq; + + --m_pieceCount[capturedPiece]; + ++m_pieceCount[Piece::none()]; + + return capturedPiece; + } + else // if (move.type == MoveType::Castle) + { + const Square rookFromSq = move.to; + const Square kingFromSq = move.from; + + const Piece rook = m_pieces[rookFromSq]; + const Piece king = m_pieces[kingFromSq]; + const Color color = king.color(); + + const CastleType castleType = CastlingTraits::moveCastlingType(move); + const Square rookToSq = CastlingTraits::rookDestination[color][castleType]; + const Square kingToSq = CastlingTraits::kingDestination[color][castleType]; + + // 4 squares are involved + m_pieces[rookFromSq] = Piece::none(); + m_pieces[kingFromSq] = Piece::none(); + m_pieces[rookToSq] = rook; + m_pieces[kingToSq] = king; + + m_pieceBB[rook] ^= rookFromSq; + m_pieceBB[rook] ^= rookToSq; + + m_pieceBB[king] ^= kingFromSq; + m_pieceBB[king] ^= kingToSq; + + m_pieceBB[Piece::none()] ^= rookFromSq; + m_pieceBB[Piece::none()] ^= rookToSq; + + m_pieceBB[Piece::none()] ^= kingFromSq; + m_pieceBB[Piece::none()] ^= kingToSq; + + m_piecesByColorBB[color] ^= rookFromSq; + m_piecesByColorBB[color] ^= rookToSq; + m_piecesByColorBB[color] ^= kingFromSq; + m_piecesByColorBB[color] ^= kingToSq; + + return Piece::none(); + } + } + + constexpr void undoMove(Move move, Piece capturedPiece) + { + if (move.type == MoveType::Normal || move.type == MoveType::Promotion) + { + const Piece toPiece = m_pieces[move.to]; + const Piece fromPiece = move.promotedPiece == Piece::none() ? toPiece : Piece(PieceType::Pawn, toPiece.color()); + + m_pieces[move.from] = fromPiece; + m_pieces[move.to] = capturedPiece; + + m_pieceBB[fromPiece] ^= move.from; + m_pieceBB[toPiece] ^= move.to; + + m_pieceBB[capturedPiece] ^= move.to; + m_pieceBB[Piece::none()] ^= move.from; + + m_piecesByColorBB[fromPiece.color()] ^= move.to; + m_piecesByColorBB[fromPiece.color()] ^= move.from; + if (capturedPiece != Piece::none()) + { + m_piecesByColorBB[capturedPiece.color()] ^= move.to; + ++m_pieceCount[capturedPiece]; + --m_pieceCount[Piece::none()]; + } + + if (move.type == MoveType::Promotion) + { + --m_pieceCount[toPiece]; + ++m_pieceCount[fromPiece]; + } + } + else if (move.type == MoveType::EnPassant) + { + const Piece movedPiece = m_pieces[move.to]; + const Piece capturedPiece_(PieceType::Pawn, !movedPiece.color()); + const Square capturedPieceSq(move.to.file(), move.from.rank()); + + m_pieces[move.to] = Piece::none(); + m_pieces[move.from] = movedPiece; + m_pieces[capturedPieceSq] = capturedPiece_; + + m_pieceBB[movedPiece] ^= move.from; + m_pieceBB[movedPiece] ^= move.to; + + m_pieceBB[Piece::none()] ^= move.from; + m_pieceBB[Piece::none()] ^= move.to; + + // on ep move there are 3 squares involved + m_pieceBB[capturedPiece_] ^= capturedPieceSq; + m_pieceBB[Piece::none()] ^= capturedPieceSq; + + m_piecesByColorBB[movedPiece.color()] ^= move.to; + m_piecesByColorBB[movedPiece.color()] ^= move.from; + m_piecesByColorBB[capturedPiece_.color()] ^= capturedPieceSq; + + ++m_pieceCount[capturedPiece_]; + --m_pieceCount[Piece::none()]; + } + else // if (move.type == MoveType::Castle) + { + const Square rookFromSq = move.to; + const Square kingFromSq = move.from; + + const Color color = move.to.rank() == rank1 ? Color::White : Color::Black; + + const CastleType castleType = CastlingTraits::moveCastlingType(move); + const Square rookToSq = CastlingTraits::rookDestination[color][castleType]; + const Square kingToSq = CastlingTraits::kingDestination[color][castleType]; + + const Piece rook = m_pieces[rookToSq]; + const Piece king = m_pieces[kingToSq]; + + // 4 squares are involved + m_pieces[rookFromSq] = rook; + m_pieces[kingFromSq] = king; + m_pieces[rookToSq] = Piece::none(); + m_pieces[kingToSq] = Piece::none(); + + m_pieceBB[rook] ^= rookFromSq; + m_pieceBB[rook] ^= rookToSq; + + m_pieceBB[king] ^= kingFromSq; + m_pieceBB[king] ^= kingToSq; + + m_pieceBB[Piece::none()] ^= rookFromSq; + m_pieceBB[Piece::none()] ^= rookToSq; + + m_pieceBB[Piece::none()] ^= kingFromSq; + m_pieceBB[Piece::none()] ^= kingToSq; + + m_piecesByColorBB[color] ^= rookFromSq; + m_piecesByColorBB[color] ^= rookToSq; + m_piecesByColorBB[color] ^= kingFromSq; + m_piecesByColorBB[color] ^= kingToSq; + } + } + + // Returns whether a given square is attacked by any piece + // of `attackerColor` side. + [[nodiscard]] inline bool isSquareAttacked(Square sq, Color attackerColor) const; + + // Returns whether a given square is attacked by any piece + // of `attackerColor` side after `move` is made. + // Move must be pseudo legal. + [[nodiscard]] inline bool isSquareAttackedAfterMove(Move move, Square sq, Color attackerColor) const; + + // Move must be pseudo legal. + // Must not be a king move. + [[nodiscard]] inline bool createsDiscoveredAttackOnOwnKing(Move move) const; + + // Returns whether a piece on a given square is attacked + // by any enemy piece. False if square is empty. + [[nodiscard]] inline bool isPieceAttacked(Square sq) const; + + // Returns whether a piece on a given square is attacked + // by any enemy piece after `move` is made. False if square is empty. + // Move must be pseudo legal. + [[nodiscard]] inline bool isPieceAttackedAfterMove(Move move, Square sq) const; + + // Returns whether the king of the moving side is attacked + // by any enemy piece after a move is made. + // Move must be pseudo legal. + [[nodiscard]] inline bool isOwnKingAttackedAfterMove(Move move) const; + + // Return a bitboard with all (pseudo legal) attacks by the piece on + // the given square. Empty if no piece on the square. + [[nodiscard]] inline Bitboard attacks(Square sq) const; + + // Returns a bitboard with all squared that have pieces + // that attack a given square (pseudo legally) + [[nodiscard]] inline Bitboard attackers(Square sq, Color attackerColor) const; + + [[nodiscard]] constexpr Piece pieceAt(Square sq) const + { + assert(sq.isOk()); + + return m_pieces[sq]; + } + + [[nodiscard]] constexpr Bitboard piecesBB(Color c) const + { + return m_piecesByColorBB[c]; + } + + [[nodiscard]] inline Square kingSquare(Color c) const + { + return piecesBB(Piece(PieceType::King, c)).first(); + } + + [[nodiscard]] constexpr Bitboard piecesBB(Piece pc) const + { + return m_pieceBB[pc]; + } + + [[nodiscard]] constexpr Bitboard piecesBB() const + { + Bitboard bb{}; + + // don't collect from null piece + return piecesBB(Color::White) | piecesBB(Color::Black); + + return bb; + } + + [[nodiscard]] constexpr std::uint8_t pieceCount(Piece pt) const + { + return m_pieceCount[pt]; + } + + [[nodiscard]] constexpr bool isPromotion(Square from, Square to) const + { + assert(from.isOk() && to.isOk()); + + return m_pieces[from].type() == PieceType::Pawn && (to.rank() == rank1 || to.rank() == rank8); + } + + const Piece* piecesRaw() const; + + private: + EnumArray m_pieces; + EnumArray m_pieceBB; + EnumArray m_piecesByColorBB; + EnumArray m_pieceCount; + + // NOTE: currently we don't track it because it's not + // required to perform ep if we don't need to check validity + // Square m_epSquare = Square::none(); + + [[nodiscard]] static constexpr bool bbsEqual(const Board& lhs, const Board& rhs) noexcept + { + for (Piece pc : values()) + { + if (lhs.m_pieceBB[pc] != rhs.m_pieceBB[pc]) + { + return false; + } + } + + return true; + } + }; + + struct Position; + + struct CompressedPosition; + + struct PositionHash128 + { + std::uint64_t high; + std::uint64_t low; + }; + + struct Position; + + struct MoveLegalityChecker + { + MoveLegalityChecker(const Position& position); + + [[nodiscard]] bool isPseudoLegalMoveLegal(const Move& move) const; + + private: + const Position* m_position; + Bitboard m_checkers; + Bitboard m_ourBlockersForKing; + Bitboard m_potentialCheckRemovals; + Square m_ksq; + }; + + struct Position : public Board + { + using BaseType = Board; + + constexpr Position() noexcept : + Board(), + m_sideToMove(Color::White), + m_epSquare(Square::none()), + m_castlingRights(CastlingRights::All), + m_rule50Counter(0), + m_ply(0) + { + } + + constexpr Position(const Board& board, Color sideToMove, Square epSquare, CastlingRights castlingRights) : + Board(board), + m_sideToMove(sideToMove), + m_epSquare(epSquare), + m_castlingRights(castlingRights), + m_rule50Counter(0), + m_ply(0) + { + } + + inline void set(std::string_view fen); + + // Returns false if the fen was not valid + // If the returned value was false the position + // is in unspecified state. + [[nodiscard]] inline bool trySet(std::string_view fen); + + [[nodiscard]] static inline Position fromFen(std::string_view fen); + + [[nodiscard]] static inline std::optional tryFromFen(std::string_view fen); + + [[nodiscard]] static inline Position startPosition(); + + [[nodiscard]] inline std::string fen() const; + + [[nodiscard]] MoveLegalityChecker moveLegalityChecker() const + { + return { *this }; + } + + constexpr void setEpSquareUnchecked(Square sq) + { + m_epSquare = sq; + } + + void setEpSquare(Square sq) + { + m_epSquare = sq; + nullifyEpSquareIfNotPossible(); + } + + constexpr void setSideToMove(Color color) + { + m_sideToMove = color; + } + + constexpr void addCastlingRights(CastlingRights rights) + { + m_castlingRights |= rights; + } + + constexpr void setCastlingRights(CastlingRights rights) + { + m_castlingRights = rights; + } + + constexpr void setRule50Counter(std::uint8_t v) + { + m_rule50Counter = v; + } + + constexpr void setPly(std::uint16_t ply) + { + m_ply = ply; + } + + inline ReverseMove doMove(const Move& move); + + constexpr void undoMove(const ReverseMove& reverseMove) + { + const Move& move = reverseMove.move; + BaseType::undoMove(move, reverseMove.capturedPiece); + + m_epSquare = reverseMove.oldEpSquare; + m_castlingRights = reverseMove.oldCastlingRights; + + m_sideToMove = !m_sideToMove; + + --m_ply; + if (m_rule50Counter > 0) + { + m_rule50Counter -= 1; + } + } + + [[nodiscard]] constexpr Color sideToMove() const + { + return m_sideToMove; + } + + [[nodiscard]] inline std::uint8_t rule50Counter() const + { + return m_rule50Counter; + } + + [[nodiscard]] inline std::uint16_t ply() const + { + return m_ply; + } + + [[nodiscard]] inline std::uint16_t fullMove() const + { + return m_ply / 2 + 1; + } + + inline void setFullMove(std::uint16_t hm) + { + m_ply = 2 * (hm - 1) + (m_sideToMove == Color::Black); + } + + [[nodiscard]] inline bool isCheck() const; + + [[nodiscard]] inline Bitboard checkers() const; + + [[nodiscard]] inline bool isCheckAfterMove(Move move) const; + + [[nodiscard]] inline bool isMoveLegal(Move move) const; + + [[nodiscard]] inline bool isPseudoLegalMoveLegal(Move move) const; + + [[nodiscard]] inline bool isMovePseudoLegal(Move move) const; + + // Returns all pieces that block a slider + // from attacking our king. When two or more + // pieces block a single slider then none + // of these pieces are included. + [[nodiscard]] inline Bitboard blockersForKing(Color color) const; + + [[nodiscard]] constexpr Square epSquare() const + { + return m_epSquare; + } + + [[nodiscard]] constexpr CastlingRights castlingRights() const + { + return m_castlingRights; + } + + [[nodiscard]] constexpr bool friend operator==(const Position& lhs, const Position& rhs) noexcept + { + return + lhs.m_sideToMove == rhs.m_sideToMove + && lhs.m_epSquare == rhs.m_epSquare + && lhs.m_castlingRights == rhs.m_castlingRights + && static_cast(lhs) == static_cast(rhs); + } + + [[nodiscard]] constexpr bool friend operator!=(const Position& lhs, const Position& rhs) noexcept + { + return !(lhs == rhs); + } + + // these are supposed to be used only for testing + // that's why there's this assert in afterMove + + [[nodiscard]] constexpr Position beforeMove(const ReverseMove& reverseMove) const + { + Position cpy(*this); + cpy.undoMove(reverseMove); + return cpy; + } + + [[nodiscard]] inline Position afterMove(Move move) const; + + [[nodiscard]] constexpr bool isEpPossible() const + { + return m_epSquare != Square::none(); + } + + [[nodiscard]] inline CompressedPosition compress() const; + + protected: + Color m_sideToMove; + Square m_epSquare; + CastlingRights m_castlingRights; + std::uint8_t m_rule50Counter; + std::uint16_t m_ply; + + static_assert(sizeof(Color) + sizeof(Square) + sizeof(CastlingRights) + sizeof(std::uint8_t) == 4); + + [[nodiscard]] inline bool isEpPossible(Square epSquare, Color sideToMove) const; + + [[nodiscard]] inline bool isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const; + + inline void nullifyEpSquareIfNotPossible(); + }; + + struct CompressedPosition + { + friend struct Position; + + // Occupied bitboard has bits set for + // each square with a piece on it. + // Each packedState byte holds 2 values (nibbles). + // First one at low bits, second one at high bits. + // Values correspond to consecutive squares + // in bitboard iteration order. + // Nibble values: + // these are the same as for Piece + // knights, bishops, queens can just be copied + // 0 : white pawn + // 1 : black pawn + // 2 : white knight + // 3 : black knight + // 4 : white bishop + // 5 : black bishop + // 6 : white rook + // 7 : black rook + // 8 : white queen + // 9 : black queen + // 10 : white king + // 11 : black king + // + // these are special + // 12 : pawn with ep square behind (white or black, depending on rank) + // 13 : white rook with coresponding castling rights + // 14 : black rook with coresponding castling rights + // 15 : black king and black is side to move + // + // Let N be the number of bits set in occupied bitboard. + // Only N nibbles are present. (N+1)/2 bytes are initialized. + + static CompressedPosition readFromBigEndian(const unsigned char* data) + { + CompressedPosition pos{}; + pos.m_occupied = Bitboard::fromBits( + (std::uint64_t)data[0] << 56 + | (std::uint64_t)data[1] << 48 + | (std::uint64_t)data[2] << 40 + | (std::uint64_t)data[3] << 32 + | (std::uint64_t)data[4] << 24 + | (std::uint64_t)data[5] << 16 + | (std::uint64_t)data[6] << 8 + | (std::uint64_t)data[7] + ); + std::memcpy(pos.m_packedState, data + 8, 16); + return pos; + } + + constexpr CompressedPosition() : + m_occupied{}, + m_packedState{} + { + } + + [[nodiscard]] friend bool operator<(const CompressedPosition& lhs, const CompressedPosition& rhs) + { + if (lhs.m_occupied.bits() < rhs.m_occupied.bits()) return true; + if (lhs.m_occupied.bits() > rhs.m_occupied.bits()) return false; + + return std::strcmp(reinterpret_cast(lhs.m_packedState), reinterpret_cast(rhs.m_packedState)) < 0; + } + + [[nodiscard]] friend bool operator==(const CompressedPosition& lhs, const CompressedPosition& rhs) + { + return lhs.m_occupied == rhs.m_occupied + && std::strcmp(reinterpret_cast(lhs.m_packedState), reinterpret_cast(rhs.m_packedState)) == 0; + } + + [[nodiscard]] inline Position decompress() const; + + [[nodiscard]] constexpr Bitboard pieceBB() const + { + return m_occupied; + } + + void writeToBigEndian(unsigned char* data) + { + const auto occupied = m_occupied.bits(); + *data++ = occupied >> 56; + *data++ = (occupied >> 48) & 0xFF; + *data++ = (occupied >> 40) & 0xFF; + *data++ = (occupied >> 32) & 0xFF; + *data++ = (occupied >> 24) & 0xFF; + *data++ = (occupied >> 16) & 0xFF; + *data++ = (occupied >> 8) & 0xFF; + *data++ = occupied & 0xFF; + std::memcpy(data, m_packedState, 16); + } + + private: + Bitboard m_occupied; + std::uint8_t m_packedState[16]; + }; + + namespace movegen + { + // For a pseudo-legal move the following are true: + // - the moving piece has the pos.sideToMove() color + // - the destination square is either empty or has a piece of the opposite color + // - if it is a pawn move it is valid (but may be illegal due to discovered checks) + // - if it is not a pawn move then the destination square is contained in attacks() + // - if it is a castling it is legal + // - a move other than castling may create a discovered attack on the king + // - a king may walk into a check + + template + inline void forEachPseudoLegalPawnMove(const Position& pos, Square from, FuncT&& f) + { + const Color sideToMove = pos.sideToMove(); + const Square epSquare = pos.epSquare(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + + Bitboard attackTargets = theirPieces; + if (epSquare != Square::none()) + { + attackTargets |= epSquare; + } + + const Bitboard attacks = bb::pawnAttacks(Bitboard::square(from), sideToMove) & attackTargets; + + const Rank secondToLastRank = sideToMove == Color::White ? rank7 : rank2; + const auto forward = sideToMove == Color::White ? FlatSquareOffset(0, 1) : FlatSquareOffset(0, -1); + + // promotions + if (from.rank() == secondToLastRank) + { + // capture promotions + for (Square toSq : attacks) + { + for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) + { + Move move{ from, toSq, MoveType::Promotion, Piece(pt, sideToMove) }; + f(move); + } + } + + // push promotions + const Square toSq = from + forward; + if (!occupied.isSet(toSq)) + { + for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) + { + Move move{ from, toSq, MoveType::Promotion, Piece(pt, sideToMove) }; + f(move); + } + } + } + else + { + // captures + for (Square toSq : attacks) + { + Move move{ from, toSq, (toSq == epSquare) ? MoveType::EnPassant : MoveType::Normal }; + f(move); + } + + const Square toSq = from + forward; + + // single push + if (!occupied.isSet(toSq)) + { + const Rank startRank = sideToMove == Color::White ? rank2 : rank7; + if (from.rank() == startRank) + { + // double push + const Square toSq2 = toSq + forward; + if (!occupied.isSet(toSq2)) + { + Move move{ from, toSq2 }; + f(move); + } + } + + Move move{ from, toSq }; + f(move); + } + } + } + + template + inline void forEachPseudoLegalPawnMove(const Position& pos, FuncT&& f) + { + const Square epSquare = pos.epSquare(); + const Bitboard ourPieces = pos.piecesBB(SideToMoveV); + const Bitboard theirPieces = pos.piecesBB(!SideToMoveV); + const Bitboard occupied = ourPieces | theirPieces; + const Bitboard pawns = pos.piecesBB(Piece(PieceType::Pawn, SideToMoveV)); + + const Bitboard secondToLastRank = SideToMoveV == Color::White ? bb::rank7 : bb::rank2; + const Bitboard secondRank = SideToMoveV == Color::White ? bb::rank2 : bb::rank7; + + const auto singlePawnMoveDestinationOffset = SideToMoveV == Color::White ? FlatSquareOffset(0, 1) : FlatSquareOffset(0, -1); + const auto doublePawnMoveDestinationOffset = SideToMoveV == Color::White ? FlatSquareOffset(0, 2) : FlatSquareOffset(0, -2); + + { + const int backward = SideToMoveV == Color::White ? -1 : 1; + const int backward2 = backward * 2; + + const Bitboard doublePawnMoveStarts = + pawns + & secondRank + & ~(occupied.shiftedVertically(backward) | occupied.shiftedVertically(backward2)); + + const Bitboard singlePawnMoveStarts = + pawns + & ~secondToLastRank + & ~occupied.shiftedVertically(backward); + + for (Square from : doublePawnMoveStarts) + { + const Square to = from + doublePawnMoveDestinationOffset; + f(Move::normal(from, to)); + } + + for (Square from : singlePawnMoveStarts) + { + const Square to = from + singlePawnMoveDestinationOffset; + f(Move::normal(from, to)); + } + } + + { + const Bitboard lastRank = SideToMoveV == Color::White ? bb::rank8 : bb::rank1; + const FlatSquareOffset westCaptureOffset = SideToMoveV == Color::White ? FlatSquareOffset(-1, 1) : FlatSquareOffset(-1, -1); + const FlatSquareOffset eastCaptureOffset = SideToMoveV == Color::White ? FlatSquareOffset(1, 1) : FlatSquareOffset(1, -1); + + const Bitboard pawnsWithWestCapture = bb::eastPawnAttacks(theirPieces & ~lastRank, !SideToMoveV) & pawns; + const Bitboard pawnsWithEastCapture = bb::westPawnAttacks(theirPieces & ~lastRank, !SideToMoveV) & pawns; + + for (Square from : pawnsWithWestCapture) + { + f(Move::normal(from, from + westCaptureOffset)); + } + + for (Square from : pawnsWithEastCapture) + { + f(Move::normal(from, from + eastCaptureOffset)); + } + } + + if (epSquare != Square::none()) + { + const Bitboard pawnsThatCanCapture = bb::pawnAttacks(Bitboard::square(epSquare), !SideToMoveV) & pawns; + for (Square from : pawnsThatCanCapture) + { + f(Move::enPassant(from, epSquare)); + } + } + + for (Square from : pawns & secondToLastRank) + { + const Bitboard attacks = bb::pawnAttacks(Bitboard::square(from), SideToMoveV) & theirPieces; + + // capture promotions + for (Square to : attacks) + { + for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) + { + Move move{ from, to, MoveType::Promotion, Piece(pt, SideToMoveV) }; + f(move); + } + } + + // push promotions + const Square to = from + singlePawnMoveDestinationOffset; + if (!occupied.isSet(to)) + { + for (PieceType pt : { PieceType::Knight, PieceType::Bishop, PieceType::Rook, PieceType::Queen }) + { + Move move{ from, to, MoveType::Promotion, Piece(pt, SideToMoveV) }; + f(move); + } + } + } + } + + template + inline void forEachPseudoLegalPawnMove(const Position& pos, FuncT&& f) + { + if (pos.sideToMove() == Color::White) + { + forEachPseudoLegalPawnMove(pos, std::forward(f)); + } + else + { + forEachPseudoLegalPawnMove(pos, std::forward(f)); + } + } + + template + inline void forEachPseudoLegalPieceMove(const Position& pos, Square from, FuncT&& f) + { + static_assert(PieceTypeV != PieceType::None); + + if constexpr (PieceTypeV == PieceType::Pawn) + { + forEachPseudoLegalPawnMove(pos, from, f); + } + else + { + const Color sideToMove = pos.sideToMove(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + const Bitboard attacks = bb::attacks(from, occupied) & ~ourPieces; + + for (Square toSq : attacks) + { + Move move{ from, toSq }; + f(move); + } + } + } + + template + inline void forEachPseudoLegalPieceMove(const Position& pos, FuncT&& f) + { + static_assert(PieceTypeV != PieceType::None); + + if constexpr (PieceTypeV == PieceType::Pawn) + { + forEachPseudoLegalPawnMove(pos, f); + } + else + { + const Color sideToMove = pos.sideToMove(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + const Bitboard pieces = pos.piecesBB(Piece(PieceTypeV, sideToMove)); + for (Square fromSq : pieces) + { + const Bitboard attacks = bb::attacks(fromSq, occupied) & ~ourPieces; + for (Square toSq : attacks) + { + Move move{ fromSq, toSq }; + f(move); + } + } + } + } + + template + inline void forEachCastlingMove(const Position& pos, FuncT&& f) + { + CastlingRights rights = pos.castlingRights(); + if (rights == CastlingRights::None) + { + return; + } + + const Color sideToMove = pos.sideToMove(); + const Bitboard ourPieces = pos.piecesBB(sideToMove); + const Bitboard theirPieces = pos.piecesBB(!sideToMove); + const Bitboard occupied = ourPieces | theirPieces; + + // we first reduce the set of legal castlings by checking the paths for pieces + if (sideToMove == Color::White) + { + if ((CastlingTraits::castlingPath[Color::White][CastleType::Short] & occupied).any()) rights &= ~CastlingRights::WhiteKingSide; + if ((CastlingTraits::castlingPath[Color::White][CastleType::Long] & occupied).any()) rights &= ~CastlingRights::WhiteQueenSide; + rights &= ~CastlingRights::Black; + } + else + { + if ((CastlingTraits::castlingPath[Color::Black][CastleType::Short] & occupied).any()) rights &= ~CastlingRights::BlackKingSide; + if ((CastlingTraits::castlingPath[Color::Black][CastleType::Long] & occupied).any()) rights &= ~CastlingRights::BlackQueenSide; + rights &= ~CastlingRights::White; + } + + if (rights == CastlingRights::None) + { + return; + } + + // King must not be in check. Done here because it is quite expensive. + const Square ksq = pos.kingSquare(sideToMove); + if (pos.isSquareAttacked(ksq, !sideToMove)) + { + return; + } + + // Loop through all possible castlings. + for (CastleType castlingType : values()) + { + const CastlingRights right = CastlingTraits::castlingRights[sideToMove][castlingType]; + + if (!contains(rights, right)) + { + continue; + } + + // If we have this castling right + // we check whether the king passes an attacked square. + const Square passedSquare = CastlingTraits::squarePassedByKing[sideToMove][castlingType]; + if (pos.isSquareAttacked(passedSquare, !sideToMove)) + { + continue; + } + + // If it's a castling move then the change in square occupation + // cannot have an effect because otherwise there would be + // a slider attacker attacking the castling king. + if (pos.isSquareAttacked(CastlingTraits::kingDestination[sideToMove][castlingType], !sideToMove)) + { + continue; + } + + // If not we can castle. + Move move = Move::castle(castlingType, sideToMove); + f(move); + } + } + + // Calls a given function for all pseudo legal moves for the position. + // `pos` must be a legal chess position + template + inline void forEachPseudoLegalMove(const Position& pos, FuncT&& func) + { + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachPseudoLegalPieceMove(pos, func); + forEachCastlingMove(pos, func); + } + + // Calls a given function for all legal moves for the position. + // `pos` must be a legal chess position + template + inline void forEachLegalMove(const Position& pos, FuncT&& func) + { + auto funcIfLegal = [&func, checker = pos.moveLegalityChecker()](Move move) { + if (checker.isPseudoLegalMoveLegal(move)) + { + func(move); + } + }; + + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachPseudoLegalPieceMove(pos, funcIfLegal); + forEachCastlingMove(pos, func); + } + + // Generates all pseudo legal moves for the position. + // `pos` must be a legal chess position + [[nodiscard]] std::vector generatePseudoLegalMoves(const Position& pos); + + // Generates all legal moves for the position. + // `pos` must be a legal chess position + [[nodiscard]] std::vector generateLegalMoves(const Position& pos); + } + + [[nodiscard]] inline bool Position::isCheck() const + { + return BaseType::isSquareAttacked(kingSquare(m_sideToMove), !m_sideToMove); + } + + [[nodiscard]] inline Bitboard Position::checkers() const + { + return BaseType::attackers(kingSquare(m_sideToMove), !m_sideToMove); + } + + [[nodiscard]] inline bool Position::isCheckAfterMove(Move move) const + { + return BaseType::isSquareAttackedAfterMove(move, kingSquare(!m_sideToMove), m_sideToMove); + } + + [[nodiscard]] inline bool Position::isMoveLegal(Move move) const + { + return + isMovePseudoLegal(move) + && isPseudoLegalMoveLegal(move); + } + + [[nodiscard]] inline bool Position::isPseudoLegalMoveLegal(Move move) const + { + return + (move.type == MoveType::Castle) + || !isOwnKingAttackedAfterMove(move); + } + + [[nodiscard]] inline bool Position::isMovePseudoLegal(Move move) const + { + if (!move.from.isOk() || !move.to.isOk()) + { + return false; + } + + if (move.from == move.to) + { + return false; + } + + if (move.type != MoveType::Promotion && move.promotedPiece != Piece::none()) + { + return false; + } + + const Piece movedPiece = pieceAt(move.from); + if (movedPiece == Piece::none()) + { + return false; + } + + if (movedPiece.color() != m_sideToMove) + { + return false; + } + + const Bitboard occupied = piecesBB(); + const Bitboard ourPieces = piecesBB(m_sideToMove); + const bool isNormal = move.type == MoveType::Normal; + + switch (movedPiece.type()) + { + case PieceType::Pawn: + { + bool isValid = false; + // TODO: use iterators so we don't loop over all moves + // when we can avoid it. + movegen::forEachPseudoLegalPawnMove(*this, move.from, [&isValid, &move](const Move& genMove) { + if (move == genMove) + { + isValid = true; + } + }); + return isValid; + } + + case PieceType::Bishop: + return isNormal && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); + + case PieceType::Knight: + return isNormal && (bb::pseudoAttacks(move.from) & ~ourPieces).isSet(move.to); + + case PieceType::Rook: + return isNormal && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); + + case PieceType::Queen: + return isNormal && (bb::attacks(move.from, occupied) & ~ourPieces).isSet(move.to); + + case PieceType::King: + { + if (move.type == MoveType::Castle) + { + bool isValid = false; + movegen::forEachCastlingMove(*this, [&isValid, &move](const Move& genMove) { + if (move == genMove) + { + isValid = true; + } + }); + return isValid; + } + else + { + return isNormal && (bb::pseudoAttacks(move.from) & ~ourPieces).isSet(move.to); + } + } + + default: + return false; + } + } + + [[nodiscard]] inline Bitboard Position::blockersForKing(Color color) const + { + const Color attackerColor = !color; + + const Bitboard occupied = piecesBB(); + + const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); + + const Square ksq = kingSquare(color); + + const Bitboard opponentBishopLikePieces = (bishops | queens); + const Bitboard bishopPseudoAttacks = bb::pseudoAttacks(ksq); + + const Bitboard opponentRookLikePieces = (rooks | queens); + const Bitboard rookPseudoAttacks = bb::pseudoAttacks(ksq); + + const Bitboard xrayers = + (bishopPseudoAttacks & opponentBishopLikePieces) + | (rookPseudoAttacks & opponentRookLikePieces); + + Bitboard allBlockers = Bitboard::none(); + + for (Square xrayer : xrayers) + { + const Bitboard blockers = bb::between(xrayer, ksq) & occupied; + if (blockers.exactlyOne()) + { + allBlockers |= blockers; + } + } + + return allBlockers; + } + + inline MoveLegalityChecker::MoveLegalityChecker(const Position& position) : + m_position(&position), + m_checkers(position.checkers()), + m_ourBlockersForKing( + position.blockersForKing(position.sideToMove()) + & position.piecesBB(position.sideToMove()) + ), + m_ksq(position.kingSquare(position.sideToMove())) + { + if (m_checkers.exactlyOne()) + { + const Bitboard knightCheckers = m_checkers & bb::pseudoAttacks(m_ksq); + if (knightCheckers.any()) + { + // We're checked by a knight, we have to remove it or move the king. + m_potentialCheckRemovals = knightCheckers; + } + else + { + // If we're not checked by a knight we can block it. + m_potentialCheckRemovals = bb::between(m_ksq, m_checkers.first()) | m_checkers; + } + } + else + { + // Double check, king has to move. + m_potentialCheckRemovals = Bitboard::none(); + } + } + + [[nodiscard]] inline bool MoveLegalityChecker::isPseudoLegalMoveLegal(const Move& move) const + { + if (m_checkers.any()) + { + if (move.from == m_ksq || move.type == MoveType::EnPassant) + { + return m_position->isPseudoLegalMoveLegal(move); + } + else + { + // This means there's only one check and we either + // blocked it or removed the piece that attacked + // our king. So the only threat is if it's a discovered check. + return + m_potentialCheckRemovals.isSet(move.to) + && !m_ourBlockersForKing.isSet(move.from); + } + } + else + { + if (move.from == m_ksq) + { + return m_position->isPseudoLegalMoveLegal(move); + } + else if (move.type == MoveType::EnPassant) + { + return !m_position->createsDiscoveredAttackOnOwnKing(move); + } + else if (m_ourBlockersForKing.isSet(move.from)) + { + // If it was a blocker it may have only moved in line with our king. + // Otherwise it's a discovered check. + return bb::line(m_ksq, move.from).isSet(move.to); + } + else + { + return true; + } + } + } + + static_assert(sizeof(CompressedPosition) == 24); + static_assert(std::is_trivially_copyable_v); + + namespace detail + { + [[nodiscard]] FORCEINLINE constexpr std::uint8_t compressOrdinaryPiece(const Position&, Square, Piece piece) + { + return static_cast(ordinal(piece)); + } + + [[nodiscard]] FORCEINLINE constexpr std::uint8_t compressPawn(const Position& position, Square sq, Piece piece) + { + const Square epSquare = position.epSquare(); + if (epSquare == Square::none()) + { + return static_cast(ordinal(piece)); + } + else + { + const Color sideToMove = position.sideToMove(); + const Rank rank = sq.rank(); + const File file = sq.file(); + // use bitwise operators, there is a lot of unpredictable branches but in + // total the result is quite predictable + if ( + (file == epSquare.file()) + && ( + ((rank == rank4) & (sideToMove == Color::Black)) + | ((rank == rank5) & (sideToMove == Color::White)) + ) + ) + { + return 12; + } + else + { + return static_cast(ordinal(piece)); + } + } + } + + [[nodiscard]] FORCEINLINE constexpr std::uint8_t compressRook(const Position& position, Square sq, Piece piece) + { + const CastlingRights castlingRights = position.castlingRights(); + const Color color = piece.color(); + + if (color == Color::White + && ( + (sq == a1 && contains(castlingRights, CastlingRights::WhiteQueenSide)) + || (sq == h1 && contains(castlingRights, CastlingRights::WhiteKingSide)) + ) + ) + { + return 13; + } + else if ( + color == Color::Black + && ( + (sq == a8 && contains(castlingRights, CastlingRights::BlackQueenSide)) + || (sq == h8 && contains(castlingRights, CastlingRights::BlackKingSide)) + ) + ) + { + return 14; + } + else + { + return static_cast(ordinal(piece)); + } + } + + [[nodiscard]] FORCEINLINE constexpr std::uint8_t compressKing(const Position& position, Square /* sq */, Piece piece) + { + const Color color = piece.color(); + const Color sideToMove = position.sideToMove(); + + if (color == Color::White) + { + return 10; + } + else if (sideToMove == Color::White) + { + return 11; + } + else + { + return 15; + } + } + } + + namespace detail::lookup + { + static constexpr EnumArray pieceCompressorFunc = []() { + EnumArray pieceCompressorFunc_{}; + + pieceCompressorFunc_[PieceType::Knight] = detail::compressOrdinaryPiece; + pieceCompressorFunc_[PieceType::Bishop] = detail::compressOrdinaryPiece; + pieceCompressorFunc_[PieceType::Queen] = detail::compressOrdinaryPiece; + + pieceCompressorFunc_[PieceType::Pawn] = detail::compressPawn; + pieceCompressorFunc_[PieceType::Rook] = detail::compressRook; + pieceCompressorFunc_[PieceType::King] = detail::compressKing; + + pieceCompressorFunc_[PieceType::None] = [](const Position&, Square, Piece) -> std::uint8_t { /* should never happen */ return 0; }; + + return pieceCompressorFunc_; + }(); + } + + [[nodiscard]] inline CompressedPosition Position::compress() const + { + auto compressPiece = [this](Square sq, Piece piece) -> std::uint8_t { + if (piece.type() == PieceType::Pawn) // it's likely to be a pawn + { + return detail::compressPawn(*this, sq, piece); + } + else + { + return detail::lookup::pieceCompressorFunc[piece.type()](*this, sq, piece); + } + }; + + const Bitboard occ = piecesBB(); + + CompressedPosition compressed; + compressed.m_occupied = occ; + + auto it = occ.begin(); + auto end = occ.end(); + for (int i = 0;; ++i) + { + if (it == end) break; + compressed.m_packedState[i] = compressPiece(*it, pieceAt(*it)); + ++it; + + if (it == end) break; + compressed.m_packedState[i] |= compressPiece(*it, pieceAt(*it)) << 4; + ++it; + } + + return compressed; + } + + [[nodiscard]] inline Position CompressedPosition::decompress() const + { + Position pos; + pos.setCastlingRights(CastlingRights::None); + + auto decompressPiece = [&pos](Square sq, std::uint8_t nibble) { + switch (nibble) + { + case 0: + case 1: + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + case 8: + case 9: + case 10: + case 11: + { + pos.place(fromOrdinal(nibble), sq); + return; + } + + case 12: + { + const Rank rank = sq.rank(); + if (rank == rank4) + { + pos.place(whitePawn, sq); + pos.setEpSquareUnchecked(sq + Offset{ 0, -1 }); + } + else // (rank == rank5) + { + pos.place(blackPawn, sq); + pos.setEpSquareUnchecked(sq + Offset{ 0, 1 }); + } + return; + } + + case 13: + { + pos.place(whiteRook, sq); + if (sq == a1) + { + pos.addCastlingRights(CastlingRights::WhiteQueenSide); + } + else // (sq == H1) + { + pos.addCastlingRights(CastlingRights::WhiteKingSide); + } + return; + } + + case 14: + { + pos.place(blackRook, sq); + if (sq == a8) + { + pos.addCastlingRights(CastlingRights::BlackQueenSide); + } + else // (sq == H8) + { + pos.addCastlingRights(CastlingRights::BlackKingSide); + } + return; + } + + case 15: + { + pos.place(blackKing, sq); + pos.setSideToMove(Color::Black); + return; + } + + } + + return; + }; + + const Bitboard occ = m_occupied; + + auto it = occ.begin(); + auto end = occ.end(); + for (int i = 0;; ++i) + { + if (it == end) break; + decompressPiece(*it, m_packedState[i] & 0xF); + ++it; + + if (it == end) break; + decompressPiece(*it, m_packedState[i] >> 4); + ++it; + } + + return pos; + } + + + [[nodiscard]] bool Board::isSquareAttacked(Square sq, Color attackerColor) const + { + assert(sq.isOk()); + + const Bitboard occupied = piecesBB(); + const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); + + const Bitboard allSliders = (bishops | rooks | queens); + if ((bb::pseudoAttacks(sq) & allSliders).any()) + { + if (bb::isAttackedBySlider( + sq, + bishops, + rooks, + queens, + occupied + )) + { + return true; + } + } + + const Bitboard king = piecesBB(Piece(PieceType::King, attackerColor)); + if ((bb::pseudoAttacks(sq) & king).any()) + { + return true; + } + + const Bitboard knights = piecesBB(Piece(PieceType::Knight, attackerColor)); + if ((bb::pseudoAttacks(sq) & knights).any()) + { + return true; + } + + const Bitboard pawns = piecesBB(Piece(PieceType::Pawn, attackerColor)); + const Bitboard pawnAttacks = bb::pawnAttacks(pawns, attackerColor); + + return pawnAttacks.isSet(sq); + } + + [[nodiscard]] bool Board::isSquareAttackedAfterMove(Move move, Square sq, Color attackerColor) const + { + const Bitboard occupiedChange = Bitboard::square(move.from) | move.to; + + Bitboard occupied = (piecesBB() ^ move.from) | move.to; + + Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); + Bitboard king = piecesBB(Piece(PieceType::King, attackerColor)); + Bitboard knights = piecesBB(Piece(PieceType::Knight, attackerColor)); + Bitboard pawns = piecesBB(Piece(PieceType::Pawn, attackerColor)); + + if (move.type == MoveType::EnPassant) + { + const Square capturedPawnSq(move.to.file(), move.from.rank()); + occupied ^= capturedPawnSq; + pawns ^= capturedPawnSq; + } + else if (pieceAt(move.to) != Piece::none()) + { + const Bitboard notCaptured = ~Bitboard::square(move.to); + bishops &= notCaptured; + rooks &= notCaptured; + queens &= notCaptured; + knights &= notCaptured; + pawns &= notCaptured; + } + + // Potential attackers may have moved. + const Piece movedPiece = pieceAt(move.from); + if (movedPiece.color() == attackerColor) + { + switch (movedPiece.type()) + { + case PieceType::Pawn: + pawns ^= occupiedChange; + break; + case PieceType::Knight: + knights ^= occupiedChange; + break; + case PieceType::Bishop: + bishops ^= occupiedChange; + break; + case PieceType::Rook: + rooks ^= occupiedChange; + break; + case PieceType::Queen: + queens ^= occupiedChange; + break; + case PieceType::King: + { + if (move.type == MoveType::Castle) + { + const CastleType castleType = CastlingTraits::moveCastlingType(move); + + king ^= move.from; + king ^= CastlingTraits::kingDestination[attackerColor][castleType]; + rooks ^= move.to; + rooks ^= CastlingTraits::rookDestination[attackerColor][castleType]; + } + else + { + king ^= occupiedChange; + } + + break; + } + case PieceType::None: + assert(false); + } + } + + // If it's a castling move then the change in square occupation + // cannot have an effect because otherwise there would be + // a slider attacker attacking the castling king. + // (It could have an effect in chess960 if the slider + // attacker was behind the rook involved in castling, + // but we don't care about chess960.) + + const Bitboard allSliders = (bishops | rooks | queens); + if ((bb::pseudoAttacks(sq) & allSliders).any()) + { + if (bb::isAttackedBySlider( + sq, + bishops, + rooks, + queens, + occupied + )) + { + return true; + } + } + + if ((bb::pseudoAttacks(sq) & king).any()) + { + return true; + } + + if ((bb::pseudoAttacks(sq) & knights).any()) + { + return true; + } + + const Bitboard pawnAttacks = bb::pawnAttacks(pawns, attackerColor); + + return pawnAttacks.isSet(sq); + } + + [[nodiscard]] bool Board::createsDiscoveredAttackOnOwnKing(Move move) const + { + Bitboard occupied = (piecesBB() ^ move.from) | move.to; + + const Piece movedPiece = pieceAt(move.from); + const Color kingColor = movedPiece.color(); + const Color attackerColor = !kingColor; + const Square ksq = kingSquare(kingColor); + + Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); + + if (move.type == MoveType::EnPassant) + { + const Square capturedPawnSq(move.to.file(), move.from.rank()); + occupied ^= capturedPawnSq; + } + else if (pieceAt(move.to) != Piece::none()) + { + const Bitboard notCaptured = ~Bitboard::square(move.to); + bishops &= notCaptured; + rooks &= notCaptured; + queens &= notCaptured; + } + + const Bitboard allSliders = (bishops | rooks | queens); + if ((bb::pseudoAttacks(ksq) & allSliders).any()) + { + if (bb::isAttackedBySlider( + ksq, + bishops, + rooks, + queens, + occupied + )) + { + return true; + } + } + + return false; + } + + [[nodiscard]] bool Board::isPieceAttacked(Square sq) const + { + const Piece piece = pieceAt(sq); + + if (piece == Piece::none()) + { + return false; + } + + return isSquareAttacked(sq, !piece.color()); + } + + [[nodiscard]] bool Board::isPieceAttackedAfterMove(Move move, Square sq) const + { + const Piece piece = pieceAt(sq); + + if (piece == Piece::none()) + { + return false; + } + + if (sq == move.from) + { + // We moved the piece we're interested in. + // For every move the piece ends up on the move.to except + // for the case of castling moves. + // But we know pseudo legal castling moves + // are already legal, so the king cannot be in check after. + if (move.type == MoveType::Castle) + { + return false; + } + + // So update the square we're interested in. + sq = move.to; + } + + return isSquareAttackedAfterMove(move, sq, !piece.color()); + } + + [[nodiscard]] bool Board::isOwnKingAttackedAfterMove(Move move) const + { + if (move.type == MoveType::Castle) + { + // Pseudo legal castling moves are already legal. + // This is ensured by the move generator. + return false; + } + + const Piece movedPiece = pieceAt(move.from); + + return isPieceAttackedAfterMove(move, kingSquare(movedPiece.color())); + } + + [[nodiscard]] Bitboard Board::attacks(Square sq) const + { + const Piece piece = pieceAt(sq); + if (piece == Piece::none()) + { + return Bitboard::none(); + } + + if (piece.type() == PieceType::Pawn) + { + return bb::pawnAttacks(Bitboard::square(sq), piece.color()); + } + else + { + return bb::attacks(piece.type(), sq, piecesBB()); + } + } + + [[nodiscard]] Bitboard Board::attackers(Square sq, Color attackerColor) const + { + // En-passant square is not included. + + Bitboard allAttackers = Bitboard::none(); + + const Bitboard occupied = piecesBB(); + + const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, attackerColor)); + const Bitboard rooks = piecesBB(Piece(PieceType::Rook, attackerColor)); + const Bitboard queens = piecesBB(Piece(PieceType::Queen, attackerColor)); + + const Bitboard bishopLikePieces = (bishops | queens); + const Bitboard bishopAttacks = bb::attacks(sq, occupied); + allAttackers |= bishopAttacks & bishopLikePieces; + + const Bitboard rookLikePieces = (rooks | queens); + const Bitboard rookAttacks = bb::attacks(sq, occupied); + allAttackers |= rookAttacks & rookLikePieces; + + const Bitboard king = piecesBB(Piece(PieceType::King, attackerColor)); + allAttackers |= bb::pseudoAttacks(sq) & king; + + const Bitboard knights = piecesBB(Piece(PieceType::Knight, attackerColor)); + allAttackers |= bb::pseudoAttacks(sq) & knights; + + const Bitboard pawns = piecesBB(Piece(PieceType::Pawn, attackerColor)); + allAttackers |= bb::pawnAttacks(Bitboard::square(sq), !attackerColor) & pawns; + + return allAttackers; + } + + inline const Piece* Board::piecesRaw() const + { + return m_pieces.data(); + } + + namespace detail::lookup + { + static constexpr EnumArray fenPiece = []() { + EnumArray fenPiece_{}; + + fenPiece_[whitePawn] = 'P'; + fenPiece_[blackPawn] = 'p'; + fenPiece_[whiteKnight] = 'N'; + fenPiece_[blackKnight] = 'n'; + fenPiece_[whiteBishop] = 'B'; + fenPiece_[blackBishop] = 'b'; + fenPiece_[whiteRook] = 'R'; + fenPiece_[blackRook] = 'r'; + fenPiece_[whiteQueen] = 'Q'; + fenPiece_[blackQueen] = 'q'; + fenPiece_[whiteKing] = 'K'; + fenPiece_[blackKing] = 'k'; + fenPiece_[Piece::none()] = 'X'; + + return fenPiece_; + }(); + } + + [[nodiscard]] inline std::string Board::fen() const + { + std::string fen; + fen.reserve(96); // longest fen is probably in range of around 88 + + Rank rank = rank8; + File file = fileA; + std::uint8_t emptyCounter = 0; + + for (;;) + { + const Square sq(file, rank); + const Piece piece = m_pieces[sq]; + + if (piece == Piece::none()) + { + ++emptyCounter; + } + else + { + if (emptyCounter != 0) + { + fen.push_back(static_cast(emptyCounter) + '0'); + emptyCounter = 0; + } + + fen.push_back(detail::lookup::fenPiece[piece]); + } + + ++file; + if (file > fileH) + { + file = fileA; + --rank; + + if (emptyCounter != 0) + { + fen.push_back(static_cast(emptyCounter) + '0'); + emptyCounter = 0; + } + + if (rank < rank1) + { + break; + } + fen.push_back('/'); + } + } + + return fen; + } + + void Position::set(std::string_view fen) + { + (void)trySet(fen); + } + + // Returns false if the fen was not valid + // If the returned value was false the position + // is in unspecified state. + [[nodiscard]] bool Position::trySet(std::string_view fen) + { + // Lazily splits by ' '. Returns empty string views if at the end. + auto nextPart = [fen, start = std::size_t{ 0 }]() mutable { + std::size_t end = fen.find(' ', start); + if (end == std::string::npos) + { + std::string_view substr = fen.substr(start); + start = fen.size(); + return substr; + } + else + { + std::string_view substr = fen.substr(start, end - start); + start = end + 1; // to skip whitespace + return substr; + } + }; + + if (!BaseType::trySet(nextPart())) return false; + + { + const auto side = nextPart(); + if (side == std::string_view("w")) m_sideToMove = Color::White; + else if (side == std::string_view("b")) m_sideToMove = Color::Black; + else return false; + + if (isSquareAttacked(kingSquare(!m_sideToMove), m_sideToMove)) return false; + } + + { + const auto castlingRights = nextPart(); + auto castlingRightsOpt = parser_bits::tryParseCastlingRights(castlingRights); + if (!castlingRightsOpt.has_value()) + { + return false; + } + else + { + m_castlingRights = *castlingRightsOpt; + } + } + + { + const auto epSquare = nextPart(); + auto epSquareOpt = parser_bits::tryParseEpSquare(epSquare); + if (!epSquareOpt.has_value()) + { + return false; + } + else + { + m_epSquare = *epSquareOpt; + } + } + + { + const auto rule50 = nextPart(); + if (!rule50.empty()) + { + m_rule50Counter = std::stoi(rule50.data()); + } + else + { + m_rule50Counter = 0; + } + } + + { + const auto fullMove = nextPart(); + if (!fullMove.empty()) + { + m_ply = 2 * (std::stoi(fullMove.data()) - 1) + (m_sideToMove == Color::Black); + } + else + { + m_ply = 0; + } + } + + nullifyEpSquareIfNotPossible(); + + return true; + } + + [[nodiscard]] Position Position::fromFen(std::string_view fen) + { + Position pos{}; + pos.set(fen); + return pos; + } + + [[nodiscard]] std::optional Position::tryFromFen(std::string_view fen) + { + Position pos{}; + if (pos.trySet(fen)) return pos; + else return {}; + } + + [[nodiscard]] Position Position::startPosition() + { + static const Position pos = fromFen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"); + return pos; + } + + [[nodiscard]] std::string Position::fen() const + { + std::string fen = Board::fen(); + + fen += ' '; + fen += m_sideToMove == Color::White ? 'w' : 'b'; + + fen += ' '; + parser_bits::appendCastlingRightsToString(m_castlingRights, fen); + + fen += ' '; + parser_bits::appendEpSquareToString(m_epSquare, fen); + + fen += ' '; + fen += std::to_string(m_rule50Counter); + + fen += ' '; + fen += std::to_string(fullMove()); + + return fen; + } + + namespace detail::lookup + { + static constexpr EnumArray preservedCastlingRights = []() { + EnumArray preservedCastlingRights_{}; + for (CastlingRights& rights : preservedCastlingRights_) + { + rights = ~CastlingRights::None; + } + + preservedCastlingRights_[e1] = ~CastlingRights::White; + preservedCastlingRights_[e8] = ~CastlingRights::Black; + + preservedCastlingRights_[h1] = ~CastlingRights::WhiteKingSide; + preservedCastlingRights_[a1] = ~CastlingRights::WhiteQueenSide; + preservedCastlingRights_[h8] = ~CastlingRights::BlackKingSide; + preservedCastlingRights_[a8] = ~CastlingRights::BlackQueenSide; + + return preservedCastlingRights_; + }(); + } + + inline ReverseMove Position::doMove(const Move& move) + { + assert(move.from.isOk() && move.to.isOk()); + + const PieceType movedPiece = pieceAt(move.from).type(); + + m_ply += 1; + m_rule50Counter += 1; + + if (move.type != MoveType::Castle && (movedPiece == PieceType::Pawn || pieceAt(move.to) != Piece::none())) + { + m_rule50Counter = 0; + } + + const Square oldEpSquare = m_epSquare; + const CastlingRights oldCastlingRights = m_castlingRights; + m_castlingRights &= detail::lookup::preservedCastlingRights[move.from]; + m_castlingRights &= detail::lookup::preservedCastlingRights[move.to]; + + m_epSquare = Square::none(); + // for double pushes move index differs by 16 or -16; + if((movedPiece == PieceType::Pawn) & ((ordinal(move.to) ^ ordinal(move.from)) == 16)) + { + m_epSquare = fromOrdinal((ordinal(move.to) + ordinal(move.from)) >> 1); + } + + const Piece captured = BaseType::doMove(move); + m_sideToMove = !m_sideToMove; + + nullifyEpSquareIfNotPossible(); + + return { move, captured, oldEpSquare, oldCastlingRights }; + } + + [[nodiscard]] inline Position Position::afterMove(Move move) const + { + Position cpy(*this); + auto pc = cpy.doMove(move); + + (void)pc; + //assert(cpy.beforeMove(move, pc) == *this); // this assert would result in infinite recursion + + return cpy; + } + + [[nodiscard]] inline bool Position::isEpPossible(Square epSquare, Color sideToMove) const + { + const Bitboard pawnsAttackingEpSquare = + bb::pawnAttacks(Bitboard::square(epSquare), !sideToMove) + & piecesBB(Piece(PieceType::Pawn, sideToMove)); + + if (!pawnsAttackingEpSquare.any()) + { + return false; + } + + return isEpPossibleColdPath(epSquare, pawnsAttackingEpSquare, sideToMove); + } + + [[nodiscard]] inline bool Position::isEpPossibleColdPath(Square epSquare, Bitboard pawnsAttackingEpSquare, Color sideToMove) const + { + if (pieceAt(epSquare) != Piece::none()) + { + return false; + } + + const auto forward = + sideToMove == chess::Color::White + ? FlatSquareOffset(0, 1) + : FlatSquareOffset(0, -1); + + if (pieceAt(epSquare + forward) != Piece::none()) + { + return false; + } + + if (pieceAt(epSquare + -forward) != Piece(PieceType::Pawn, !sideToMove)) + { + return false; + } + + // only set m_epSquare when it matters, ie. when + // the opposite side can actually capture + for (Square sq : pawnsAttackingEpSquare) + { + // If we're here the previous move by other side + // was a double pawn move so our king is either not in check + // or is attacked only by the moved pawn - in which + // case it can be captured by our pawn if it doesn't + // create a discovered check on our king. + // So overall we only have to check whether our king + // ends up being uncovered to a slider attack. + + const Square ksq = kingSquare(sideToMove); + + const Bitboard bishops = piecesBB(Piece(PieceType::Bishop, !sideToMove)); + const Bitboard rooks = piecesBB(Piece(PieceType::Rook, !sideToMove)); + const Bitboard queens = piecesBB(Piece(PieceType::Queen, !sideToMove)); + + const Bitboard relevantAttackers = bishops | rooks | queens; + const Bitboard pseudoSliderAttacksFromKing = bb::pseudoAttacks(ksq); + if ((relevantAttackers & pseudoSliderAttacksFromKing).isEmpty()) + { + // It's enough that one pawn can capture. + return true; + } + + const Square capturedPawnSq(epSquare.file(), sq.rank()); + const Bitboard occupied = ((piecesBB() ^ sq) | epSquare) ^ capturedPawnSq; + + if (!bb::isAttackedBySlider( + ksq, + bishops, + rooks, + queens, + occupied + )) + { + // It's enough that one pawn can capture. + return true; + } + } + + return false; + } + + inline void Position::nullifyEpSquareIfNotPossible() + { + if (m_epSquare != Square::none() && !isEpPossible(m_epSquare, m_sideToMove)) + { + m_epSquare = Square::none(); + } + } + + namespace uci + { + [[nodiscard]] inline std::string moveToUci(const Position& pos, const Move& move); + [[nodiscard]] inline Move uciToMove(const Position& pos, std::string_view sv); + + [[nodiscard]] inline std::string moveToUci(const Position& pos, const Move& move) + { + std::string s; + + parser_bits::appendSquareToString(move.from, s); + + if (move.type == MoveType::Castle) + { + const CastleType castleType = CastlingTraits::moveCastlingType(move); + + const Square kingDestination = CastlingTraits::kingDestination[pos.sideToMove()][castleType]; + parser_bits::appendSquareToString(kingDestination, s); + } + else + { + parser_bits::appendSquareToString(move.to, s); + + if (move.type == MoveType::Promotion) + { + // lowercase piece symbol + s += EnumTraits::toChar(move.promotedPiece.type(), Color::Black); + } + } + + return s; + } + + [[nodiscard]] inline Move uciToMove(const Position& pos, std::string_view sv) + { + const Square from = parser_bits::parseSquare(sv.data()); + const Square to = parser_bits::parseSquare(sv.data() + 2); + + if (sv.size() == 5) + { + const PieceType promotedPieceType = *fromChar(sv[4]); + return Move::promotion(from, to, Piece(promotedPieceType, pos.sideToMove())); + } + else + { + if ( + pos.pieceAt(from).type() == PieceType::King + && std::abs(from.file() - to.file()) > 1 + ) + { + // uci king destinations are on files C or G. + const CastleType castleType = + (to.file() == fileG) + ? CastleType::Short + : CastleType::Long; + + return Move::castle(castleType, pos.sideToMove()); + } + else if (pos.pieceAt(from).type() == PieceType::Pawn && pos.epSquare() == to) + { + return Move::enPassant(from, to); + } + else + { + return Move::normal(from, to); + } + } + } + } +} + +namespace binpack +{ + constexpr std::size_t KiB = 1024; + constexpr std::size_t MiB = (1024*KiB); + constexpr std::size_t GiB = (1024*MiB); + + constexpr std::size_t suggestedChunkSize = MiB; + constexpr std::size_t maxMovelistSize = 10*KiB; // a safe upper bound + constexpr std::size_t maxChunkSize = 100*MiB; // to prevent malformed files from causing huge allocations + + using namespace std::literals; + + namespace nodchip + { + // This namespace contains modified code from https://github.com/nodchip/Stockfish + // which is released under GPL v3 license https://www.gnu.org/licenses/gpl-3.0.html + + using namespace std; + + struct StockfishMove + { + [[nodiscard]] static StockfishMove fromMove(chess::Move move) + { + StockfishMove sfm; + + sfm.m_raw = 0; + + unsigned moveFlag = 0; + if (move.type == chess::MoveType::Promotion) moveFlag = 1; + else if (move.type == chess::MoveType::EnPassant) moveFlag = 2; + else if (move.type == chess::MoveType::Castle) moveFlag = 3; + + unsigned promotionIndex = 0; + if (move.type == chess::MoveType::Promotion) + { + promotionIndex = static_cast(move.promotedPiece.type()) - static_cast(chess::PieceType::Knight); + } + + sfm.m_raw |= static_cast(moveFlag); + sfm.m_raw <<= 2; + sfm.m_raw |= static_cast(promotionIndex); + sfm.m_raw <<= 6; + sfm.m_raw |= static_cast(move.from); + sfm.m_raw <<= 6; + sfm.m_raw |= static_cast(move.to); + + return sfm; + } + + [[nodiscard]] chess::Move toMove() const + { + const chess::Square to = static_cast((m_raw & (0b111111 << 0) >> 0)); + const chess::Square from = static_cast((m_raw & (0b111111 << 6)) >> 6); + + const unsigned promotionIndex = (m_raw & (0b11 << 12)) >> 12; + const chess::PieceType promotionType = static_cast(static_cast(chess::PieceType::Knight) + promotionIndex); + + const unsigned moveFlag = (m_raw & (0b11 << 14)) >> 14; + chess::MoveType type = chess::MoveType::Normal; + if (moveFlag == 1) type = chess::MoveType::Promotion; + else if (moveFlag == 2) type = chess::MoveType::EnPassant; + else if (moveFlag == 3) type = chess::MoveType::Castle; + + if (type == chess::MoveType::Promotion) + { + const chess::Color stm = to.rank() == chess::rank8 ? chess::Color::White : chess::Color::Black; + return chess::Move{from, to, type, chess::Piece(promotionType, stm)}; + } + + return chess::Move{from, to, type}; + } + + [[nodiscard]] std::string toString() const + { + const chess::Square to = static_cast((m_raw & (0b111111 << 0) >> 0)); + const chess::Square from = static_cast((m_raw & (0b111111 << 6)) >> 6); + + const unsigned promotionIndex = (m_raw & (0b11 << 12)) >> 12; + const chess::PieceType promotionType = static_cast(static_cast(chess::PieceType::Knight) + promotionIndex); + + std::string r; + chess::parser_bits::appendSquareToString(from, r); + chess::parser_bits::appendSquareToString(to, r); + if (promotionType != chess::PieceType::None) + { + r += chess::EnumTraits::toChar(promotionType, chess::Color::Black); + } + + return r; + } + + private: + std::uint16_t m_raw; + }; + static_assert(sizeof(StockfishMove) == sizeof(std::uint16_t)); + + struct PackedSfen + { + uint8_t data[32]; + }; + + struct PackedSfenValue + { + // phase + PackedSfen sfen; + + // Evaluation value returned from Learner::search() + int16_t score; + + // PV first move + // Used when finding the match rate with the teacher + StockfishMove move; + + // Trouble of the phase from the initial phase. + uint16_t gamePly; + + // 1 if the player on this side ultimately wins the game. -1 if you are losing. + // 0 if a draw is reached. + // The draw is in the teacher position generation command gensfen, + // Only write if LEARN_GENSFEN_DRAW_RESULT is enabled. + int8_t game_result; + + // When exchanging the file that wrote the teacher aspect with other people + //Because this structure size is not fixed, pad it so that it is 40 bytes in any environment. + uint8_t padding; + + // 32 + 2 + 2 + 2 + 1 + 1 = 40bytes + }; + static_assert(sizeof(PackedSfenValue) == 40); + // Class that handles bitstream + + // useful when doing aspect encoding + struct BitStream + { + // Set the memory to store the data in advance. + // Assume that memory is cleared to 0. + void set_data(uint8_t* data_) { data = data_; reset(); } + + // Get the pointer passed in set_data(). + uint8_t* get_data() const { return data; } + + // Get the cursor. + int get_cursor() const { return bit_cursor; } + + // reset the cursor + void reset() { bit_cursor = 0; } + + // Write 1bit to the stream. + // If b is non-zero, write out 1. If 0, write 0. + void write_one_bit(int b) + { + if (b) + data[bit_cursor / 8] |= 1 << (bit_cursor & 7); + + ++bit_cursor; + } + + // Get 1 bit from the stream. + int read_one_bit() + { + int b = (data[bit_cursor / 8] >> (bit_cursor & 7)) & 1; + ++bit_cursor; + + return b; + } + + // write n bits of data + // Data shall be written out from the lower order of d. + void write_n_bit(int d, int n) + { + for (int i = 0; i (pos.kingSquare(chess::Color::White)), 6); + stream.write_n_bit(static_cast(pos.kingSquare(chess::Color::Black)), 6); + + // Write the pieces on the board other than the kings. + for (chess::Rank r = chess::rank8; r >= chess::rank1; --r) + { + for (chess::File f = chess::fileA; f <= chess::fileH; ++f) + { + chess::Piece pc = pos.pieceAt(chess::Square(f, r)); + if (pc.type() == chess::PieceType::King) + continue; + write_board_piece_to_stream(pc); + } + } + + // TODO(someone): Support chess960. + auto cr = pos.castlingRights(); + stream.write_one_bit(contains(cr, chess::CastlingRights::WhiteKingSide)); + stream.write_one_bit(contains(cr, chess::CastlingRights::WhiteQueenSide)); + stream.write_one_bit(contains(cr, chess::CastlingRights::BlackKingSide)); + stream.write_one_bit(contains(cr, chess::CastlingRights::BlackQueenSide)); + + if (pos.epSquare() == chess::Square::none()) { + stream.write_one_bit(0); + } + else { + stream.write_one_bit(1); + stream.write_n_bit(static_cast(pos.epSquare()), 6); + } + + stream.write_n_bit(pos.rule50Counter(), 6); + + stream.write_n_bit(pos.fullMove(), 8); + + // Write high bits of half move. This is a fix for the + // limited range of half move counter. + // This is backwards compatibile. + stream.write_n_bit(pos.fullMove() >> 8, 8); + + // Write the highest bit of rule50 at the end. This is a backwards + // compatibile fix for rule50 having only 6 bits stored. + // This bit is just ignored by the old parsers. + stream.write_n_bit(pos.rule50Counter() >> 6, 1); + + assert(stream.get_cursor() <= 256); + } + + // sfen packed by pack() (256bit = 32bytes) + // Or sfen to decode with unpack() + uint8_t *data; // uint8_t[32]; + + BitStream stream; + + // Output the board pieces to stream. + void write_board_piece_to_stream(chess::Piece pc) + { + // piece type + chess::PieceType pr = pc.type(); + auto c = huffman_table[static_cast(pr)]; + stream.write_n_bit(c.code, c.bits); + + if (pc == chess::Piece::none()) + return; + + // first and second flag + stream.write_one_bit(static_cast(pc.color())); + } + + // Read one board piece from stream + [[nodiscard]] chess::Piece read_board_piece_from_stream() + { + int pr = static_cast(chess::PieceType::None); + int code = 0, bits = 0; + while (true) + { + code |= stream.read_one_bit() << bits; + ++bits; + + assert(bits <= 6); + + for (pr = static_cast(chess::PieceType::Pawn); pr <= static_cast(chess::PieceType::None); ++pr) + if (huffman_table[pr].code == code + && huffman_table[pr].bits == bits) + goto Found; + } + Found:; + if (pr == static_cast(chess::PieceType::None)) + return chess::Piece::none(); + + // first and second flag + chess::Color c = (chess::Color)stream.read_one_bit(); + + return chess::Piece(static_cast(pr), c); + } + }; + + + [[nodiscard]] inline chess::Position pos_from_packed_sfen(const PackedSfen& sfen) + { + SfenPacker packer; + auto& stream = packer.stream; + stream.set_data(const_cast(reinterpret_cast(&sfen))); + + chess::Position pos{}; + + // Active color + pos.setSideToMove((chess::Color)stream.read_one_bit()); + + // First the position of the ball + pos.place(chess::Piece(chess::PieceType::King, chess::Color::White), static_cast(stream.read_n_bit(6))); + pos.place(chess::Piece(chess::PieceType::King, chess::Color::Black), static_cast(stream.read_n_bit(6))); + + // Piece placement + for (chess::Rank r = chess::rank8; r >= chess::rank1; --r) + { + for (chess::File f = chess::fileA; f <= chess::fileH; ++f) + { + auto sq = chess::Square(f, r); + + // it seems there are already balls + chess::Piece pc; + if (pos.pieceAt(sq).type() != chess::PieceType::King) + { + assert(pos.pieceAt(sq) == chess::Piece::none()); + pc = packer.read_board_piece_from_stream(); + } + else + { + pc = pos.pieceAt(sq); + } + + // There may be no pieces, so skip in that case. + if (pc == chess::Piece::none()) + continue; + + if (pc.type() != chess::PieceType::King) + { + pos.place(pc, sq); + } + + assert(stream.get_cursor() <= 256); + } + } + + // Castling availability. + chess::CastlingRights cr = chess::CastlingRights::None; + if (stream.read_one_bit()) { + cr |= chess::CastlingRights::WhiteKingSide; + } + if (stream.read_one_bit()) { + cr |= chess::CastlingRights::WhiteQueenSide; + } + if (stream.read_one_bit()) { + cr |= chess::CastlingRights::BlackKingSide; + } + if (stream.read_one_bit()) { + cr |= chess::CastlingRights::BlackQueenSide; + } + pos.setCastlingRights(cr); + + // En passant square. Ignore if no pawn capture is possible + if (stream.read_one_bit()) { + chess::Square ep_square = static_cast(stream.read_n_bit(6)); + pos.setEpSquare(ep_square); + } + + // Halfmove clock + std::uint8_t rule50 = stream.read_n_bit(6); + + // Fullmove number + std::uint16_t fullmove = stream.read_n_bit(8); + + // Fullmove number, high bits + // This was added as a fix for fullmove clock + // overflowing at 256. This change is backwards compatibile. + fullmove |= stream.read_n_bit(8) << 8; + + // Read the highest bit of rule50. This was added as a fix for rule50 + // counter having only 6 bits stored. + // In older entries this will just be a zero bit. + rule50 |= stream.read_n_bit(1) << 6; + + pos.setFullMove(fullmove); + pos.setRule50Counter(rule50); + + assert(stream.get_cursor() <= 256); + + return pos; + } + } + + struct CompressedTrainingDataFile + { + struct Header + { + std::uint32_t chunkSize; + }; + + CompressedTrainingDataFile(std::string path, std::ios_base::openmode om = std::ios_base::app) : + m_path(std::move(path)), + m_file(m_path, std::ios_base::binary | std::ios_base::in | std::ios_base::out | om) + { + // Necessary for MAC because app mode makes it put the reading + // head at the end. + m_file.seekg(0); + } + + void append(const char* data, std::uint32_t size) + { + writeChunkHeader({size}); + m_file.write(data, size); + } + + [[nodiscard]] bool hasNextChunk() + { + if (!m_file) + { + return false; + } + + m_file.peek(); + return !m_file.eof(); + } + + [[nodiscard]] std::vector readNextChunk() + { + auto size = readChunkHeader().chunkSize; + std::vector data(size); + m_file.read(reinterpret_cast(data.data()), size); + return data; + } + + private: + std::string m_path; + std::fstream m_file; + + void writeChunkHeader(Header h) + { + unsigned char header[8]; + header[0] = 'B'; + header[1] = 'I'; + header[2] = 'N'; + header[3] = 'P'; + header[4] = h.chunkSize; + header[5] = h.chunkSize >> 8; + header[6] = h.chunkSize >> 16; + header[7] = h.chunkSize >> 24; + m_file.write(reinterpret_cast(header), 8); + } + + [[nodiscard]] Header readChunkHeader() + { + unsigned char header[8]; + m_file.read(reinterpret_cast(header), 8); + if (header[0] != 'B' || header[1] != 'I' || header[2] != 'N' || header[3] != 'P') + { + assert(false); + // throw std::runtime_error("Invalid binpack file or chunk."); + } + + const std::uint32_t size = + header[4] + | (header[5] << 8) + | (header[6] << 16) + | (header[7] << 24); + + if (size > maxChunkSize) + { + assert(false); + // throw std::runtime_error("Chunks size larger than supported. Malformed file?"); + } + + return { size }; + } + }; + + [[nodiscard]] inline std::uint16_t signedToUnsigned(std::int16_t a) + { + std::uint16_t r; + std::memcpy(&r, &a, sizeof(std::uint16_t)); + if (r & 0x8000) + { + r ^= 0x7FFF; + } + r = (r << 1) | (r >> 15); + return r; + } + + [[nodiscard]] inline std::int16_t unsignedToSigned(std::uint16_t r) + { + std::int16_t a; + r = (r << 15) | (r >> 1); + if (r & 0x8000) + { + r ^= 0x7FFF; + } + std::memcpy(&a, &r, sizeof(std::uint16_t)); + return a; + } + + struct TrainingDataEntry + { + chess::Position pos; + chess::Move move; + std::int16_t score; + std::uint16_t ply; + std::int16_t result; + + [[nodiscard]] bool isValid() const + { + return pos.isMoveLegal(move); + } + }; + + [[nodiscard]] inline TrainingDataEntry packedSfenValueToTrainingDataEntry(const nodchip::PackedSfenValue& psv) + { + TrainingDataEntry ret; + + ret.pos = nodchip::pos_from_packed_sfen(psv.sfen); + ret.move = psv.move.toMove(); + ret.score = psv.score; + ret.ply = psv.gamePly; + ret.result = psv.game_result; + + return ret; + } + + [[nodiscard]] inline nodchip::PackedSfenValue trainingDataEntryToPackedSfenValue(const TrainingDataEntry& plain) + { + nodchip::PackedSfenValue ret; + + nodchip::SfenPacker sp; + sp.data = reinterpret_cast(&ret.sfen); + sp.pack(plain.pos); + + ret.score = plain.score; + ret.move = nodchip::StockfishMove::fromMove(plain.move); + ret.gamePly = plain.ply; + ret.game_result = plain.result; + ret.padding = 0xff; // for consistency with the .bin format. + + return ret; + } + + [[nodiscard]] inline bool isContinuation(const TrainingDataEntry& lhs, const TrainingDataEntry& rhs) + { + return + lhs.result == -rhs.result + && lhs.ply + 1 == rhs.ply + && lhs.pos.afterMove(lhs.move) == rhs.pos; + } + + struct PackedTrainingDataEntry + { + unsigned char bytes[32]; + }; + + [[nodiscard]] inline std::size_t usedBitsSafe(std::size_t value) + { + if (value == 0) return 0; + return chess::util::usedBits(value - 1); + } + + static constexpr std::size_t scoreVleBlockSize = 4; + + struct PackedMoveScoreListReader + { + TrainingDataEntry entry; + std::uint16_t numPlies; + unsigned char* movetext; + + PackedMoveScoreListReader(const TrainingDataEntry& entry_, unsigned char* movetext_, std::uint16_t numPlies_) : + entry(entry_), + numPlies(numPlies_), + movetext(movetext_), + m_lastScore(-entry_.score) + { + + } + + [[nodiscard]] std::uint8_t extractBitsLE8(std::size_t count) + { + if (count == 0) return 0; + + if (m_readBitsLeft == 0) + { + m_readOffset += 1; + m_readBitsLeft = 8; + } + + const std::uint8_t byte = movetext[m_readOffset] << (8 - m_readBitsLeft); + std::uint8_t bits = byte >> (8 - count); + + if (count > m_readBitsLeft) + { + const auto spillCount = count - m_readBitsLeft; + bits |= movetext[m_readOffset + 1] >> (8 - spillCount); + + m_readBitsLeft += 8; + m_readOffset += 1; + } + + m_readBitsLeft -= count; + + return bits; + } + + [[nodiscard]] std::uint16_t extractVle16(std::size_t blockSize) + { + auto mask = (1 << blockSize) - 1; + std::uint16_t v = 0; + std::size_t offset = 0; + for(;;) + { + std::uint16_t block = extractBitsLE8(blockSize + 1); + v |= ((block & mask) << offset); + if (!(block >> blockSize)) + { + break; + } + + offset += blockSize; + } + return v; + } + + [[nodiscard]] TrainingDataEntry nextEntry() + { + entry.pos.doMove(entry.move); + auto [move, score] = nextMoveScore(entry.pos); + entry.move = move; + entry.score = score; + entry.ply += 1; + entry.result = -entry.result; + return entry; + } + + [[nodiscard]] bool hasNext() const + { + return m_numReadPlies < numPlies; + } + + [[nodiscard]] std::pair nextMoveScore(const chess::Position& pos) + { + chess::Move move; + std::int16_t score; + + const chess::Color sideToMove = pos.sideToMove(); + const chess::Bitboard ourPieces = pos.piecesBB(sideToMove); + const chess::Bitboard theirPieces = pos.piecesBB(!sideToMove); + const chess::Bitboard occupied = ourPieces | theirPieces; + + const auto pieceId = extractBitsLE8(usedBitsSafe(ourPieces.count())); + const auto from = chess::Square(chess::nthSetBitIndex(ourPieces.bits(), pieceId)); + + const auto pt = pos.pieceAt(from).type(); + switch (pt) + { + case chess::PieceType::Pawn: + { + const chess::Rank promotionRank = pos.sideToMove() == chess::Color::White ? chess::rank7 : chess::rank2; + const chess::Rank startRank = pos.sideToMove() == chess::Color::White ? chess::rank2 : chess::rank7; + const auto forward = sideToMove == chess::Color::White ? chess::FlatSquareOffset(0, 1) : chess::FlatSquareOffset(0, -1); + + const chess::Square epSquare = pos.epSquare(); + + chess::Bitboard attackTargets = theirPieces; + if (epSquare != chess::Square::none()) + { + attackTargets |= epSquare; + } + + chess::Bitboard destinations = chess::bb::pawnAttacks(chess::Bitboard::square(from), sideToMove) & attackTargets; + + const chess::Square sqForward = from + forward; + if (!occupied.isSet(sqForward)) + { + destinations |= sqForward; + if ( + from.rank() == startRank + && !occupied.isSet(sqForward + forward) + ) + { + destinations |= sqForward + forward; + } + } + + const auto destinationsCount = destinations.count(); + if (from.rank() == promotionRank) + { + const auto moveId = extractBitsLE8(usedBitsSafe(destinationsCount * 4ull)); + const chess::Piece promotedPiece = chess::Piece( + chess::fromOrdinal(ordinal(chess::PieceType::Knight) + (moveId % 4ull)), + sideToMove + ); + const auto to = chess::Square(chess::nthSetBitIndex(destinations.bits(), moveId / 4ull)); + + move = chess::Move::promotion(from, to, promotedPiece); + break; + } + else + { + auto moveId = extractBitsLE8(usedBitsSafe(destinationsCount)); + const auto to = chess::Square(chess::nthSetBitIndex(destinations.bits(), moveId)); + if (to == epSquare) + { + move = chess::Move::enPassant(from, to); + break; + } + else + { + move = chess::Move::normal(from, to); + break; + } + } + } + case chess::PieceType::King: + { + const chess::CastlingRights ourCastlingRightsMask = + sideToMove == chess::Color::White + ? chess::CastlingRights::White + : chess::CastlingRights::Black; + + const chess::CastlingRights castlingRights = pos.castlingRights(); + + const chess::Bitboard attacks = chess::bb::pseudoAttacks(from) & ~ourPieces; + const std::size_t attacksSize = attacks.count(); + const std::size_t numCastlings = chess::intrin::popcount(ordinal(castlingRights & ourCastlingRightsMask)); + + const auto moveId = extractBitsLE8(usedBitsSafe(attacksSize + numCastlings)); + + if (moveId >= attacksSize) + { + const std::size_t idx = moveId - attacksSize; + + const chess::CastleType castleType = + idx == 0 + && chess::contains(castlingRights, chess::CastlingTraits::castlingRights[sideToMove][chess::CastleType::Long]) + ? chess::CastleType::Long + : chess::CastleType::Short; + + move = chess::Move::castle(castleType, sideToMove); + break; + } + else + { + auto to = chess::Square(chess::nthSetBitIndex(attacks.bits(), moveId)); + move = chess::Move::normal(from, to); + break; + } + break; + } + default: + { + const chess::Bitboard attacks = chess::bb::attacks(pt, from, occupied) & ~ourPieces; + const auto moveId = extractBitsLE8(usedBitsSafe(attacks.count())); + auto to = chess::Square(chess::nthSetBitIndex(attacks.bits(), moveId)); + move = chess::Move::normal(from, to); + break; + } + } + + score = m_lastScore + unsignedToSigned(extractVle16(scoreVleBlockSize)); + m_lastScore = -score; + + ++m_numReadPlies; + + return {move, score}; + } + + [[nodiscard]] std::size_t numReadBytes() + { + return m_readOffset + (m_readBitsLeft != 8); + } + + private: + std::size_t m_readBitsLeft = 8; + std::size_t m_readOffset = 0; + std::int16_t m_lastScore = 0; + std::uint16_t m_numReadPlies = 0; + }; + + struct PackedMoveScoreList + { + std::uint16_t numPlies = 0; + std::vector movetext; + + void clear(const TrainingDataEntry& e) + { + numPlies = 0; + movetext.clear(); + m_bitsLeft = 0; + m_lastScore = -e.score; + } + + void addBitsLE8(std::uint8_t bits, std::size_t count) + { + if (count == 0) return; + + if (m_bitsLeft == 0) + { + movetext.emplace_back(bits << (8 - count)); + m_bitsLeft = 8; + } + else if (count <= m_bitsLeft) + { + movetext.back() |= bits << (m_bitsLeft - count); + } + else + { + const auto spillCount = count - m_bitsLeft; + movetext.back() |= bits >> spillCount; + movetext.emplace_back(bits << (8 - spillCount)); + m_bitsLeft += 8; + } + + m_bitsLeft -= count; + } + + void addBitsVle16(std::uint16_t v, std::size_t blockSize) + { + auto mask = (1 << blockSize) - 1; + for(;;) + { + std::uint8_t block = (v & mask) | ((v > mask) << blockSize); + addBitsLE8(block, blockSize + 1); + v >>= blockSize; + if (v == 0) break; + } + } + + + void addMoveScore(const chess::Position& pos, chess::Move move, std::int16_t score) + { + const chess::Color sideToMove = pos.sideToMove(); + const chess::Bitboard ourPieces = pos.piecesBB(sideToMove); + const chess::Bitboard theirPieces = pos.piecesBB(!sideToMove); + const chess::Bitboard occupied = ourPieces | theirPieces; + + const std::uint8_t pieceId = (pos.piecesBB(sideToMove) & chess::bb::before(move.from)).count(); + std::size_t numMoves = 0; + int moveId = 0; + const auto pt = pos.pieceAt(move.from).type(); + switch (pt) + { + case chess::PieceType::Pawn: + { + const chess::Rank secondToLastRank = pos.sideToMove() == chess::Color::White ? chess::rank7 : chess::rank2; + const chess::Rank startRank = pos.sideToMove() == chess::Color::White ? chess::rank2 : chess::rank7; + const auto forward = sideToMove == chess::Color::White ? chess::FlatSquareOffset(0, 1) : chess::FlatSquareOffset(0, -1); + + const chess::Square epSquare = pos.epSquare(); + + chess::Bitboard attackTargets = theirPieces; + if (epSquare != chess::Square::none()) + { + attackTargets |= epSquare; + } + + chess::Bitboard destinations = chess::bb::pawnAttacks(chess::Bitboard::square(move.from), sideToMove) & attackTargets; + + const chess::Square sqForward = move.from + forward; + if (!occupied.isSet(sqForward)) + { + destinations |= sqForward; + + if ( + move.from.rank() == startRank + && !occupied.isSet(sqForward + forward) + ) + { + destinations |= sqForward + forward; + } + } + + moveId = (destinations & chess::bb::before(move.to)).count(); + numMoves = destinations.count(); + if (move.from.rank() == secondToLastRank) + { + const auto promotionIndex = (ordinal(move.promotedPiece.type()) - ordinal(chess::PieceType::Knight)); + moveId = moveId * 4 + promotionIndex; + numMoves *= 4; + } + + break; + } + case chess::PieceType::King: + { + const chess::CastlingRights ourCastlingRightsMask = + sideToMove == chess::Color::White + ? chess::CastlingRights::White + : chess::CastlingRights::Black; + + const chess::CastlingRights castlingRights = pos.castlingRights(); + + const chess::Bitboard attacks = chess::bb::pseudoAttacks(move.from) & ~ourPieces; + const auto attacksSize = attacks.count(); + const auto numCastlingRights = chess::intrin::popcount(ordinal(castlingRights & ourCastlingRightsMask)); + + numMoves += attacksSize; + numMoves += numCastlingRights; + + if (move.type == chess::MoveType::Castle) + { + const auto longCastlingRights = chess::CastlingTraits::castlingRights[sideToMove][chess::CastleType::Long]; + + moveId = attacksSize - 1; + + if (chess::contains(castlingRights, longCastlingRights)) + { + // We have to add one no matter if it's the used one or not. + moveId += 1; + } + + if (chess::CastlingTraits::moveCastlingType(move) == chess::CastleType::Short) + { + moveId += 1; + } + } + else + { + moveId = (attacks & chess::bb::before(move.to)).count(); + } + break; + } + default: + { + const chess::Bitboard attacks = chess::bb::attacks(pt, move.from, occupied) & ~ourPieces; + + moveId = (attacks & chess::bb::before(move.to)).count(); + numMoves = attacks.count(); + } + } + + const std::size_t numPieces = ourPieces.count(); + addBitsLE8(pieceId, usedBitsSafe(numPieces)); + addBitsLE8(moveId, usedBitsSafe(numMoves)); + + std::uint16_t scoreDelta = signedToUnsigned(score - m_lastScore); + addBitsVle16(scoreDelta, scoreVleBlockSize); + m_lastScore = -score; + + ++numPlies; + } + + private: + std::size_t m_bitsLeft = 0; + std::int16_t m_lastScore = 0; + }; + + + [[nodiscard]] inline PackedTrainingDataEntry packEntry(const TrainingDataEntry& plain) + { + PackedTrainingDataEntry packed; + + auto compressedPos = plain.pos.compress(); + auto compressedMove = plain.move.compress(); + + static_assert(sizeof(compressedPos) + sizeof(compressedMove) + 6 == sizeof(PackedTrainingDataEntry)); + + std::size_t offset = 0; + compressedPos.writeToBigEndian(packed.bytes); + offset += sizeof(compressedPos); + compressedMove.writeToBigEndian(packed.bytes + offset); + offset += sizeof(compressedMove); + std::uint16_t pr = plain.ply | (signedToUnsigned(plain.result) << 14); + packed.bytes[offset++] = signedToUnsigned(plain.score) >> 8; + packed.bytes[offset++] = signedToUnsigned(plain.score); + packed.bytes[offset++] = pr >> 8; + packed.bytes[offset++] = pr; + packed.bytes[offset++] = plain.pos.rule50Counter() >> 8; + packed.bytes[offset++] = plain.pos.rule50Counter(); + + return packed; + } + + [[nodiscard]] inline TrainingDataEntry unpackEntry(const PackedTrainingDataEntry& packed) + { + TrainingDataEntry plain; + + std::size_t offset = 0; + auto compressedPos = chess::CompressedPosition::readFromBigEndian(packed.bytes); + plain.pos = compressedPos.decompress(); + offset += sizeof(compressedPos); + auto compressedMove = chess::CompressedMove::readFromBigEndian(packed.bytes + offset); + plain.move = compressedMove.decompress(); + offset += sizeof(compressedMove); + plain.score = unsignedToSigned((packed.bytes[offset] << 8) | packed.bytes[offset+1]); + offset += 2; + std::uint16_t pr = (packed.bytes[offset] << 8) | packed.bytes[offset+1]; + plain.ply = pr & 0x3FFF; + plain.pos.setPly(plain.ply); + plain.result = unsignedToSigned(pr >> 14); + offset += 2; + plain.pos.setRule50Counter((packed.bytes[offset] << 8) | packed.bytes[offset+1]); + + return plain; + } + + struct CompressedTrainingDataEntryWriter + { + static constexpr std::size_t chunkSize = suggestedChunkSize; + + CompressedTrainingDataEntryWriter(std::string path, std::ios_base::openmode om = std::ios_base::app) : + m_outputFile(path, om), + m_lastEntry{}, + m_movelist{}, + m_packedSize(0), + m_packedEntries(chunkSize + maxMovelistSize), + m_isFirst(true) + { + m_lastEntry.ply = 0xFFFF; // so it's never a continuation + m_lastEntry.result = 0x7FFF; + } + + void addTrainingDataEntry(const TrainingDataEntry& e) + { + bool isCont = isContinuation(m_lastEntry, e); + if (isCont) + { + // add to movelist + m_movelist.addMoveScore(e.pos, e.move, e.score); + } + else + { + if (!m_isFirst) + { + writeMovelist(); + } + + if (m_packedSize >= chunkSize) + { + m_outputFile.append(m_packedEntries.data(), m_packedSize); + m_packedSize = 0; + } + + auto packed = packEntry(e); + std::memcpy(m_packedEntries.data() + m_packedSize, &packed, sizeof(PackedTrainingDataEntry)); + m_packedSize += sizeof(PackedTrainingDataEntry); + + m_movelist.clear(e); + + m_isFirst = false; + } + + m_lastEntry = e; + } + + ~CompressedTrainingDataEntryWriter() + { + if (m_packedSize > 0) + { + if (!m_isFirst) + { + writeMovelist(); + } + + m_outputFile.append(m_packedEntries.data(), m_packedSize); + m_packedSize = 0; + } + } + + private: + CompressedTrainingDataFile m_outputFile; + TrainingDataEntry m_lastEntry; + PackedMoveScoreList m_movelist; + std::size_t m_packedSize; + std::vector m_packedEntries; + bool m_isFirst; + + void writeMovelist() + { + m_packedEntries[m_packedSize++] = m_movelist.numPlies >> 8; + m_packedEntries[m_packedSize++] = m_movelist.numPlies; + if (m_movelist.numPlies > 0) + { + std::memcpy(m_packedEntries.data() + m_packedSize, m_movelist.movetext.data(), m_movelist.movetext.size()); + m_packedSize += m_movelist.movetext.size(); + } + }; + }; + + struct CompressedTrainingDataEntryReader + { + static constexpr std::size_t chunkSize = suggestedChunkSize; + + CompressedTrainingDataEntryReader(std::string path, std::ios_base::openmode om = std::ios_base::app) : + m_inputFile(path, om), + m_chunk(), + m_movelistReader(std::nullopt), + m_offset(0), + m_isEnd(false) + { + if (!m_inputFile.hasNextChunk()) + { + m_isEnd = true; + } + else + { + m_chunk = m_inputFile.readNextChunk(); + } + } + + [[nodiscard]] bool hasNext() + { + return !m_isEnd; + } + + [[nodiscard]] TrainingDataEntry next() + { + if (m_movelistReader.has_value()) + { + const auto e = m_movelistReader->nextEntry(); + + if (!m_movelistReader->hasNext()) + { + m_offset += m_movelistReader->numReadBytes(); + m_movelistReader.reset(); + + fetchNextChunkIfNeeded(); + } + + return e; + } + + PackedTrainingDataEntry packed; + std::memcpy(&packed, m_chunk.data() + m_offset, sizeof(PackedTrainingDataEntry)); + m_offset += sizeof(PackedTrainingDataEntry); + + const std::uint16_t numPlies = (m_chunk[m_offset] << 8) | m_chunk[m_offset + 1]; + m_offset += 2; + + const auto e = unpackEntry(packed); + + if (numPlies > 0) + { + m_movelistReader.emplace(e, reinterpret_cast(m_chunk.data()) + m_offset, numPlies); + } + else + { + fetchNextChunkIfNeeded(); + } + + return e; + } + + private: + CompressedTrainingDataFile m_inputFile; + std::vector m_chunk; + std::optional m_movelistReader; + std::size_t m_offset; + bool m_isEnd; + + void fetchNextChunkIfNeeded() + { + if (m_offset + sizeof(PackedTrainingDataEntry) + 2 > m_chunk.size()) + { + if (m_inputFile.hasNextChunk()) + { + m_chunk = m_inputFile.readNextChunk(); + m_offset = 0; + } + else + { + m_isEnd = true; + } + } + } + }; + + inline void emitPlainEntry(std::string& buffer, const TrainingDataEntry& plain) + { + buffer += "fen "; + buffer += plain.pos.fen(); + buffer += '\n'; + + buffer += "move "; + buffer += chess::uci::moveToUci(plain.pos, plain.move); + buffer += '\n'; + + buffer += "score "; + buffer += std::to_string(plain.score); + buffer += '\n'; + + buffer += "ply "; + buffer += std::to_string(plain.ply); + buffer += '\n'; + + buffer += "result "; + buffer += std::to_string(plain.result); + buffer += "\ne\n"; + } + + inline void emitBinEntry(std::vector& buffer, const TrainingDataEntry& plain) + { + auto psv = trainingDataEntryToPackedSfenValue(plain); + const char* data = reinterpret_cast(&psv); + buffer.insert(buffer.end(), data, data+sizeof(psv)); + } + + inline void convertPlainToBinpack(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) + { + constexpr std::size_t reportEveryNPositions = 100'000; + + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; + + CompressedTrainingDataEntryWriter writer(outputPath, om); + TrainingDataEntry e; + + std::string key; + std::string value; + std::string move; + + std::ifstream inputFile(inputPath); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; + + for(;;) + { + inputFile >> key; + if (!inputFile) + { + break; + } + + if (key == "e"sv) + { + e.move = chess::uci::uciToMove(e.pos, move); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + writer.addTrainingDataEntry(e); + + ++numProcessedPositions; + const auto cur = inputFile.tellg(); + if (numProcessedPositions % reportEveryNPositions == 0) + { + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + continue; + } + + inputFile >> std::ws; + std::getline(inputFile, value, '\n'); + + if (key == "fen"sv) e.pos = chess::Position::fromFen(value.c_str()); + if (key == "move"sv) move = value; + if (key == "score"sv) e.score = std::stoi(value); + if (key == "ply"sv) e.ply = std::stoi(value); + if (key == "result"sv) e.result = std::stoi(value); + } + + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; + } + + inline void convertBinpackToPlain(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) + { + constexpr std::size_t bufferSize = MiB; + + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; + + CompressedTrainingDataEntryReader reader(inputPath); + std::ofstream outputFile(outputPath, om); + const auto base = outputFile.tellp(); + std::size_t numProcessedPositions = 0; + std::string buffer; + buffer.reserve(bufferSize * 2); + + while(reader.hasNext()) + { + auto e = reader.next(); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + emitPlainEntry(buffer, e); + + ++numProcessedPositions; + + if (buffer.size() > bufferSize) + { + outputFile << buffer; + buffer.clear(); + + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + } + + if (!buffer.empty()) + { + outputFile << buffer; + + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; + } + + + inline void convertBinToBinpack(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) + { + constexpr std::size_t reportEveryNPositions = 100'000; + + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; + + CompressedTrainingDataEntryWriter writer(outputPath, om); + + std::ifstream inputFile(inputPath, std::ios_base::binary); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; + + nodchip::PackedSfenValue psv; + for(;;) + { + inputFile.read(reinterpret_cast(&psv), sizeof(psv)); + if (inputFile.gcount() != 40) + { + break; + } + + auto e = packedSfenValueToTrainingDataEntry(psv); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + std::cerr << static_cast(e.move.type) << '\n'; + return; + } + + writer.addTrainingDataEntry(e); + + ++numProcessedPositions; + const auto cur = inputFile.tellg(); + if (numProcessedPositions % reportEveryNPositions == 0) + { + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + } + + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; + } + + inline void convertBinpackToBin(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) + { + constexpr std::size_t bufferSize = MiB; + + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; + + CompressedTrainingDataEntryReader reader(inputPath); + std::ofstream outputFile(outputPath, std::ios_base::binary | om); + const auto base = outputFile.tellp(); + std::size_t numProcessedPositions = 0; + std::vector buffer; + buffer.reserve(bufferSize * 2); + + while(reader.hasNext()) + { + auto e = reader.next(); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + emitBinEntry(buffer, e); + + ++numProcessedPositions; + + if (buffer.size() > bufferSize) + { + outputFile.write(buffer.data(), buffer.size()); + buffer.clear(); + + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + } + + if (!buffer.empty()) + { + outputFile.write(buffer.data(), buffer.size()); + + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; + } + + inline void convertBinToPlain(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) + { + constexpr std::size_t bufferSize = MiB; + + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; + + std::ifstream inputFile(inputPath, std::ios_base::binary); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; + + std::ofstream outputFile(outputPath, om); + std::string buffer; + buffer.reserve(bufferSize * 2); + + nodchip::PackedSfenValue psv; + for(;;) + { + inputFile.read(reinterpret_cast(&psv), sizeof(psv)); + if (inputFile.gcount() != 40) + { + break; + } + + auto e = packedSfenValueToTrainingDataEntry(psv); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + emitPlainEntry(buffer, e); + + ++numProcessedPositions; + + if (buffer.size() > bufferSize) + { + outputFile << buffer; + buffer.clear(); + + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + } + + if (!buffer.empty()) + { + outputFile << buffer; + + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; + } + + inline void convertPlainToBin(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate) + { + constexpr std::size_t bufferSize = MiB; + + std::cout << "Converting " << inputPath << " to " << outputPath << '\n'; + + std::ofstream outputFile(outputPath, std::ios_base::binary | om); + std::vector buffer; + buffer.reserve(bufferSize * 2); + + TrainingDataEntry e; + + std::string key; + std::string value; + std::string move; + + std::ifstream inputFile(inputPath); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; + + for(;;) + { + inputFile >> key; + if (!inputFile) + { + break; + } + + if (key == "e"sv) + { + e.move = chess::uci::uciToMove(e.pos, move); + if (validate && !e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + emitBinEntry(buffer, e); + + ++numProcessedPositions; + + if (buffer.size() > bufferSize) + { + outputFile.write(buffer.data(), buffer.size()); + buffer.clear(); + + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + continue; + } + + inputFile >> std::ws; + std::getline(inputFile, value, '\n'); + + if (key == "fen"sv) e.pos = chess::Position::fromFen(value.c_str()); + if (key == "move"sv) move = value; + if (key == "score"sv) e.score = std::stoi(value); + if (key == "ply"sv) e.ply = std::stoi(value); + if (key == "result"sv) e.result = std::stoi(value); + } + + if (!buffer.empty()) + { + outputFile.write(buffer.data(), buffer.size()); + + const auto cur = outputFile.tellp(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + std::cout << "Finished. Converted " << numProcessedPositions << " positions.\n"; + } + + inline void validatePlain(std::string inputPath) + { + constexpr std::size_t reportSize = 1000000; + + std::cout << "Validating " << inputPath << '\n'; + + TrainingDataEntry e; + + std::string key; + std::string value; + std::string move; + + std::ifstream inputFile(inputPath); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; + std::size_t numProcessedPositionsBatch = 0; + + for(;;) + { + inputFile >> key; + if (!inputFile) + { + break; + } + + if (key == "e"sv) + { + e.move = chess::uci::uciToMove(e.pos, move); + if (!e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + ++numProcessedPositions; + ++numProcessedPositionsBatch; + + if (numProcessedPositionsBatch >= reportSize) + { + numProcessedPositionsBatch -= reportSize; + const auto cur = inputFile.tellg(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + continue; + } + + inputFile >> std::ws; + std::getline(inputFile, value, '\n'); + + if (key == "fen"sv) e.pos = chess::Position::fromFen(value.c_str()); + if (key == "move"sv) move = value; + if (key == "score"sv) e.score = std::stoi(value); + if (key == "ply"sv) e.ply = std::stoi(value); + if (key == "result"sv) e.result = std::stoi(value); + } + + if (numProcessedPositionsBatch) + { + const auto cur = inputFile.tellg(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n"; + } + + inline void validateBin(std::string inputPath) + { + constexpr std::size_t reportSize = 1000000; + + std::cout << "Validating " << inputPath << '\n'; + + std::ifstream inputFile(inputPath, std::ios_base::binary); + const auto base = inputFile.tellg(); + std::size_t numProcessedPositions = 0; + std::size_t numProcessedPositionsBatch = 0; + + nodchip::PackedSfenValue psv; + for(;;) + { + inputFile.read(reinterpret_cast(&psv), sizeof(psv)); + if (inputFile.gcount() != 40) + { + break; + } + + auto e = packedSfenValueToTrainingDataEntry(psv); + if (!e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + ++numProcessedPositions; + ++numProcessedPositionsBatch; + + if (numProcessedPositionsBatch >= reportSize) + { + numProcessedPositionsBatch -= reportSize; + const auto cur = inputFile.tellg(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + } + + if (numProcessedPositionsBatch) + { + const auto cur = inputFile.tellg(); + std::cout << "Processed " << (cur - base) << " bytes and " << numProcessedPositions << " positions.\n"; + } + + std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n"; + } + + inline void validateBinpack(std::string inputPath) + { + constexpr std::size_t reportSize = 1000000; + + std::cout << "Validating " << inputPath << '\n'; + + CompressedTrainingDataEntryReader reader(inputPath); + std::size_t numProcessedPositions = 0; + std::size_t numProcessedPositionsBatch = 0; + + while(reader.hasNext()) + { + auto e = reader.next(); + if (!e.isValid()) + { + std::cerr << "Illegal move " << chess::uci::moveToUci(e.pos, e.move) << " for position " << e.pos.fen() << '\n'; + return; + } + + ++numProcessedPositions; + ++numProcessedPositionsBatch; + + if (numProcessedPositionsBatch >= reportSize) + { + numProcessedPositionsBatch -= reportSize; + std::cout << "Processed " << numProcessedPositions << " positions.\n"; + } + } + + if (numProcessedPositionsBatch) + { + std::cout << "Processed " << numProcessedPositions << " positions.\n"; + } + + std::cout << "Finished. Validated " << numProcessedPositions << " positions.\n"; + } +} diff --git a/src/main.cpp b/src/main.cpp index 62e0ed52..ff6564a6 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -18,6 +18,8 @@ #include +#include "nnue/evaluate_nnue.h" + #include "bitboard.h" #include "endgame.h" #include "position.h" diff --git a/src/misc.cpp b/src/misc.cpp index 3b071ccb..d17e5c7e 100644 --- a/src/misc.cpp +++ b/src/misc.cpp @@ -63,6 +63,8 @@ using namespace std; namespace Stockfish { +SynchronizedRegionLogger sync_region_cout(std::cout); + namespace { /// Version number. If Version is left empty, then compile date in the format @@ -649,4 +651,30 @@ void init(int argc, char* argv[]) { } // namespace CommandLine +// Returns a string that represents the current time. (Used when learning evaluation functions) +std::string now_string() +{ + // Using std::ctime(), localtime() gives a warning that MSVC is not secure. + // This shouldn't happen in the C++ standard, but... + +#if defined(_MSC_VER) + // C4996 : 'ctime' : This function or variable may be unsafe.Consider using ctime_s instead. +#pragma warning(disable : 4996) +#endif + + auto now = std::chrono::system_clock::now(); + auto tp = std::chrono::system_clock::to_time_t(now); + auto result = string(std::ctime(&tp)); + + // remove line endings if they are included at the end + while (*result.rbegin() == '\n' || (*result.rbegin() == '\r')) + result.pop_back(); + return result; +} + +void sleep(int ms) +{ + std::this_thread::sleep_for(std::chrono::milliseconds(ms)); +} + } // namespace Stockfish diff --git a/src/misc.h b/src/misc.h index dae37cd9..8b332efb 100644 --- a/src/misc.h +++ b/src/misc.h @@ -19,12 +19,20 @@ #ifndef MISC_H_INCLUDED #define MISC_H_INCLUDED +#include #include #include +#include +#include #include #include #include + #include +#include +#include +#include +#include #include "types.h" @@ -128,6 +136,221 @@ private: std::size_t size_ = 0; }; +// This logger allows printing many parts in a region atomically +// but doesn't block the threads trying to append to other regions. +// Instead if some region tries to pring while other region holds +// the lock the messages are queued to be printed as soon as the +// current region releases the lock. +struct SynchronizedRegionLogger +{ + using RegionId = std::uint64_t; + + struct Region + { + friend struct SynchronizedRegionLogger; + + Region() : + logger(nullptr), region_id(0), is_held(false) + { + } + + Region(const Region&) = delete; + Region& operator=(const Region&) = delete; + + Region(Region&& other) : + logger(other.logger), region_id(other.region_id), is_held(other.is_held) + { + other.logger = nullptr; + other.is_held = false; + } + + Region& operator=(Region&& other) { + if (is_held && logger != nullptr) + { + logger->release_region(region_id); + } + + logger = other.logger; + region_id = other.region_id; + is_held = other.is_held; + + other.is_held = false; + + return *this; + } + + ~Region() { unlock(); } + + void unlock() { + if (is_held) { + is_held = false; + + if (logger != nullptr) + logger->release_region(region_id); + } + } + + Region& operator << (std::ostream&(*pManip)(std::ostream&)) { + if (logger != nullptr) + logger->write(region_id, pManip); + + return *this; + } + + template + Region& operator << (const T& value) { + if (logger != nullptr) + logger->write(region_id, value); + + return *this; + } + + private: + SynchronizedRegionLogger* logger; + RegionId region_id; + bool is_held; + + Region(SynchronizedRegionLogger& log, RegionId id) : + logger(&log), region_id(id), is_held(true) + { + } + }; + +private: + struct RegionBookkeeping + { + RegionBookkeeping(RegionId rid) : id(rid), is_held(true) {} + + std::vector pending_parts; + RegionId id; + bool is_held; + }; + + RegionId init_next_region() + { + static RegionId next_id = 0; + + std::lock_guard lock(mutex); + + const auto id = next_id++; + regions.emplace_back(id); + + return id; + } + + void write(RegionId id, std::ostream&(*pManip)(std::ostream&)) { + std::lock_guard lock(mutex); + + if (regions.empty()) + return; + + if (id == regions.front().id) { + // We can just directly print to the output because + // we are at the front of the region queue. + out << *pManip; + } else { + // We have to schedule the print until previous regions are + // processed + auto* region = find_region_nolock(id); + if (region == nullptr) + return; + + std::stringstream ss; + ss << *pManip; + region->pending_parts.emplace_back(std::move(ss).str()); + } + } + + template + void write(RegionId id, const T& value) { + std::lock_guard lock(mutex); + + if (regions.empty()) + return; + + if (id == regions.front().id) { + // We can just directly print to the output because + // we are at the front of the region queue. + out << value; + } else { + // We have to schedule the print until previous regions are + // processed + auto* region = find_region_nolock(id); + if (region == nullptr) + return; + + std::stringstream ss; + ss << value; + region->pending_parts.emplace_back(std::move(ss).str()); + } + } + + std::ostream& out; + + std::deque regions; + + std::mutex mutex; + + RegionBookkeeping* find_region_nolock(RegionId id) { + // Linear search because the amount of concurrent regions should be small. + auto it = std::find_if( + regions.begin(), + regions.end(), + [id](const RegionBookkeeping& r) { return r.id == id; }); + + if (it == regions.end()) + return nullptr; + else + return &*it; + } + + void release_region(RegionId id) { + std::lock_guard lock(mutex); + + auto* region = find_region_nolock(id); + if (region == nullptr) + return; + + region->is_held = false; + + process_backlog_nolock(); + } + + void process_backlog_nolock() + { + while(!regions.empty()) { + auto& region = regions.front(); + + for(auto& part : region.pending_parts) { + out << part; + } + + // If the region is still held then we don't + // want to start printing stuff from the next region. + if (region.is_held) + break; + + regions.pop_front(); + } + } + +public: + + SynchronizedRegionLogger(std::ostream& s) : + out(s) + { + } + + [[nodiscard]] Region new_region() { + const auto id = init_next_region(); + return Region(*this, id); + } + +}; + +extern SynchronizedRegionLogger sync_region_cout; + + /// xorshift64star Pseudo-Random Number Generator /// This class is based on original code written and dedicated /// to the public domain by Sebastiano Vigna (2014). @@ -143,6 +366,19 @@ private: /// For further analysis see /// +static uint64_t string_hash(const std::string& str) +{ + uint64_t h = 525201411107845655ull; + + for (auto c : str) { + h ^= static_cast(c); + h *= 0x5bd1e9955bd1e995ull; + h ^= h >> 47; + } + + return h; +} + class PRNG { uint64_t s; @@ -154,7 +390,9 @@ class PRNG { } public: + PRNG() { set_seed_from_time(); } PRNG(uint64_t seed) : s(seed) { assert(seed); } + PRNG(const std::string& seed) { set_seed(seed); } template T rand() { return T(rand64()); } @@ -162,8 +400,54 @@ public: /// Output values only have 1/8th of their bits set on average. template T sparse_rand() { return T(rand64() & rand64() & rand64()); } + // Returns a random number from 0 to n-1. (Not uniform distribution, but this is enough in reality) + uint64_t rand(uint64_t n) { return rand() % n; } + + // Return the random seed used internally. + uint64_t get_seed() const { return s; } + + void set_seed(uint64_t seed) { s = seed; } + + uint64_t next_random_seed() + { + uint64_t seed = 0; + for(int i = 0; i < 64; ++i) + { + const auto off = rand64() % 64; + seed |= (rand64() & (uint64_t(1) << off)) >> off; + seed <<= 1; + } + return seed; + } + + void set_seed_from_time() + { + set_seed(std::chrono::system_clock::now().time_since_epoch().count()); + } + + void set_seed(const std::string& str) + { + if (str.empty()) + { + set_seed_from_time(); + } + else if (std::all_of(str.begin(), str.end(), [](char c) { return std::isdigit(c);} )) { + set_seed(std::stoull(str)); + } + else + { + set_seed(string_hash(str)); + } + } }; +// Display a random seed. (For debugging) +inline std::ostream& operator<<(std::ostream& os, PRNG& prng) +{ + os << "PRNG::seed = " << std::hex << prng.get_seed() << std::dec; + return os; +} + inline uint64_t mul_hi64(uint64_t a, uint64_t b) { #if defined(__GNUC__) && defined(IS_64BIT) __extension__ typedef unsigned __int128 uint128; @@ -178,6 +462,74 @@ inline uint64_t mul_hi64(uint64_t a, uint64_t b) { #endif } +// This bitset can be accessed concurrently, provided +// the concurrent accesses are performed on distinct +// instances of underlying type. That means the cuncurrent +// accesses need to be spaced by at least +// bits_per_bucket bits. +// But at least best_concurrent_access_stride bits +// is recommended to prevent false sharing. +template +struct LargeBitset +{ +private: + constexpr static uint64_t cache_line_size = 64; + +public: + using UnderlyingType = uint64_t; + + constexpr static uint64_t num_bits = N; + constexpr static uint64_t bits_per_bucket = 8 * sizeof(uint64_t); + constexpr static uint64_t num_buckets = (num_bits + bits_per_bucket - 1) / bits_per_bucket; + constexpr static uint64_t best_concurrent_access_stride = 8 * cache_line_size; + + LargeBitset() + { + std::fill(std::begin(bits), std::end(bits), 0); + } + + void set(uint64_t idx) + { + const uint64_t bucket = idx / bits_per_bucket; + const uint64_t bit = uint64_t(1) << (idx % bits_per_bucket); + bits[bucket] |= bit; + } + + bool test(uint64_t idx) const + { + const uint64_t bucket = idx / bits_per_bucket; + const uint64_t bit = uint64_t(1) << (idx % bits_per_bucket); + return bits[bucket] & bit; + } + + uint64_t count() const + { + uint64_t c = 0; + uint64_t i = 0; + + for (; i < num_buckets - 3; i += 4) + { + uint64_t c0 = popcount(bits[i+0]); + uint64_t c1 = popcount(bits[i+1]); + uint64_t c2 = popcount(bits[i+2]); + uint64_t c3 = popcount(bits[i+3]); + c0 += c1; + c2 += c3; + c += c0 + c2; + } + + for (; i < num_buckets; ++i) + { + c += popcount(bits[i]); + } + + return c; + } + +private: + alignas(cache_line_size) UnderlyingType bits[num_buckets]; +}; + /// Under Windows it is not possible for a process to run on more than one /// logical processor group. This usually means to be limited to use max 64 /// cores. To overcome this, some special platform specific API should be @@ -188,6 +540,111 @@ namespace WinProcGroup { void bindThisThread(size_t idx); } +// Returns a string that represents the current time. (Used for log output when learning evaluation function) +std::string now_string(); +void sleep(int ms); + +namespace Algo { + // Fisher-Yates + template + void shuffle(std::vector& buf, Rng&& prng) + { + const auto size = buf.size(); + for (uint64_t i = 0; i < size; ++i) + std::swap(buf[i], buf[prng.rand(size - i) + i]); + } + + // split the string + inline std::vector split(const std::string& input, char delimiter) { + std::istringstream stream(input); + std::string field; + std::vector fields; + + while (std::getline(stream, field, delimiter)) { + fields.push_back(field); + } + + return fields; + } +} + +// -------------------- +// Path +// -------------------- + +// Something like Path class in C#. File name manipulation. +// Match with the C# method name. +struct Path +{ + // Combine the path name and file name and return it. + // If the folder name is not an empty string, append it if there is no'/' or'\\' at the end. + static std::string combine(const std::string& folder, const std::string& filename) + { + if (folder.length() >= 1 && *folder.rbegin() != '/' && *folder.rbegin() != '\\') + return folder + "/" + filename; + + return folder + filename; + } + + // Get the file name part (excluding the folder name) from the full path expression. + static std::string get_file_name(const std::string& path) + { + // I don't know which "\" or "/" is used. + auto path_index1 = path.find_last_of("\\") + 1; + auto path_index2 = path.find_last_of("/") + 1; + auto path_index = std::max(path_index1, path_index2); + + return path.substr(path_index); + } +}; + +// It is ignored when new even though alignas is specified & because it is ignored when the STL container allocates memory, +// A custom allocator used for that. +template +class AlignedAllocator { +public: + using value_type = T; + + AlignedAllocator() {} + AlignedAllocator(const AlignedAllocator&) {} + AlignedAllocator(AlignedAllocator&&) {} + + template AlignedAllocator(const AlignedAllocator&) {} + + T* allocate(std::size_t n) { return (T*)std_aligned_alloc(alignof(T), n * sizeof(T)); } + void deallocate(T* p, std::size_t ) { std_aligned_free(p); } +}; + +template +class CacheLineAlignedAllocator { +public: + using value_type = T; + + constexpr static uint64_t cache_line_size = 64; + + CacheLineAlignedAllocator() {} + CacheLineAlignedAllocator(const CacheLineAlignedAllocator&) {} + CacheLineAlignedAllocator(CacheLineAlignedAllocator&&) {} + + template CacheLineAlignedAllocator(const CacheLineAlignedAllocator&) {} + + T* allocate(std::size_t n) { return (T*)std_aligned_alloc(cache_line_size, n * sizeof(T)); } + void deallocate(T* p, std::size_t) { std_aligned_free(p); } +}; + +// -------------------- +// Dependency Wrapper +// -------------------- + +namespace Dependency +{ + // In the Linux environment, if you getline() the text file is'\r\n' + // Since'\r' remains at the end, write a wrapper to remove this'\r'. + // So when calling getline() on fstream, + // just write getline() instead of std::getline() and use this function. + extern bool getline(std::ifstream& fs, std::string& s); +} + namespace CommandLine { void init(int argc, char* argv[]); diff --git a/src/movegen.h b/src/movegen.h index 3f895f05..704cd3e2 100644 --- a/src/movegen.h +++ b/src/movegen.h @@ -68,6 +68,9 @@ struct MoveList { return std::find(begin(), end(), move) != end(); } + // returns the i th element + const ExtMove at(size_t i) const { assert(0 <= i && i < size()); return begin()[i]; } + private: ExtMove moveList[MAX_MOVES], *last; }; diff --git a/src/nnue/layers/input_slice.h b/src/nnue/layers/input_slice.h index b6bf1727..2ecaf0ee 100644 --- a/src/nnue/layers/input_slice.h +++ b/src/nnue/layers/input_slice.h @@ -57,7 +57,6 @@ class InputSlice { bool write_parameters(std::ostream& /*stream*/) const { return true; } - // Forward propagation const OutputType* propagate( const TransformedFeatureType* transformedFeatures, diff --git a/src/nnue/nnue_common.h b/src/nnue/nnue_common.h index 75ac7862..efc33fb8 100644 --- a/src/nnue/nnue_common.h +++ b/src/nnue/nnue_common.h @@ -21,6 +21,8 @@ #ifndef NNUE_COMMON_H_INCLUDED #define NNUE_COMMON_H_INCLUDED +#include "../types.h" + #include #include diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index 59a965ac..47fe9c06 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -24,6 +24,9 @@ #include "nnue_common.h" #include "nnue_architecture.h" +#include "../misc.h" +#include "../position.h" + #include // std::memset() namespace Stockfish::Eval::NNUE { diff --git a/src/position.cpp b/src/position.cpp index 0686d245..6bb2edb4 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -23,6 +23,8 @@ #include #include +#include "nnue/evaluate_nnue.h" + #include "bitboard.h" #include "misc.h" #include "movegen.h" @@ -32,6 +34,9 @@ #include "uci.h" #include "syzygy/tbprobe.h" +#include "tools/packed_sfen.h" +#include "tools/sfen_packer.h" + using std::string; namespace Stockfish { @@ -754,7 +759,7 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { else st->nonPawnMaterial[them] -= PieceValue[MG][captured]; - if (Eval::useNNUE) + if (Eval::NNUE::useNNUE != Eval::NNUE::UseNNUEMode::False) { dp.dirty_num = 2; // 1 piece moved, 1 piece captured dp.piece[1] = captured; @@ -798,7 +803,7 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { // Move the piece. The tricky Chess960 castling is handled earlier if (type_of(m) != CASTLING) { - if (Eval::useNNUE) + if (Eval::NNUE::useNNUE != Eval::NNUE::UseNNUEMode::False) { dp.piece[0] = pc; dp.from[0] = from; @@ -829,7 +834,7 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { remove_piece(to); put_piece(promotion, to); - if (Eval::useNNUE) + if (Eval::NNUE::useNNUE != Eval::NNUE::UseNNUEMode::False) { // Promoting pawn to SQ_NONE, promoted piece from SQ_NONE dp.to[0] = SQ_NONE; @@ -967,7 +972,7 @@ void Position::do_castling(Color us, Square from, Square& to, Square& rfrom, Squ rto = relative_square(us, kingSide ? SQ_F1 : SQ_D1); to = relative_square(us, kingSide ? SQ_G1 : SQ_C1); - if (Do && Eval::useNNUE) + if (Do && Eval::NNUE::useNNUE != Eval::NNUE::UseNNUEMode::False) { auto& dp = st->dirtyPiece; dp.piece[0] = make_piece(us, KING); @@ -1001,6 +1006,7 @@ void Position::do_null_move(StateInfo& newSt) { newSt.previous = st; st = &newSt; + // Used by NNUE st->dirtyPiece.dirty_num = 0; st->dirtyPiece.piece[0] = NO_PIECE; // Avoid checks in UpdateAccumulator() st->accumulator.computed[WHITE] = false; @@ -1177,6 +1183,22 @@ bool Position::is_draw(int ply) const { } +/// Position::is_fifty_move_draw() returns true if a game can be claimed +/// by a fifty-move draw rule. + +bool Position::is_fifty_move_draw() const { + + return (st->rule50 > 99 && (!checkers() || MoveList(*this).size())); +} + + +/// Position::is_three_fold_repetition() returns true if there is 3-fold repetition. +bool Position::is_three_fold_repetition() const { + + return st->repetition < 0; +} + + // Position::has_repeated() tests whether there has been at least one repetition // of positions since the last capture or pawn move. @@ -1345,4 +1367,18 @@ bool Position::pos_is_ok() const { return true; } +// Add a function that directly unpacks for speed. It's pretty tough. +// Write it by combining packer::unpack() and Position::set(). +// If there is a problem with the passed phase and there is an error, non-zero is returned. +int Position::set_from_packed_sfen(const Tools::PackedSfen& sfen , StateInfo* si, Thread* th) +{ + return Tools::set_from_packed_sfen(*this, sfen, si, th); +} + +// Get the packed sfen. Returns to the buffer specified in the argument. +void Position::sfen_pack(Tools::PackedSfen& sfen) +{ + sfen = Tools::sfen_pack(*this); +} + } // namespace Stockfish diff --git a/src/position.h b/src/position.h index 9f694a79..20f999bc 100644 --- a/src/position.h +++ b/src/position.h @@ -31,6 +31,9 @@ #include "nnue/nnue_accumulator.h" +#include "tools/packed_sfen.h" +#include "tools/sfen_packer.h" + namespace Stockfish { /// StateInfo struct stores information needed to restore a Position object to @@ -157,6 +160,8 @@ public: bool is_chess960() const; Thread* this_thread() const; bool is_draw(int ply) const; + bool is_fifty_move_draw() const; + bool is_three_fold_repetition() const; bool has_game_cycle(int ply) const; bool has_repeated() const; int rule50_count() const; @@ -171,6 +176,28 @@ public: // Used by NNUE StateInfo* state() const; + // --sfenization helper + + friend int Tools::set_from_packed_sfen(Position& pos, const Tools::PackedSfen& sfen, StateInfo* si, Thread* th); + + // Get the packed sfen. Returns to the buffer specified in the argument. + // Do not include gamePly in pack. + void sfen_pack(Tools::PackedSfen& sfen); + + // It is slow to go through sfen, so I made a function to set packed sfen directly. + // Equivalent to pos.set(sfen_unpack(data),si,th);. + // If there is a problem with the passed phase and there is an error, non-zero is returned. + // PackedSfen does not include gamePly so it cannot be restored. If you want to set it, specify it with an argument. + int set_from_packed_sfen(const Tools::PackedSfen& sfen, StateInfo* si, Thread* th); + + void clear() { std::memset(this, 0, sizeof(Position)); } + + // Give the board, hand piece, and turn, and return the sfen. + //static std::string sfen_from_rawdata(Piece board[81], Hand hands[2], Color turn, int gamePly); + + // Returns the position of the ball on the c side. + Square king_square(Color c) const { return lsb(pieces(c, KING)); } + void put_piece(Piece pc, Square s); void remove_piece(Square s); @@ -414,6 +441,8 @@ inline StateInfo* Position::state() const { return st; } +static const char* const StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"; + } // namespace Stockfish #endif // #ifndef POSITION_H_INCLUDED diff --git a/src/search.cpp b/src/search.cpp index c48b74bc..b2d6e4b0 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -23,6 +23,8 @@ #include #include +#include "nnue/evaluate_nnue.h" + #include "evaluate.h" #include "misc.h" #include "movegen.h" @@ -42,20 +44,12 @@ namespace Search { LimitsType Limits; } -namespace Tablebases { - - int Cardinality; - bool RootInTB; - bool UseRule50; - Depth ProbeDepth; -} - -namespace TB = Tablebases; - using std::string; using Eval::evaluate; using namespace Search; +bool Search::prune_at_shallow_depth = true; + namespace { // Different node types, used as a template parameter @@ -276,6 +270,9 @@ void Thread::search() { bestValue = delta = alpha = -VALUE_INFINITE; beta = VALUE_INFINITE; + if (!this->rootMoves.empty()) + Tablebases::rank_root_moves(this->rootPos, this->rootMoves); + if (mainThread) { if (mainThread->bestPreviousScore == VALUE_INFINITE) @@ -368,7 +365,7 @@ void Thread::search() { // Start with a small aspiration window and, in the case of a fail // high/low, re-search with a bigger window until we don't fail // high/low anymore. - int failedHighCnt = 0; + failedHighCnt = 0; while (true) { Depth adjustedDepth = std::max(1, rootDepth - failedHighCnt - searchAgainCounter); @@ -673,27 +670,27 @@ namespace { } // Step 5. Tablebases probe - if (!rootNode && TB::Cardinality) + if (!rootNode && thisThread->Cardinality) { int piecesCount = pos.count(); - if ( piecesCount <= TB::Cardinality - && (piecesCount < TB::Cardinality || depth >= TB::ProbeDepth) + if ( piecesCount <= thisThread->Cardinality + && (piecesCount < thisThread->Cardinality || depth >= thisThread->ProbeDepth) && pos.rule50_count() == 0 && !pos.can_castle(ANY_CASTLING)) { - TB::ProbeState err; - TB::WDLScore wdl = Tablebases::probe_wdl(pos, &err); + Tablebases::ProbeState err; + Tablebases::WDLScore wdl = Tablebases::probe_wdl(pos, &err); // Force check of time on the next occasion if (thisThread == Threads.main()) static_cast(thisThread)->callsCnt = 0; - if (err != TB::ProbeState::FAIL) + if (err != Tablebases::ProbeState::FAIL) { thisThread->tbHits.fetch_add(1, std::memory_order_relaxed); - int drawScore = TB::UseRule50 ? 1 : 0; + int drawScore = thisThread->UseRule50 ? 1 : 0; // use the range VALUE_MATE_IN_MAX_PLY to VALUE_TB_WIN_IN_MAX_PLY to score value = wdl < -drawScore ? VALUE_MATED_IN_MAX_PLY + ss->ply + 1 @@ -976,7 +973,9 @@ moves_loop: // When in check, search starts here ss->moveCount = ++moveCount; - if (rootNode && thisThread == Threads.main() && Time.elapsed() > 3000) + if (rootNode && thisThread == Threads.main() && Time.elapsed() > 3000 + && !Limits.silent + ) sync_cout << "info depth " << depth << " currmove " << UCI::move(move, pos.is_chess960()) << " currmovenumber " << moveCount + thisThread->pvIdx << sync_endl; @@ -993,6 +992,7 @@ moves_loop: // When in check, search starts here // Step 13. Pruning at shallow depth (~200 Elo). Depth conditions are important for mate finding. if ( !rootNode + && (PvNode ? prune_at_shallow_depth : true) && pos.non_pawn_material(us) && bestValue > VALUE_TB_LOSS_IN_MAX_PLY) { @@ -1800,7 +1800,7 @@ string UCI::pv(const Position& pos, Depth depth, Value alpha, Value beta) { size_t pvIdx = pos.this_thread()->pvIdx; size_t multiPV = std::min((size_t)Options["MultiPV"], rootMoves.size()); uint64_t nodesSearched = Threads.nodes_searched(); - uint64_t tbHits = Threads.tb_hits() + (TB::RootInTB ? rootMoves.size() : 0); + uint64_t tbHits = Threads.tb_hits() + (pos.this_thread()->rootInTB ? rootMoves.size() : 0); for (size_t i = 0; i < multiPV; ++i) { @@ -1815,7 +1815,7 @@ string UCI::pv(const Position& pos, Depth depth, Value alpha, Value beta) { if (v == -VALUE_INFINITE) v = VALUE_ZERO; - bool tb = TB::RootInTB && abs(v) < VALUE_MATE_IN_MAX_PLY; + bool tb = pos.this_thread()->rootInTB && abs(v) < VALUE_MATE_IN_MAX_PLY; v = tb ? rootMoves[i].tbScore : v; if (ss.rdbuf()->in_avail()) // Not at first line @@ -1884,34 +1884,38 @@ bool RootMove::extract_ponder_from_tt(Position& pos) { void Tablebases::rank_root_moves(Position& pos, Search::RootMoves& rootMoves) { - RootInTB = false; - UseRule50 = bool(Options["Syzygy50MoveRule"]); - ProbeDepth = int(Options["SyzygyProbeDepth"]); - Cardinality = int(Options["SyzygyProbeLimit"]); + pos.this_thread()->Cardinality = int(Options["SyzygyProbeLimit"]); + pos.this_thread()->ProbeDepth = int(Options["SyzygyProbeDepth"]); + pos.this_thread()->UseRule50 = bool(Options["Syzygy50MoveRule"]); + pos.this_thread()->rootInTB = false; + + auto& cardinality = pos.this_thread()->Cardinality; + auto& probeDepth = pos.this_thread()->ProbeDepth; + auto& rootInTB = pos.this_thread()->rootInTB; bool dtz_available = true; // Tables with fewer pieces than SyzygyProbeLimit are searched with // ProbeDepth == DEPTH_ZERO - if (Cardinality > MaxCardinality) + if (cardinality > Tablebases::MaxCardinality) { - Cardinality = MaxCardinality; - ProbeDepth = 0; + cardinality = Tablebases::MaxCardinality; + probeDepth = 0; } - if (Cardinality >= popcount(pos.pieces()) && !pos.can_castle(ANY_CASTLING)) + if (cardinality >= popcount(pos.pieces()) && !pos.can_castle(ANY_CASTLING)) { // Rank moves using DTZ tables - RootInTB = root_probe(pos, rootMoves); + rootInTB = root_probe(pos, rootMoves); - if (!RootInTB) + if (!rootInTB) { // DTZ tables are missing; try to rank moves using WDL tables dtz_available = false; - RootInTB = root_probe_wdl(pos, rootMoves); + rootInTB = root_probe_wdl(pos, rootMoves); } } - if (RootInTB) + if (rootInTB) { // Sort moves according to TB rank std::stable_sort(rootMoves.begin(), rootMoves.end(), @@ -1919,7 +1923,7 @@ void Tablebases::rank_root_moves(Position& pos, Search::RootMoves& rootMoves) { // Probe during search only if DTZ is not available and we are winning if (dtz_available || rootMoves[0].tbScore <= VALUE_DRAW) - Cardinality = 0; + cardinality = 0; } else { @@ -1927,6 +1931,270 @@ void Tablebases::rank_root_moves(Position& pos, Search::RootMoves& rootMoves) { for (auto& m : rootMoves) m.tbRank = 0; } + +} + +// --- expose the functions such as fixed depth search used for learning to the outside +namespace Search +{ + // For learning, prepare a stub that can call search,qsearch() from one thread. + // From now on, it is better to have a Searcher and prepare a substitution table for each thread like Apery. + // It might have been good. + + // Initialization for learning. + // Called from Tools::search(),Tools::qsearch(). + static bool init_for_search(Position& pos, Stack* ss) + { + + // RootNode requires ss->ply == 0. + // Because it clears to zero, ss->ply == 0, so it's okay... + + std::memset(ss - 7, 0, 10 * sizeof(Stack)); + + // Regarding this_thread. + + { + auto th = pos.this_thread(); + + th->completedDepth = 0; + th->selDepth = 0; + th->rootDepth = 0; + th->nmpMinPly = th->bestMoveChanges = th->failedHighCnt = 0; + th->ttHitAverage = TtHitAverageWindow * TtHitAverageResolution / 2; + + // Zero initialization of the number of search nodes + th->nodes = 0; + + // Clear all history types. This initialization takes a little time, and the accuracy of the search is rather low, so the good and bad are not well understood. + // th->clear(); + + int ct = int(Options["Contempt"]) * PawnValueEg / 100; // From centipawns + Color us = pos.side_to_move(); + + // In analysis mode, adjust contempt in accordance with user preference + if (Limits.infinite || Options["UCI_AnalyseMode"]) + ct = Options["Analysis Contempt"] == "Off" ? 0 + : Options["Analysis Contempt"] == "Both" ? ct + : Options["Analysis Contempt"] == "White" && us == BLACK ? -ct + : Options["Analysis Contempt"] == "Black" && us == WHITE ? -ct + : ct; + + // Evaluation score is from the white point of view + th->contempt = (us == WHITE ? make_score(ct, ct / 2) + : -make_score(ct, ct / 2)); + + for (int i = 7; i > 0; i--) + (ss - i)->continuationHistory = &th->continuationHistory[0][0][NO_PIECE][0]; // Use as a sentinel + + // set rootMoves + auto& rootMoves = th->rootMoves; + + rootMoves.clear(); + for (auto m: MoveList(pos)) + rootMoves.push_back(Search::RootMove(m)); + + // Check if we're at a terminal node. Otherwise we end up returning + // malformed PV later on. + if (rootMoves.empty()) + return false; + + Tablebases::rank_root_moves(pos, rootMoves); + } + + return true; + } + + // Stationary search. + // + // Precondition) Search thread is set by pos.set_this_thread(Threads[thread_id]). + // Also, when Threads.stop arrives, the search is interrupted, so the PV at that time is not correct. + // After returning from search(), if Threads.stop == true, do not use the search result. + // Also, note that before calling, if you do not call it with Threads.stop == false, the search will be interrupted and it will return. + // + // If it is clogged, MOVE_RESIGN is returned in the PV array. + // + //Although it was possible to specify alpha and beta with arguments, this will show the result when searching in that window + // Because it writes to the substitution table, the value that can be pruned is written to that window when learning + // As it has a bad effect, I decided to stop allowing the window range to be specified. + ValueAndPV qsearch(Position& pos) + { + Stack stack[MAX_PLY+10], *ss = stack+7; + Move pv[MAX_PLY+1]; + + if (!init_for_search(pos, ss)) + return {}; + + ss->pv = pv; // For the time being, it must be a dummy and somewhere with a buffer. + + if (pos.is_draw(0)) { + // Return draw value if draw. + return { VALUE_DRAW, {} }; + } + + // Is it stuck? + if (MoveList(pos).size() == 0) + { + // Return the mated value if checkmated. + return { mated_in(/*ss->ply*/ 0 + 1), {} }; + } + + auto bestValue = Stockfish::qsearch(pos, ss, -VALUE_INFINITE, VALUE_INFINITE, 0); + + // Returns the PV obtained. + std::vector pvs; + for (Move* p = &ss->pv[0]; is_ok(*p); ++p) + pvs.push_back(*p); + + return ValueAndPV(bestValue, pvs); + } + + // Normal search. Depth depth (specified as an integer). + // 3 If you want a score for hand reading, + // auto v = search(pos,3); + // Do something like + // Evaluation value is obtained in v.first and PV is obtained in v.second. + // When multi pv is enabled, you can get the PV (reading line) array in pos.this_thread()->rootMoves[N].pv. + // Specify multi pv with the argument multiPV of this function. (The value of Options["MultiPV"] is ignored) + // + // Declaration win judgment is not done as root (because it is troublesome to handle), so it is not done here. + // Handle it by the caller. + // + // Precondition) Search thread is set by pos.set_this_thread(Threads[thread_id]). + // Also, when Threads.stop arrives, the search is interrupted, so the PV at that time is not correct. + // After returning from search(), if Threads.stop == true, do not use the search result. + // Also, note that before calling, if you do not call it with Threads.stop == false, the search will be interrupted and it will return. + + ValueAndPV search(Position& pos, int depth_, size_t multiPV /* = 1 */, uint64_t nodesLimit /* = 0 */) + { + std::vector pvs; + + Depth depth = depth_; + if (depth < 0) + return std::pair>(Eval::evaluate(pos), std::vector()); + + if (depth == 0) + return qsearch(pos); + + Stack stack[MAX_PLY + 10], * ss = stack + 7; + Move pv[MAX_PLY + 1]; + + if (!init_for_search(pos, ss)) + return {}; + + ss->pv = pv; // For the time being, it must be a dummy and somewhere with a buffer. + + // Initialize the variables related to this_thread + auto th = pos.this_thread(); + auto& rootDepth = th->rootDepth; + auto& pvIdx = th->pvIdx; + auto& pvLast = th->pvLast; + auto& rootMoves = th->rootMoves; + auto& completedDepth = th->completedDepth; + auto& selDepth = th->selDepth; + + // A function to search the top N of this stage as best move + //size_t multiPV = Options["MultiPV"]; + + // Do not exceed the number of moves in this situation + multiPV = std::min(multiPV, rootMoves.size()); + + // If you do not multiply the node limit by the value of MultiPV, you will not be thinking about the same node for one candidate hand when you fix the depth and have MultiPV. + nodesLimit *= multiPV; + + Value alpha = -VALUE_INFINITE; + Value beta = VALUE_INFINITE; + Value delta = -VALUE_INFINITE; + Value bestValue = -VALUE_INFINITE; + + while ((rootDepth += 1) <= depth + // exit this loop even if the node limit is exceeded + // The number of search nodes is passed in the argument of this function. + && !(nodesLimit /* limited nodes */ && th->nodes.load(std::memory_order_relaxed) >= nodesLimit) + ) + { + for (RootMove& rm : rootMoves) + rm.previousScore = rm.score; + + size_t pvFirst = 0; + pvLast = 0; + + // MultiPV loop. We perform a full root search for each PV line + for (pvIdx = 0; pvIdx < multiPV && !Threads.stop; ++pvIdx) + { + if (pvIdx == pvLast) + { + pvFirst = pvLast; + for (pvLast++; pvLast < rootMoves.size(); pvLast++) + if (rootMoves[pvLast].tbRank != rootMoves[pvFirst].tbRank) + break; + } + + // selDepth output with USI info for each depth and PV line + selDepth = 0; + + // Switch to aspiration search for depth 5 and above. + if (rootDepth >= 4) + { + Value prev = rootMoves[pvIdx].previousScore; + delta = Value(17); + alpha = std::max(prev - delta,-VALUE_INFINITE); + beta = std::min(prev + delta, VALUE_INFINITE); + } + + while (true) + { + Depth adjustedDepth = std::max(1, rootDepth); + bestValue = Stockfish::search(pos, ss, alpha, beta, adjustedDepth, false); + + stable_sort(rootMoves.begin() + pvIdx, rootMoves.end()); + //my_stable_sort(pos.this_thread()->thread_id(),&rootMoves[0] + pvIdx, rootMoves.size() - pvIdx); + + // Expand aspiration window for fail low/high. + // However, if it is the value specified by the argument, it will be treated as fail low/high and break. + if (bestValue <= alpha) + { + beta = (alpha + beta) / 2; + alpha = std::max(bestValue - delta, -VALUE_INFINITE); + } + else if (bestValue >= beta) + { + beta = std::min(bestValue + delta, VALUE_INFINITE); + } + else + break; + + delta += delta / 4 + 5; + assert(-VALUE_INFINITE <= alpha && beta <= VALUE_INFINITE); + + // runaway check + //assert(th->nodes.load(std::memory_order_relaxed) <= 1000000 ); + } + + stable_sort(rootMoves.begin(), rootMoves.begin() + pvIdx + 1); + //my_stable_sort(pos.this_thread()->thread_id() , &rootMoves[0] , pvIdx + 1); + + } // multi PV + + completedDepth = rootDepth; + } + + // Pass PV_is(ok) to eliminate this PV, there may be NULL_MOVE in the middle. + // MOVE_WIN has never been thrust. (For now) + for (Move move : rootMoves[0].pv) + { + if (!is_ok(move)) + break; + pvs.push_back(move); + } + + //sync_cout << rootDepth << sync_endl; + + // Considering multiPV, the score of rootMoves[0] is returned as bestValue. + bestValue = rootMoves[0].score; + + return ValueAndPV(bestValue, pvs); + } + } } // namespace Stockfish diff --git a/src/search.h b/src/search.h index 801baacc..73a3bd86 100644 --- a/src/search.h +++ b/src/search.h @@ -24,6 +24,7 @@ #include "misc.h" #include "movepick.h" #include "types.h" +#include "uci.h" namespace Stockfish { @@ -34,6 +35,7 @@ namespace Search { /// Threshold used for countermoves based pruning constexpr int CounterMovePruneThreshold = 0; +extern bool prune_at_shallow_depth; /// Stack struct keeps track of the information we need to remember from nodes /// shallower and deeper in the tree during the search. Each search thread has @@ -90,6 +92,7 @@ struct LimitsType { time[WHITE] = time[BLACK] = inc[WHITE] = inc[BLACK] = npmsec = movetime = TimePoint(0); movestogo = depth = mate = perft = infinite = 0; nodes = 0; + silent = false; } bool use_time_management() const { @@ -100,6 +103,9 @@ struct LimitsType { TimePoint time[COLOR_NB], inc[COLOR_NB], npmsec, movetime, startTime; int movestogo, depth, mate, perft, infinite; int64_t nodes; + // Silent mode that does not output to the screen (for continuous self-play in process) + // Do not output PV at this time. + bool silent; }; extern LimitsType Limits; @@ -107,7 +113,13 @@ extern LimitsType Limits; void init(); void clear(); -} // namespace Search +// A pair of reader and evaluation value. Returned by Tools::search(),Tools::qsearch(). +using ValueAndPV = std::pair>; + +ValueAndPV qsearch(Position& pos); +ValueAndPV search(Position& pos, int depth_, size_t multiPV = 1, uint64_t nodesLimit = 0); + +} } // namespace Stockfish diff --git a/src/syzygy/tbprobe.cpp b/src/syzygy/tbprobe.cpp index 96b2970f..f382edbc 100644 --- a/src/syzygy/tbprobe.cpp +++ b/src/syzygy/tbprobe.cpp @@ -52,7 +52,7 @@ using namespace Stockfish::Tablebases; -int Stockfish::Tablebases::MaxCardinality; +int Stockfish::Tablebases::MaxCardinality = 0; namespace Stockfish { diff --git a/src/thread.cpp b/src/thread.cpp index da8e1d05..5051708d 100644 --- a/src/thread.cpp +++ b/src/thread.cpp @@ -37,6 +37,7 @@ ThreadPool Threads; // Global object Thread::Thread(size_t n) : idx(n), stdThread(&Thread::idle_loop, this) { wait_for_search_finished(); + wait_for_worker_finished(); } @@ -82,6 +83,14 @@ void Thread::start_searching() { cv.notify_one(); // Wake up the thread in idle_loop() } +void Thread::execute_with_worker(std::function t) +{ + std::lock_guard lk(mutex); + worker = std::move(t); + searching = true; + cv.notify_one(); // Wake up the thread in idle_loop() +} + /// Thread::wait_for_search_finished() blocks on the condition variable /// until the thread has finished searching. @@ -93,6 +102,12 @@ void Thread::wait_for_search_finished() { } +void Thread::wait_for_worker_finished() { + + std::unique_lock lk(mutex); + cv.wait(lk, [&]{ return !searching; }); +} + /// Thread::idle_loop() is where the thread is parked, blocked on the /// condition variable, when it has no work to do. @@ -110,15 +125,25 @@ void Thread::idle_loop() { { std::unique_lock lk(mutex); searching = false; + worker = nullptr; cv.notify_one(); // Wake up anyone waiting for search finished cv.wait(lk, [&]{ return searching; }); if (exit) return; + auto wrk = std::move(worker); + lk.unlock(); - search(); + if (wrk) + { + wrk(*this); + } + else + { + search(); + } } } @@ -165,6 +190,13 @@ void ThreadPool::clear() { main()->previousTimeReduction = 1.0; } +void ThreadPool::execute_with_workers(const std::function& worker) +{ + for(Thread* th : *this) + { + th->execute_with_worker(worker); + } +} /// ThreadPool::start_thinking() wakes up main thread waiting in idle_loop() and /// returns immediately. Main thread will wake up other threads and start the search. @@ -185,9 +217,6 @@ void ThreadPool::start_thinking(Position& pos, StateListPtr& states, || std::count(limits.searchmoves.begin(), limits.searchmoves.end(), m)) rootMoves.emplace_back(m); - if (!rootMoves.empty()) - Tablebases::rank_root_moves(pos, rootMoves); - // After ownership transfer 'states' becomes empty, so if we stop the search // and call 'go' again without setting a new position states.get() == NULL. assert(states.get() || setupStates.get()); @@ -263,4 +292,11 @@ void ThreadPool::wait_for_search_finished() const { th->wait_for_search_finished(); } + +void ThreadPool::wait_for_workers_finished() const { + + for (Thread* th : *this) + th->wait_for_worker_finished(); +} + } // namespace Stockfish diff --git a/src/thread.h b/src/thread.h index 5bfa2359..c0218577 100644 --- a/src/thread.h +++ b/src/thread.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "material.h" #include "movepick.h" @@ -39,24 +40,51 @@ namespace Stockfish { /// pointer to an entry its life time is unlimited and we don't have /// to care about someone changing the entry under our feet. +namespace Detail { + + template + struct TypeIdentity { + using Type = T; + }; + +} + class Thread { std::mutex mutex; std::condition_variable cv; size_t idx; bool exit = false, searching = true; // Set before starting std::thread + std::function worker; + std::function on_eval_callback; NativeThread stdThread; public: explicit Thread(size_t); virtual ~Thread(); virtual void search(); + + // The function object to be executed is taken by value to remove + // the need for separate lvalue and rvalue overloads. + // The worker thread needs to have ownership of the task + // to be executed because otherwise there's no way to manage its lifetime. + virtual void execute_with_worker(std::function t); + void clear(); void idle_loop(); void start_searching(); void wait_for_search_finished(); size_t id() const { return idx; } + void wait_for_worker_finished(); + + template + void set_eval_callback(FuncT&& f) { on_eval_callback = std::forward(f); } + + void clear_eval_callback() { on_eval_callback = nullptr; } + + void on_eval() { if (on_eval_callback) on_eval_callback(rootPos); } + Pawns::Table pawnsTable; Material::Table materialTable; size_t pvIdx, pvLast; @@ -75,6 +103,11 @@ public: CapturePieceToHistory captureHistory; ContinuationHistory continuationHistory[2][2]; Score trend; + int failedHighCnt; + bool rootInTB; + int Cardinality; + bool UseRule50; + Depth ProbeDepth; }; @@ -102,6 +135,61 @@ struct MainThread : public Thread { struct ThreadPool : public std::vector { + // Each thread gets its own copy of the `worker` function object. + // This means that each worker thread will have exclusive access + // to the state of the `worker` function object. + void execute_with_workers(const std::function& worker); + + template + void for_each_index_with_workers( + IndexT begin, + typename Detail::TypeIdentity::Type end, + FuncT func) + { + // This value must outlive the function call. + // It's fairly safe if we make it static + // because for_each_index_with_workers + // is not reentrant nor thread safe. + static std::atomic i_atomic; + i_atomic.store(begin); + + execute_with_workers( + [end, func](Thread& th) mutable { + for(;;) { + const auto i = i_atomic.fetch_add(1); + if (i >= end) + break; + + func(th, i); + } + }); + } + + template + void for_each_index_chunk_with_workers( + IndexT begin, + typename Detail::TypeIdentity::Type end, + FuncT func) + { + // This value must outlive the function call. + // It's fairly safe if we make it static + // because for_each_index_with_workers + // is not reentrant nor thread safe. + const IndexT size = end - begin; + const IndexT chunk_size = (size + this->size()) / this->size(); + + execute_with_workers( + [chunk_size, end, func](Thread& th) mutable { + const IndexT thread_id = th.id(); + const IndexT offset = chunk_size * thread_id; + if (offset >= end) + return; + + const IndexT count = offset + chunk_size > end ? end - offset : chunk_size; + func(th, offset, count); + }); + } + void start_thinking(Position&, StateListPtr&, const Search::LimitsType&, bool = false); void clear(); void set(size_t); @@ -112,6 +200,7 @@ struct ThreadPool : public std::vector { Thread* get_best_thread() const; void start_searching(); void wait_for_search_finished() const; + void wait_for_workers_finished() const; std::atomic_bool stop, increaseDepth; diff --git a/src/tools/convert.cpp b/src/tools/convert.cpp new file mode 100644 index 00000000..03f3e4a7 --- /dev/null +++ b/src/tools/convert.cpp @@ -0,0 +1,815 @@ +#include "convert.h" + +#include "uci.h" +#include "misc.h" +#include "thread.h" +#include "position.h" +#include "tt.h" + +#include "extra/nnue_data_binpack_format.h" + +#include "nnue/evaluate_nnue.h" + +#include "syzygy/tbprobe.h" + +#include +#include +#include +#include +#include +#include // std::exp(),std::pow(),std::log() +#include // memcpy() +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +namespace sys = std::filesystem; + +namespace Stockfish::Tools +{ + bool fen_is_ok(Position& pos, std::string input_fen) { + std::string pos_fen = pos.fen(); + std::istringstream ss_input(input_fen); + std::istringstream ss_pos(pos_fen); + + // example : "2r4r/4kpp1/nb1np3/p2p3p/B2P1BP1/PP6/4NPKP/2R1R3 w - h6 0 24" + // --> "2r4r/4kpp1/nb1np3/p2p3p/B2P1BP1/PP6/4NPKP/2R1R3" + std::string str_input, str_pos; + ss_input >> str_input; + ss_pos >> str_pos; + + // Only compare "Piece placement field" between input_fen and pos.fen(). + return str_input == str_pos; + } + + void convert_bin( + const vector& filenames, + const string& output_file_name, + const int ply_minimum, + const int ply_maximum, + const int interpolate_eval, + const int src_score_min_value, + const int src_score_max_value, + const int dest_score_min_value, + const int dest_score_max_value, + const bool check_invalid_fen, + const bool check_illegal_move) + { + std::cout << "check_invalid_fen=" << check_invalid_fen << std::endl; + std::cout << "check_illegal_move=" << check_illegal_move << std::endl; + + std::fstream fs; + uint64_t data_size = 0; + uint64_t filtered_size = 0; + uint64_t filtered_size_fen = 0; + uint64_t filtered_size_move = 0; + uint64_t filtered_size_ply = 0; + auto th = Threads.main(); + auto& tpos = th->rootPos; + // convert plain rag to packed sfenvalue for Yaneura king + fs.open(output_file_name, ios::app | ios::binary); + StateListPtr states; + for (auto filename : filenames) { + std::cout << "convert " << filename << " ... "; + std::string line; + ifstream ifs; + ifs.open(filename); + PackedSfenValue p; + data_size = 0; + filtered_size = 0; + filtered_size_fen = 0; + filtered_size_move = 0; + filtered_size_ply = 0; + p.gamePly = 1; // Not included in apery format. Should be initialized + bool ignore_flag_fen = false; + bool ignore_flag_move = false; + bool ignore_flag_ply = false; + while (std::getline(ifs, line)) { + std::stringstream ss(line); + std::string token; + std::string value; + ss >> token; + if (token == "fen") { + states = StateListPtr(new std::deque(1)); // Drop old and create a new one + std::string input_fen = line.substr(4); + tpos.set(input_fen, false, &states->back(), Threads.main()); + if (check_invalid_fen && !fen_is_ok(tpos, input_fen)) { + ignore_flag_fen = true; + filtered_size_fen++; + } + else { + tpos.sfen_pack(p.sfen); + } + } + else if (token == "move") { + ss >> value; + Move move = UCI::to_move(tpos, value); + if (check_illegal_move && move == MOVE_NONE) { + ignore_flag_move = true; + filtered_size_move++; + } + else { + p.move = move; + } + } + else if (token == "score") { + double score; + ss >> score; + // Training Formula ?Issue #71 ?nodchip/Stockfish https://github.com/nodchip/Stockfish/issues/71 + // Normalize to [0.0, 1.0]. + score = (score - src_score_min_value) / (src_score_max_value - src_score_min_value); + // Scale to [dest_score_min_value, dest_score_max_value]. + score = score * (dest_score_max_value - dest_score_min_value) + dest_score_min_value; + p.score = std::clamp((int32_t)std::round(score), -(int32_t)VALUE_MATE, (int32_t)VALUE_MATE); + } + else if (token == "ply") { + int temp; + ss >> temp; + if (temp < ply_minimum || temp > ply_maximum) { + ignore_flag_ply = true; + filtered_size_ply++; + } + p.gamePly = uint16_t(temp); // No cast here? + if (interpolate_eval != 0) { + p.score = min(3000, interpolate_eval * temp); + } + } + else if (token == "result") { + int temp; + ss >> temp; + p.game_result = int8_t(temp); // Do you need a cast here? + if (interpolate_eval) { + p.score = p.score * p.game_result; + } + } + else if (token == "e") { + if (!(ignore_flag_fen || ignore_flag_move || ignore_flag_ply)) { + fs.write((char*)&p, sizeof(PackedSfenValue)); + data_size += 1; + // debug + // std::cout< 0.25 * PawnValueEg + // #-4 --> -mate_in(4) + // #3 --> mate_in(3) + // -M4 --> -mate_in(4) + // +M3 --> mate_in(3) + Value parse_score_from_pgn_extract(std::string eval, bool& success) { + success = true; + + if (eval.substr(0, 1) == "#") { + if (eval.substr(1, 1) == "-") { + return -mate_in(stoi(eval.substr(2, eval.length() - 2))); + } + else { + return mate_in(stoi(eval.substr(1, eval.length() - 1))); + } + } + else if (eval.substr(0, 2) == "-M") { + //std::cout << "eval=" << eval << std::endl; + return -mate_in(stoi(eval.substr(2, eval.length() - 2))); + } + else if (eval.substr(0, 2) == "+M") { + //std::cout << "eval=" << eval << std::endl; + return mate_in(stoi(eval.substr(2, eval.length() - 2))); + } + else { + char* endptr; + double value = strtod(eval.c_str(), &endptr); + + if (*endptr != '\0') { + success = false; + return VALUE_ZERO; + } + else { + return Value(value * static_cast(PawnValueEg)); + } + } + } + + // for Debug + //#define DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT + + bool is_like_fen(std::string fen) { + int count_space = std::count(fen.cbegin(), fen.cend(), ' '); + int count_slash = std::count(fen.cbegin(), fen.cend(), '/'); + +#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT) + //std::cout << "count_space=" << count_space << std::endl; + //std::cout << "count_slash=" << count_slash << std::endl; +#endif + + return count_space == 5 && count_slash == 7; + } + + void convert_bin_from_pgn_extract( + const vector& filenames, + const string& output_file_name, + const bool pgn_eval_side_to_move, + const bool convert_no_eval_fens_as_score_zero) + { + std::cout << "pgn_eval_side_to_move=" << pgn_eval_side_to_move << std::endl; + std::cout << "convert_no_eval_fens_as_score_zero=" << convert_no_eval_fens_as_score_zero << std::endl; + + auto th = Threads.main(); + auto& pos = th->rootPos; + + std::fstream ofs; + ofs.open(output_file_name, ios::out | ios::binary); + + int game_count = 0; + int fen_count = 0; + + for (auto filename : filenames) { + std::cout << now_string() << " convert " << filename << std::endl; + ifstream ifs; + ifs.open(filename); + + int game_result = 0; + + std::string line; + while (std::getline(ifs, line)) { + + if (line.empty()) { + continue; + } + + else if (line.substr(0, 1) == "[") { + std::regex pattern_result(R"(\[Result (.+?)\])"); + std::smatch match; + + // example: [Result "1-0"] + if (std::regex_search(line, match, pattern_result)) { + game_result = parse_game_result_from_pgn_extract(match.str(1)); +#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT) + std::cout << "game_result=" << game_result << std::endl; +#endif + game_count++; + if (game_count % 10000 == 0) { + std::cout << now_string() << " game_count=" << game_count << ", fen_count=" << fen_count << std::endl; + } + } + + continue; + } + + else { + int gamePly = 1; + auto itr = line.cbegin(); + + while (true) { + gamePly++; + + PackedSfenValue psv; + memset((char*)&psv, 0, sizeof(PackedSfenValue)); + + // fen + { + bool fen_found = false; + + while (!fen_found) { + std::regex pattern_bracket(R"(\{(.+?)\})"); + std::smatch match; + if (!std::regex_search(itr, line.cend(), match, pattern_bracket)) { + break; + } + + itr += match.position(0) + match.length(0) - 1; + std::string str_fen = match.str(1); + trim(str_fen); + + if (is_like_fen(str_fen)) { + fen_found = true; + + StateInfo si; + pos.set(str_fen, false, &si, th); + pos.sfen_pack(psv.sfen); + } + +#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT) + std::cout << "str_fen=" << str_fen << std::endl; + std::cout << "fen_found=" << fen_found << std::endl; +#endif + } + + if (!fen_found) { + break; + } + } + + // move + { + std::regex pattern_move(R"(\}(.+?)\{)"); + std::smatch match; + if (!std::regex_search(itr, line.cend(), match, pattern_move)) { + break; + } + + itr += match.position(0) + match.length(0) - 1; + std::string str_move = match.str(1); + trim(str_move); +#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT) + std::cout << "str_move=" << str_move << std::endl; +#endif + psv.move = UCI::to_move(pos, str_move); + } + + // eval + bool eval_found = false; + { + std::regex pattern_bracket(R"(\{(.+?)\})"); + std::smatch match; + if (!std::regex_search(itr, line.cend(), match, pattern_bracket)) { + break; + } + + std::string str_eval_clk = match.str(1); + trim(str_eval_clk); +#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT) + std::cout << "str_eval_clk=" << str_eval_clk << std::endl; +#endif + + // example: { [%eval 0.25] [%clk 0:10:00] } + // example: { [%eval #-4] [%clk 0:10:00] } + // example: { [%eval #3] [%clk 0:10:00] } + // example: { +0.71/22 1.2s } + // example: { -M4/7 0.003s } + // example: { M3/245 0.017s } + // example: { +M1/245 0.010s, White mates } + // example: { 0.60 } + // example: { book } + // example: { rnbqkb1r/pp3ppp/2p1pn2/3p4/2PP4/2N2N2/PP2PPPP/R1BQKB1R w KQkq - 0 5 } + + // Considering the absence of eval + if (!is_like_fen(str_eval_clk)) { + itr += match.position(0) + match.length(0) - 1; + + if (str_eval_clk != "book") { + std::regex pattern_eval1(R"(\[\%eval (.+?)\])"); + std::regex pattern_eval2(R"((.+?)\/)"); + + std::string str_eval; + if (std::regex_search(str_eval_clk, match, pattern_eval1) || + std::regex_search(str_eval_clk, match, pattern_eval2)) { + str_eval = match.str(1); + trim(str_eval); + } + else { + str_eval = str_eval_clk; + } + + bool success = false; + Value value = parse_score_from_pgn_extract(str_eval, success); + if (success) { + eval_found = true; + psv.score = std::clamp(value, -VALUE_MATE, VALUE_MATE); + } + +#if defined(DEBUG_CONVERT_BIN_FROM_PGN_EXTRACT) + std::cout << "str_eval=" << str_eval << std::endl; + std::cout << "success=" << success << ", psv.score=" << psv.score << std::endl; +#endif + } + } + } + + // write + if (eval_found || convert_no_eval_fens_as_score_zero) { + if (!eval_found && convert_no_eval_fens_as_score_zero) { + psv.score = 0; + } + + psv.gamePly = gamePly; + psv.game_result = game_result; + + if (pos.side_to_move() == BLACK) { + if (!pgn_eval_side_to_move) { + psv.score *= -1; + } + psv.game_result *= -1; + } + + ofs.write((char*)&psv, sizeof(PackedSfenValue)); + + fen_count++; + } + } + + game_result = 0; + } + } + } + + std::cout << now_string() << " game_count=" << game_count << ", fen_count=" << fen_count << std::endl; + std::cout << now_string() << " all done" << std::endl; + ofs.close(); + } + + void convert_plain( + const vector& filenames, + const string& output_file_name) + { + Position tpos; + std::ofstream ofs; + ofs.open(output_file_name, ios::app); + auto th = Threads.main(); + for (auto filename : filenames) { + std::cout << "convert " << filename << " ... "; + + // Just convert packedsfenvalue to text + std::fstream fs; + fs.open(filename, ios::in | ios::binary); + PackedSfenValue p; + while (true) + { + if (fs.read((char*)&p, sizeof(PackedSfenValue))) { + StateInfo si; + tpos.set_from_packed_sfen(p.sfen, &si, th); + + // write as plain text + ofs << "fen " << tpos.fen() << std::endl; + ofs << "move " << UCI::move(Move(p.move), false) << std::endl; + ofs << "score " << p.score << std::endl; + ofs << "ply " << int(p.gamePly) << std::endl; + ofs << "result " << int(p.game_result) << std::endl; + ofs << "e" << std::endl; + } + else { + break; + } + } + fs.close(); + std::cout << "done" << std::endl; + } + ofs.close(); + std::cout << "all done" << std::endl; + } + + static inline const std::string plain_extension = ".plain"; + static inline const std::string bin_extension = ".bin"; + static inline const std::string binpack_extension = ".binpack"; + + static bool file_exists(const std::string& name) + { + std::ifstream f(name); + return f.good(); + } + + static bool ends_with(const std::string& lhs, const std::string& end) + { + if (end.size() > lhs.size()) return false; + + return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); + } + + static bool is_convert_of_type( + const std::string& input_path, + const std::string& output_path, + const std::string& expected_input_extension, + const std::string& expected_output_extension) + { + return ends_with(input_path, expected_input_extension) + && ends_with(output_path, expected_output_extension); + } + + using ConvertFunctionType = void(std::string inputPath, std::string outputPath, std::ios_base::openmode om, bool validate); + + static ConvertFunctionType* get_convert_function(const std::string& input_path, const std::string& output_path) + { + if (is_convert_of_type(input_path, output_path, plain_extension, bin_extension)) + return binpack::convertPlainToBin; + if (is_convert_of_type(input_path, output_path, plain_extension, binpack_extension)) + return binpack::convertPlainToBinpack; + + if (is_convert_of_type(input_path, output_path, bin_extension, plain_extension)) + return binpack::convertBinToPlain; + if (is_convert_of_type(input_path, output_path, bin_extension, binpack_extension)) + return binpack::convertBinToBinpack; + + if (is_convert_of_type(input_path, output_path, binpack_extension, plain_extension)) + return binpack::convertBinpackToPlain; + if (is_convert_of_type(input_path, output_path, binpack_extension, bin_extension)) + return binpack::convertBinpackToBin; + + return nullptr; + } + + static void convert(const std::string& input_path, const std::string& output_path, std::ios_base::openmode om, bool validate) + { + if(!file_exists(input_path)) + { + std::cerr << "Input file does not exist.\n"; + return; + } + + auto func = get_convert_function(input_path, output_path); + if (func != nullptr) + { + func(input_path, output_path, om, validate); + } + else + { + std::cerr << "Conversion between files of these types is not supported.\n"; + } + } + + static void convert(const std::vector& args) + { + if (args.size() < 2 || args.size() > 4) + { + std::cerr << "Invalid arguments.\n"; + std::cerr << "Usage: convert from_path to_path [append] [validate]\n"; + return; + } + + const bool append = std::find(args.begin() + 2, args.end(), "append") != args.end(); + const bool validate = std::find(args.begin() + 2, args.end(), "validate") != args.end(); + + const std::ios_base::openmode openmode = + append + ? std::ios_base::app + : std::ios_base::trunc; + + convert(args[0], args[1], openmode, validate); + } + + void convert(istringstream& is) + { + std::vector args; + + while (true) + { + std::string token = ""; + is >> token; + if (token == "") + break; + + args.push_back(token); + } + + convert(args); + } + + static void append_files_from_dir( + std::vector& filenames, + const std::string& base_dir, + const std::string& target_dir) + { + string kif_base_dir = Path::combine(base_dir, target_dir); + + sys::path p(kif_base_dir); // Origin of enumeration + std::for_each(sys::directory_iterator(p), sys::directory_iterator(), + [&](const sys::path& path) { + if (sys::is_regular_file(path)) + filenames.push_back(Path::combine(target_dir, path.filename().generic_string())); + }); + } + + static void rebase_files( + std::vector& filenames, + const std::string& base_dir) + { + for (auto& file : filenames) + { + file = Path::combine(base_dir, file); + } + } + + void convert_bin_from_pgn_extract(std::istringstream& is) + { + std::vector filenames; + + string base_dir; + string target_dir; + + bool pgn_eval_side_to_move = false; + bool convert_no_eval_fens_as_score_zero = false; + + string output_file_name = "shuffled_sfen.bin"; + + while (true) + { + string option; + is >> option; + + if (option == "") + break; + + if (option == "targetdir") is >> target_dir; + else if (option == "targetfile") + { + std::string filename; + is >> filename; + filenames.push_back(filename); + } + + else if (option == "basedir") is >> base_dir; + + else if (option == "pgn_eval_side_to_move") is >> pgn_eval_side_to_move; + else if (option == "convert_no_eval_fens_as_score_zero") is >> convert_no_eval_fens_as_score_zero; + else if (option == "output_file_name") is >> output_file_name; + else + { + cout << "Unknown option: " << option << ". Ignoring.\n"; + } + } + + if (!target_dir.empty()) + { + append_files_from_dir(filenames, base_dir, target_dir); + } + rebase_files(filenames, base_dir); + + Eval::NNUE::init(); + + cout << "convert_bin_from_pgn-extract.." << endl; + convert_bin_from_pgn_extract( + filenames, + output_file_name, + pgn_eval_side_to_move, + convert_no_eval_fens_as_score_zero); + } + + void convert_bin(std::istringstream& is) + { + std::vector filenames; + + string base_dir; + string target_dir; + + int ply_minimum = 0; + int ply_maximum = 114514; + bool interpolate_eval = 0; + bool check_invalid_fen = false; + bool check_illegal_move = false; + + bool pgn_eval_side_to_move = false; + bool convert_no_eval_fens_as_score_zero = false; + + double src_score_min_value = 0.0; + double src_score_max_value = 1.0; + double dest_score_min_value = 0.0; + double dest_score_max_value = 1.0; + + string output_file_name = "shuffled_sfen.bin"; + + while (true) + { + string option; + is >> option; + + if (option == "") + break; + + if (option == "targetdir") is >> target_dir; + else if (option == "targetfile") + { + std::string filename; + is >> filename; + filenames.push_back(filename); + } + + else if (option == "basedir") is >> base_dir; + + else if (option == "ply_minimum") is >> ply_minimum; + else if (option == "ply_maximum") is >> ply_maximum; + else if (option == "interpolate_eval") is >> interpolate_eval; + else if (option == "check_invalid_fen") is >> check_invalid_fen; + else if (option == "check_illegal_move") is >> check_illegal_move; + else if (option == "pgn_eval_side_to_move") is >> pgn_eval_side_to_move; + else if (option == "convert_no_eval_fens_as_score_zero") is >> convert_no_eval_fens_as_score_zero; + else if (option == "src_score_min_value") is >> src_score_min_value; + else if (option == "src_score_max_value") is >> src_score_max_value; + else if (option == "dest_score_min_value") is >> dest_score_min_value; + else if (option == "dest_score_max_value") is >> dest_score_max_value; + else if (option == "output_file_name") is >> output_file_name; + else + { + cout << "Unknown option: " << option << ". Ignoring.\n"; + } + } + + if (!target_dir.empty()) + { + append_files_from_dir(filenames, base_dir, target_dir); + } + rebase_files(filenames, base_dir); + + Eval::NNUE::init(); + + cout << "convert_bin.." << endl; + convert_bin( + filenames, + output_file_name, + ply_minimum, + ply_maximum, + interpolate_eval, + src_score_min_value, + src_score_max_value, + dest_score_min_value, + dest_score_max_value, + check_invalid_fen, + check_illegal_move + ); + } + + void convert_plain(std::istringstream& is) + { + std::vector filenames; + + string base_dir; + string target_dir; + + string output_file_name = "shuffled_sfen.bin"; + + while (true) + { + string option; + is >> option; + + if (option == "") + break; + + if (option == "targetdir") is >> target_dir; + else if (option == "targetfile") + { + std::string filename; + is >> filename; + filenames.push_back(filename); + } + + else if (option == "basedir") is >> base_dir; + + else if (option == "output_file_name") is >> output_file_name; + else + { + cout << "Unknown option: " << option << ". Ignoring.\n"; + } + } + + if (!target_dir.empty()) + { + append_files_from_dir(filenames, base_dir, target_dir); + } + rebase_files(filenames, base_dir); + + Eval::NNUE::init(); + + cout << "convert_plain.." << endl; + convert_plain(filenames, output_file_name); + } +} diff --git a/src/tools/convert.h b/src/tools/convert.h new file mode 100644 index 00000000..0bd9ee23 --- /dev/null +++ b/src/tools/convert.h @@ -0,0 +1,18 @@ +#ifndef _CONVERT_H_ +#define _CONVERT_H_ + +#include +#include +#include + +namespace Stockfish::Tools { + void convert(std::istringstream& is); + + void convert_bin_from_pgn_extract(std::istringstream& is); + + void convert_bin(std::istringstream& is); + + void convert_plain(std::istringstream& is); +} + +#endif diff --git a/src/tools/opening_book.cpp b/src/tools/opening_book.cpp new file mode 100644 index 00000000..63ff7ed1 --- /dev/null +++ b/src/tools/opening_book.cpp @@ -0,0 +1,43 @@ +#include "opening_book.h" + +#include + +namespace Stockfish::Tools { + + EpdOpeningBook::EpdOpeningBook(const std::string& file, PRNG& prng) : + OpeningBook(file) + { + std::ifstream in(file); + if (!in) + { + return; + } + + std::string line; + while (std::getline(in, line)) + { + if (line.empty()) + continue; + + fens.emplace_back(line); + } + + Algo::shuffle(fens, prng); + } + + static bool ends_with(const std::string& lhs, const std::string& end) + { + if (end.size() > lhs.size()) return false; + + return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); + } + + std::unique_ptr open_opening_book(const std::string& filename, PRNG& prng) + { + if (ends_with(filename, ".epd")) + return std::make_unique(filename, prng); + + return nullptr; + } + +} diff --git a/src/tools/opening_book.h b/src/tools/opening_book.h new file mode 100644 index 00000000..8323ed10 --- /dev/null +++ b/src/tools/opening_book.h @@ -0,0 +1,60 @@ +#ifndef LEARN_OPENING_BOOK_H +#define LEARN_OPENING_BOOK_H + +#include "misc.h" +#include "position.h" +#include "thread.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace Stockfish::Tools { + + struct OpeningBook { + + const std::string& next_fen() + { + assert(fens.size() > 0); + + std::unique_lock lock(mutex); + + auto& fen = fens[current_index++]; + if (current_index >= fens.size()) + current_index = 0; + + return fen; + } + + std::size_t size() const { return fens.size(); } + + const std::string& get_filename() const { return filename; } + + protected: + OpeningBook(const std::string& file) : + filename(file), + current_index(0) + { + } + + + std::mutex mutex; + std::string filename; + std::vector fens; + std::size_t current_index; + }; + + struct EpdOpeningBook : OpeningBook { + + EpdOpeningBook(const std::string& file, PRNG& prng); + }; + + std::unique_ptr open_opening_book(const std::string& filename, PRNG& prng); + +} + +#endif diff --git a/src/tools/packed_sfen.h b/src/tools/packed_sfen.h new file mode 100644 index 00000000..969405ea --- /dev/null +++ b/src/tools/packed_sfen.h @@ -0,0 +1,46 @@ +#ifndef _PACKED_SFEN_H_ +#define _PACKED_SFEN_H_ + +#include +#include + +namespace Stockfish::Tools { + + // packed sfen + struct PackedSfen { std::uint8_t data[32]; }; + + // Structure in which PackedSfen and evaluation value are integrated + // If you write different contents for each option, it will be a problem when reusing the teacher game + // For the time being, write all the following members regardless of the options. + struct PackedSfenValue + { + // phase + PackedSfen sfen; + + // Evaluation value returned from Tools::search() + std::int16_t score; + + // PV first move + // Used when finding the match rate with the teacher + std::uint16_t move; + + // Trouble of the phase from the initial phase. + std::uint16_t gamePly; + + // 1 if the player on this side ultimately wins the game. -1 if you are losing. + // 0 if a draw is reached. + // The draw is in the teacher position generation command gensfen, + // Only write if LEARN_GENSFEN_DRAW_RESULT is enabled. + std::int8_t game_result; + + // When exchanging the file that wrote the teacher aspect with other people + //Because this structure size is not fixed, pad it so that it is 40 bytes in any environment. + std::uint8_t padding; + + // 32 + 2 + 2 + 2 + 1 + 1 = 40bytes + }; + + // Phase array: PSVector stands for packed sfen vector. + using PSVector = std::vector; +} +#endif diff --git a/src/tools/sfen_packer.cpp b/src/tools/sfen_packer.cpp new file mode 100644 index 00000000..8182503c --- /dev/null +++ b/src/tools/sfen_packer.cpp @@ -0,0 +1,384 @@ +#include "sfen_packer.h" + +#include "packed_sfen.h" + +#include "misc.h" +#include "position.h" + +#include +#include +#include // std::memset() + +using namespace std; + +namespace Stockfish::Tools { + + // Class that handles bitstream + // useful when doing aspect encoding + struct BitStream + { + // Set the memory to store the data in advance. + // Assume that memory is cleared to 0. + void set_data(std::uint8_t* data_) { data = data_; reset(); } + + // Get the pointer passed in set_data(). + uint8_t* get_data() const { return data; } + + // Get the cursor. + int get_cursor() const { return bit_cursor; } + + // reset the cursor + void reset() { bit_cursor = 0; } + + // Write 1bit to the stream. + // If b is non-zero, write out 1. If 0, write 0. + void write_one_bit(int b) + { + if (b) + data[bit_cursor / 8] |= 1 << (bit_cursor & 7); + + ++bit_cursor; + } + + // Get 1 bit from the stream. + int read_one_bit() + { + int b = (data[bit_cursor / 8] >> (bit_cursor & 7)) & 1; + ++bit_cursor; + + return b; + } + + // write n bits of data + // Data shall be written out from the lower order of d. + void write_n_bit(int d, int n) + { + for (int i = 0; i = RANK_1; --r) + { + for (File f = FILE_A; f <= FILE_H; ++f) + { + Piece pc = pos.piece_on(make_square(f, r)); + if (type_of(pc) == KING) + continue; + write_board_piece_to_stream(pc); + } + } + + // TODO(someone): Support chess960. + stream.write_one_bit(pos.can_castle(WHITE_OO)); + stream.write_one_bit(pos.can_castle(WHITE_OOO)); + stream.write_one_bit(pos.can_castle(BLACK_OO)); + stream.write_one_bit(pos.can_castle(BLACK_OOO)); + + if (pos.ep_square() == SQ_NONE) { + stream.write_one_bit(0); + } + else { + stream.write_one_bit(1); + stream.write_n_bit(static_cast(pos.ep_square()), 6); + } + + stream.write_n_bit(pos.state()->rule50, 6); + + const int fm = 1 + (pos.game_ply()-(pos.side_to_move() == BLACK)) / 2; + stream.write_n_bit(fm, 8); + + // Write high bits of half move. This is a fix for the + // limited range of half move counter. + // This is backwards compatibile. + stream.write_n_bit(fm >> 8, 8); + + // Write the highest bit of rule50 at the end. This is a backwards + // compatibile fix for rule50 having only 6 bits stored. + // This bit is just ignored by the old parsers. + stream.write_n_bit(pos.state()->rule50 >> 6, 1); + + assert(stream.get_cursor() <= 256); + } + + // Output the board pieces to stream. + void SfenPacker::write_board_piece_to_stream(Piece pc) + { + // piece type + PieceType pr = type_of(pc); + auto c = huffman_table[pr]; + stream.write_n_bit(c.code, c.bits); + + if (pc == NO_PIECE) + return; + + // first and second flag + stream.write_one_bit(color_of(pc)); + } + + // Read one board piece from stream + Piece SfenPacker::read_board_piece_from_stream() + { + PieceType pr = NO_PIECE_TYPE; + int code = 0, bits = 0; + while (true) + { + code |= stream.read_one_bit() << bits; + ++bits; + + assert(bits <= 6); + + for (pr = NO_PIECE_TYPE; pr (reinterpret_cast(&sfen))); + + pos.clear(); + std::memset(si, 0, sizeof(StateInfo)); + si->accumulator.state[WHITE] = Eval::NNUE::INIT; + si->accumulator.state[BLACK] = Eval::NNUE::INIT; + pos.st = si; + + // Active color + pos.sideToMove = (Color)stream.read_one_bit(); + + // First the position of the ball + for (auto c : Colors) + pos.board[stream.read_n_bit(6)] = make_piece(c, KING); + + // Piece placement + for (Rank r = RANK_8; r >= RANK_1; --r) + { + for (File f = FILE_A; f <= FILE_H; ++f) + { + auto sq = make_square(f, r); + + // it seems there are already balls + Piece pc; + if (type_of(pos.board[sq]) != KING) + { + assert(pos.board[sq] == NO_PIECE); + pc = packer.read_board_piece_from_stream(); + } + else + { + pc = pos.board[sq]; + // put_piece() will catch ASSERT unless you remove it all. + pos.board[sq] = NO_PIECE; + } + + // There may be no pieces, so skip in that case. + if (pc == NO_PIECE) + continue; + + pos.put_piece(Piece(pc), sq); + + if (stream.get_cursor()> 256) + return 1; + } + } + + // Castling availability. + // TODO(someone): Support chess960. + pos.st->castlingRights = 0; + if (stream.read_one_bit()) { + Square rsq; + for (rsq = relative_square(WHITE, SQ_H1); pos.piece_on(rsq) != W_ROOK; --rsq) {} + pos.set_castling_right(WHITE, rsq); + } + if (stream.read_one_bit()) { + Square rsq; + for (rsq = relative_square(WHITE, SQ_A1); pos.piece_on(rsq) != W_ROOK; ++rsq) {} + pos.set_castling_right(WHITE, rsq); + } + if (stream.read_one_bit()) { + Square rsq; + for (rsq = relative_square(BLACK, SQ_H1); pos.piece_on(rsq) != B_ROOK; --rsq) {} + pos.set_castling_right(BLACK, rsq); + } + if (stream.read_one_bit()) { + Square rsq; + for (rsq = relative_square(BLACK, SQ_A1); pos.piece_on(rsq) != B_ROOK; ++rsq) {} + pos.set_castling_right(BLACK, rsq); + } + + // En passant square. Ignore if no pawn capture is possible + if (stream.read_one_bit()) { + Square ep_square = static_cast(stream.read_n_bit(6)); + pos.st->epSquare = ep_square; + + if (!(pos.attackers_to(pos.st->epSquare) & pos.pieces(pos.sideToMove, PAWN)) + || !(pos.pieces(~pos.sideToMove, PAWN) & (pos.st->epSquare + pawn_push(~pos.sideToMove)))) + pos.st->epSquare = SQ_NONE; + } + else { + pos.st->epSquare = SQ_NONE; + } + + // Halfmove clock + pos.st->rule50 = stream.read_n_bit(6); + + // Fullmove number + pos.gamePly = stream.read_n_bit(8); + + // Read the highest bit of rule50. This was added as a fix for rule50 + // counter having only 6 bits stored. + // In older entries this will just be a zero bit. + pos.gamePly |= stream.read_n_bit(8) << 8; + + // Read the highest bit of rule50. This was added as a fix for rule50 + // counter having only 6 bits stored. + // In older entries this will just be a zero bit. + pos.st->rule50 |= stream.read_n_bit(1) << 6; + + // Convert from fullmove starting from 1 to gamePly starting from 0, + // handle also common incorrect FEN with fullmove = 0. + pos.gamePly = std::max(2 * (pos.gamePly - 1), 0) + (pos.sideToMove == BLACK); + + assert(stream.get_cursor() <= 256); + + pos.chess960 = false; + pos.thisThread = th; + pos.set_state(pos.st); + + assert(pos.pos_is_ok()); + + return 0; + } + + PackedSfen sfen_pack(Position& pos) + { + PackedSfen sfen; + + SfenPacker sp; + sp.data = (uint8_t*)&sfen; + sp.pack(pos); + + return sfen; + } +} diff --git a/src/tools/sfen_packer.h b/src/tools/sfen_packer.h new file mode 100644 index 00000000..42093043 --- /dev/null +++ b/src/tools/sfen_packer.h @@ -0,0 +1,22 @@ +#ifndef _SFEN_PACKER_H_ +#define _SFEN_PACKER_H_ + +#include "types.h" + +#include "packed_sfen.h" + +#include + +namespace Stockfish { + class Position; + struct StateInfo; + class Thread; +} + +namespace Stockfish::Tools { + + int set_from_packed_sfen(Position& pos, const PackedSfen& sfen, StateInfo* si, Thread* th); + PackedSfen sfen_pack(Position& pos); +} + +#endif \ No newline at end of file diff --git a/src/tools/sfen_reader.h b/src/tools/sfen_reader.h new file mode 100644 index 00000000..28a6104d --- /dev/null +++ b/src/tools/sfen_reader.h @@ -0,0 +1,352 @@ +#include "sfen_stream.h" + +#include "packed_sfen.h" + +#include "misc.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Stockfish::Tools{ + + enum struct SfenReaderMode + { + Sequential, + Cyclic + }; + + // Sfen reader + struct SfenReader + { + // Number of phases buffered by each thread 0.1M phases. 4M phase at 40HT + static constexpr size_t DEFAULT_THREAD_BUFFER_SIZE = 10 * 1000; + + // Buffer for reading files (If this is made larger, + // the shuffle becomes larger and the phases may vary. + // If it is too large, the memory consumption will increase. + // SFEN_READ_SIZE is a multiple of THREAD_BUFFER_SIZE. + static constexpr const size_t DEFAULT_SFEN_READ_SIZE = 1000 * 1000 * 10; + + // Do not use std::random_device(). + // Because it always the same integers on MinGW. + SfenReader( + const std::vector& filenames_, + bool do_shuffle, + SfenReaderMode mode_, + int thread_num, + const std::string& seed, + size_t read_size = DEFAULT_SFEN_READ_SIZE, + size_t buffer_size = DEFAULT_THREAD_BUFFER_SIZE + ) : + filenames(filenames_.begin(), filenames_.end()), + mode(mode_), + // Due to the implementation of waiting for buffer empty a bit + // the read size must be at least twice the buffer size. + sfen_read_size(std::max(read_size, buffer_size * 2)), + thread_buffer_size(buffer_size), + prng(seed) + { + packed_sfens.resize(thread_num); + total_read = 0; + end_of_files = false; + shuffle = do_shuffle; + stop_flag = false; + num_buffers_in_pool.store(0); + + file_worker_thread = std::thread([&] { + this->file_read_worker(); + }); + } + + ~SfenReader() + { + stop_flag = true; + + if (file_worker_thread.joinable()) + file_worker_thread.join(); + } + + // Load the phase for calculation such as mse. + PSVector read_some(uint64_t count, uint64_t count_tries, std::function do_take) + { + PSVector psv; + psv.reserve(count); + + for (uint64_t i = 0; i < count_tries; ++i) + { + PackedSfenValue ps; + if (!read_to_thread_buffer(0, ps)) + { + std::cout << "ERROR (sfen_reader): Reading failed." << std::endl; + return psv; + } + + if (do_take(ps)) + { + psv.push_back(ps); + + if (psv.size() >= count) + break; + } + } + + return psv; + } + + // [ASYNC] Thread returns one aspect. Otherwise returns false. + bool read_to_thread_buffer(size_t thread_id, PackedSfenValue& ps) + { + // If there are any positions left in the thread buffer + // then retrieve one and return it. + auto& thread_ps = packed_sfens[thread_id]; + + // Fill the read buffer if there is no remaining buffer, + // but if it doesn't even exist, finish. + // If the buffer is empty, fill it. + if ((thread_ps == nullptr || thread_ps->empty()) + && !read_to_thread_buffer_impl(thread_id)) + return false; + + // read_to_thread_buffer_impl() returned true, + // Since the filling of the thread buffer with the + // phase has been completed successfully + // thread_ps->rbegin() is alive. + + ps = thread_ps->back(); + thread_ps->pop_back(); + + // If you've run out of buffers, call delete yourself to free this buffer. + if (thread_ps->empty()) + { + thread_ps.reset(); + } + + return true; + } + + // [ASYNC] Read some aspects into thread buffer. + bool read_to_thread_buffer_impl(size_t thread_id) + { + while (true) + { + { + std::unique_lock lk(mutex); + // If you can fill from the file buffer, that's fine. + if (packed_sfens_pool.size() != 0) + { + // It seems that filling is possible, so fill and finish. + + packed_sfens[thread_id] = std::move(packed_sfens_pool.front()); + packed_sfens_pool.pop_front(); + num_buffers_in_pool.fetch_sub(1); + + total_read += thread_buffer_size; + + return true; + } + } + + // The file to read is already gone. No more use. + if (end_of_files) + return false; + + // Waiting for file worker to fill packed_sfens_pool. + // The mutex isn't locked, so it should fill up soon. + // Poor man's condition variable. + sleep(1); + } + + } + + void file_read_worker() + { + std::string currentFilename; + uint64_t numEntriesReadFromCurrentFile = 0; + + auto open_next_file = [&]() { + // no more + for(;;) + { + sfen_input_stream.reset(); + + if (filenames.empty()) + return false; + + // Get the next file name. + currentFilename = filenames.front(); + filenames.pop_front(); + + numEntriesReadFromCurrentFile = 0; + + sfen_input_stream = open_sfen_input_file(currentFilename); + + auto out = sync_region_cout.new_region(); + if (sfen_input_stream == nullptr) + { + out << "INFO (sfen_reader): File does not exist: " << currentFilename << '\n'; + } + else + { + out << "INFO (sfen_reader): Opened file for reading: " << currentFilename << '\n'; + + // in case the file is empty or was deleted. + if (sfen_input_stream->eof()) + { + out << " - File empty, nothing to read.\n"; + } + else + { + return true; + } + } + } + }; + + if (sfen_input_stream == nullptr && !open_next_file()) + { + auto out = sync_region_cout.new_region(); + out << "INFO (sfen_reader): End of files." << std::endl; + end_of_files = true; + return; + } + + // We want to set the `end_of_files` only after we read everything AND copy to the buffer pool. + bool local_end_of_files = false; + while (!local_end_of_files) + { + // Wait for the buffer to run out. + // This size() is read only, so you don't need to lock it. + while (!stop_flag && num_buffers_in_pool.load() >= sfen_read_size / thread_buffer_size) + sleep(100); + + if (stop_flag) + return; + + PSVector sfens; + sfens.reserve(sfen_read_size); + + // Read from the file into the file buffer. + while (sfens.size() < sfen_read_size) + { + std::optional p = sfen_input_stream->next(); + if (p.has_value()) + { + sfens.push_back(*p); + ++numEntriesReadFromCurrentFile; + } + else + { + if (mode == SfenReaderMode::Cyclic + && numEntriesReadFromCurrentFile > 0) + { + // The file contained data so we add it again to the end of the queue. + filenames.emplace_back(currentFilename); + } + + if(!open_next_file()) + { + // There was no next file. Abort. + auto out = sync_region_cout.new_region(); + out << "INFO (sfen_reader): End of files." << std::endl; + local_end_of_files = true; + break; + } + } + } + + // Shuffle the read phase data. + if (shuffle) + { + Algo::shuffle(sfens, prng); + } + + std::vector> buffers; + for (size_t offset = 0; offset < sfens.size(); offset += thread_buffer_size) + { + const size_t count = + offset + thread_buffer_size > sfens.size() + ? sfens.size() - offset + : thread_buffer_size; + + // Delete this pointer on the receiving side. + auto buf = std::make_unique(); + buf->resize(count); + memcpy( + buf->data(), + &sfens[offset], + sizeof(PackedSfenValue) * count); + + buffers.emplace_back(std::move(buf)); + } + + { + std::unique_lock lk(mutex); + + // The mutex lock is required because the% + // contents of packed_sfens_pool are changed. + + for (auto& buf : buffers) + { + num_buffers_in_pool.fetch_add(1); + packed_sfens_pool.emplace_back(std::move(buf)); + } + } + } + + end_of_files = true; + } + + protected: + + // worker thread reading file in background + std::thread file_worker_thread; + + // sfen files + std::deque filenames; + + std::atomic stop_flag; + + // number of phases read (file to memory buffer) + std::atomic total_read; + + // Do not shuffle when reading the phase. + bool shuffle; + + SfenReaderMode mode; + + size_t sfen_read_size; + size_t thread_buffer_size; + + // Random number to shuffle when reading the phase + PRNG prng; + + // Did you read the files and reached the end? + std::atomic end_of_files; + + // handle of sfen file + std::unique_ptr sfen_input_stream; + + // sfen for each thread + // (When the thread is used up, the thread should call delete to release it.) + std::vector> packed_sfens; + + // Mutex when accessing packed_sfens_pool + std::mutex mutex; + + // pool of sfen. The worker thread read from the file is added here. + // Each worker thread fills its own packed_sfens[thread_id] from here. + // * Lock and access the mutex. + std::list> packed_sfens_pool; + std::atomic num_buffers_in_pool; + }; +} diff --git a/src/tools/sfen_stream.h b/src/tools/sfen_stream.h new file mode 100644 index 00000000..6122e091 --- /dev/null +++ b/src/tools/sfen_stream.h @@ -0,0 +1,222 @@ +#ifndef _SFEN_STREAM_H_ +#define _SFEN_STREAM_H_ + +#include "packed_sfen.h" + +#include "extra/nnue_data_binpack_format.h" + +#include +#include +#include +#include + +namespace Stockfish::Tools { + + enum struct SfenOutputType + { + Bin, + Binpack + }; + + static bool ends_with(const std::string& lhs, const std::string& end) + { + if (end.size() > lhs.size()) return false; + + return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); + } + + static bool has_extension(const std::string& filename, const std::string& extension) + { + return ends_with(filename, "." + extension); + } + + static std::string filename_with_extension(const std::string& filename, const std::string& ext) + { + if (ends_with(filename, ext)) + { + return filename; + } + else + { + return filename + "." + ext; + } + } + + struct BasicSfenInputStream + { + virtual std::optional next() = 0; + virtual bool eof() const = 0; + virtual ~BasicSfenInputStream() {} + }; + + struct BinSfenInputStream : BasicSfenInputStream + { + static constexpr auto openmode = std::ios::in | std::ios::binary; + static inline const std::string extension = "bin"; + + BinSfenInputStream(std::string filename) : + m_stream(filename, openmode), + m_eof(!m_stream) + { + } + + std::optional next() override + { + PackedSfenValue e; + if(m_stream.read(reinterpret_cast(&e), sizeof(PackedSfenValue))) + { + return e; + } + else + { + m_eof = true; + return std::nullopt; + } + } + + bool eof() const override + { + return m_eof; + } + + ~BinSfenInputStream() override {} + + private: + std::fstream m_stream; + bool m_eof; + }; + + struct BinpackSfenInputStream : BasicSfenInputStream + { + static constexpr auto openmode = std::ios::in | std::ios::binary; + static inline const std::string extension = "binpack"; + + BinpackSfenInputStream(std::string filename) : + m_stream(filename, openmode), + m_eof(!m_stream.hasNext()) + { + } + + std::optional next() override + { + static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue)); + + if (!m_stream.hasNext()) + { + m_eof = true; + return std::nullopt; + } + + auto training_data_entry = m_stream.next(); + auto v = binpack::trainingDataEntryToPackedSfenValue(training_data_entry); + PackedSfenValue psv; + // same layout, different types. One is from generic library. + std::memcpy(&psv, &v, sizeof(PackedSfenValue)); + + return psv; + } + + bool eof() const override + { + return m_eof; + } + + ~BinpackSfenInputStream() override {} + + private: + binpack::CompressedTrainingDataEntryReader m_stream; + bool m_eof; + }; + + struct BasicSfenOutputStream + { + virtual void write(const PSVector& sfens) = 0; + virtual ~BasicSfenOutputStream() {} + }; + + struct BinSfenOutputStream : BasicSfenOutputStream + { + static constexpr auto openmode = std::ios::out | std::ios::binary | std::ios::app; + static inline const std::string extension = "bin"; + + BinSfenOutputStream(std::string filename) : + m_stream(filename_with_extension(filename, extension), openmode) + { + } + + void write(const PSVector& sfens) override + { + m_stream.write(reinterpret_cast(sfens.data()), sizeof(PackedSfenValue) * sfens.size()); + } + + ~BinSfenOutputStream() override {} + + private: + std::fstream m_stream; + }; + + struct BinpackSfenOutputStream : BasicSfenOutputStream + { + static constexpr auto openmode = std::ios::out | std::ios::binary | std::ios::app; + static inline const std::string extension = "binpack"; + + BinpackSfenOutputStream(std::string filename) : + m_stream(filename_with_extension(filename, extension), openmode) + { + } + + void write(const PSVector& sfens) override + { + static_assert(sizeof(binpack::nodchip::PackedSfenValue) == sizeof(PackedSfenValue)); + + for(auto& sfen : sfens) + { + // The library uses a type that's different but layout-compatibile. + binpack::nodchip::PackedSfenValue e; + std::memcpy(&e, &sfen, sizeof(binpack::nodchip::PackedSfenValue)); + m_stream.addTrainingDataEntry(binpack::packedSfenValueToTrainingDataEntry(e)); + } + } + + ~BinpackSfenOutputStream() override {} + + private: + binpack::CompressedTrainingDataEntryWriter m_stream; + }; + + inline std::unique_ptr open_sfen_input_file(const std::string& filename) + { + if (has_extension(filename, BinSfenInputStream::extension)) + return std::make_unique(filename); + else if (has_extension(filename, BinpackSfenInputStream::extension)) + return std::make_unique(filename); + + return nullptr; + } + + inline std::unique_ptr create_new_sfen_output(const std::string& filename, SfenOutputType sfen_output_type) + { + switch(sfen_output_type) + { + case SfenOutputType::Bin: + return std::make_unique(filename); + case SfenOutputType::Binpack: + return std::make_unique(filename); + } + + assert(false); + return nullptr; + } + + inline std::unique_ptr create_new_sfen_output(const std::string& filename) + { + if (has_extension(filename, BinSfenOutputStream::extension)) + return std::make_unique(filename); + else if (has_extension(filename, BinpackSfenOutputStream::extension)) + return std::make_unique(filename); + + return nullptr; + } +} + +#endif \ No newline at end of file diff --git a/src/tools/sfen_writer.h b/src/tools/sfen_writer.h new file mode 100644 index 00000000..10dc98d6 --- /dev/null +++ b/src/tools/sfen_writer.h @@ -0,0 +1,203 @@ +#include "packed_sfen.h" +#include "sfen_stream.h" + +#include "misc.h" + +#include "extra/nnue_data_binpack_format.h" + +#include "syzygy/tbprobe.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Stockfish::Tools { + + // Helper class for exporting Sfen + struct SfenWriter + { + // Amount of sfens required to flush the buffer. + static constexpr size_t SFEN_WRITE_SIZE = 5000; + + // File name to write and number of threads to create + SfenWriter(std::string filename_, int thread_num, uint64_t save_count, SfenOutputType sfen_output_type) + { + sfen_buffers_pool.reserve((size_t)thread_num * 10); + sfen_buffers.resize(thread_num); + + auto out = sync_region_cout.new_region(); + out << "INFO (sfen_writer): Creating new data file at " << filename_ << std::endl; + + sfen_format = sfen_output_type; + output_file_stream = create_new_sfen_output(filename_, sfen_format); + filename = filename_; + save_every = save_count; + + finished = false; + + file_worker_thread = std::thread([&] { this->file_write_worker(); }); + } + + ~SfenWriter() + { + flush(); + + finished = true; + file_worker_thread.join(); + output_file_stream.reset(); + +#if !defined(NDEBUG) + { + // All buffers should be empty since file_worker_thread + // should have written everything before exiting. + for (const auto& p : sfen_buffers) { assert(p == nullptr); (void)p ; } + assert(sfen_buffers_pool.empty()); + } +#endif + } + + void write(size_t thread_id, const PackedSfenValue& psv) + { + // We have a buffer for each thread and add it there. + // If the buffer overflows, write it to a file. + + // This buffer is prepared for each thread. + auto& buf = sfen_buffers[thread_id]; + + // Secure since there is no buf at the first time + // and immediately after writing the thread buffer. + if (!buf) + { + buf = std::make_unique(); + buf->reserve(SFEN_WRITE_SIZE); + } + + // Buffer is exclusive to this thread. + // There is no need for a critical section. + buf->push_back(psv); + + if (buf->size() >= SFEN_WRITE_SIZE) + { + // If you load it in sfen_buffers_pool, the worker will do the rest. + + // Critical section since sfen_buffers_pool is shared among threads. + std::unique_lock lk(mutex); + sfen_buffers_pool.emplace_back(std::move(buf)); + } + } + + void flush() + { + for (size_t i = 0; i < sfen_buffers.size(); ++i) + { + flush(i); + } + } + + // Move what remains in the buffer for your thread to a buffer for writing to a file. + void flush(size_t thread_id) + { + std::unique_lock lk(mutex); + + auto& buf = sfen_buffers[thread_id]; + + // There is a case that buf==nullptr, so that check is necessary. + if (buf && buf->size() != 0) + { + sfen_buffers_pool.emplace_back(std::move(buf)); + } + } + + // Dedicated thread to write to file + void file_write_worker() + { + while (!finished || sfen_buffers_pool.size()) + { + std::vector> buffers; + { + std::unique_lock lk(mutex); + + // Atomically swap take the filled buffers and + // create a new buffer pool for threads to fill. + buffers = std::move(sfen_buffers_pool); + sfen_buffers_pool = std::vector>(); + } + + if (!buffers.size()) + { + // Poor man's condition variable. + sleep(100); + } + else + { + for (auto& buf : buffers) + { + output_file_stream->write(*buf); + + sfen_write_count += buf->size(); + + // Add the processed number here, and if it exceeds save_every, + // change the file name and reset this counter. + sfen_write_count_current_file += buf->size(); + if (sfen_write_count_current_file >= save_every) + { + sfen_write_count_current_file = 0; + + // Sequential number attached to the file + int n = (int)(sfen_write_count / save_every); + + // Rename the file and open it again. + // Add ios::app in consideration of overwriting. + // (Depending on the operation, it may not be necessary.) + std::string new_filename = filename + "_" + std::to_string(n); + output_file_stream = create_new_sfen_output(new_filename, sfen_format); + + auto out = sync_region_cout.new_region(); + out << "INFO (sfen_writer): Creating new data file at " << new_filename << std::endl; + } + } + } + } + } + + private: + + std::unique_ptr output_file_stream; + + // A new net is saved after every save_every sfens are processed. + uint64_t save_every = std::numeric_limits::max(); + + // File name passed in the constructor + std::string filename; + + // Thread to write to the file + std::thread file_worker_thread; + + // Flag that all threads have finished + std::atomic finished; + + SfenOutputType sfen_format; + + // buffer before writing to file + // sfen_buffers is the buffer for each thread + // sfen_buffers_pool is a buffer for writing. + // After loading the phase in the former buffer by SFEN_WRITE_SIZE, + // transfer it to the latter. + std::vector> sfen_buffers; + std::vector> sfen_buffers_pool; + + // Mutex required to access sfen_buffers_pool + std::mutex mutex; + + // Number of sfens written in total, and the + // number of sfens written in the current file. + uint64_t sfen_write_count = 0; + uint64_t sfen_write_count_current_file = 0; + }; +} diff --git a/src/tools/stats.cpp b/src/tools/stats.cpp new file mode 100644 index 00000000..6786abd8 --- /dev/null +++ b/src/tools/stats.cpp @@ -0,0 +1,1277 @@ +#include "stats.h" + +#include "sfen_stream.h" +#include "packed_sfen.h" +#include "sfen_writer.h" + +#include "thread.h" +#include "position.h" +#include "evaluate.h" +#include "search.h" + +#include "nnue/evaluate_nnue.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Stockfish::Tools::Stats +{ + struct Indentation + { + char character = ' '; + int width_per_indent = 4; + int num_indents = 0; + + [[nodiscard]] Indentation next() const + { + return Indentation{ character, width_per_indent, num_indents + 1 }; + } + + [[nodiscard]] std::string to_string() const + { + return std::string(num_indents * width_per_indent, character); + } + }; + + template + [[nodiscard]] int get_num_base_10_digits(IntT v) + { + int digits = 1; + while (v != 0) + { + digits += 1; + v /= 10; + } + return digits; + } + + [[nodiscard]] std::string left_pad_to_length(const std::string& str, char ch, int length) + { + const int str_size = static_cast(str.size()); + if (str_size < length) + { + return std::string(length - str_size, ch) + str; + } + else + { + return str; + } + } + + [[nodiscard]] std::string right_pad_to_length(const std::string& str, char ch, int length) + { + const int str_size = static_cast(str.size()); + if (str_size < length) + { + return str + std::string(length - str_size, ch); + } + else + { + return str; + } + } + + [[nodiscard]] std::string indent_text(const std::string& text, Indentation indent) + { + std::string delimiter = "\n"; + std::string indent_str = indent.to_string(); + + std::string indented; + + std::string::size_type pos = 0; + std::string::size_type prev = 0; + while ((pos = text.find(delimiter, prev)) != std::string::npos) + { + std::string line = text.substr(prev, pos - prev); + indented += indent_str + line + delimiter; + prev = pos + delimiter.size(); + } + + { + std::string line = text.substr(prev); + indented += indent_str + line; + } + + return indented; + } + + struct IndentedTextBlock + { + Indentation indentation; + std::string text; + + IndentedTextBlock(Indentation indent, std::string str) : + indentation(indent), + text(std::move(str)) + { + } + + [[nodiscard]] static std::string join(const std::vector& blocks, const std::string& delimiter) + { + std::string result; + + bool is_first = true; + for (auto&& [indentation, text] : blocks) + { + if (!is_first) + { + result += delimiter; + } + + result += indent_text(text, indentation); + + is_first = false; + } + + return result; + } + }; + + struct StatisticOutputEntryNode + { + [[nodiscard]] const std::vector>& get_children() const + { + return m_children; + } + + template + StatisticOutputEntryNode& emplace_child(Ts&&... args) + { + return *(m_children.emplace_back(std::make_unique(std::forward(args)...))); + } + + template + StatisticOutputEntryNode& add_child(std::unique_ptr&& node) + { + return *(m_children.emplace_back(std::move(node))); + } + + [[nodiscard]] virtual std::vector to_indented_text_blocks(Indentation indent) const = 0; + + protected: + std::vector> m_children; + + void add_indented_children_blocks(std::vector& blocks, Indentation indent) const + { + for (auto&& child : m_children) + { + auto part = child->to_indented_text_blocks(indent.next()); + blocks.insert(blocks.end(), part.begin(), part.end()); + } + } + }; + + struct StatisticOutputEntryHeader : StatisticOutputEntryNode + { + StatisticOutputEntryHeader(const std::string& text) : + m_text(text) + { + } + + [[nodiscard]] virtual std::vector to_indented_text_blocks(Indentation indent) const override + { + std::vector blocks; + + blocks.emplace_back(indent, m_text); + + this->add_indented_children_blocks(blocks, indent); + + return blocks; + } + + private: + std::string m_text; + }; + + template + struct StatisticOutputEntryValue : StatisticOutputEntryNode + { + StatisticOutputEntryValue(const std::string& name, const T& value, bool value_in_new_line = false) : + m_value(name, value), + m_value_in_new_line(value_in_new_line) + { + } + + [[nodiscard]] virtual std::vector to_indented_text_blocks(Indentation indent) const override + { + std::vector blocks; + + std::string value_str; + if constexpr (std::is_same_v) + { + value_str = m_value.second; + } + else + { + value_str = std::to_string(m_value.second); + } + + if (m_value_in_new_line) + { + blocks.emplace_back(indent, m_value.first + ": "); + blocks.emplace_back(indent.next(), value_str); + } + else + { + blocks.emplace_back(indent, m_value.first + ": " + value_str); + } + + this->add_indented_children_blocks(blocks, indent); + + return blocks; + } + + private: + std::pair m_value; + bool m_value_in_new_line; + }; + + struct StatisticOutput + { + template + StatisticOutputEntryNode& emplace_node(Ts&&... args) + { + return *(m_nodes.emplace_back(std::make_unique(std::forward(args)...))); + } + + template + StatisticOutputEntryNode& add_child(std::unique_ptr&& node) + { + return *(m_nodes.emplace_back(std::move(node))); + } + + [[nodiscard]] const std::vector>& get_nodes() const + { + return m_nodes; + } + + void add(StatisticOutput&& other) + { + for (auto&& node : other.m_nodes) + { + m_nodes.emplace_back(std::move(node)); + } + } + + [[nodiscard]] std::string to_string() const + { + std::vector blocks; + + for (auto&& node : m_nodes) + { + auto part = node->to_indented_text_blocks(Indentation{}); + blocks.insert(blocks.end(), part.begin(), part.end()); + } + + return IndentedTextBlock::join(blocks, "\n"); + } + + private: + std::vector> m_nodes; + }; + + struct StatisticGathererBase + { + virtual void on_entry(const Position&, const Move&, const PackedSfenValue&) {} + virtual void reset() = 0; + [[nodiscard]] virtual const std::string& get_name() const = 0; + [[nodiscard]] virtual StatisticOutput get_output() const = 0; + }; + + struct StatisticGathererFactoryBase + { + [[nodiscard]] virtual std::unique_ptr create() const = 0; + [[nodiscard]] virtual const std::string& get_name() const = 0; + }; + + template + struct StatisticGathererFactory : StatisticGathererFactoryBase + { + static inline std::string name = T::name; + + [[nodiscard]] std::unique_ptr create() const override + { + return std::make_unique(); + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + }; + + struct StatisticGathererSet : StatisticGathererBase + { + void add(const StatisticGathererFactoryBase& factory) + { + const std::string name = factory.get_name(); + if (m_gatherers_names.count(name) == 0) + { + m_gatherers_names.insert(name); + m_gatherers.emplace_back(factory.create()); + } + } + + void add(std::unique_ptr&& gatherer) + { + const std::string name = gatherer->get_name(); + if (m_gatherers_names.count(name) == 0) + { + m_gatherers_names.insert(name); + m_gatherers.emplace_back(std::move(gatherer)); + } + } + + void on_entry(const Position& pos, const Move& move, const PackedSfenValue& psv) override + { + for (auto& g : m_gatherers) + { + g->on_entry(pos, move, psv); + } + } + + void reset() override + { + for (auto& g : m_gatherers) + { + g->reset(); + } + } + + [[nodiscard]] const std::string& get_name() const override + { + static std::string name = "SET"; + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + for (auto&& s : m_gatherers) + { + out.add(s->get_output()); + } + return out; + } + + private: + std::vector> m_gatherers; + std::set m_gatherers_names; + }; + + struct StatisticGathererRegistry + { + void add_statistic_gatherers_by_group( + StatisticGathererSet& gatherers, + const std::string& group) const + { + auto it = m_gatherers_by_group.find(group); + if (it != m_gatherers_by_group.end()) + { + for (auto& factory : it->second) + { + gatherers.add(*factory); + } + } + } + + template + void add(const ArgsTs&... group) + { + auto dummy = {(add_single(group), 0)...}; + (void)dummy; + add_single("all"); + } + + private: + std::map>> m_gatherers_by_group; + std::map> m_gatherers_names_by_group; + + template + void add_single(const ArgT& group) + { + using FactoryT = StatisticGathererFactory; + + if (m_gatherers_names_by_group[group].count(FactoryT::name) == 0) + { + m_gatherers_by_group[group].emplace_back(std::make_unique()); + m_gatherers_names_by_group[group].insert(FactoryT::name); + } + } + }; + + /* + Statistic gatherer helpers + */ + + template + struct StatPerSquare + { + StatPerSquare() + { + for (int i = 0; i < SQUARE_NB; ++i) + m_squares[i] = 0; + } + + [[nodiscard]] T& operator[](Square sq) + { + return m_squares[sq]; + } + + [[nodiscard]] const T& operator[](Square sq) const + { + return m_squares[sq]; + } + + [[nodiscard]] std::unique_ptr get_output_node(const std::string& name) const + { + int max_digits = 1; + for (int i = 0; i < SQUARE_NB; ++i) + { + const int d = get_num_base_10_digits(m_squares[i]); + if (d > max_digits) + { + max_digits = d; + } + } + + std::stringstream ss; + for (int i = 0; i < SQUARE_NB; ++i) + { + ss << std::setw(max_digits) << m_squares[i ^ (int)SQ_A8] << ' '; + if ((i + 1) % 8 == 0) + ss << '\n'; + } + + return std::make_unique>(name, ss.str(), true); + } + + private: + std::array m_squares; + }; + + /* + Definitions for specific statistic gatherers follow: + */ + + struct PositionCounter : StatisticGathererBase + { + static inline std::string name = "PositionCounter"; + + PositionCounter() : + m_num_positions(0) + { + } + + void on_entry(const Position&, const Move&, const PackedSfenValue&) override + { + m_num_positions += 1; + } + + void reset() override + { + m_num_positions = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + out.emplace_node>("Number of positions", m_num_positions); + return out; + } + + private: + std::uint64_t m_num_positions; + }; + + struct KingSquareCounter : StatisticGathererBase + { + static inline std::string name = "KingSquareCounter"; + + KingSquareCounter() : + m_white{}, + m_black{} + { + + } + + void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override + { + m_white[pos.square(WHITE)] += 1; + m_black[pos.square(BLACK)] += 1; + } + + void reset() override + { + m_white = StatPerSquare{}; + m_black = StatPerSquare{}; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("King square distribution:"); + header.add_child(m_white.get_output_node("White king squares")); + header.add_child(m_black.get_output_node("Black king squares")); + return out; + } + + private: + StatPerSquare m_white; + StatPerSquare m_black; + }; + + struct MoveFromCounter : StatisticGathererBase + { + static inline std::string name = "MoveFromCounter"; + + MoveFromCounter() : + m_white{}, + m_black{} + { + + } + + void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override + { + if (pos.side_to_move() == WHITE) + m_white[from_sq(move)] += 1; + else + m_black[from_sq(move)] += 1; + } + + void reset() override + { + m_white = StatPerSquare{}; + m_black = StatPerSquare{}; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("Move from square distribution:"); + header.add_child(m_white.get_output_node("White move from squares")); + header.add_child(m_black.get_output_node("Black move from squares")); + return out; + } + + private: + StatPerSquare m_white; + StatPerSquare m_black; + }; + + struct MoveToCounter : StatisticGathererBase + { + static inline std::string name = "MoveToCounter"; + + MoveToCounter() : + m_white{}, + m_black{} + { + + } + + void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override + { + if (pos.side_to_move() == WHITE) + m_white[to_sq(move)] += 1; + else + m_black[to_sq(move)] += 1; + } + + void reset() override + { + m_white = StatPerSquare{}; + m_black = StatPerSquare{}; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("Move to square distribution:"); + header.add_child(m_white.get_output_node("White move to squares")); + header.add_child(m_black.get_output_node("Black move to squares")); + return out; + } + + private: + StatPerSquare m_white; + StatPerSquare m_black; + }; + + struct MoveTypeCounter : StatisticGathererBase + { + static inline std::string name = "MoveTypeCounter"; + + MoveTypeCounter() : + m_total(0), + m_normal(0), + m_capture(0), + m_promotion(0), + m_castling(0), + m_enpassant(0) + { + + } + + void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override + { + m_total += 1; + + if (!pos.empty(to_sq(move))) + m_capture += 1; + + if (type_of(move) == CASTLING) + m_castling += 1; + else if (type_of(move) == PROMOTION) + m_promotion += 1; + else if (type_of(move) == EN_PASSANT) + m_enpassant += 1; + else if (type_of(move) == NORMAL) + m_normal += 1; + } + + void reset() override + { + m_total = 0; + m_normal = 0; + m_capture = 0; + m_promotion = 0; + m_castling = 0; + m_enpassant = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("Number of moves by type:"); + header.emplace_child>("Total", m_total); + header.emplace_child>("Normal", m_normal); + header.emplace_child>("Capture", m_capture); + header.emplace_child>("Promotion", m_promotion); + header.emplace_child>("Castling", m_castling); + header.emplace_child>("En-passant", m_enpassant); + return out; + } + + private: + std::uint64_t m_total; + std::uint64_t m_normal; + std::uint64_t m_capture; + std::uint64_t m_promotion; + std::uint64_t m_castling; + std::uint64_t m_enpassant; + }; + + struct PieceCountCounter : StatisticGathererBase + { + static inline std::string name = "PieceCountCounter"; + + PieceCountCounter() + { + reset(); + } + + void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override + { + m_piece_count_hist[popcount(pos.pieces())] += 1; + } + + void reset() override + { + for (int i = 0; i < SQUARE_NB; ++i) + m_piece_count_hist[i] = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("Number of positions by piece count:"); + bool do_write = false; + for (int i = SQUARE_NB - 1; i >= 0; --i) + { + if (m_piece_count_hist[i] != 0) + do_write = true; + + // Start writing when the first non-zero number pops up. + if (do_write) + { + header.emplace_child>(std::to_string(i), m_piece_count_hist[i]); + } + } + return out; + } + + private: + std::uint64_t m_piece_count_hist[SQUARE_NB]; + }; + + struct MovedPieceTypeCounter : StatisticGathererBase + { + static inline std::string name = "MovedPieceTypeCounter"; + + MovedPieceTypeCounter() + { + reset(); + } + + void on_entry(const Position& pos, const Move& move, const PackedSfenValue&) override + { + m_moved_piece_type_hist[type_of(pos.piece_on(from_sq(move)))] += 1; + } + + void reset() override + { + for (int i = 0; i < PIECE_TYPE_NB; ++i) + m_moved_piece_type_hist[i] = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("Number of moves by piece type:"); + header.emplace_child>("Pawn", m_moved_piece_type_hist[PAWN]); + header.emplace_child>("Knight", m_moved_piece_type_hist[KNIGHT]); + header.emplace_child>("Bishop", m_moved_piece_type_hist[BISHOP]); + header.emplace_child>("Rook", m_moved_piece_type_hist[ROOK]); + header.emplace_child>("Queen", m_moved_piece_type_hist[QUEEN]); + header.emplace_child>("King", m_moved_piece_type_hist[KING]); + return out; + } + + private: + std::uint64_t m_moved_piece_type_hist[PIECE_TYPE_NB]; + }; + + struct PlyDiscontinuitiesCounter : StatisticGathererBase + { + static inline std::string name = "PlyDiscontinuitiesCounter"; + + PlyDiscontinuitiesCounter() + { + reset(); + } + + void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override + { + const int current_ply = pos.game_ply(); + if (m_prev_ply != -1) + { + const bool is_discontinuity = (current_ply != (m_prev_ply + 1)); + if (is_discontinuity) + { + m_num_discontinuities += 1; + } + } + m_prev_ply = current_ply; + } + + void reset() override + { + m_num_discontinuities = 0; + m_prev_ply = -1; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + out.emplace_node>("Number of ply discontinuities (usually games)", m_num_discontinuities); + return out; + } + + private: + std::uint64_t m_num_discontinuities; + int m_prev_ply; + }; + + struct MaterialImbalanceDistribution : StatisticGathererBase + { + static inline std::string name = "MaterialImbalanceDistribution"; + static constexpr int max_imbalance = 64; + + MaterialImbalanceDistribution() + { + reset(); + } + + void on_entry(const Position& pos, const Move&, const PackedSfenValue&) override + { + const int imbalance = get_simple_material(pos, WHITE) - get_simple_material(pos, BLACK); + const int imbalance_idx = std::clamp(imbalance, -max_imbalance, max_imbalance) + max_imbalance; + m_num_imbalances[imbalance_idx] += 1; + } + + void reset() override + { + for (auto& imb : m_num_imbalances) + imb = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("Number of \"simple eval\" imbalances for white's perspective:"); + const int key_length = get_num_base_10_digits(max_imbalance) + 1; + int min_non_zero = max_imbalance; + int max_non_zero = -max_imbalance; + for (int i = -max_imbalance; i <= max_imbalance; ++i) + { + if (m_num_imbalances[i + max_imbalance] != 0) + { + min_non_zero = std::min(min_non_zero, i); + max_non_zero = std::max(max_non_zero, i); + } + } + + for (int i = min_non_zero; i <= max_non_zero; ++i) + { + header.emplace_child>( + left_pad_to_length(std::to_string(i), ' ', key_length), + m_num_imbalances[i + max_imbalance] + ); + } + return out; + } + + private: + std::uint64_t m_num_imbalances[max_imbalance + 1 + max_imbalance]; + + [[nodiscard]] int get_simple_material(const Position& pos, Color c) + { + return + 9 * pos.count(c) + + 5 * pos.count(c) + + 3 * pos.count(c) + + 3 * pos.count(c) + + pos.count(c); + } + }; + + struct ResultDistribution : StatisticGathererBase + { + static inline std::string name = "ResultDistribution"; + + ResultDistribution() + { + reset(); + } + + void on_entry(const Position& pos, const Move&, const PackedSfenValue& psv) override + { + const Color stm = pos.side_to_move(); + if (psv.game_result == 0) + { + m_draws += 1; + } + else if (psv.game_result == 1) + { + m_stm_wins += 1; + m_wins[stm] += 1; + } + else + { + m_stm_loses += 1; + m_wins[~stm] += 1; + } + } + + void reset() override + { + m_wins[WHITE] = 0; + m_wins[BLACK] = 0; + m_draws = 0; + m_stm_wins = 0; + m_stm_loses = 0; + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("Distribution of results:"); + header.emplace_child>("White wins", m_wins[WHITE]); + header.emplace_child>("Black wins", m_wins[BLACK]); + header.emplace_child>("Draws", m_draws); + header.emplace_child>("Side to move wins", m_stm_wins); + header.emplace_child>("Side to move loses", m_stm_loses); + return out; + } + + private: + std::uint64_t m_wins[COLOR_NB]; + std::uint64_t m_draws; + std::uint64_t m_stm_wins; + std::uint64_t m_stm_loses; + }; + + template + struct EndgameConfigurations : StatisticGathererBase + { + static_assert(MaxManCount < 10); + static_assert(MaxManCount > 2); + + static inline std::string name = std::string("EndgameConfigurations") + std::to_string(MaxManCount); + + using MaterialKey = std::uint64_t; + + EndgameConfigurations() + { + reset(); + } + + void on_entry(const Position& pos, const Move&, const PackedSfenValue& psv) override + { + const int piece_count = pos.count(); + if (piece_count > MaxManCount) + { + return; + } + + const auto index = get_material_key_for_position(pos); + auto& entry = m_entries[index]; + entry.count += 1; + if (psv.game_result == 0) + { + entry.draws += 1; + } + else + { + const Color winner_side = psv.game_result == 1 ? pos.side_to_move() : ~pos.side_to_move(); + if (winner_side == WHITE) + { + entry.white_wins += 1; + } + else + { + entry.black_wins += 1; + } + } + } + + void reset() override + { + m_entries.clear(); + } + + [[nodiscard]] const std::string& get_name() const override + { + return name; + } + + [[nodiscard]] StatisticOutput get_output() const override + { + StatisticOutput out; + auto& header = out.emplace_node("Distribution of endgame configurations (count W D L Perf%):"); + std::vector> flattened(m_entries.begin(), m_entries.end()); + std::sort(flattened.begin(), flattened.end(), [](const auto& lhs, const auto& rhs) { return lhs.second.count > rhs.second.count; }); + for (auto&& [index, entry] : flattened) + { + header.emplace_child>( + get_padded_name_by_material_key(index), + entry.to_string() + ); + } + return out; + } + + private: + struct Entry + { + std::uint64_t count = 0; + std::uint64_t white_wins = 0; + std::uint64_t black_wins = 0; + std::uint64_t draws = 0; + + [[nodiscard]] std::string to_string() const + { + constexpr int wide_column_width = 9; + constexpr int narrow_column_width = 4; + + const float perf = + (white_wins + draws / 2.0f) + / (white_wins + black_wins + draws); + + return + left_pad_to_length(std::to_string(count), ' ', wide_column_width) + ' ' + + left_pad_to_length(std::to_string(white_wins), ' ', wide_column_width) + ' ' + + left_pad_to_length(std::to_string(draws), ' ', wide_column_width) + ' ' + + left_pad_to_length(std::to_string(black_wins), ' ', wide_column_width) + ' ' + + left_pad_to_length(std::to_string(static_cast(perf * 100.0f + 0.5f)), ' ', narrow_column_width) + '%'; + } + }; + // can support up to 17 pieces. + // it's basically the material string encoded as a number in base 8 + // encoding is from the least significant digit to most significant + // v=1, P=2, N=3, B=4, R=5, Q=6, K=7. 0 indicates end + std::map m_entries; + + [[nodiscard]] MaterialKey get_material_key_for_position(const Position& pos) const + { + MaterialKey index = 0; + std::uint64_t shift = 0; + + index += 7 << shift; shift += 3; + + for (int i = 0; i < pos.count(WHITE); ++i) { index += 2 << shift; shift += 3; } + for (int i = 0; i < pos.count(WHITE); ++i) { index += 3 << shift; shift += 3; } + for (int i = 0; i < pos.count(WHITE); ++i) { index += 4 << shift; shift += 3; } + for (int i = 0; i < pos.count(WHITE); ++i) { index += 5 << shift; shift += 3; } + for (int i = 0; i < pos.count(WHITE); ++i) { index += 6 << shift; shift += 3; } + + index += 1 << shift; shift += 3; + index += 7 << shift; shift += 3; + + for (int i = 0; i < pos.count(BLACK); ++i) { index += 2 << shift; shift += 3; } + for (int i = 0; i < pos.count(BLACK); ++i) { index += 3 << shift; shift += 3; } + for (int i = 0; i < pos.count(BLACK); ++i) { index += 4 << shift; shift += 3; } + for (int i = 0; i < pos.count(BLACK); ++i) { index += 5 << shift; shift += 3; } + for (int i = 0; i < pos.count(BLACK); ++i) { index += 6 << shift; shift += 3; } + + return index; + } + + [[nodiscard]] std::string get_padded_name_by_material_key(MaterialKey index) const + { + std::string sides[COLOR_NB]; + int material[COLOR_NB] = { 0, 0 }; + Color side = WHITE; + + while (index != 0) + { + switch (index % 8) + { + case 1: + side = BLACK; + break; + case 2: + sides[side] += 'P'; + material[side] += 1; + break; + case 3: + sides[side] += 'N'; + material[side] += 3; + break; + case 4: + sides[side] += 'B'; + material[side] += 3; + break; + case 5: + sides[side] += 'R'; + material[side] += 5; + break; + case 6: + sides[side] += 'Q'; + material[side] += 9; + break; + case 7: + sides[side] += 'K'; + break; + default: + break; + } + index >>= 3; + } + + const int imbalance = material[WHITE] - material[BLACK]; + const std::string imbalance_str = + std::string(imbalance > 0 ? "+" : "") // force + sign for positive values + + std::string(imbalance == 0 ? " " : "") // pad 0 + + std::to_string(imbalance); + + return + right_pad_to_length(sides[WHITE], ' ', MaxManCount-1) + + 'v' + + right_pad_to_length(sides[BLACK], ' ', MaxManCount-1) + + " (" + + right_pad_to_length(imbalance_str, ' ', 3) + + ')'; + } + }; + + /* + This function provides factories for all possible statistic gatherers. + Each new statistic gatherer needs to be added there. + */ + const auto& get_statistics_gatherers_registry() + { + static StatisticGathererRegistry s_reg = [](){ + StatisticGathererRegistry reg; + + reg.add("position_count"); + + reg.add("king", "king_square_count"); + + reg.add("move", "move_from_count"); + reg.add("move", "move_to_count"); + reg.add("move", "move_type"); + reg.add("move", "moved_piece_type"); + + reg.add("ply_discontinuities"); + + reg.add("material_imbalance"); + + reg.add("results"); + + reg.add("piece_count"); + + reg.add>("endgames_6man"); + + return reg; + }(); + + return s_reg; + } + + void do_gather_statistics( + const std::string& filename, + StatisticGathererSet& statistic_gatherers, + std::uint64_t max_count, + const std::optional& output_filename) + { + Thread* th = Threads.main(); + Position& pos = th->rootPos; + StateInfo si; + + auto in = Tools::open_sfen_input_file(filename); + + auto on_entry = [&](const Position& position, const Move& move, const PackedSfenValue& psv) { + statistic_gatherers.on_entry(position, move, psv); + }; + + if (in == nullptr) + { + std::cerr << "Invalid input file type.\n"; + return; + } + + uint64_t num_processed = 0; + while (num_processed < max_count) + { + auto v = in->next(); + if (!v.has_value()) + break; + + auto& psv = v.value(); + + pos.set_from_packed_sfen(psv.sfen, &si, th); + + on_entry(pos, (Move)psv.move, psv); + + num_processed += 1; + if (num_processed % 1'000'000 == 0) + { + std::cout << "Processed " << num_processed << " positions.\n"; + } + } + + std::cout << "Finished gathering statistics.\n\n"; + std::cout << "Results:\n\n"; + + const auto output_str = statistic_gatherers.get_output().to_string(); + std::cout << output_str; + if (output_filename.has_value()) + { + std::ofstream out_file(*output_filename); + out_file << output_str; + } + } + + void gather_statistics(std::istringstream& is) + { + Eval::NNUE::init(); + + auto& registry = get_statistics_gatherers_registry(); + + StatisticGathererSet statistic_gatherers; + + std::string input_file; + std::optional output_file; + std::uint64_t max_count = std::numeric_limits::max(); + + while(true) + { + std::string token; + is >> token; + + if (token == "") + break; + + if (token == "input_file") + is >> input_file; + else if (token == "output_file") + { + std::string s; + is >> s; + output_file = s; + } + else if (token == "max_count") + is >> max_count; + else + registry.add_statistic_gatherers_by_group(statistic_gatherers, token); + } + + do_gather_statistics(input_file, statistic_gatherers, max_count, output_file); + } + +} diff --git a/src/tools/stats.h b/src/tools/stats.h new file mode 100644 index 00000000..3276bab6 --- /dev/null +++ b/src/tools/stats.h @@ -0,0 +1,12 @@ +#ifndef _STATS_H_ +#define _STATS_H_ + +#include + +namespace Stockfish::Tools::Stats { + + void gather_statistics(std::istringstream& is); + +} + +#endif diff --git a/src/tools/training_data_generator.cpp b/src/tools/training_data_generator.cpp new file mode 100644 index 00000000..6495d566 --- /dev/null +++ b/src/tools/training_data_generator.cpp @@ -0,0 +1,899 @@ +#include "training_data_generator.h" + +#include "sfen_writer.h" +#include "packed_sfen.h" +#include "opening_book.h" + +#include "misc.h" +#include "position.h" +#include "thread.h" +#include "tt.h" +#include "uci.h" + +#include "extra/nnue_data_binpack_format.h" + +#include "nnue/evaluate_nnue.h" + +#include "syzygy/tbprobe.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +namespace Stockfish::Tools +{ + // Class to generate sfen with multiple threads + struct TrainingDataGenerator + { + struct Params + { + // Min and max depths for search during gensfen + int search_depth_min = 3; + int search_depth_max = -1; + + // Number of the nodes to be searched. + // 0 represents no limits. + uint64_t nodes = 0; + + // Upper limit of evaluation value of generated situation + int eval_limit = 3000; + + // minimum ply with random move + // maximum ply with random move + // Number of random moves in one station + int random_move_minply = 1; + int random_move_maxply = 24; + int random_move_count = 5; + + // Move kings with a probability of 1/N when randomly moving like Apery software. + // When you move the king again, there is a 1/N chance that it will randomly moved + // once in the opponent's turn. + // Apery has N=2. Specifying 0 here disables this function. + int random_move_like_apery = 0; + + // For when using multi pv instead of random move. + // random_multi_pv is the number of candidates for MultiPV. + // When adopting the move of the candidate move, the difference + // between the evaluation value of the move of the 1st place + // and the evaluation value of the move of the Nth place is. + // Must be in the range random_multi_pv_diff. + // random_multi_pv_depth is the search depth for MultiPV. + int random_multi_pv = 0; + int random_multi_pv_diff = 32000; + int random_multi_pv_depth = -1; + + // The minimum and maximum ply (number of steps from + // the initial phase) of the sfens to write out. + int write_minply = 16; + int write_maxply = 400; + + uint64_t save_every = std::numeric_limits::max(); + + std::string output_file_name = "training_data"; + + SfenOutputType sfen_format = SfenOutputType::Binpack; + + std::string seed; + + bool write_out_draw_game_in_training_data_generation = true; + bool detect_draw_by_consecutive_low_score = true; + bool detect_draw_by_insufficient_mating_material = true; + + uint64_t num_threads; + + std::string book; + + void enforce_constraints() + { + search_depth_max = std::max(search_depth_min, search_depth_max); + + // Limit the maximum to a one-stop score. (Otherwise you might not end the loop) + eval_limit = std::min(eval_limit, (int)mate_in(2)); + + save_every = std::max(save_every, REPORT_STATS_EVERY); + + num_threads = Options["Threads"]; + + random_multi_pv_depth = std::max(search_depth_max, random_multi_pv_depth); + } + }; + + // Hash to limit the export of identical sfens + static constexpr uint64_t GENSFEN_HASH_SIZE = 64 * 1024 * 1024; + // It must be 2**N because it will be used as the mask to calculate hash_index. + static_assert((GENSFEN_HASH_SIZE& (GENSFEN_HASH_SIZE - 1)) == 0); + + static constexpr uint64_t REPORT_DOT_EVERY = 5000; + static constexpr uint64_t REPORT_STATS_EVERY = 200000; + static_assert(REPORT_STATS_EVERY % REPORT_DOT_EVERY == 0); + + TrainingDataGenerator( + const Params& prm + ) : + params(prm), + sfen_writer(prm.output_file_name, prm.num_threads, prm.save_every, prm.sfen_format) + { + hash.resize(GENSFEN_HASH_SIZE); + prngs.reserve(prm.num_threads); + auto seed = prm.seed; + for (uint64_t i = 0; i < prm.num_threads; ++i) + { + prngs.emplace_back(seed); + seed = prngs.back().next_random_seed(); + } + + if (!prm.book.empty()) + { + opening_book = open_opening_book(prm.book, prngs[0]); + if (opening_book == nullptr) + { + std::cout << "WARNING: Failed to open opening book " << prm.book << ". Falling back to startpos.\n"; + } + } + + // Output seed to veryfy by the user if it's not identical by chance. + std::cout << prngs[0] << std::endl; + } + + void generate(uint64_t limit); + + private: + Params params; + + std::vector prngs; + + std::mutex stats_mutex; + TimePoint last_stats_report_time; + + // sfen exporter + SfenWriter sfen_writer; + + SynchronizedRegionLogger::Region out; + + vector hash; // 64MB*sizeof(HASH_KEY) = 512MB + + std::unique_ptr opening_book; + + static void set_gensfen_search_limits(); + + void generate_worker( + Thread& th, + std::atomic& counter, + uint64_t limit); + + bool was_seen_before(const Position& pos); + + optional get_current_game_result( + Position& pos, + const vector& move_hist_scores) const; + + vector generate_random_move_flags(PRNG& prng); + + optional choose_random_move( + PRNG& prng, + Position& pos, + std::vector& random_move_flag, + int ply, + int& random_move_c); + + bool commit_psv( + Thread& th, + PSVector& sfens, + int8_t lastTurnIsWin, + std::atomic& counter, + uint64_t limit, + Color result_color); + + void report(uint64_t done, uint64_t new_done); + + void maybe_report(uint64_t done); + }; + + void TrainingDataGenerator::set_gensfen_search_limits() + { + // About Search::Limits + // Be careful because this member variable is global and affects other threads. + auto& limits = Search::Limits; + + // Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done) + limits.infinite = true; + + // Since PV is an obstacle when displayed, erase it. + limits.silent = true; + + // If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it. + limits.nodes = 0; + + // depth is also processed by the one passed as an argument of Tools::search(). + limits.depth = 0; + } + + void TrainingDataGenerator::generate(uint64_t limit) + { + last_stats_report_time = 0; + + set_gensfen_search_limits(); + + std::atomic counter{0}; + Threads.execute_with_workers([&counter, limit, this](Thread& th) { + generate_worker(th, counter, limit); + }); + Threads.wait_for_workers_finished(); + + sfen_writer.flush(); + + if (limit % REPORT_STATS_EVERY != 0) + { + report(limit, limit % REPORT_STATS_EVERY); + } + + std::cout << std::endl; + } + + void TrainingDataGenerator::generate_worker( + Thread& th, + std::atomic& counter, + uint64_t limit) + { + // For the time being, it will be treated as a draw + // at the maximum number of steps to write. + // Maximum StateInfo + Search PV to advance to leaf buffer + std::vector> states( + params.write_maxply + MAX_PLY /* == search_depth_min + α */); + + StateInfo si; + + auto& prng = prngs[th.id()]; + + // end flag + bool quit = false; + + // repeat until the specified number of times + while (!quit) + { + // It is necessary to set a dependent thread for Position. + // When parallelizing, Threads (since this is a vector, + // Do the same for up to Threads[0]...Threads[thread_num-1]. + auto& pos = th.rootPos; + if (opening_book != nullptr) + { + auto& fen = opening_book->next_fen(); + pos.set(fen, false, &si, &th); + } + else + { + pos.set(StartFEN, false, &si, &th); + } + + int resign_counter = 0; + bool should_resign = prng.rand(10) > 1; + // Vector for holding the sfens in the current simulated game. + PSVector packed_sfens; + packed_sfens.reserve(params.write_maxply + MAX_PLY); + + // Precomputed flags. Used internally by choose_random_move. + vector random_move_flag = generate_random_move_flags(prng); + + // A counter that keeps track of the number of random moves + // When random_move_minply == -1, random moves are + // performed continuously, so use it at this time. + // Used internally by choose_random_move. + int actual_random_move_count = 0; + + // Save history of move scores for adjudication + vector move_hist_scores; + + auto flush_psv = [&](int8_t result) { + quit = commit_psv(th, packed_sfens, result, counter, limit, pos.side_to_move()); + }; + + for (int ply = 0; ; ++ply) + { + // Current search depth + const int depth = params.search_depth_min + (int)prng.rand(params.search_depth_max - params.search_depth_min + 1); + + // Starting search calls init_for_search + auto [search_value, search_pv] = Search::search(pos, depth, 1, params.nodes); + + // This has to be performed after search because it needs to know + // rootMoves which are filled in init_for_search. + const auto result = get_current_game_result(pos, move_hist_scores); + if (result.has_value()) + { + flush_psv(result.value()); + break; + } + + // Always adjudivate by eval limit. + // Also because of this we don't have to check for TB/MATE scores + if (abs(search_value) >= params.eval_limit) + { + resign_counter++; + if ((should_resign && resign_counter >= 4) || abs(search_value) >= VALUE_KNOWN_WIN) { + flush_psv((search_value >= params.eval_limit) ? 1 : -1); + break; + } + } + else + { + resign_counter = 0; + } + + // In case there is no PV and the game was not ended here + // there is nothing we can do, we can't continue the game, + // we don't know the result, so discard this game. + if (search_pv.empty()) + { + break; + } + + // Save the move score for adjudication. + move_hist_scores.push_back(search_value); + + // Discard stuff before write_minply is reached + // because it can harm training due to overfitting. + // Initial positions would be too common. + if (ply >= params.write_minply && !was_seen_before(pos)) + { + auto& psv = packed_sfens.emplace_back(); + + // Here we only write the position data. + // Result is added after the whole game is done. + pos.sfen_pack(psv.sfen); + + psv.score = search_value; + psv.move = search_pv[0]; + psv.gamePly = ply; + } + + // Update the next move according to best search result or random move. + auto random_move = choose_random_move(prng, pos, random_move_flag, ply, actual_random_move_count); + const Move next_move = random_move.has_value() ? *random_move : search_pv[0]; + + // We don't have the whole game yet, but it ended, + // so the writing process ends and the next game starts. + // This shouldn't really happen. + if (!is_ok(next_move)) + { + break; + } + + // Do move. + pos.do_move(next_move, states[ply]); + } + } + } + + bool TrainingDataGenerator::was_seen_before(const Position& pos) + { + // Look into the position hashtable to see if the same + // position was seen before. + // This is a good heuristic to exlude already seen + // positions without many false positives. + auto key = pos.key(); + auto hash_index = (size_t)(key & (GENSFEN_HASH_SIZE - 1)); + auto old_key = hash[hash_index]; + if (key == old_key) + { + return true; + } + else + { + // Replace with the current key. + hash[hash_index] = key; + return false; + } + } + + optional TrainingDataGenerator::get_current_game_result( + Position& pos, + const vector& move_hist_scores) const + { + // Variables for draw adjudication. + // Todo: Make this as an option. + + // start the adjudication when ply reaches this value + constexpr int adj_draw_ply = 80; + + // 4 move scores for each side have to be checked + constexpr int adj_draw_cnt = 8; + + // move score in CP + constexpr int adj_draw_score = 0; + + // For the time being, it will be treated as a + // draw at the maximum number of steps to write. + const int ply = move_hist_scores.size(); + + // has it reached the max length or is a draw by fifty-move rule + // or by 3-fold repetition + if (ply >= params.write_maxply + || pos.is_fifty_move_draw() + || pos.is_three_fold_repetition()) + { + return 0; + } + + if(pos.this_thread()->rootMoves.empty()) + { + // If there is no legal move + return pos.checkers() + ? -1 /* mate */ + : 0 /* stalemate */; + } + + // Adjudicate game to a draw if the last 4 scores of each engine is 0. + if (params.detect_draw_by_consecutive_low_score) + { + if (ply >= adj_draw_ply) + { + int num_cons_plies_within_draw_score = 0; + bool is_adj_draw = false; + + for (auto it = move_hist_scores.rbegin(); + it != move_hist_scores.rend(); ++it) + { + if (abs(*it) <= adj_draw_score) + { + num_cons_plies_within_draw_score++; + } + else + { + // Draw scores must happen on consecutive plies + break; + } + + if (num_cons_plies_within_draw_score >= adj_draw_cnt) + { + is_adj_draw = true; + break; + } + } + + if (is_adj_draw) + { + return 0; + } + } + } + + // Draw by insufficient mating material + if (params.detect_draw_by_insufficient_mating_material) + { + if (pos.count() <= 4) + { + int num_pieces = pos.count(); + + // (1) KvK + if (num_pieces == 2) + { + return 0; + } + + // (2) KvK + 1 minor piece + if (num_pieces == 3) + { + int minor_pc = pos.count(WHITE) + pos.count(WHITE) + + pos.count(BLACK) + pos.count(BLACK); + if (minor_pc == 1) + { + return 0; + } + } + + // (3) KBvKB, bishops of the same color + else if (num_pieces == 4) + { + if (pos.count(WHITE) == 1 && pos.count(BLACK) == 1) + { + // Color of bishops is black. + if ((pos.pieces(WHITE, BISHOP) & DarkSquares) + && (pos.pieces(BLACK, BISHOP) & DarkSquares)) + { + return 0; + } + // Color of bishops is white. + if ((pos.pieces(WHITE, BISHOP) & ~DarkSquares) + && (pos.pieces(BLACK, BISHOP) & ~DarkSquares)) + { + return 0; + } + } + } + } + } + + return nullopt; + } + + vector TrainingDataGenerator::generate_random_move_flags(PRNG& prng) + { + vector random_move_flag; + + // Depending on random move selection parameters setup + // the array of flags that indicates whether a random move + // be taken at a given ply. + + // Make an array like a[0] = 0 ,a[1] = 1, ... + // Fisher-Yates shuffle and take out the first N items. + // Actually, I only want N pieces, so I only need + // to shuffle the first N pieces with Fisher-Yates. + + vector a; + a.reserve((size_t)params.random_move_maxply); + + // random_move_minply ,random_move_maxply is specified by 1 origin, + // Note that we are handling 0 origin here. + for (int i = std::max(params.random_move_minply - 1, 0); i < params.random_move_maxply; ++i) + { + a.push_back(i); + } + + // In case of Apery random move, insert() may be called random_move_count times. + // Reserve only the size considering it. + random_move_flag.resize((size_t)params.random_move_maxply + params.random_move_count); + + // A random move that exceeds the size() of a[] cannot be applied, so limit it. + for (int i = 0; i < std::min(params.random_move_count, (int)a.size()); ++i) + { + swap(a[i], a[prng.rand((uint64_t)a.size() - i) + i]); + random_move_flag[a[i]] = true; + } + + return random_move_flag; + } + + optional TrainingDataGenerator::choose_random_move( + PRNG& prng, + Position& pos, + std::vector& random_move_flag, + int ply, + int& random_move_c) + { + optional random_move; + + // Randomly choose one from legal move + if ( + // 1. Random move of random_move_count times from random_move_minply to random_move_maxply + (params.random_move_minply != -1 && ply < (int)random_move_flag.size() && random_move_flag[ply]) || + // 2. A mode to perform random move of random_move_count times after leaving the startpos + (params.random_move_minply == -1 && random_move_c < params.random_move_count)) + { + ++random_move_c; + + // It's not a mate, so there should be one legal move... + if (params.random_multi_pv == 0) + { + // Normal random move + MoveList list(pos); + + // I don't really know the goodness and badness of making this the Apery method. + if (params.random_move_like_apery == 0 + || prng.rand(params.random_move_like_apery) != 0) + { + // Normally one move from legal move + random_move = list.at((size_t)prng.rand((uint64_t)list.size())); + } + else + { + // if you can move the king, move the king + Move moves[8]; // Near 8 + Move* p = &moves[0]; + for (auto& m : list) + { + if (type_of(pos.moved_piece(m)) == KING) + { + *(p++) = m; + } + } + + size_t n = p - &moves[0]; + if (n != 0) + { + // move to move the king + random_move = moves[prng.rand(n)]; + + // In Apery method, at this time there is a 1/2 chance + // that the opponent will also move randomly + if (prng.rand(2) == 0) + { + // Is it a simple hack to add a "1" next to random_move_flag[ply]? + random_move_flag.insert(random_move_flag.begin() + ply + 1, 1, true); + } + } + else + { + // Normally one move from legal move + random_move = list.at((size_t)prng.rand((uint64_t)list.size())); + } + } + } + else + { + Search::search(pos, params.random_multi_pv_depth, params.random_multi_pv); + + // Select one from the top N hands of root Moves + auto& rm = pos.this_thread()->rootMoves; + + uint64_t s = min((uint64_t)rm.size(), (uint64_t)params.random_multi_pv); + for (uint64_t i = 1; i < s; ++i) + { + // The difference from the evaluation value of rm[0] must + // be within the range of random_multi_pv_diff. + // It can be assumed that rm[x].score is arranged in descending order. + if (rm[0].score > rm[i].score + params.random_multi_pv_diff) + { + s = i; + break; + } + } + + random_move = rm[prng.rand(s)].pv[0]; + } + } + + return random_move; + } + + // Write out the phases loaded in sfens to a file. + // result: win/loss in the next phase after the final phase in sfens + // 1 when winning. -1 when losing. Pass 0 for a draw. + // Return value: true if the specified number of + // sfens has already been reached and the process ends. + bool TrainingDataGenerator::commit_psv( + Thread& th, + PSVector& sfens, + int8_t result, + std::atomic& counter, + uint64_t limit, + Color result_color) + { + if (!params.write_out_draw_game_in_training_data_generation && result == 0) + { + // We didn't write anything so why quit. + return false; + } + + auto side_to_move_from_sfen = [](auto& sfen){ + return (Color)(sfen.sfen.data[0] & 1); + }; + + // From the final stage (one step before) to the first stage, give information on the outcome of the game for each stage. + // The phases stored in sfens are assumed to be continuous (in order). + for (auto it = sfens.rbegin(); it != sfens.rend(); ++it) + { + // The side to move is packed as the lowest bit of the first byte + const Color side_to_move = side_to_move_from_sfen(*it); + it->game_result = side_to_move == result_color ? result : -result; + } + + // Write sfens in move order to make potential compression easier + for (auto& sfen : sfens) + { + // Return true if there is already enough data generated. + const auto iter = counter.fetch_add(1); + if (iter >= limit) + return true; + + // because `iter` was done, now we do one more + maybe_report(iter + 1); + + // Write out one sfen. + sfen_writer.write(th.id(), sfen); + } + + return false; + } + + void TrainingDataGenerator::report(uint64_t done, uint64_t new_done) + { + const auto now_time = now(); + const TimePoint elapsed = now_time - last_stats_report_time + 1; + + out + << endl + << done << " sfens, " + << new_done * 1000 / elapsed << " sfens/second, " + << "at " << now_string() << sync_endl; + + last_stats_report_time = now_time; + + out = sync_region_cout.new_region(); + } + + void TrainingDataGenerator::maybe_report(uint64_t done) + { + if (done % REPORT_DOT_EVERY == 0) + { + std::lock_guard lock(stats_mutex); + + if (last_stats_report_time == 0) + { + last_stats_report_time = now(); + out = sync_region_cout.new_region(); + } + + if (done != 0) + { + out << '.'; + + if (done % REPORT_STATS_EVERY == 0) + { + report(done, REPORT_STATS_EVERY); + } + } + } + } + + // Command to generate a game record + void generate_training_data(istringstream& is) + { + // Number of generated game records default = 8 billion phases (Ponanza specification) + uint64_t loop_max = 8000000000UL; + + TrainingDataGenerator::Params params; + + // Add a random number to the end of the file name. + bool random_file_name = false; + std::string sfen_format = "binpack"; + + string token; + while (true) + { + token = ""; + is >> token; + if (token == "") + break; + + if (token == "depth") + { + is >> params.search_depth_min; + params.search_depth_max = params.search_depth_min; + } + else if (token == "min_depth") + is >> params.search_depth_min; + else if (token == "max_depth") + is >> params.search_depth_min; + else if (token == "nodes") + is >> params.nodes; + else if (token == "count") + is >> loop_max; + else if (token == "output_file_name") + is >> params.output_file_name; + else if (token == "eval_limit") + is >> params.eval_limit; + else if (token == "random_move_min_ply") + is >> params.random_move_minply; + else if (token == "random_move_max_ply") + is >> params.random_move_maxply; + else if (token == "random_move_count") + is >> params.random_move_count; + else if (token == "random_move_like_apery") + is >> params.random_move_like_apery; + else if (token == "random_multi_pv") + is >> params.random_multi_pv; + else if (token == "random_multi_pv_diff") + is >> params.random_multi_pv_diff; + else if (token == "random_multi_pv_depth") + is >> params.random_multi_pv_depth; + else if (token == "write_min_ply") + is >> params.write_minply; + else if (token == "write_max_ply") + is >> params.write_maxply; + else if (token == "save_every") + is >> params.save_every; + else if (token == "book") + is >> params.book; + else if (token == "random_file_name") + is >> random_file_name; + else if (token == "keep_draws") + is >> params.write_out_draw_game_in_training_data_generation; + else if (token == "adjudicate_draws_by_score") + is >> params.detect_draw_by_consecutive_low_score; + else if (token == "adjudicate_draws_by_insufficient_material") + is >> params.detect_draw_by_insufficient_mating_material; + else if (token == "data_format") + is >> sfen_format; + else if (token == "seed") + is >> params.seed; + else if (token == "set_recommended_uci_options") + { + UCI::setoption("Contempt", "0"); + UCI::setoption("Skill Level", "20"); + UCI::setoption("UCI_Chess960", "false"); + UCI::setoption("UCI_AnalyseMode", "false"); + UCI::setoption("UCI_LimitStrength", "false"); + UCI::setoption("PruneAtShallowDepth", "false"); + UCI::setoption("EnableTranspositionTable", "true"); + } + else + { + cout << "ERROR: Unknown option " << token << ". Exiting...\n"; + return; + } + } + + if (!sfen_format.empty()) + { + if (sfen_format == "bin") + params.sfen_format = SfenOutputType::Bin; + else if (sfen_format == "binpack") + params.sfen_format = SfenOutputType::Binpack; + else + cout << "WARNING: Unknown sfen format `" << sfen_format << "`. Using bin\n"; + } + + if (random_file_name) + { + // Give a random number to output_file_name at this point. + // Do not use std::random_device(). Because it always the same integers on MinGW. + PRNG r(params.seed); + + // Just in case, reassign the random numbers. + for (int i = 0; i < 10; ++i) + r.rand(1); + + auto to_hex = [](uint64_t u) { + std::stringstream ss; + ss << std::hex << u; + return ss.str(); + }; + + // I don't want to wear 64bit numbers by accident, so I'next_move going to make a 64bit number 2 just in case. + params.output_file_name += "_" + to_hex(r.rand()) + to_hex(r.rand()); + } + + params.enforce_constraints(); + + std::cout << "INFO: Executing generate_training_data command\n"; + + std::cout << "INFO: Parameters:\n"; + std::cout + << " - search_depth_min = " << params.search_depth_min << endl + << " - search_depth_max = " << params.search_depth_max << endl + << " - nodes = " << params.nodes << endl + << " - count = " << loop_max << endl + << " - eval_limit = " << params.eval_limit << endl + << " - num threads (UCI) = " << params.num_threads << endl + << " - random_move_min_ply = " << params.random_move_minply << endl + << " - random_move_max_ply = " << params.random_move_maxply << endl + << " - random_move_count = " << params.random_move_count << endl + << " - random_move_like_apery = " << params.random_move_like_apery << endl + << " - random_multi_pv = " << params.random_multi_pv << endl + << " - random_multi_pv_diff = " << params.random_multi_pv_diff << endl + << " - random_multi_pv_depth = " << params.random_multi_pv_depth << endl + << " - write_min_ply = " << params.write_minply << endl + << " - write_max_ply = " << params.write_maxply << endl + << " - book = " << params.book << endl + << " - output_file_name = " << params.output_file_name << endl + << " - save_every = " << params.save_every << endl + << " - random_file_name = " << random_file_name << endl + << " - write_drawn_games = " << params.write_out_draw_game_in_training_data_generation << endl + << " - draw by low score = " << params.detect_draw_by_consecutive_low_score << endl + << " - draw by insuff. mat. = " << params.detect_draw_by_insufficient_mating_material << endl; + + // Show if the training data generator uses NNUE. + Eval::NNUE::verify(); + + Threads.main()->ponder = false; + + TrainingDataGenerator gensfen(params); + gensfen.generate(loop_max); + + std::cout << "INFO: generate_training_data finished." << endl; + } +} diff --git a/src/tools/training_data_generator.h b/src/tools/training_data_generator.h new file mode 100644 index 00000000..9a105155 --- /dev/null +++ b/src/tools/training_data_generator.h @@ -0,0 +1,14 @@ +#ifndef _GENSFEN_H_ +#define _GENSFEN_H_ + +#include "position.h" + +#include + +namespace Stockfish::Tools { + + // Automatic generation of teacher position + void generate_training_data(std::istringstream& is); +} + +#endif \ No newline at end of file diff --git a/src/tools/training_data_generator_nonpv.cpp b/src/tools/training_data_generator_nonpv.cpp new file mode 100644 index 00000000..278259c6 --- /dev/null +++ b/src/tools/training_data_generator_nonpv.cpp @@ -0,0 +1,491 @@ +#include "training_data_generator_nonpv.h" + +#include "sfen_writer.h" +#include "packed_sfen.h" +#include "opening_book.h" + +#include "misc.h" +#include "position.h" +#include "thread.h" +#include "tt.h" +#include "uci.h" + +#include "extra/nnue_data_binpack_format.h" + +#include "nnue/evaluate_nnue.h" + +#include "syzygy/tbprobe.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +namespace Stockfish::Tools +{ + // Class to generate sfen with multiple threads + struct TrainingDataGeneratorNonPv + { + struct Params + { + // The depth for search on the fens gathered during exploration + int search_depth = 3; + + // the min/max number of nodes to use for exploration per ply + int exploration_min_nodes = 5000; + int exploration_max_nodes = 15000; + + // The pct of positions explored that are saved for rescoring + float exploration_save_rate = 0.01; + + // Upper limit of evaluation value of generated situation + int eval_limit = 4000; + + // the upper limit on evaluation during exploration selfplay + int exploration_eval_limit = 4000; + + int exploration_max_ply = 200; + + int exploration_min_pieces = 8; + + std::string output_file_name = "training_data_nonpv"; + + SfenOutputType sfen_format = SfenOutputType::Binpack; + + std::string seed; + + int num_threads; + + std::string book; + + bool smart_fen_skipping = false; + + void enforce_constraints() + { + // Limit the maximum to a one-stop score. (Otherwise you might not end the loop) + eval_limit = std::min(eval_limit, (int)mate_in(2)); + exploration_eval_limit = std::min(eval_limit, (int)mate_in(2)); + exploration_min_nodes = std::max(100, exploration_min_nodes); + exploration_max_nodes = std::max(exploration_min_nodes, exploration_max_nodes); + + num_threads = Options["Threads"]; + } + }; + + static constexpr uint64_t REPORT_DOT_EVERY = 5000; + static constexpr uint64_t REPORT_STATS_EVERY = 200000; + static_assert(REPORT_STATS_EVERY % REPORT_DOT_EVERY == 0); + + TrainingDataGeneratorNonPv( + const Params& prm + ) : + params(prm), + prng(prm.seed), + sfen_writer(prm.output_file_name, prm.num_threads, std::numeric_limits::max(), prm.sfen_format) + { + if (!prm.book.empty()) + { + opening_book = open_opening_book(prm.book, prng); + if (opening_book == nullptr) + { + std::cout << "WARNING: Failed to open opening book " << prm.book << ". Falling back to startpos.\n"; + } + } + + // Output seed to veryfy by the user if it's not identical by chance. + std::cout << prng << std::endl; + } + + void generate(uint64_t limit); + + private: + Params params; + + PRNG prng; + + std::mutex stats_mutex; + TimePoint last_stats_report_time; + + // sfen exporter + SfenWriter sfen_writer; + + SynchronizedRegionLogger::Region out; + + std::unique_ptr opening_book; + + static void set_gensfen_search_limits(); + + void generate_worker( + Thread& th, + std::atomic& counter, + uint64_t limit); + + bool commit_psv( + Thread& th, + PSVector& sfens, + std::atomic& counter, + uint64_t limit); + + PSVector do_exploration( + Thread& th, + int count); + + void report(uint64_t done, uint64_t new_done); + + void maybe_report(uint64_t done); + }; + + void TrainingDataGeneratorNonPv::set_gensfen_search_limits() + { + // About Search::Limits + // Be careful because this member variable is global and affects other threads. + auto& limits = Search::Limits; + + // Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done) + limits.infinite = true; + + // Since PV is an obstacle when displayed, erase it. + limits.silent = true; + + // If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it. + limits.nodes = 0; + + // depth is also processed by the one passed as an argument of Tools::search(). + limits.depth = 0; + } + + void TrainingDataGeneratorNonPv::generate(uint64_t limit) + { + last_stats_report_time = 0; + + set_gensfen_search_limits(); + + std::atomic counter{0}; + Threads.execute_with_workers([&counter, limit, this](Thread& th) { + generate_worker(th, counter, limit); + }); + Threads.wait_for_workers_finished(); + + sfen_writer.flush(); + + if (limit % REPORT_STATS_EVERY != 0) + { + report(limit, limit % REPORT_STATS_EVERY); + } + + std::cout << std::endl; + } + + PSVector TrainingDataGeneratorNonPv::do_exploration( + Thread& th, + int count) + { + constexpr int max_depth = 30; + + PSVector psv; + + std::vector> states( + max_depth + MAX_PLY /* == search_depth_min + α */); + + th.set_eval_callback([this, &psv](Position& pos) { + if ((double)prng.rand() / std::numeric_limits::max() < params.exploration_save_rate) + { + psv.emplace_back(); + pos.sfen_pack(psv.back().sfen); + } + }); + + auto& pos = th.rootPos; + StateInfo si; + + for (int i = 0; i < count; ++i) + { + if (opening_book != nullptr) + { + auto& fen = opening_book->next_fen(); + pos.set(fen, false, &si, &th); + } + else + { + pos.set(StartFEN, false, &si, &th); + } + + for(int ply = 0; ply < params.exploration_max_ply; ++ply) + { + auto nodes = prng.rand(params.exploration_max_nodes - params.exploration_min_nodes + 1) + params.exploration_min_nodes; + + auto [search_value, search_pv] = Search::search(pos, max_depth, 1, nodes); + + if (search_pv.empty()) + { + break; + } + + if (std::abs(search_value) > params.exploration_eval_limit) + { + break; + } + + pos.do_move(search_pv[0], states[ply]); + + if (popcount(pos.pieces()) < params.exploration_min_pieces) + { + break; + } + } + } + + th.clear_eval_callback(); + + return psv; + } + + void TrainingDataGeneratorNonPv::generate_worker( + Thread& th, + std::atomic& counter, + uint64_t limit) + { + constexpr int exploration_batch_size = 1; + + StateInfo si; + + PSVector psv; + + // end flag + bool quit = false; + + // repeat until the specified number of times + while (!quit) + { + // It is necessary to set a dependent thread for Position. + // When parallelizing, Threads (since this is a vector, + // Do the same for up to Threads[0]...Threads[thread_num-1]. + auto& pos = th.rootPos; + + auto packed_sfens = do_exploration(th, exploration_batch_size); + psv.clear(); + + for (auto& ps : packed_sfens) + { + pos.set_from_packed_sfen(ps.sfen, &si, &th); + pos.state()->rule50 = 0; + + if (params.smart_fen_skipping && pos.checkers()) + { + continue; + } + + auto [search_value, search_pv] = Search::search(pos, params.search_depth, 1); + + if (search_pv.empty()) + { + continue; + } + + if (std::abs(search_value) > params.eval_limit) + { + continue; + } + + if (params.smart_fen_skipping && pos.capture_or_promotion(search_pv[0])) + { + continue; + } + + auto& new_ps = psv.emplace_back(); + pos.sfen_pack(new_ps.sfen); + new_ps.score = search_value; + new_ps.move = search_pv[0]; + new_ps.gamePly = 1; + new_ps.game_result = 0; + new_ps.padding = 0; + } + + quit = commit_psv(th, psv, counter, limit); + } + } + + // Write out the phases loaded in sfens to a file. + // result: win/loss in the next phase after the final phase in sfens + // 1 when winning. -1 when losing. Pass 0 for a draw. + // Return value: true if the specified number of + // sfens has already been reached and the process ends. + bool TrainingDataGeneratorNonPv::commit_psv( + Thread& th, + PSVector& sfens, + std::atomic& counter, + uint64_t limit) + { + // Write sfens in move order to make potential compression easier + for (auto& sfen : sfens) + { + // Return true if there is already enough data generated. + const auto iter = counter.fetch_add(1); + if (iter >= limit) + return true; + + // because `iter` was done, now we do one more + maybe_report(iter + 1); + + // Write out one sfen. + sfen_writer.write(th.id(), sfen); + } + + return false; + } + + void TrainingDataGeneratorNonPv::report(uint64_t done, uint64_t new_done) + { + const auto now_time = now(); + const TimePoint elapsed = now_time - last_stats_report_time + 1; + + out + << endl + << done << " sfens, " + << new_done * 1000 / elapsed << " sfens/second, " + << "at " << now_string() << sync_endl; + + last_stats_report_time = now_time; + + out = sync_region_cout.new_region(); + } + + void TrainingDataGeneratorNonPv::maybe_report(uint64_t done) + { + if (done % REPORT_DOT_EVERY == 0) + { + std::lock_guard lock(stats_mutex); + + if (last_stats_report_time == 0) + { + last_stats_report_time = now(); + out = sync_region_cout.new_region(); + } + + if (done != 0) + { + out << '.'; + + if (done % REPORT_STATS_EVERY == 0) + { + report(done, REPORT_STATS_EVERY); + } + } + } + } + + // Command to generate a game record + void generate_training_data_nonpv(istringstream& is) + { + // Number of generated game records default = 8 billion phases (Ponanza specification) + TrainingDataGeneratorNonPv::Params params; + + uint64_t count = 1'000'000; + + // Add a random number to the end of the file name. + std::string sfen_format = "binpack"; + + string token; + while (true) + { + token = ""; + is >> token; + if (token == "") + break; + + if (token == "depth") + is >> params.search_depth; + else if (token == "count") + is >> count; + else if (token == "output_file") + is >> params.output_file_name; + else if (token == "exploration_eval_limit") + is >> params.exploration_eval_limit; + else if (token == "eval_limit") + is >> params.eval_limit; + else if (token == "exploration_min_nodes") + is >> params.exploration_min_nodes; + else if (token == "exploration_max_nodes") + is >> params.exploration_max_nodes; + else if (token == "exploration_min_pieces") + is >> params.exploration_min_pieces; + else if (token == "exploration_save_rate") + is >> params.exploration_save_rate; + else if (token == "book") + is >> params.book; + else if (token == "data_format") + is >> sfen_format; + else if (token == "seed") + is >> params.seed; + else if (token == "smart_fen_skipping") + params.smart_fen_skipping = true; + else if (token == "set_recommended_uci_options") + { + UCI::setoption("Contempt", "0"); + UCI::setoption("Skill Level", "20"); + UCI::setoption("UCI_Chess960", "false"); + UCI::setoption("UCI_AnalyseMode", "false"); + UCI::setoption("UCI_LimitStrength", "false"); + UCI::setoption("PruneAtShallowDepth", "false"); + UCI::setoption("EnableTranspositionTable", "true"); + } + else + { + cout << "ERROR: Unknown option " << token << ". Exiting...\n"; + return; + } + } + + if (!sfen_format.empty()) + { + if (sfen_format == "bin") + params.sfen_format = SfenOutputType::Bin; + else if (sfen_format == "binpack") + params.sfen_format = SfenOutputType::Binpack; + else + cout << "WARNING: Unknown sfen format `" << sfen_format << "`. Using bin\n"; + } + + params.enforce_constraints(); + + std::cout << "INFO: Executing generate_training_data_nonpv command\n"; + + std::cout << "INFO: Parameters:\n"; + std::cout + << " - search_depth = " << params.search_depth << endl + << " - output_file = " << params.output_file_name << endl + << " - exploration_eval_limit = " << params.exploration_eval_limit << endl + << " - eval_limit = " << params.eval_limit << endl + << " - exploration_min_nodes = " << params.exploration_min_nodes << endl + << " - exploration_max_nodes = " << params.exploration_max_nodes << endl + << " - exploration_min_pieces = " << params.exploration_min_pieces << endl + << " - exploration_save_rate = " << params.exploration_save_rate << endl + << " - book = " << params.book << endl + << " - data_format = " << sfen_format << endl + << " - seed = " << params.seed << endl + << " - count = " << count << endl; + + // Show if the training data generator uses NNUE. + Eval::NNUE::verify(); + + Threads.main()->ponder = false; + + TrainingDataGeneratorNonPv gensfen(params); + gensfen.generate(count); + + std::cout << "INFO: generate_training_data_nonpv finished." << endl; + } +} diff --git a/src/tools/training_data_generator_nonpv.h b/src/tools/training_data_generator_nonpv.h new file mode 100644 index 00000000..8bd093cb --- /dev/null +++ b/src/tools/training_data_generator_nonpv.h @@ -0,0 +1,12 @@ +#ifndef _GENSFEN_NONPV_H_ +#define _GENSFEN_NONPV_H_ + +#include + +namespace Stockfish::Tools { + + // Automatic generation of teacher position + void generate_training_data_nonpv(std::istringstream& is); +} + +#endif \ No newline at end of file diff --git a/src/tools/transform.cpp b/src/tools/transform.cpp new file mode 100644 index 00000000..f657b410 --- /dev/null +++ b/src/tools/transform.cpp @@ -0,0 +1,524 @@ +#include "transform.h" + +#include "sfen_stream.h" +#include "packed_sfen.h" +#include "sfen_writer.h" + +#include "thread.h" +#include "position.h" +#include "evaluate.h" +#include "search.h" + +#include "nnue/evaluate_nnue.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Stockfish::Tools +{ + using CommandFunc = void(*)(std::istringstream&); + + enum struct NudgedStaticMode + { + Absolute, + Relative, + Interpolate + }; + + struct NudgedStaticParams + { + std::string input_filename = "in.binpack"; + std::string output_filename = "out.binpack"; + NudgedStaticMode mode = NudgedStaticMode::Absolute; + int absolute_nudge = 5; + float relative_nudge = 0.1; + float interpolate_nudge = 0.1; + + void enforce_constraints() + { + relative_nudge = std::max(relative_nudge, 0.0f); + absolute_nudge = std::max(absolute_nudge, 0); + } + }; + + struct RescoreParams + { + std::string input_filename = "in.epd"; + std::string output_filename = "out.binpack"; + int depth = 3; + int research_count = 0; + bool keep_moves = true; + + void enforce_constraints() + { + depth = std::max(1, depth); + research_count = std::max(0, research_count); + } + }; + + [[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16) + { + auto saturate_i32_to_i16 = [](int v) { + return static_cast( + std::clamp( + v, + (int)std::numeric_limits::min(), + (int)std::numeric_limits::max() + ) + ); + }; + + auto saturate_f32_to_i16 = [saturate_i32_to_i16](float v) { + return saturate_i32_to_i16((int)v); + }; + + int static_eval = static_eval_i16; + int deep_eval = deep_eval_i16; + + switch(params.mode) + { + case NudgedStaticMode::Absolute: + return saturate_i32_to_i16( + static_eval + std::clamp( + deep_eval - static_eval, + -params.absolute_nudge, + params.absolute_nudge + ) + ); + + case NudgedStaticMode::Relative: + return saturate_f32_to_i16( + (float)static_eval * std::clamp( + (float)deep_eval / (float)static_eval, + (1.0f - params.relative_nudge), + (1.0f + params.relative_nudge) + ) + ); + + case NudgedStaticMode::Interpolate: + return saturate_f32_to_i16( + (float)static_eval * (1.0f - params.interpolate_nudge) + + (float)deep_eval * params.interpolate_nudge + ); + + default: + assert(false); + return 0; + } + } + + void do_nudged_static(NudgedStaticParams& params) + { + Thread* th = Threads.main(); + Position& pos = th->rootPos; + StateInfo si; + + auto in = Tools::open_sfen_input_file(params.input_filename); + auto out = Tools::create_new_sfen_output(params.output_filename); + + if (in == nullptr) + { + std::cerr << "Invalid input file type.\n"; + return; + } + + if (out == nullptr) + { + std::cerr << "Invalid output file type.\n"; + return; + } + + PSVector buffer; + uint64_t batch_size = 1'000'000; + + buffer.reserve(batch_size); + + uint64_t num_processed = 0; + for (;;) + { + auto v = in->next(); + if (!v.has_value()) + break; + + auto& ps = v.value(); + + pos.set_from_packed_sfen(ps.sfen, &si, th); + auto static_eval = Eval::evaluate(pos); + auto deep_eval = ps.score; + ps.score = nudge(params, static_eval, deep_eval); + + buffer.emplace_back(ps); + if (buffer.size() >= batch_size) + { + num_processed += buffer.size(); + + out->write(buffer); + buffer.clear(); + + std::cout << "Processed " << num_processed << " positions.\n"; + } + } + + if (!buffer.empty()) + { + num_processed += buffer.size(); + + out->write(buffer); + buffer.clear(); + + std::cout << "Processed " << num_processed << " positions.\n"; + } + + std::cout << "Finished.\n"; + } + + void nudged_static(std::istringstream& is) + { + NudgedStaticParams params{}; + + while(true) + { + std::string token; + is >> token; + + if (token == "") + break; + + if (token == "absolute") + { + params.mode = NudgedStaticMode::Absolute; + is >> params.absolute_nudge; + } + else if (token == "relative") + { + params.mode = NudgedStaticMode::Relative; + is >> params.relative_nudge; + } + else if (token == "interpolate") + { + params.mode = NudgedStaticMode::Interpolate; + is >> params.interpolate_nudge; + } + else if (token == "input_file") + is >> params.input_filename; + else if (token == "output_file") + is >> params.output_filename; + else + { + std::cout << "ERROR: Unknown option " << token << ". Exiting...\n"; + return; + } + } + + std::cout << "Performing transform nudged_static with parameters:\n"; + std::cout << "input_file : " << params.input_filename << '\n'; + std::cout << "output_file : " << params.output_filename << '\n'; + std::cout << "\n"; + if (params.mode == NudgedStaticMode::Absolute) + { + std::cout << "mode : absolute\n"; + std::cout << "absolute_nudge : " << params.absolute_nudge << '\n'; + } + else if (params.mode == NudgedStaticMode::Relative) + { + std::cout << "mode : relative\n"; + std::cout << "relative_nudge : " << params.relative_nudge << '\n'; + } + else if (params.mode == NudgedStaticMode::Interpolate) + { + std::cout << "mode : interpolate\n"; + std::cout << "interpolate_nudge : " << params.interpolate_nudge << '\n'; + } + std::cout << '\n'; + + params.enforce_constraints(); + do_nudged_static(params); + } + + void do_rescore_epd(RescoreParams& params) + { + std::ifstream fens_file(params.input_filename); + + auto next_fen = [&fens_file, mutex = std::mutex{}]() mutable -> std::optional{ + std::string fen; + + std::unique_lock lock(mutex); + + if (std::getline(fens_file, fen) && fen.size() >= 10) + { + return fen; + } + else + { + return std::nullopt; + } + }; + + PSVector buffer; + uint64_t batch_size = 10'000; + + buffer.reserve(batch_size); + + auto out = Tools::create_new_sfen_output(params.output_filename); + + std::mutex mutex; + uint64_t num_processed = 0; + + // About Search::Limits + // Be careful because this member variable is global and affects other threads. + auto& limits = Search::Limits; + + // Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done) + limits.infinite = true; + + // Since PV is an obstacle when displayed, erase it. + limits.silent = true; + + // If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it. + limits.nodes = 0; + + // depth is also processed by the one passed as an argument of Tools::search(). + limits.depth = 0; + + Threads.execute_with_workers([&](auto& th){ + Position& pos = th.rootPos; + StateInfo si; + + for(;;) + { + auto fen = next_fen(); + if (!fen.has_value()) + return; + + pos.set(*fen, false, &si, &th); + pos.state()->rule50 = 0; + + + for (int cnt = 0; cnt < params.research_count; ++cnt) + Search::search(pos, params.depth, 1); + + auto [search_value, search_pv] = Search::search(pos, params.depth, 1); + + if (search_pv.empty()) + continue; + + PackedSfenValue ps; + pos.sfen_pack(ps.sfen); + ps.score = search_value; + ps.move = search_pv[0]; + ps.gamePly = 1; + ps.game_result = 0; + ps.padding = 0; + + std::unique_lock lock(mutex); + buffer.emplace_back(ps); + if (buffer.size() >= batch_size) + { + num_processed += buffer.size(); + + out->write(buffer); + buffer.clear(); + + std::cout << "Processed " << num_processed << " positions.\n"; + } + } + }); + Threads.wait_for_workers_finished(); + + if (!buffer.empty()) + { + num_processed += buffer.size(); + + out->write(buffer); + buffer.clear(); + + std::cout << "Processed " << num_processed << " positions.\n"; + } + + std::cout << "Finished.\n"; + } + + void do_rescore_data(RescoreParams& params) + { + // TODO: Use SfenReader once it works correctly in sequential mode. See issue #271 + auto in = Tools::open_sfen_input_file(params.input_filename); + auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> PSVector { + + PSVector psv; + psv.reserve(n); + + std::unique_lock lock(mutex); + + for (int i = 0; i < n; ++i) + { + auto ps_opt = in->next(); + if (ps_opt.has_value()) + { + psv.emplace_back(*ps_opt); + } + else + { + break; + } + } + + return psv; + }; + + auto sfen_format = ends_with(params.output_filename, ".binpack") ? SfenOutputType::Binpack : SfenOutputType::Bin; + + auto out = SfenWriter( + params.output_filename, + Threads.size(), + std::numeric_limits::max(), + sfen_format); + + // About Search::Limits + // Be careful because this member variable is global and affects other threads. + auto& limits = Search::Limits; + + // Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done) + limits.infinite = true; + + // Since PV is an obstacle when displayed, erase it. + limits.silent = true; + + // If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it. + limits.nodes = 0; + + // depth is also processed by the one passed as an argument of Tools::search(). + limits.depth = 0; + + std::atomic num_processed = 0; + + Threads.execute_with_workers([&](auto& th){ + Position& pos = th.rootPos; + StateInfo si; + + for (;;) + { + PSVector psv = readsome(5000); + if (psv.empty()) + break; + + for(auto& ps : psv) + { + pos.set_from_packed_sfen(ps.sfen, &si, &th); + + for (int cnt = 0; cnt < params.research_count; ++cnt) + Search::search(pos, params.depth, 1); + + auto [search_value, search_pv] = Search::search(pos, params.depth, 1); + + if (search_pv.empty()) + continue; + + pos.sfen_pack(ps.sfen); + ps.score = search_value; + if (!params.keep_moves) + ps.move = search_pv[0]; + ps.padding = 0; + + out.write(th.id(), ps); + + auto p = num_processed.fetch_add(1) + 1; + if (p % 10000 == 0) + { + std::cout << "Processed " << p << " positions.\n"; + } + } + } + }); + Threads.wait_for_workers_finished(); + + std::cout << "Finished.\n"; + } + + void do_rescore(RescoreParams& params) + { + if (ends_with(params.input_filename, ".epd")) + { + do_rescore_epd(params); + } + else if (ends_with(params.input_filename, ".bin") || ends_with(params.input_filename, ".binpack")) + { + do_rescore_data(params); + } + else + { + std::cerr << "Invalid input file type.\n"; + } + } + + void rescore(std::istringstream& is) + { + RescoreParams params{}; + + while(true) + { + std::string token; + is >> token; + + if (token == "") + break; + + if (token == "depth") + is >> params.depth; + else if (token == "input_file") + is >> params.input_filename; + else if (token == "output_file") + is >> params.output_filename; + else if (token == "keep_moves") + is >> params.keep_moves; + else if (token == "research_count") + is >> params.research_count; + else + { + std::cout << "ERROR: Unknown option " << token << ". Exiting...\n"; + return; + } + } + + params.enforce_constraints(); + + std::cout << "Performing transform rescore with parameters:\n"; + std::cout << "depth : " << params.depth << '\n'; + std::cout << "input_file : " << params.input_filename << '\n'; + std::cout << "output_file : " << params.output_filename << '\n'; + std::cout << "keep_moves : " << params.keep_moves << '\n'; + std::cout << "research_count : " << params.research_count << '\n'; + std::cout << '\n'; + + do_rescore(params); + } + + void transform(std::istringstream& is) + { + const std::map subcommands = { + { "nudged_static", &nudged_static }, + { "rescore", &rescore } + }; + + Eval::NNUE::init(); + + std::string subcommand; + is >> subcommand; + + auto func = subcommands.find(subcommand); + if (func == subcommands.end()) + { + std::cout << "Invalid subcommand " << subcommand << ". Exiting...\n"; + return; + } + + func->second(is); + } + +} diff --git a/src/tools/transform.h b/src/tools/transform.h new file mode 100644 index 00000000..dcaf9a1d --- /dev/null +++ b/src/tools/transform.h @@ -0,0 +1,12 @@ +#ifndef _TRANSFORM_H_ +#define _TRANSFORM_H_ + +#include + +namespace Stockfish::Tools { + + void transform(std::istringstream& is); + +} + +#endif diff --git a/src/tools/validate_training_data.cpp b/src/tools/validate_training_data.cpp new file mode 100644 index 00000000..18cae456 --- /dev/null +++ b/src/tools/validate_training_data.cpp @@ -0,0 +1,122 @@ +#include "validate_training_data.h" + +#include "uci.h" +#include "misc.h" +#include "thread.h" +#include "position.h" +#include "tt.h" + +#include "extra/nnue_data_binpack_format.h" + +#include "nnue/evaluate_nnue.h" + +#include "syzygy/tbprobe.h" + +#include +#include +#include +#include +#include +#include // std::exp(),std::pow(),std::log() +#include // memcpy() +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +namespace sys = std::filesystem; + +namespace Stockfish::Tools +{ + static inline const std::string plain_extension = ".plain"; + static inline const std::string bin_extension = ".bin"; + static inline const std::string binpack_extension = ".binpack"; + + static bool file_exists(const std::string& name) + { + std::ifstream f(name); + return f.good(); + } + + static bool ends_with(const std::string& lhs, const std::string& end) + { + if (end.size() > lhs.size()) return false; + + return std::equal(end.rbegin(), end.rend(), lhs.rbegin()); + } + + static bool is_validation_of_type( + const std::string& input_path, + const std::string& expected_input_extension) + { + return ends_with(input_path, expected_input_extension); + } + + using ValidateFunctionType = void(std::string inputPath); + + static ValidateFunctionType* get_validate_function(const std::string& input_path) + { + if (is_validation_of_type(input_path, plain_extension)) + return binpack::validatePlain; + + if (is_validation_of_type(input_path, bin_extension)) + return binpack::validateBin; + + if (is_validation_of_type(input_path, binpack_extension)) + return binpack::validateBinpack; + + return nullptr; + } + + static void validate_training_data(const std::string& input_path) + { + if(!file_exists(input_path)) + { + std::cerr << "Input file does not exist.\n"; + return; + } + + auto func = get_validate_function(input_path); + if (func != nullptr) + { + func(input_path); + } + else + { + std::cerr << "Validation of files of this type is not supported.\n"; + } + } + + static void validate_training_data(const std::vector& args) + { + if (args.size() != 1) + { + std::cerr << "Invalid arguments.\n"; + std::cerr << "Usage: validate in_path\n"; + return; + } + + validate_training_data(args[0]); + } + + void validate_training_data(istringstream& is) + { + std::vector args; + + while (true) + { + std::string token = ""; + is >> token; + if (token == "") + break; + + args.push_back(token); + } + + validate_training_data(args); + } +} diff --git a/src/tools/validate_training_data.h b/src/tools/validate_training_data.h new file mode 100644 index 00000000..0c62ab50 --- /dev/null +++ b/src/tools/validate_training_data.h @@ -0,0 +1,12 @@ +#ifndef _VALIDATE_TRAINING_DATA_H_ +#define _VALIDATE_TRAINING_DATA_H_ + +#include +#include +#include + +namespace Stockfish::Tools { + void validate_training_data(std::istringstream& is); +} + +#endif diff --git a/src/tt.cpp b/src/tt.cpp index 1f495ca9..981fb4e2 100644 --- a/src/tt.cpp +++ b/src/tt.cpp @@ -30,11 +30,16 @@ namespace Stockfish { TranspositionTable TT; // Our global transposition table +bool TranspositionTable::enable_transposition_table = true; + /// TTEntry::save() populates the TTEntry with a new node's data, possibly /// overwriting an old position. Update is not atomic and can be racy. void TTEntry::save(Key k, Value v, bool pv, Bound b, Depth d, Move m, Value ev) { + if (!TranspositionTable::enable_transposition_table) { + return; + } // Preserve any existing move for the same position if (m || (uint16_t)k != key16) move16 = (uint16_t)m; @@ -119,6 +124,11 @@ void TranspositionTable::clear() { TTEntry* TranspositionTable::probe(const Key key, bool& found) const { + if (!enable_transposition_table) { + found = false; + return first_entry(0); + } + TTEntry* const tte = first_entry(key); const uint16_t key16 = (uint16_t)key; // Use the low 16 bits as key inside the cluster diff --git a/src/tt.h b/src/tt.h index d915d92e..b58cda63 100644 --- a/src/tt.h +++ b/src/tt.h @@ -92,6 +92,8 @@ public: return &table[mul_hi64(key, clusterCount)].entry[0]; } + static bool enable_transposition_table; + private: friend struct TTEntry; diff --git a/src/types.h b/src/types.h index 0bd4a1c4..bca02fc8 100644 --- a/src/types.h +++ b/src/types.h @@ -137,6 +137,8 @@ enum Color { WHITE, BLACK, COLOR_NB = 2 }; +constexpr Color Colors[2] = { WHITE, BLACK }; + enum CastlingRights { NO_CASTLING, WHITE_OO, @@ -449,6 +451,11 @@ constexpr Square to_sq(Move m) { return Square(m & 0x3F); } +// Return relative square when turning the board 180 degrees +constexpr Square rotate180(Square sq) { + return (Square)(sq ^ 0x3F); +} + constexpr int from_to(Move m) { return m & 0xFFF; } diff --git a/src/uci.cpp b/src/uci.cpp index b3738a4a..b2382299 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -22,15 +22,23 @@ #include #include +#include "nnue/evaluate_nnue.h" #include "evaluate.h" #include "movegen.h" #include "position.h" #include "search.h" +#include "syzygy/tbprobe.h" #include "thread.h" #include "timeman.h" #include "tt.h" #include "uci.h" -#include "syzygy/tbprobe.h" + +#include "tools/validate_training_data.h" +#include "tools/training_data_generator.h" +#include "tools/training_data_generator_nonpv.h" +#include "tools/convert.h" +#include "tools/transform.h" +#include "tools/stats.h" using namespace std; @@ -40,10 +48,6 @@ extern vector setup_bench(const Position&, istream&); namespace { - // FEN string of the initial position, normal chess - const char* StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"; - - // position() is called when engine receives the "position" UCI command. // The function sets up the position described in the given FEN string ("fen") // or the starting position ("startpos") and then makes the moves given in the @@ -96,7 +100,7 @@ namespace { // setoption() is called when engine receives the "setoption" UCI command. The // function updates the UCI option ("name") to the given value ("value"). - void setoption(istringstream& is) { + void setoption_from_stream(istringstream& is) { string token, name, value; @@ -110,10 +114,7 @@ namespace { while (is >> token) value += (value.empty() ? "" : " ") + token; - if (Options.count(name)) - Options[name] = value; - else - sync_cout << "No such option: " << name << sync_endl; + UCI::setoption(name, value); } @@ -182,7 +183,7 @@ namespace { else trace_eval(pos); } - else if (token == "setoption") setoption(is); + else if (token == "setoption") setoption_from_stream(is); else if (token == "position") position(pos, is, states); else if (token == "ucinewgame") { Search::clear(); elapsed = now(); } // Search::clear() may take some while } @@ -197,6 +198,14 @@ namespace { << "\nNodes/second : " << 1000 * nodes / elapsed << endl; } + void setoption(const std::string& name, const std::string& value) + { + if (Options.count(name)) + Options[name] = value; + else + sync_cout << "No such option: " << name << sync_endl; + } + // The win rate model returns the probability (per mille) of winning given an eval // and a game-ply. The model fits rather accurately the LTC fishtest statistics. int win_rate_model(Value v, int ply) { @@ -215,12 +224,49 @@ namespace { // Transform eval to centipawns with limited range double x = std::clamp(double(100 * v) / PawnValueEg, -2000.0, 2000.0); + // Transform eval to centipawns with limited range + double x = std::clamp(double(100 * v) / PawnValueEg, -1000.0, 1000.0); + // Return win rate in per mille (rounded to nearest) return int(0.5 + 1000 / (1 + std::exp((a - x) / b))); } } // namespace +// -------------------- +// Call qsearch(),search() directly for testing +// -------------------- + +void qsearch_cmd(Position& pos) +{ + cout << "qsearch : "; + auto pv = Search::qsearch(pos); + cout << "Value = " << pv.first << " , " << UCI::value(pv.first) << " , PV = "; + for (auto m : pv.second) + cout << UCI::move(m, false) << " "; + cout << endl; +} + +void search_cmd(Position& pos, istringstream& is) +{ + string token; + int depth = 1; + int multi_pv = (int)Options["MultiPV"]; + while (is >> token) + { + if (token == "depth") + is >> depth; + if (token == "multipv") + is >> multi_pv; + } + + cout << "search depth = " << depth << " , multi_pv = " << multi_pv << " : "; + auto pv = Search::search(pos, depth, multi_pv); + cout << "Value = " << pv.first << " , " << UCI::value(pv.first) << " , PV = "; + for (auto m : pv.second) + cout << UCI::move(m, false) << " "; + cout << endl; +} /// UCI::loop() waits for a command from stdin, parses it and calls the appropriate /// function. Also intercepts EOF from stdin to ensure gracefully exiting if the @@ -264,7 +310,7 @@ void UCI::loop(int argc, char* argv[]) { << "\n" << Options << "\nuciok" << sync_endl; - else if (token == "setoption") setoption(is); + else if (token == "setoption") setoption_from_stream(is); else if (token == "go") go(pos, is, states); else if (token == "position") position(pos, is, states); else if (token == "ucinewgame") Search::clear(); @@ -285,6 +331,25 @@ void UCI::loop(int argc, char* argv[]) { filename = f; Eval::NNUE::save_eval(filename); } + else if (token == "generate_training_data") Tools::generate_training_data(is); + else if (token == "generate_training_data") Tools::generate_training_data_nonpv(is); + else if (token == "convert") Tools::convert(is); + else if (token == "validate_training_data") Tools::validate_training_data(is); + else if (token == "convert_bin") Tools::convert_bin(is); + else if (token == "convert_plain") Tools::convert_plain(is); + else if (token == "convert_bin_from_pgn_extract") Tools::convert_bin_from_pgn_extract(is); + else if (token == "transform") Tools::transform(is); + else if (token == "gather_statistics") Tools::Stats::gather_statistics(is); + + // Command to call qsearch(),search() directly for testing + else if (token == "qsearch") qsearch_cmd(pos); + else if (token == "search") search_cmd(pos, is); + else if (token == "tasktest") + { + Threads.execute_with_workers([](auto& th) { + std::cout << th.id() << '\n'; + }); + } else if (!token.empty() && token[0] != '#') sync_cout << "Unknown command: " << cmd << sync_endl; diff --git a/src/uci.h b/src/uci.h index d3160109..9e4c6e2f 100644 --- a/src/uci.h +++ b/src/uci.h @@ -75,6 +75,7 @@ std::string move(Move m, bool chess960); std::string pv(const Position& pos, Depth depth, Value alpha, Value beta); std::string wdl(Value v, int ply); Move to_move(const Position& pos, std::string& str); +void setoption(const std::string& name, const std::string& value); } // namespace UCI diff --git a/src/ucioption.cpp b/src/ucioption.cpp index 07b3027d..5af78ec4 100644 --- a/src/ucioption.cpp +++ b/src/ucioption.cpp @@ -21,6 +21,7 @@ #include #include +#include "nnue/evaluate_nnue.h" #include "evaluate.h" #include "misc.h" #include "search.h" @@ -45,6 +46,12 @@ void on_threads(const Option& o) { Threads.set(size_t(o)); } void on_tb_path(const Option& o) { Tablebases::init(o); } void on_use_NNUE(const Option& ) { Eval::NNUE::init(); } void on_eval_file(const Option& ) { Eval::NNUE::init(); } +void on_prune_at_shallow_depth(const Option& o) { + Search::prune_at_shallow_depth = o; +} +void on_enable_transposition_table(const Option& o) { + TranspositionTable::enable_transposition_table = o; +} /// Our case insensitive less() function as required by UCI protocol bool CaseInsensitiveLess::operator() (const string& s1, const string& s2) const { @@ -79,8 +86,22 @@ void init(OptionsMap& o) { o["SyzygyProbeDepth"] << Option(1, 1, 100); o["Syzygy50MoveRule"] << Option(true); o["SyzygyProbeLimit"] << Option(7, 0, 7); - o["Use NNUE"] << Option(true, on_use_NNUE); + o["Use NNUE"] << Option("true var true var false var pure", "true", on_use_NNUE); o["EvalFile"] << Option(EvalFileDefaultName, on_eval_file); + // When the evaluation function is loaded at the ucinewgame timing, it is necessary to convert the new evaluation function. + // I want to hit the test eval convert command, but there is no new evaluation function + // It ends abnormally before executing this command. + // Therefore, with this hidden option, you can suppress the loading of the evaluation function when ucinewgame, + // Hit the test eval convert command. + o["SkipLoadingEval"] << Option(false); + // When learning the evaluation function, you can change the folder to save the evaluation function. + // Evalsave by default. This folder shall be prepared in advance. + // Automatically create a folder under this folder like "0/", "1/", ... and save the evaluation function file there. + o["EvalSaveDir"] << Option("evalsave"); + // Prune at shallow depth on PV nodes. False is recommended when using fixed depth search. + o["PruneAtShallowDepth"] << Option(true, on_prune_at_shallow_depth); + // Enable transposition table. + o["EnableTranspositionTable"] << Option(true, on_enable_transposition_table); } @@ -134,7 +155,7 @@ Option::operator double() const { } Option::operator std::string() const { - assert(type == "string"); + assert(type == "check" || type == "spin" || type == "combo" || type == "button" || type == "string"); return currentValue; } diff --git a/tests/instrumented.sh b/tests/instrumented.sh index e9455eab..62e31d46 100755 --- a/tests/instrumented.sh +++ b/tests/instrumented.sh @@ -16,6 +16,9 @@ case $1 in exeprefix='valgrind --error-exitcode=42 --errors-for-leak-kinds=all --leak-check=full' postfix='1>/dev/null' threads="1" + bench_depth=5 + go_depth=10 + tt_size=16 ;; --valgrind-thread) echo "valgrind-thread testing started" @@ -23,6 +26,9 @@ case $1 in exeprefix='valgrind --fair-sched=try --error-exitcode=42' postfix='1>/dev/null' threads="2" + bench_depth=5 + go_depth=10 + tt_size=16 ;; --sanitizer-undefined) echo "sanitizer-undefined testing started" @@ -30,6 +36,9 @@ case $1 in exeprefix='' postfix='2>&1 | grep -A50 "runtime error:"' threads="1" + bench_depth=8 + go_depth=20 + tt_size=128 ;; --sanitizer-thread) echo "sanitizer-thread testing started" @@ -37,6 +46,9 @@ case $1 in exeprefix='' postfix='2>&1 | grep -A50 "WARNING: ThreadSanitizer:"' threads="2" + bench_depth=8 + go_depth=20 + tt_size=128 cat << EOF > tsan.supp race:Stockfish::TTEntry::move @@ -70,7 +82,7 @@ for args in "eval" \ "go depth 10" \ "go movetime 1000" \ "go wtime 8000 btime 8000 winc 500 binc 500" \ - "bench 128 $threads 8 default depth" + "bench $tt_size $threads $bench_depth default depth" do echo "$prefix $exeprefix ./stockfish $args $postfix" @@ -121,7 +133,7 @@ cat << EOF > syzygy.exp send "uci\n" send "setoption name SyzygyPath value ../tests/syzygy/\n" expect "info string Found 35 tablebases" {} timeout {exit 1} - send "bench 128 1 8 default depth\n" + send "bench $tt_size 1 $bench_depth default depth\n" send "quit\n" expect eof @@ -130,7 +142,57 @@ cat << EOF > syzygy.exp exit \$value EOF -for exp in game.exp syzygy.exp +# generate_training_data testing 01 +cat << EOF > data_generation01.exp + set timeout 240 + spawn $exeprefix ./stockfish + + send "uci\n" + expect "uciok" + + send "setoption name Threads value $threads\n" + send "setoption name Use NNUE value false\n" + send "isready\n" + send "generate_training_data depth 3 count 100 keep_draws 1 eval_limit 32000 output_file_name training_data/training_data.bin output_format bin\n" + expect "INFO: Gensfen finished." + send "convert_plain targetfile training_data/training_data.bin output_file_name training_data.txt\n" + expect "all done" + send "generate_training_data depth 3 count 100 keep_draws 1 eval_limit 32000 output_file_name training_data/training_data.binpack output_format binpack\n" + expect "INFO: Gensfen finished." + + send "quit\n" + expect eof + + # return error code of the spawned program, useful for valgrind + lassign [wait] pid spawnid os_error_flag value + exit \$value +EOF + +# generate_training_data testing 02 +cat << EOF > data_generation02.exp + set timeout 240 + spawn $exeprefix ./stockfish + + send "uci\n" + expect "uciok" + + send "setoption name Threads value $threads\n" + send "setoption name Use NNUE value true\n" + send "isready\n" + send "generate_training_data depth 4 count 50 keep_draws 1 eval_limit 32000 output_file_name validation_data/validation_data.bin output_format bin\n" + expect "INFO: Gensfen finished." + send "generate_training_data depth 4 count 50 keep_draws 1 eval_limit 32000 output_file_name validation_data/validation_data.binpack output_format binpack\n" + expect "INFO: Gensfen finished." + + send "quit\n" + expect eof + + # return error code of the spawned program, useful for valgrind + lassign [wait] pid spawnid os_error_flag value + exit \$value +EOF + +for exp in game.exp data_generation01.exe data_generation02.exp do echo "$prefix expect $exp $postfix"