mirror of
https://github.com/opelly27/Stockfish.git
synced 2026-05-20 12:07:43 +00:00
Optimize trainer clipped relu propagate
This commit is contained in:
@@ -50,7 +50,73 @@ namespace Eval::NNUE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto input = previous_layer_trainer_->propagate(thread_pool, batch);
|
const auto input = previous_layer_trainer_->propagate(thread_pool, batch);
|
||||||
|
|
||||||
batch_size_ = static_cast<IndexType>(batch.size());
|
batch_size_ = static_cast<IndexType>(batch.size());
|
||||||
|
|
||||||
|
#if defined (USE_SSE2)
|
||||||
|
|
||||||
|
{
|
||||||
|
static_assert(kOutputDimensions % 16 == 0, "This implementation assumes that it can process 16 floats at a time");
|
||||||
|
|
||||||
|
const __m128 kZero4 = _mm_set1_ps(+kZero);
|
||||||
|
const __m128 kOne4 = _mm_set1_ps(+kOne);
|
||||||
|
|
||||||
|
for (IndexType b = 0; b < batch.size(); ++b)
|
||||||
|
{
|
||||||
|
const IndexType batch_offset = kOutputDimensions * b;
|
||||||
|
|
||||||
|
for (IndexType i = 0; i < kOutputDimensions; i += 16)
|
||||||
|
{
|
||||||
|
__m128 out0 = _mm_loadu_ps(&input[i + 0 + batch_offset]);
|
||||||
|
__m128 out1 = _mm_loadu_ps(&input[i + 4 + batch_offset]);
|
||||||
|
__m128 out2 = _mm_loadu_ps(&input[i + 8 + batch_offset]);
|
||||||
|
__m128 out3 = _mm_loadu_ps(&input[i + 12 + batch_offset]);
|
||||||
|
|
||||||
|
out0 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out0));
|
||||||
|
out1 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out1));
|
||||||
|
out2 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out2));
|
||||||
|
out3 = _mm_max_ps(kZero4, _mm_min_ps(kOne4, out3));
|
||||||
|
|
||||||
|
_mm_storeu_ps(&output_[i + 0 + batch_offset], out0);
|
||||||
|
_mm_storeu_ps(&output_[i + 4 + batch_offset], out1);
|
||||||
|
_mm_storeu_ps(&output_[i + 8 + batch_offset], out2);
|
||||||
|
_mm_storeu_ps(&output_[i + 12 + batch_offset], out3);
|
||||||
|
|
||||||
|
__m128 minact0 = _mm_loadu_ps(&min_activations_[i + 0]);
|
||||||
|
__m128 minact1 = _mm_loadu_ps(&min_activations_[i + 4]);
|
||||||
|
__m128 minact2 = _mm_loadu_ps(&min_activations_[i + 8]);
|
||||||
|
__m128 minact3 = _mm_loadu_ps(&min_activations_[i + 12]);
|
||||||
|
|
||||||
|
__m128 maxact0 = _mm_loadu_ps(&max_activations_[i + 0]);
|
||||||
|
__m128 maxact1 = _mm_loadu_ps(&max_activations_[i + 4]);
|
||||||
|
__m128 maxact2 = _mm_loadu_ps(&max_activations_[i + 8]);
|
||||||
|
__m128 maxact3 = _mm_loadu_ps(&max_activations_[i + 12]);
|
||||||
|
|
||||||
|
minact0 = _mm_min_ps(out0, minact0);
|
||||||
|
minact1 = _mm_min_ps(out1, minact1);
|
||||||
|
minact2 = _mm_min_ps(out2, minact2);
|
||||||
|
minact3 = _mm_min_ps(out3, minact3);
|
||||||
|
|
||||||
|
maxact0 = _mm_max_ps(out0, maxact0);
|
||||||
|
maxact1 = _mm_max_ps(out1, maxact1);
|
||||||
|
maxact2 = _mm_max_ps(out2, maxact2);
|
||||||
|
maxact3 = _mm_max_ps(out3, maxact3);
|
||||||
|
|
||||||
|
_mm_storeu_ps(&min_activations_[i + 0], minact0);
|
||||||
|
_mm_storeu_ps(&min_activations_[i + 4], minact1);
|
||||||
|
_mm_storeu_ps(&min_activations_[i + 8], minact2);
|
||||||
|
_mm_storeu_ps(&min_activations_[i + 12], minact3);
|
||||||
|
|
||||||
|
_mm_storeu_ps(&max_activations_[i + 0], maxact0);
|
||||||
|
_mm_storeu_ps(&max_activations_[i + 4], maxact1);
|
||||||
|
_mm_storeu_ps(&max_activations_[i + 8], maxact2);
|
||||||
|
_mm_storeu_ps(&max_activations_[i + 12], maxact3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
for (IndexType b = 0; b < batch_size_; ++b) {
|
for (IndexType b = 0; b < batch_size_; ++b) {
|
||||||
const IndexType batch_offset = kOutputDimensions * b;
|
const IndexType batch_offset = kOutputDimensions * b;
|
||||||
for (IndexType i = 0; i < kOutputDimensions; ++i) {
|
for (IndexType i = 0; i < kOutputDimensions; ++i) {
|
||||||
@@ -61,6 +127,8 @@ namespace Eval::NNUE {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
return output_.data();
|
return output_.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user