From 7dcda459aeca65d630bafc871f485b7a683648a3 Mon Sep 17 00:00:00 2001 From: Kevin Lu Date: Wed, 8 Apr 2026 11:57:24 -0700 Subject: [PATCH 1/3] pairwise initial --- engine/nnue/network.cpp | 64 +++++++++++++++++++++++++++-------------- engine/nnue/network.hpp | 2 +- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index bddf499..6bf7bee 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -30,37 +30,57 @@ int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nb } int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator &ntm, uint8_t nbucket) { + // For pairwise multiplication, we need to first multiply the accumulators together + // Then do standard matmul + // int32_t score = 0; + // for (int i = 0; i < HL_SIZE / 2; i++) { + // int32_t stm1 = std::clamp((int)stm.val[i], 0, QA); + // int32_t stm2 = std::clamp((int)stm.val[i + HL_SIZE / 2], 0, QA); + + // int32_t ntm1 = std::clamp((int)ntm.val[i], 0, QA); + // int32_t ntm2 = std::clamp((int)ntm.val[i + HL_SIZE / 2], 0, QA); + + // score += stm1 * stm2 * net.output_weights[nbucket][i]; + // score += ntm1 * ntm2 * net.output_weights[nbucket][i + HL_SIZE / 2]; + // } + __m256i sum = _mm256_setzero_si256(); - const __m256i zero = _mm256_setzero_si256(); - const __m256i qa_vec = _mm256_set1_epi16(QA); + const __m256i zero = _mm256_setzero_si256(); + const __m256i qa_vec = _mm256_set1_epi16(QA); - for (int i = 0; i < HL_SIZE; i += 16) { - __m256i stm_vals = _mm256_loadu_si256((__m256i*)&stm.val[i]); - __m256i ntm_vals = _mm256_loadu_si256((__m256i*)&ntm.val[i]); + for (int i = 0; i < HL_SIZE; i += 16) { + __m256i stm_vals1 = _mm256_loadu_si256((__m256i*)&stm.val[i]); + __m256i stm_vals2 = _mm256_loadu_si256((__m256i*)&stm.val[i + HL_SIZE / 2]); + __m256i ntm_vals1 = _mm256_loadu_si256((__m256i*)&ntm.val[i]); + __m256i ntm_vals2 = _mm256_loadu_si256((__m256i*)&ntm.val[i + HL_SIZE / 2]); - stm_vals = _mm256_max_epi16(stm_vals, zero); - stm_vals = _mm256_min_epi16(stm_vals, qa_vec); + stm_vals1 = _mm256_max_epi16(stm_vals1, zero); + stm_vals1 = _mm256_min_epi16(stm_vals1, qa_vec); + stm_vals2 = _mm256_max_epi16(stm_vals2, zero); + stm_vals2 = _mm256_min_epi16(stm_vals2, qa_vec); - ntm_vals = _mm256_max_epi16(ntm_vals, zero); - ntm_vals = _mm256_min_epi16(ntm_vals, qa_vec); + ntm_vals1 = _mm256_max_epi16(ntm_vals1, zero); + ntm_vals1 = _mm256_min_epi16(ntm_vals1, qa_vec); + ntm_vals2 = _mm256_max_epi16(ntm_vals2, zero); + ntm_vals2 = _mm256_min_epi16(ntm_vals2, qa_vec); - __m256i stm_weights = _mm256_loadu_si256((__m256i*)&net.output_weights[nbucket][i]); - __m256i ntm_weights = _mm256_loadu_si256((__m256i*)&net.output_weights[nbucket][HL_SIZE + i]); + __m256i stm_weights = _mm256_loadu_si256((__m256i*)&net.output_weights[nbucket][i]); + __m256i ntm_weights = _mm256_loadu_si256((__m256i*)&net.output_weights[nbucket][HL_SIZE / 2 + i]); - __m256i stm_prod = _mm256_mullo_epi16(stm_vals, stm_weights); - __m256i ntm_prod = _mm256_mullo_epi16(ntm_vals, ntm_weights); + __m256i stm_prod = _mm256_mullo_epi16(stm_vals1, stm_weights); + __m256i ntm_prod = _mm256_mullo_epi16(ntm_vals1, ntm_weights); - __m256i stm_res = _mm256_madd_epi16(stm_prod, stm_vals); - __m256i ntm_res = _mm256_madd_epi16(ntm_prod, ntm_vals); + __m256i stm_res = _mm256_madd_epi16(stm_prod, stm_vals2); + __m256i ntm_res = _mm256_madd_epi16(ntm_prod, ntm_vals2); - sum = _mm256_add_epi32(sum, stm_res); - sum = _mm256_add_epi32(sum, ntm_res); - } + sum = _mm256_add_epi32(sum, stm_res); + sum = _mm256_add_epi32(sum, ntm_res); + } - __m128i sum_128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1)); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(2, 3, 0, 1))); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(1, 0, 3, 2))); - int32_t score = _mm_cvtsi128_si32(sum_128); + __m128i sum_128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1)); + sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(2, 3, 0, 1))); + sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(1, 0, 3, 2))); + int32_t score = _mm_cvtsi128_si32(sum_128); score /= QA; score += net.output_bias[nbucket]; diff --git a/engine/nnue/network.hpp b/engine/nnue/network.hpp index 10bf396..6122ed0 100644 --- a/engine/nnue/network.hpp +++ b/engine/nnue/network.hpp @@ -28,7 +28,7 @@ struct Accumulator { struct Network { int16_t accumulator_weights[INPUT_SIZE * NINPUTS][HL_SIZE]; int16_t accumulator_biases[HL_SIZE]; - int16_t output_weights[NBUCKETS][2 * HL_SIZE]; + int16_t output_weights[NBUCKETS][HL_SIZE]; int16_t output_bias[NBUCKETS]; void load(); From 5d5855ccf3c69b090259291168901521c48848d6 Mon Sep 17 00:00:00 2001 From: William Ma Date: Wed, 8 Apr 2026 15:56:13 -0400 Subject: [PATCH 2/3] vectorize Bench: 3851225 --- engine/nnue/network.cpp | 62 ++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/engine/nnue/network.cpp b/engine/nnue/network.cpp index 6bf7bee..fe51e6f 100644 --- a/engine/nnue/network.cpp +++ b/engine/nnue/network.cpp @@ -30,57 +30,43 @@ int calculate_index(Square sq, PieceType pt, bool side, bool perspective, int nb } int32_t nnue_eval(const Network &net, const Accumulator &stm, const Accumulator &ntm, uint8_t nbucket) { - // For pairwise multiplication, we need to first multiply the accumulators together - // Then do standard matmul - // int32_t score = 0; - // for (int i = 0; i < HL_SIZE / 2; i++) { - // int32_t stm1 = std::clamp((int)stm.val[i], 0, QA); - // int32_t stm2 = std::clamp((int)stm.val[i + HL_SIZE / 2], 0, QA); - - // int32_t ntm1 = std::clamp((int)ntm.val[i], 0, QA); - // int32_t ntm2 = std::clamp((int)ntm.val[i + HL_SIZE / 2], 0, QA); - - // score += stm1 * stm2 * net.output_weights[nbucket][i]; - // score += ntm1 * ntm2 * net.output_weights[nbucket][i + HL_SIZE / 2]; - // } - __m256i sum = _mm256_setzero_si256(); - const __m256i zero = _mm256_setzero_si256(); - const __m256i qa_vec = _mm256_set1_epi16(QA); + const __m256i zero = _mm256_setzero_si256(); + const __m256i qa_vec = _mm256_set1_epi16(QA); - for (int i = 0; i < HL_SIZE; i += 16) { - __m256i stm_vals1 = _mm256_loadu_si256((__m256i*)&stm.val[i]); - __m256i stm_vals2 = _mm256_loadu_si256((__m256i*)&stm.val[i + HL_SIZE / 2]); - __m256i ntm_vals1 = _mm256_loadu_si256((__m256i*)&ntm.val[i]); - __m256i ntm_vals2 = _mm256_loadu_si256((__m256i*)&ntm.val[i + HL_SIZE / 2]); + for (int i = 0; i < HL_SIZE / 2; i += 16) { + __m256i stm_vals1 = _mm256_loadu_si256((__m256i *)&stm.val[i]); + __m256i stm_vals2 = _mm256_loadu_si256((__m256i *)&stm.val[i + HL_SIZE / 2]); + __m256i ntm_vals1 = _mm256_loadu_si256((__m256i *)&ntm.val[i]); + __m256i ntm_vals2 = _mm256_loadu_si256((__m256i *)&ntm.val[i + HL_SIZE / 2]); - stm_vals1 = _mm256_max_epi16(stm_vals1, zero); + stm_vals1 = _mm256_max_epi16(stm_vals1, zero); stm_vals1 = _mm256_min_epi16(stm_vals1, qa_vec); stm_vals2 = _mm256_max_epi16(stm_vals2, zero); - stm_vals2 = _mm256_min_epi16(stm_vals2, qa_vec); + stm_vals2 = _mm256_min_epi16(stm_vals2, qa_vec); - ntm_vals1 = _mm256_max_epi16(ntm_vals1, zero); + ntm_vals1 = _mm256_max_epi16(ntm_vals1, zero); ntm_vals1 = _mm256_min_epi16(ntm_vals1, qa_vec); ntm_vals2 = _mm256_max_epi16(ntm_vals2, zero); - ntm_vals2 = _mm256_min_epi16(ntm_vals2, qa_vec); + ntm_vals2 = _mm256_min_epi16(ntm_vals2, qa_vec); - __m256i stm_weights = _mm256_loadu_si256((__m256i*)&net.output_weights[nbucket][i]); - __m256i ntm_weights = _mm256_loadu_si256((__m256i*)&net.output_weights[nbucket][HL_SIZE / 2 + i]); + __m256i stm_weights = _mm256_loadu_si256((__m256i *)&net.output_weights[nbucket][i]); + __m256i ntm_weights = _mm256_loadu_si256((__m256i *)&net.output_weights[nbucket][i + HL_SIZE / 2]); - __m256i stm_prod = _mm256_mullo_epi16(stm_vals1, stm_weights); - __m256i ntm_prod = _mm256_mullo_epi16(ntm_vals1, ntm_weights); + __m256i stm_prod = _mm256_mullo_epi16(stm_vals1, stm_weights); + __m256i ntm_prod = _mm256_mullo_epi16(ntm_vals1, ntm_weights); - __m256i stm_res = _mm256_madd_epi16(stm_prod, stm_vals2); - __m256i ntm_res = _mm256_madd_epi16(ntm_prod, ntm_vals2); + __m256i stm_res = _mm256_madd_epi16(stm_prod, stm_vals2); + __m256i ntm_res = _mm256_madd_epi16(ntm_prod, ntm_vals2); - sum = _mm256_add_epi32(sum, stm_res); - sum = _mm256_add_epi32(sum, ntm_res); - } + sum = _mm256_add_epi32(sum, stm_res); + sum = _mm256_add_epi32(sum, ntm_res); + } - __m128i sum_128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1)); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(2, 3, 0, 1))); - sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(1, 0, 3, 2))); - int32_t score = _mm_cvtsi128_si32(sum_128); + __m128i sum_128 = _mm_add_epi32(_mm256_extracti128_si256(sum, 0), _mm256_extracti128_si256(sum, 1)); + sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(2, 3, 0, 1))); + sum_128 = _mm_add_epi32(sum_128, _mm_shuffle_epi32(sum_128, _MM_SHUFFLE(1, 0, 3, 2))); + int32_t score = _mm_cvtsi128_si32(sum_128); score /= QA; score += net.output_bias[nbucket]; From 79c249928cbe5bc604cf83d4d0e74abfd0d82dd1 Mon Sep 17 00:00:00 2001 From: William Ma Date: Thu, 9 Apr 2026 17:07:47 -0400 Subject: [PATCH 3/3] Fix pawnvalue command Bench: 3851225 --- engine/main.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/engine/main.cpp b/engine/main.cpp index 1fc1118..79b6105 100644 --- a/engine/main.cpp +++ b/engine/main.cpp @@ -413,6 +413,7 @@ __attribute__((weak)) int main(int argc, char *argv[]) { for (int i = 0; i < 8; i++) { pos.reset_startpos(); pos.mailbox[SQ_A2 + i] = NO_PIECE; + am.refresh_finny(pos, am.idx); Value score = eval(pos, am); int diff = startpos_score - score; tot += diff;