diff --git a/src/nnue/layers/affine_transform.h b/src/nnue/layers/affine_transform.h index 34777ef6..adf152ee 100644 --- a/src/nnue/layers/affine_transform.h +++ b/src/nnue/layers/affine_transform.h @@ -301,20 +301,40 @@ namespace Eval::NNUE::Layers { } else if constexpr (kOutputDimensions == 1) { - constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth; - - vec_t sum0 = vec_setzero(); - - const auto row0 = reinterpret_cast(&weights_[0]); - - for (int j = 0; j < (int)kNumChunks; ++j) +#if defined (USE_AVX512) + if constexpr (kPaddedInputDimensions % (kSimdWidth * 2) != 0) { - const vec_t in = input_vector[j]; + constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth; + const auto input_vector256 = reinterpret_cast(input); - vec_add_dpbusd_32(sum0, in, row0[j]); + __m256i sum0 = _mm256_setzero_si256(); + const auto row0 = reinterpret_cast(&weights_[0]); + + for (int j = 0; j < (int)kNumChunks; ++j) + { + const __m256i in = input_vector256[j]; + m256_add_dpbusd_epi32(sum0, in, row0[j]); + } + output[0] = m256_hadd(sum0, biases_[0]); } + else +#endif + { +#if defined (USE_AVX512) + constexpr IndexType kNumChunks = kPaddedInputDimensions / (kSimdWidth * 2); +#else + constexpr IndexType kNumChunks = kPaddedInputDimensions / kSimdWidth; +#endif + vec_t sum0 = vec_setzero(); + const auto row0 = reinterpret_cast(&weights_[0]); - output[0] = vec_hadd(sum0, biases_[0]); + for (int j = 0; j < (int)kNumChunks; ++j) + { + const vec_t in = input_vector[j]; + vec_add_dpbusd_32(sum0, in, row0[j]); + } + output[0] = vec_hadd(sum0, biases_[0]); + } } #else