#ifndef _NNUE_TRAINER_FEATURE_TRANSFORMER_H_ #define _NNUE_TRAINER_FEATURE_TRANSFORMER_H_ #include "trainer.h" #include "extra/stockfish_blas.h" #include "features/all_factorizers.h" #include "learn/learn.h" #include "nnue/nnue_feature_transformer.h" #include "thread.h" #include #include #include #include #include // Specialization for feature transformer of learning class template of NNUE evaluation function namespace Eval::NNUE { // Learning: Input feature converter template <> class Trainer { private: // Type of layer to learn using LayerType = FeatureTransformer; public: template friend struct AlignedDeleter; template friend std::shared_ptr make_aligned_shared_ptr(ArgumentTypes&&... arguments); // factory function static std::shared_ptr create(LayerType* target_layer) { return make_aligned_shared_ptr(target_layer); } // Set options such as hyperparameters void send_message(Message* message) { if (receive_message("momentum", message)) { momentum_ = static_cast(std::stod(message->value)); } if (receive_message("learning_rate_scale", message)) { learning_rate_scale_ = static_cast(std::stod(message->value)); } if (receive_message("reset", message)) { dequantize_parameters(); } if (receive_message("quantize_parameters", message)) { quantize_parameters(); } if (receive_message("clear_unobserved_feature_weights", message)) { clear_unobserved_feature_weights(); } if (receive_message("check_health", message)) { check_health(); } } // Initialize the parameters with random numbers template void initialize(RNG& rng) { std::fill(std::begin(weights_), std::end(weights_), +kZero); const double kSigma = 0.1 / std::sqrt(RawFeatures::kMaxActiveDimensions); auto distribution = std::normal_distribution(0.0, kSigma); for (IndexType i = 0; i < kHalfDimensions * RawFeatures::kDimensions; ++i) { const auto weight = static_cast(distribution(rng)); weights_[i] = weight; } for (IndexType i = 0; i < kHalfDimensions; ++i) { biases_[i] = static_cast(0.5); } quantize_parameters(); } const LearnFloatType* step_start(ThreadPool& thread_pool, const std::vector& combined_batch) { if (output_.size() < kOutputDimensions * combined_batch.size()) { output_.resize(kOutputDimensions * combined_batch.size()); gradients_.resize(kOutputDimensions * combined_batch.size()); } if (thread_stat_states_.size() < thread_pool.size()) { thread_stat_states_.resize(thread_pool.size()); } if (thread_bias_states_.size() < thread_pool.size()) { thread_bias_states_.resize(thread_pool.size()); } batch_ = &combined_batch; auto& main_thread_bias_state = thread_bias_states_[0]; #if defined(USE_BLAS) cblas_sscal( kHalfDimensions, momentum_, main_thread_bias_state.biases_diff_, 1 ); #else Blas::sscal( kHalfDimensions, momentum_, main_thread_bias_state.biases_diff_, 1 ); #endif for (IndexType i = 1; i < thread_bias_states_.size(); ++i) thread_bias_states_[i].reset(); return output_.data(); } // forward propagation void propagate(Thread& th, uint64_t offset, uint64_t count) { auto& thread_stat_state = thread_stat_states_[th.thread_idx()]; for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType c = 0; c < 2; ++c) { const IndexType output_offset = batch_offset + kHalfDimensions * c; #if defined(USE_BLAS) cblas_scopy( kHalfDimensions, biases_, 1, &output_[output_offset], 1 ); for (const auto& feature : (*batch_)[b].training_features[c]) { const IndexType weights_offset = kHalfDimensions * feature.get_index(); cblas_saxpy( kHalfDimensions, (float)feature.get_count(), &weights_[weights_offset], 1, &output_[output_offset], 1 ); } #else Blas::scopy( kHalfDimensions, biases_, 1, &output_[output_offset], 1 ); for (const auto& feature : (*batch_)[b].training_features[c]) { const IndexType weights_offset = kHalfDimensions * feature.get_index(); Blas::saxpy( kHalfDimensions, (float)feature.get_count(), &weights_[weights_offset], 1, &output_[output_offset], 1 ); } #endif } } #if defined (USE_SSE2) { static_assert(kHalfDimensions % 16 == 0, "This implementation assumes that it can process 16 floats at a time"); auto m128_hmin_ps = [](__m128 x3210) { __m128 x0032 = _mm_shuffle_ps(x3210, x3210, _MM_SHUFFLE(0, 0, 3, 2)); __m128 min_x_x_13_20 = _mm_min_ps(x3210, x0032); // a = [ # , # , min(x[1], x[3]) , min(x[2], x[0]) ] __m128 min_x_x_20_13 = _mm_shuffle_ps(min_x_x_13_20, min_x_x_13_20, _MM_SHUFFLE(0, 0, 0, 1)); return _mm_cvtss_f32(_mm_min_ps(min_x_x_13_20, min_x_x_20_13)); }; auto m128_hmax_ps = [](__m128 x3210) { __m128 x0032 = _mm_shuffle_ps(x3210, x3210, _MM_SHUFFLE(0, 0, 3, 2)); __m128 max_x_x_13_20 = _mm_max_ps(x3210, x0032); // a = [ # , # , max(x[1], x[3]) , max(x[2], x[0]) ] __m128 max_x_x_20_13 = _mm_shuffle_ps(max_x_x_13_20, max_x_x_13_20, _MM_SHUFFLE(0, 0, 0, 1)); return _mm_cvtss_f32(_mm_max_ps(max_x_x_13_20, max_x_x_20_13)); }; const __m128 kZero4 = _mm_set1_ps(+kZero); const __m128 kOne4 = _mm_set1_ps(+kOne); __m128 min_pre_activation0 = _mm_set1_ps(thread_stat_state.min_pre_activation_); __m128 min_pre_activation1 = _mm_set1_ps(thread_stat_state.min_pre_activation_); __m128 max_pre_activation0 = _mm_set1_ps(thread_stat_state.max_pre_activation_); __m128 max_pre_activation1 = _mm_set1_ps(thread_stat_state.max_pre_activation_); for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; i += 16) { __m128 out0 = _mm_loadu_ps(&output_[batch_offset + i + 0]); __m128 out1 = _mm_loadu_ps(&output_[batch_offset + i + 4]); __m128 out2 = _mm_loadu_ps(&output_[batch_offset + i + 8]); __m128 out3 = _mm_loadu_ps(&output_[batch_offset + i + 12]); __m128 min01 = _mm_min_ps(out0, out1); __m128 min23 = _mm_min_ps(out2, out3); __m128 max01 = _mm_max_ps(out0, out1); __m128 max23 = _mm_max_ps(out2, out3); min_pre_activation0 = _mm_min_ps(min_pre_activation0, min01); min_pre_activation1 = _mm_min_ps(min_pre_activation1, min23); max_pre_activation0 = _mm_max_ps(max_pre_activation0, max01); max_pre_activation1 = _mm_max_ps(max_pre_activation1, max23); out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0)); out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1)); out2 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out2)); out3 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out3)); _mm_storeu_ps(&output_[batch_offset + i + 0], out0); _mm_storeu_ps(&output_[batch_offset + i + 4], out1); _mm_storeu_ps(&output_[batch_offset + i + 8], out2); _mm_storeu_ps(&output_[batch_offset + i + 12], out3); } } thread_stat_state.min_pre_activation_ = m128_hmin_ps(_mm_min_ps(min_pre_activation0, min_pre_activation1)); thread_stat_state.max_pre_activation_ = m128_hmax_ps(_mm_max_ps(max_pre_activation0, max_pre_activation1)); for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType half = 0; half < 2; ++half) { const IndexType half_offset = batch_offset + half * kHalfDimensions; for (IndexType i = 0; i < kHalfDimensions; i += 16) { const __m128 out0 = _mm_loadu_ps(&output_[i + 0 + half_offset]); const __m128 out1 = _mm_loadu_ps(&output_[i + 4 + half_offset]); const __m128 out2 = _mm_loadu_ps(&output_[i + 8 + half_offset]); const __m128 out3 = _mm_loadu_ps(&output_[i + 12 + half_offset]); __m128 minact0 = _mm_loadu_ps(&thread_stat_state.min_activations_[i + 0]); __m128 minact1 = _mm_loadu_ps(&thread_stat_state.min_activations_[i + 4]); __m128 minact2 = _mm_loadu_ps(&thread_stat_state.min_activations_[i + 8]); __m128 minact3 = _mm_loadu_ps(&thread_stat_state.min_activations_[i + 12]); __m128 maxact0 = _mm_loadu_ps(&thread_stat_state.max_activations_[i + 0]); __m128 maxact1 = _mm_loadu_ps(&thread_stat_state.max_activations_[i + 4]); __m128 maxact2 = _mm_loadu_ps(&thread_stat_state.max_activations_[i + 8]); __m128 maxact3 = _mm_loadu_ps(&thread_stat_state.max_activations_[i + 12]); minact0 = _mm_min_ps(out0, minact0); minact1 = _mm_min_ps(out1, minact1); minact2 = _mm_min_ps(out2, minact2); minact3 = _mm_min_ps(out3, minact3); maxact0 = _mm_max_ps(out0, maxact0); maxact1 = _mm_max_ps(out1, maxact1); maxact2 = _mm_max_ps(out2, maxact2); maxact3 = _mm_max_ps(out3, maxact3); _mm_storeu_ps(&thread_stat_state.min_activations_[i + 0], minact0); _mm_storeu_ps(&thread_stat_state.min_activations_[i + 4], minact1); _mm_storeu_ps(&thread_stat_state.min_activations_[i + 8], minact2); _mm_storeu_ps(&thread_stat_state.min_activations_[i + 12], minact3); _mm_storeu_ps(&thread_stat_state.max_activations_[i + 0], maxact0); _mm_storeu_ps(&thread_stat_state.max_activations_[i + 4], maxact1); _mm_storeu_ps(&thread_stat_state.max_activations_[i + 8], maxact2); _mm_storeu_ps(&thread_stat_state.max_activations_[i + 12], maxact3); } } } } #else // clipped ReLU for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; ++i) { const IndexType index = batch_offset + i; thread_stat_state.min_pre_activation_ = std::min(thread_stat_state.min_pre_activation_, output_[index]); thread_stat_state.max_pre_activation_ = std::max(thread_stat_state.max_pre_activation_, output_[index]); output_[index] = std::max(+kZero, std::min(+kOne, output_[index])); const IndexType t = i % kHalfDimensions; thread_stat_state.min_activations_[t] = std::min(thread_stat_state.min_activations_[t], output_[index]); thread_stat_state.max_activations_[t] = std::max(thread_stat_state.max_activations_[t], output_[index]); } } #endif } // backpropagation void backpropagate(Thread& th, const LearnFloatType* gradients, uint64_t offset, uint64_t count) { auto& thread_stat_state = thread_stat_states_[th.thread_idx()]; auto& thread_bias_state = thread_bias_states_[th.thread_idx()]; #if defined (USE_SSE2) { static_assert(kHalfDimensions % 16 == 0, "This implementation assumes that it can process 16 floats at a time"); const __m128 kZero4 = _mm_set1_ps(+kZero); const __m128 kOne4 = _mm_set1_ps(+kOne); for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; i += 16) { __m128 out0 = _mm_loadu_ps(&output_[batch_offset + i + 0]); __m128 out1 = _mm_loadu_ps(&output_[batch_offset + i + 4]); __m128 out2 = _mm_loadu_ps(&output_[batch_offset + i + 8]); __m128 out3 = _mm_loadu_ps(&output_[batch_offset + i + 12]); __m128 clipped0 = _mm_or_ps(_mm_cmple_ps(out0, kZero4), _mm_cmpge_ps(out0, kOne4)); __m128 clipped1 = _mm_or_ps(_mm_cmple_ps(out1, kZero4), _mm_cmpge_ps(out1, kOne4)); __m128 clipped2 = _mm_or_ps(_mm_cmple_ps(out2, kZero4), _mm_cmpge_ps(out2, kOne4)); __m128 clipped3 = _mm_or_ps(_mm_cmple_ps(out3, kZero4), _mm_cmpge_ps(out3, kOne4)); __m128 grad0 = _mm_loadu_ps(&gradients[batch_offset + i + 0]); __m128 grad1 = _mm_loadu_ps(&gradients[batch_offset + i + 4]); __m128 grad2 = _mm_loadu_ps(&gradients[batch_offset + i + 8]); __m128 grad3 = _mm_loadu_ps(&gradients[batch_offset + i + 12]); grad0 = _mm_andnot_ps(clipped0, grad0); grad1 = _mm_andnot_ps(clipped1, grad1); grad2 = _mm_andnot_ps(clipped2, grad2); grad3 = _mm_andnot_ps(clipped3, grad3); _mm_storeu_ps(&gradients_[batch_offset + i + 0], grad0); _mm_storeu_ps(&gradients_[batch_offset + i + 4], grad1); _mm_storeu_ps(&gradients_[batch_offset + i + 8], grad2); _mm_storeu_ps(&gradients_[batch_offset + i + 12], grad3); const int clipped_mask = (_mm_movemask_ps(clipped0) << 0) | (_mm_movemask_ps(clipped1) << 4) | (_mm_movemask_ps(clipped2) << 8) | (_mm_movemask_ps(clipped3) << 12); thread_stat_state.num_clipped_ += popcount(clipped_mask); } } } #else for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType i = 0; i < kOutputDimensions; ++i) { const IndexType index = batch_offset + i; const bool clipped = (output_[index] <= kZero) | (output_[index] >= kOne); gradients_[index] = gradients[index] * !clipped; thread_stat_state.num_clipped_ += clipped; } } #endif thread_stat_state.num_total_ += count * kOutputDimensions; #if defined(USE_BLAS) for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType c = 0; c < 2; ++c) { const IndexType output_offset = batch_offset + kHalfDimensions * c; cblas_saxpy( kHalfDimensions, 1.0, &gradients_[output_offset], 1, thread_bias_state.biases_diff_, 1 ); } } #else for (IndexType b = offset; b < offset + count; ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType c = 0; c < 2; ++c) { const IndexType output_offset = batch_offset + kHalfDimensions * c; Blas::saxpy( kHalfDimensions, 1.0, &gradients_[output_offset], 1, thread_bias_state.biases_diff_, 1 ); } } #endif } void reduce_thread_stat_state() { for (IndexType i = 1; i < thread_stat_states_.size(); ++i) { thread_stat_states_[0] += thread_stat_states_[i]; } } void reduce_thread_bias_state() { for (IndexType i = 1; i < thread_bias_states_.size(); ++i) { thread_bias_states_[0] += thread_bias_states_[i]; } } void step_end(ThreadPool& thread_pool, LearnFloatType learning_rate) { const LearnFloatType local_learning_rate = learning_rate * learning_rate_scale_; // Since the weight matrix updates only the columns corresponding to the features that appeared in the input, // Correct the learning rate and adjust the scale without using momentum const LearnFloatType effective_learning_rate = static_cast(local_learning_rate / (1.0 - momentum_)); reduce_thread_bias_state(); auto& main_thread_state = thread_bias_states_[0]; #if defined(USE_BLAS) cblas_saxpy( kHalfDimensions, -local_learning_rate, main_thread_state.biases_diff_, 1, biases_, 1 ); #else Blas::saxpy( kHalfDimensions, -local_learning_rate, main_thread_state.biases_diff_, 1, biases_, 1 ); #endif thread_pool.execute_with_workers( [&, num_threads = thread_pool.size()](Thread& th) { const auto thread_index = th.thread_idx(); for (IndexType b = 0; b < batch_->size(); ++b) { const IndexType batch_offset = kOutputDimensions * b; for (IndexType c = 0; c < 2; ++c) { const IndexType output_offset = batch_offset + kHalfDimensions * c; for (const auto& feature : (*batch_)[b].training_features[c]) { const IndexType feature_index = feature.get_index(); // We assign each bucket a continuous range of bits at least // of cache line size to prevent false sharing. // For HalfKP this is enough to saturate about 80 threads. const IndexType thread_bucket = (feature_index / BitsetType::best_concurrent_access_stride) % num_threads; if (thread_bucket != thread_index) continue; // This operation can be performed safely because // each thread accesses a different memory location // (even a different cache line) observed_features.set(feature_index); const IndexType weights_offset = kHalfDimensions * feature_index; const auto scale = static_cast( effective_learning_rate / feature.get_count()); #if defined (USE_BLAS) cblas_saxpy( kHalfDimensions, -scale, &gradients_[output_offset], 1, &weights_[weights_offset], 1 ); #else Blas::saxpy( kHalfDimensions, -scale, &gradients_[output_offset], 1, &weights_[weights_offset], 1 ); #endif } } } } ); thread_pool.wait_for_workers_finished(); } private: // constructor Trainer(LayerType* target_layer) : batch_(nullptr), target_layer_(target_layer), biases_(), weights_(), momentum_(0.2), learning_rate_scale_(1.0) { dequantize_parameters(); } // Weight saturation and parameterization void quantize_parameters() { for (IndexType i = 0; i < kHalfDimensions; ++i) { target_layer_->biases_[i] = round(biases_[i] * kBiasScale); } std::vector training_features; Threads.for_each_index_with_workers( 0, RawFeatures::kDimensions, [this, training_features](Thread&, int j) mutable { training_features.clear(); Features::Factorizer::append_training_features( j, &training_features); for (IndexType i = 0; i < kHalfDimensions; ++i) { double sum = 0.0; for (const auto& feature : training_features) { sum += weights_[kHalfDimensions * feature.get_index() + i]; } target_layer_->weights_[kHalfDimensions * j + i] = round(sum * kWeightScale); } } ); Threads.wait_for_workers_finished(); } void reset_stats() { for (auto& state : thread_stat_states_) state.reset(); } // read parameterized integer void dequantize_parameters() { for (IndexType i = 0; i < kHalfDimensions; ++i) { biases_[i] = static_cast( target_layer_->biases_[i] / kBiasScale); } std::fill(std::begin(weights_), std::end(weights_), +kZero); for (IndexType i = 0; i < kHalfDimensions * RawFeatures::kDimensions; ++i) { weights_[i] = static_cast( target_layer_->weights_[i] / kWeightScale); } reset_stats(); for (auto& state : thread_bias_states_) state.reset(); } // Set the weight corresponding to the feature that does not appear in the learning data to 0 void clear_unobserved_feature_weights() { for (IndexType i = 0; i < kInputDimensions; ++i) { if (!observed_features.test(i)) { std::fill(std::begin(weights_) + kHalfDimensions * i, std::begin(weights_) + kHalfDimensions * (i + 1), +kZero); } } quantize_parameters(); } // Check if there are any problems with learning void check_health() { constexpr LearnFloatType kPreActivationLimit = std::numeric_limits::max() / kWeightScale; reduce_thread_stat_state(); auto& main_thread_state = thread_stat_states_[0]; const auto largest_min_activation = *std::max_element( std::begin(main_thread_state.min_activations_), std::end(main_thread_state.min_activations_)); const auto smallest_max_activation = *std::min_element( std::begin(main_thread_state.max_activations_), std::end(main_thread_state.max_activations_)); double abs_bias_sum = 0.0; double abs_weight_sum = 0.0; for(auto b : biases_) abs_bias_sum += std::abs(b); for(auto w : weights_) abs_weight_sum += std::abs(w); auto out = sync_region_cout.new_region(); out << "INFO (check_health):" << " layer " << LayerType::kLayerIndex << " - " << LayerType::get_name() << std::endl; out << " - observed " << observed_features.count() << " (out of " << kInputDimensions << ") features" << std::endl; out << " - (min, max) of pre-activations = " << main_thread_state.min_pre_activation_ << ", " << main_thread_state.max_pre_activation_ << " (limit = " << kPreActivationLimit << ")" << std::endl; out << " - largest min activation = " << largest_min_activation << " , smallest max activation = " << smallest_max_activation << std::endl; out << " - avg_abs_bias = " << abs_bias_sum / std::size(biases_) << std::endl; out << " - avg_abs_weight = " << abs_weight_sum / std::size(weights_) << std::endl; out << " - clipped " << static_cast(main_thread_state.num_clipped_) / main_thread_state.num_total_ * 100.0 << "% of outputs" << std::endl; out.unlock(); reset_stats(); } // number of input/output dimensions static constexpr IndexType kInputDimensions = Features::Factorizer::get_dimensions(); static constexpr IndexType kOutputDimensions = LayerType::kOutputDimensions; static constexpr IndexType kHalfDimensions = LayerType::kHalfDimensions; // Coefficient used for parameterization static constexpr LearnFloatType kActivationScale = std::numeric_limits::max(); static constexpr LearnFloatType kBiasScale = kActivationScale; static constexpr LearnFloatType kWeightScale = kActivationScale; // LearnFloatType constant static constexpr LearnFloatType kZero = static_cast(0.0); static constexpr LearnFloatType kOne = static_cast(1.0); // mini batch const std::vector* batch_; // layer to learn LayerType* const target_layer_; IndexType num_total_; // parameter alignas(kCacheLineSize) LearnFloatType biases_[kHalfDimensions]; alignas(kCacheLineSize) LearnFloatType weights_[kHalfDimensions * kInputDimensions]; // Buffer used for updating parameters std::vector> gradients_; // Forward propagation buffer std::vector> output_; // Features that appeared in the training data using BitsetType = LargeBitset; BitsetType observed_features; // hyper parameter LearnFloatType momentum_; LearnFloatType learning_rate_scale_; struct alignas(kCacheLineSize) ThreadStatState { alignas(kCacheLineSize) LearnFloatType min_activations_[kHalfDimensions]; alignas(kCacheLineSize) LearnFloatType max_activations_[kHalfDimensions]; LearnFloatType min_pre_activation_; LearnFloatType max_pre_activation_; IndexType num_clipped_; IndexType num_total_; ThreadStatState() { reset(); } ThreadStatState& operator+=(const ThreadStatState& other) { for (IndexType i = 0; i < kHalfDimensions; ++i) { min_activations_[i] = std::min(min_activations_[i], other.min_activations_[i]); } for (IndexType i = 0; i < kHalfDimensions; ++i) { max_activations_[i] = std::max(max_activations_[i], other.max_activations_[i]); } min_pre_activation_ = std::min(min_pre_activation_, other.min_pre_activation_); max_pre_activation_ = std::max(max_pre_activation_, other.max_pre_activation_); num_clipped_ += other.num_clipped_; num_total_ += other.num_total_; return *this; } void reset() { std::fill(std::begin(min_activations_), std::end(min_activations_), std::numeric_limits::max()); std::fill(std::begin(max_activations_), std::end(max_activations_), std::numeric_limits::lowest()); min_pre_activation_ = std::numeric_limits::max(); max_pre_activation_ = std::numeric_limits::lowest(); num_clipped_ = 0; num_total_ = 0; } }; struct alignas(kCacheLineSize) ThreadBiasState { alignas(kCacheLineSize) LearnFloatType biases_diff_[kHalfDimensions]; ThreadBiasState() { reset(); } ThreadBiasState& operator+=(const ThreadBiasState& other) { for (IndexType i = 0; i < kHalfDimensions; ++i) { biases_diff_[i] += other.biases_diff_[i]; } return *this; } void reset() { std::fill(std::begin(biases_diff_), std::end(biases_diff_), 0.0f); } }; std::vector> thread_stat_states_; std::vector> thread_bias_states_; }; } // namespace Eval::NNUE #endif