Prepare trainer affine transform.

This commit is contained in:
Tomasz Sobczyk
2020-11-22 19:23:30 +01:00
committed by nodchip
parent 4ea8572b6d
commit 0d4b803b08
+142 -73
View File
@@ -91,19 +91,52 @@ namespace Eval::NNUE {
quantize_parameters(); quantize_parameters();
} }
// forward propagation const LearnFloatType* step_start(ThreadPool& thread_pool, const std::vector<Example>& combined_batch)
const LearnFloatType* propagate(ThreadPool& thread_pool, const std::vector<Example>& batch) { {
if (output_.size() < kOutputDimensions * batch.size()) { if (output_.size() < kOutputDimensions * combined_batch.size()) {
output_.resize(kOutputDimensions * batch.size()); output_.resize(kOutputDimensions * combined_batch.size());
gradients_.resize(kInputDimensions * batch.size()); gradients_.resize(kInputDimensions * combined_batch.size());
} }
batch_size_ = static_cast<IndexType>(batch.size()); if (thread_states_.size() < thread_pool.size())
batch_input_ = previous_layer_trainer_->propagate(thread_pool, batch); {
thread_states_.resize(thread_pool.size());
}
combined_batch_size_ = static_cast<IndexType>(combined_batch.size());
combined_batch_input_ = previous_layer_trainer_->step_start(thread_pool, combined_batch);
auto& main_thread_state = thread_states_[0];
#if defined(USE_BLAS) #if defined(USE_BLAS)
for (IndexType b = 0; b < batch_size_; ++b) { // update
cblas_sscal(
kOutputDimensions, momentum_, main_thread_state.biases_diff_, 1
);
#else
Blas::sscal(
kOutputDimensions, momentum_, main_thread_state.biases_diff_, 1
);
#endif
for (IndexType i = 1; i < thread_states_.size(); ++i)
thread_states_[i].reset_biases();
return output_.data();
}
// forward propagation
void propagate(Thread& th, const uint64_t offset, const uint64_t count) {
previous_layer_trainer_->propagate(th, offset, count);
#if defined(USE_BLAS)
for (IndexType b = offset; b < offset + count; ++b) {
const IndexType batch_offset = kOutputDimensions * b; const IndexType batch_offset = kOutputDimensions * b;
cblas_scopy( cblas_scopy(
kOutputDimensions, biases_, 1, &output_[batch_offset], 1 kOutputDimensions, biases_, 1, &output_[batch_offset], 1
@@ -112,149 +145,151 @@ namespace Eval::NNUE {
cblas_sgemm( cblas_sgemm(
CblasColMajor, CblasTrans, CblasNoTrans, CblasColMajor, CblasTrans, CblasNoTrans,
kOutputDimensions, batch_size_, kInputDimensions, kOutputDimensions, count, kInputDimensions,
1.0, 1.0,
weights_, kInputDimensions, weights_, kInputDimensions,
batch_input_, kInputDimensions, combined_batch_input_ + offset * kInputDimensions, kInputDimensions,
1.0, 1.0,
&output_[0], kOutputDimensions &output_[offset * kOutputDimensions], kOutputDimensions
); );
#else #else
for (IndexType b = 0; b < batch_size_; ++b) { for (IndexType b = offset; b < offset + count; ++b) {
const IndexType batch_offset = kOutputDimensions * b; const IndexType batch_offset = kOutputDimensions * b;
Blas::scopy( Blas::scopy(
thread_pool,
kOutputDimensions, biases_, 1, &output_[batch_offset], 1 kOutputDimensions, biases_, 1, &output_[batch_offset], 1
); );
} }
Blas::sgemm( Blas::sgemm(
thread_pool,
Blas::MatrixLayout::ColMajor, Blas::MatrixTranspose::Trans, Blas::MatrixTranspose::NoTrans, Blas::MatrixLayout::ColMajor, Blas::MatrixTranspose::Trans, Blas::MatrixTranspose::NoTrans,
kOutputDimensions, batch_size_, kInputDimensions, kOutputDimensions, count, kInputDimensions,
1.0, 1.0,
weights_, kInputDimensions, weights_, kInputDimensions,
batch_input_, kInputDimensions, combined_batch_input_ + offset * kInputDimensions, kInputDimensions,
1.0, 1.0,
&output_[0], kOutputDimensions &output_[offset * kOutputDimensions], kOutputDimensions
); );
#endif #endif
return output_.data();
} }
// backpropagation // backpropagation
void backpropagate(ThreadPool& thread_pool, void backpropagate(Thread& th,
const LearnFloatType* gradients, const LearnFloatType* gradients,
LearnFloatType learning_rate) { uint64_t offset,
uint64_t count) {
const LearnFloatType local_learning_rate =
learning_rate * learning_rate_scale_;
auto& thread_state = thread_states_[th.thread_idx()];
const auto momentum = th.thread_idx() == 0 ? momentum_ : 0.0f;
#if defined(USE_BLAS) #if defined(USE_BLAS)
cblas_sgemm( cblas_sgemm(
CblasColMajor, CblasNoTrans, CblasNoTrans, CblasColMajor, CblasNoTrans, CblasNoTrans,
kInputDimensions, batch_size_, kOutputDimensions, kInputDimensions, count, kOutputDimensions,
1.0, 1.0,
weights_, kInputDimensions, weights_, kInputDimensions,
gradients, kOutputDimensions, gradients + offset * kOutputDimensions, kOutputDimensions,
0.0, 0.0,
&gradients_[0], kInputDimensions &gradients_[offset * kInputDimensions], kInputDimensions
); );
// update for (IndexType b = offset; b < offset + count; ++b) {
cblas_sscal(
kOutputDimensions, momentum_, biases_diff_, 1
);
for (IndexType b = 0; b < batch_size_; ++b) {
const IndexType batch_offset = kOutputDimensions * b; const IndexType batch_offset = kOutputDimensions * b;
cblas_saxpy( cblas_saxpy(
kOutputDimensions, 1.0, kOutputDimensions, 1.0,
&gradients[batch_offset], 1, biases_diff_, 1 &gradients[batch_offset], 1, thread_state.biases_diff_, 1
); );
} }
cblas_sgemm( cblas_sgemm(
CblasRowMajor, CblasTrans, CblasNoTrans, CblasRowMajor, CblasTrans, CblasNoTrans,
kOutputDimensions, kInputDimensions, batch_size_, kOutputDimensions, kInputDimensions, count,
1.0, 1.0,
gradients, kOutputDimensions, gradients + offset * kOutputDimensions, kOutputDimensions,
batch_input_, kInputDimensions, combined_batch_input_ + offset * kInputDimensions, kInputDimensions,
momentum_, momentum,
weights_diff_, kInputDimensions thread_state.weights_diff_, kInputDimensions
); );
#else #else
// backpropagate // backpropagate
Blas::sgemm( Blas::sgemm(
thread_pool,
Blas::MatrixLayout::ColMajor, Blas::MatrixTranspose::NoTrans, Blas::MatrixTranspose::NoTrans, Blas::MatrixLayout::ColMajor, Blas::MatrixTranspose::NoTrans, Blas::MatrixTranspose::NoTrans,
kInputDimensions, batch_size_, kOutputDimensions, kInputDimensions, count, kOutputDimensions,
1.0, 1.0,
weights_, kInputDimensions, weights_, kInputDimensions,
gradients, kOutputDimensions, gradients + offset * kOutputDimensions, kOutputDimensions,
0.0, 0.0,
&gradients_[0], kInputDimensions &gradients_[offset * kInputDimensions], kInputDimensions
); );
for (IndexType b = offset; b < offset + count; ++b) {
Blas::sscal(
thread_pool,
kOutputDimensions, momentum_, biases_diff_, 1
);
for (IndexType b = 0; b < batch_size_; ++b) {
const IndexType batch_offset = kOutputDimensions * b; const IndexType batch_offset = kOutputDimensions * b;
Blas::saxpy(thread_pool, kOutputDimensions, 1.0, Blas::saxpy(kOutputDimensions, 1.0,
&gradients[batch_offset], 1, biases_diff_, 1); &gradients[batch_offset], 1, thread_state.biases_diff_, 1);
} }
Blas::sgemm( Blas::sgemm(
thread_pool,
Blas::MatrixLayout::RowMajor, Blas::MatrixTranspose::Trans, Blas::MatrixTranspose::NoTrans, Blas::MatrixLayout::RowMajor, Blas::MatrixTranspose::Trans, Blas::MatrixTranspose::NoTrans,
kOutputDimensions, kInputDimensions, batch_size_, kOutputDimensions, kInputDimensions, count,
1.0, 1.0,
gradients, kOutputDimensions, gradients + offset * kOutputDimensions, kOutputDimensions,
batch_input_, kInputDimensions, combined_batch_input_ + offset * kInputDimensions, kInputDimensions,
momentum_, momentum,
weights_diff_, kInputDimensions thread_state.weights_diff_, kInputDimensions
); );
#endif #endif
previous_layer_trainer_->backpropagate(th, gradients_.data(), offset, count);
}
void reduce_thread_state()
{
for (IndexType i = 1; i < thread_states_.size(); ++i)
{
thread_states_[0] += thread_states_[i];
}
}
void step_end(ThreadPool& thread_pool, LearnFloatType learning_rate)
{
const LearnFloatType local_learning_rate =
learning_rate * learning_rate_scale_;
reduce_thread_state();
auto& main_thread_state = thread_states_[0];
for (IndexType i = 0; i < kOutputDimensions; ++i) { for (IndexType i = 0; i < kOutputDimensions; ++i) {
const double d = local_learning_rate * biases_diff_[i]; const double d = local_learning_rate * main_thread_state.biases_diff_[i];
biases_[i] -= d; biases_[i] -= d;
abs_biases_diff_sum_ += std::abs(d); abs_biases_diff_sum_ += std::abs(d);
} }
num_biases_diffs_ += kOutputDimensions; num_biases_diffs_ += kOutputDimensions;
for (IndexType i = 0; i < kOutputDimensions * kInputDimensions; ++i) { for (IndexType i = 0; i < kOutputDimensions * kInputDimensions; ++i) {
const double d = local_learning_rate * weights_diff_[i]; const double d = local_learning_rate * main_thread_state.weights_diff_[i];
weights_[i] -= d; weights_[i] -= d;
abs_weights_diff_sum_ += std::abs(d); abs_weights_diff_sum_ += std::abs(d);
} }
num_weights_diffs_ += kOutputDimensions * kInputDimensions; num_weights_diffs_ += kOutputDimensions * kInputDimensions;
previous_layer_trainer_->backpropagate(thread_pool, gradients_.data(), learning_rate); previous_layer_trainer_->step_end(thread_pool, learning_rate);
} }
private: private:
// constructor // constructor
Trainer(LayerType* target_layer, FeatureTransformer* ft) : Trainer(LayerType* target_layer, FeatureTransformer* ft) :
batch_size_(0), combined_batch_size_(0),
batch_input_(nullptr), combined_batch_input_(nullptr),
previous_layer_trainer_(Trainer<PreviousLayer>::create( previous_layer_trainer_(Trainer<PreviousLayer>::create(
&target_layer->previous_layer_, ft)), &target_layer->previous_layer_, ft)),
target_layer_(target_layer), target_layer_(target_layer),
biases_(), biases_(),
weights_(), weights_(),
biases_diff_(),
weights_diff_(),
momentum_(0.2), momentum_(0.2),
learning_rate_scale_(1.0) { learning_rate_scale_(1.0) {
@@ -335,10 +370,12 @@ namespace Eval::NNUE {
} }
} }
std::fill(std::begin(biases_diff_), std::end(biases_diff_), for (auto& state : thread_states_)
static_cast<LearnFloatType>(0.0)); {
std::fill(std::begin(weights_diff_), std::end(weights_diff_), state.reset_weights();
static_cast<LearnFloatType>(0.0)); state.reset_biases();
}
reset_stats(); reset_stats();
} }
@@ -365,7 +402,7 @@ namespace Eval::NNUE {
std::numeric_limits<typename LayerType::WeightType>::max() / kWeightScale; std::numeric_limits<typename LayerType::WeightType>::max() / kWeightScale;
// number of samples in mini-batch // number of samples in mini-batch
IndexType batch_size_; IndexType combined_batch_size_;
double abs_biases_diff_sum_; double abs_biases_diff_sum_;
double abs_weights_diff_sum_; double abs_weights_diff_sum_;
@@ -373,7 +410,7 @@ namespace Eval::NNUE {
uint64_t num_weights_diffs_; uint64_t num_weights_diffs_;
// Input mini batch // Input mini batch
const LearnFloatType* batch_input_; const LearnFloatType* combined_batch_input_;
// Trainer of the previous layer // Trainer of the previous layer
const std::shared_ptr<Trainer<PreviousLayer>> previous_layer_trainer_; const std::shared_ptr<Trainer<PreviousLayer>> previous_layer_trainer_;
@@ -382,12 +419,44 @@ namespace Eval::NNUE {
LayerType* const target_layer_; LayerType* const target_layer_;
// parameter // parameter
struct alignas(kCacheLineSize) ThreadState
{
// Buffer used for updating parameters
alignas(kCacheLineSize) LearnFloatType biases_diff_[kOutputDimensions];
alignas(kCacheLineSize) LearnFloatType weights_diff_[kOutputDimensions * kInputDimensions];
ThreadState() { reset_weights(); reset_biases(); }
ThreadState& operator+=(const ThreadState& other)
{
for (IndexType i = 0; i < kOutputDimensions; ++i)
{
biases_diff_[i] += other.biases_diff_[i];
}
for (IndexType i = 0; i < kOutputDimensions * kInputDimensions; ++i)
{
weights_diff_[i] += other.weights_diff_[i];
}
return *this;
}
void reset_weights()
{
std::fill(std::begin(weights_diff_), std::end(weights_diff_), 0.0f);
}
void reset_biases()
{
std::fill(std::begin(biases_diff_), std::end(biases_diff_), 0.0f);
}
};
alignas(kCacheLineSize) LearnFloatType biases_[kOutputDimensions]; alignas(kCacheLineSize) LearnFloatType biases_[kOutputDimensions];
alignas(kCacheLineSize) LearnFloatType weights_[kOutputDimensions * kInputDimensions]; alignas(kCacheLineSize) LearnFloatType weights_[kOutputDimensions * kInputDimensions];
// Buffer used for updating parameters std::vector<ThreadState, CacheLineAlignedAllocator<ThreadState>> thread_states_;
alignas(kCacheLineSize) LearnFloatType biases_diff_[kOutputDimensions];
alignas(kCacheLineSize) LearnFloatType weights_diff_[kOutputDimensions * kInputDimensions];
// Forward propagation buffer // Forward propagation buffer
std::vector<LearnFloatType, CacheLineAlignedAllocator<LearnFloatType>> output_; std::vector<LearnFloatType, CacheLineAlignedAllocator<LearnFloatType>> output_;