From 29e452c94ebc6480809126f1193798bb5ed68823 Mon Sep 17 00:00:00 2001 From: Brian Pane Date: Tue, 7 Apr 2026 08:50:41 -0700 Subject: [PATCH] Stitched AES-GCM for aarch64 --- graviola/src/low/aarch64/aes.rs | 127 ++----------------- graviola/src/low/aarch64/aes_gcm.rs | 182 +++++++++++++++++++++++++++- graviola/src/low/aarch64/ghash.rs | 59 ++++++++- 3 files changed, 245 insertions(+), 123 deletions(-) diff --git a/graviola/src/low/aarch64/aes.rs b/graviola/src/low/aarch64/aes.rs index 7e9b7fd76..dcc339e15 100644 --- a/graviola/src/low/aarch64/aes.rs +++ b/graviola/src/low/aarch64/aes.rs @@ -11,7 +11,6 @@ use core::arch::aarch64::*; use crate::low; -use crate::low::aarch64::cpu; pub(crate) enum AesKey { Aes128(AesKey128), @@ -37,116 +36,6 @@ impl AesKey { Self::Aes256(a256) => a256.encrypt_block(inout), } } - - pub(crate) fn ctr(&self, initial_counter: &[u8; 16], cipher_inout: &mut [u8]) { - // SAFETY: this crate requires the `aes` & `neon` cpu features - unsafe { self._ctr(initial_counter, cipher_inout) } - } - - #[target_feature(enable = "aes,neon")] - fn _ctr(&self, initial_counter: &[u8; 16], cipher_inout: &mut [u8]) { - // counter and inc are big endian, so must be vrev32q_u8'd before use - // SAFETY: `initial_counter` is 16 bytes and readable - let counter = unsafe { vld1q_u8(initial_counter.as_ptr().cast()) }; - let mut counter = vreinterpretq_u32_u8(vrev32q_u8(counter)); - - let inc = vsetq_lane_u8(1, vdupq_n_u8(0), 15); - let inc = vreinterpretq_u32_u8(vrev32q_u8(inc)); - - let mut by8 = cipher_inout.chunks_exact_mut(128); - - for cipher8 in by8.by_ref() { - cpu::prefetch_rw(cipher8.as_ptr()); - counter = vaddq_u32(counter, inc); - let b0 = vrev32q_u8(vreinterpretq_u8_u32(counter)); - counter = vaddq_u32(counter, inc); - let b1 = vrev32q_u8(vreinterpretq_u8_u32(counter)); - counter = vaddq_u32(counter, inc); - let b2 = vrev32q_u8(vreinterpretq_u8_u32(counter)); - counter = vaddq_u32(counter, inc); - let b3 = vrev32q_u8(vreinterpretq_u8_u32(counter)); - counter = vaddq_u32(counter, inc); - let b4 = vrev32q_u8(vreinterpretq_u8_u32(counter)); - counter = vaddq_u32(counter, inc); - let b5 = vrev32q_u8(vreinterpretq_u8_u32(counter)); - counter = vaddq_u32(counter, inc); - let b6 = vrev32q_u8(vreinterpretq_u8_u32(counter)); - counter = vaddq_u32(counter, inc); - let b7 = vrev32q_u8(vreinterpretq_u8_u32(counter)); - - let (b0, b1, b2, b3, b4, b5, b6, b7) = match self { - Self::Aes128(a128) => { - _aes128_8_blocks(&a128.round_keys, b0, b1, b2, b3, b4, b5, b6, b7) - } - Self::Aes256(a256) => { - _aes256_8_blocks(&a256.round_keys, b0, b1, b2, b3, b4, b5, b6, b7) - } - }; - - // SAFETY: cipher8 is 128 bytes long, via `chunks_exact_mut` - unsafe { - let b0 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(0).cast()), b0); - let b1 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(16).cast()), b1); - let b2 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(32).cast()), b2); - let b3 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(48).cast()), b3); - let b4 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(64).cast()), b4); - let b5 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(80).cast()), b5); - let b6 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(96).cast()), b6); - let b7 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(112).cast()), b7); - - vst1q_u8(cipher8.as_mut_ptr().add(0).cast(), b0); - vst1q_u8(cipher8.as_mut_ptr().add(16).cast(), b1); - vst1q_u8(cipher8.as_mut_ptr().add(32).cast(), b2); - vst1q_u8(cipher8.as_mut_ptr().add(48).cast(), b3); - vst1q_u8(cipher8.as_mut_ptr().add(64).cast(), b4); - vst1q_u8(cipher8.as_mut_ptr().add(80).cast(), b5); - vst1q_u8(cipher8.as_mut_ptr().add(96).cast(), b6); - vst1q_u8(cipher8.as_mut_ptr().add(112).cast(), b7); - } - } - - let mut singles = by8.into_remainder().chunks_exact_mut(16); - - for cipher in singles.by_ref() { - counter = vaddq_u32(counter, inc); - let block = vrev32q_u8(vreinterpretq_u8_u32(counter)); - - let block = match self { - Self::Aes128(a128) => _aes128_block(&a128.round_keys, block), - Self::Aes256(a256) => _aes256_block(&a256.round_keys, block), - }; - - // SAFETY: `cipher` is 16 bytes and writable, via `chunks_exact_mut` - unsafe { - let block = veorq_u8(vld1q_u8(cipher.as_ptr().cast()), block); - vst1q_u8(cipher.as_mut_ptr().cast(), block); - } - } - - let cipher_inout = singles.into_remainder(); - if !cipher_inout.is_empty() { - let mut cipher = [0u8; 16]; - let len = cipher_inout.len(); - debug_assert!(len < 16); - cipher[..len].copy_from_slice(cipher_inout); - - counter = vaddq_u32(counter, inc); - let block = vrev32q_u8(vreinterpretq_u8_u32(counter)); - - let block = match self { - Self::Aes128(a128) => _aes128_block(&a128.round_keys, block), - Self::Aes256(a256) => _aes256_block(&a256.round_keys, block), - }; - - // SAFETY: `cipher` is 16 bytes and writable - unsafe { - let block = veorq_u8(vld1q_u8(cipher.as_ptr().cast()), block); - vst1q_u8(cipher.as_mut_ptr().cast(), block) - }; - - cipher_inout.copy_from_slice(&cipher[..len]); - } - } } pub(crate) struct AesKey128 { @@ -184,6 +73,10 @@ impl AesKey128 { // SAFETY: this crate requires the `aes` cpu feature unsafe { aes128_block(&self.round_keys, inout) } } + + pub(crate) fn round_keys(&self) -> &[uint8x16_t; 11] { + &self.round_keys + } } impl Drop for AesKey128 { @@ -235,6 +128,10 @@ impl AesKey256 { // SAFETY: this crate requires the `aes` cpu feature unsafe { aes256_block(&self.round_keys, inout) } } + + pub(crate) fn round_keys(&self) -> &[uint8x16_t; 15] { + &self.round_keys + } } impl Drop for AesKey256 { @@ -283,7 +180,7 @@ fn aes128_block(round_keys: &[uint8x16_t; 11], block_inout: &mut [u8]) { #[target_feature(enable = "aes")] #[inline] -fn _aes128_block(round_keys: &[uint8x16_t; 11], block: uint8x16_t) -> uint8x16_t { +pub(crate) fn _aes128_block(round_keys: &[uint8x16_t; 11], block: uint8x16_t) -> uint8x16_t { let block = vaeseq_u8(block, round_keys[0]); let block = vaesmcq_u8(block); let block = vaeseq_u8(block, round_keys[1]); @@ -330,7 +227,7 @@ macro_rules! round_8 { #[target_feature(enable = "aes")] #[inline] -fn _aes128_8_blocks( +pub(crate) fn _aes128_8_blocks( round_keys: &[uint8x16_t; 11], mut b0: uint8x16_t, mut b1: uint8x16_t, @@ -391,7 +288,7 @@ fn aes256_block(round_keys: &[uint8x16_t; 15], block_inout: &mut [u8; 16]) { #[target_feature(enable = "aes")] #[inline] -fn _aes256_block(round_keys: &[uint8x16_t; 15], block: uint8x16_t) -> uint8x16_t { +pub(crate) fn _aes256_block(round_keys: &[uint8x16_t; 15], block: uint8x16_t) -> uint8x16_t { let block = vaeseq_u8(block, round_keys[0]); let block = vaesmcq_u8(block); let block = vaeseq_u8(block, round_keys[1]); @@ -424,7 +321,7 @@ fn _aes256_block(round_keys: &[uint8x16_t; 15], block: uint8x16_t) -> uint8x16_t #[target_feature(enable = "aes")] #[inline] -fn _aes256_8_blocks( +pub(crate) fn _aes256_8_blocks( round_keys: &[uint8x16_t; 15], mut b0: uint8x16_t, mut b1: uint8x16_t, diff --git a/graviola/src/low/aarch64/aes_gcm.rs b/graviola/src/low/aarch64/aes_gcm.rs index 099fbb665..23a4b25db 100644 --- a/graviola/src/low/aarch64/aes_gcm.rs +++ b/graviola/src/low/aarch64/aes_gcm.rs @@ -1,7 +1,10 @@ // Written for Graviola by Joe Birr-Pixton, 2024. // SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 +use core::arch::aarch64::*; + use crate::low::AesKey; +use crate::low::aarch64::cpu; use crate::low::ghash::Ghash; pub(crate) fn encrypt( @@ -11,9 +14,8 @@ pub(crate) fn encrypt( aad: &[u8], cipher_inout: &mut [u8], ) { - ghash.add(aad); - key.ctr(initial_counter, cipher_inout); - ghash.add(cipher_inout); + // SAFETY: this crate requires the `aes` & `neon` cpu features + unsafe { _cipher::(key, ghash, initial_counter, aad, cipher_inout) } } pub(crate) fn decrypt( @@ -22,8 +24,178 @@ pub(crate) fn decrypt( initial_counter: &[u8; 16], aad: &[u8], cipher_inout: &mut [u8], +) { + // SAFETY: this crate requires the `aes` & `neon` cpu features + unsafe { _cipher::(key, ghash, initial_counter, aad, cipher_inout) } +} + +// AES-GCM encrypt (if `ENC` is `true`) or decrypt. +#[target_feature(enable = "aes,neon")] +fn _cipher( + key: &AesKey, + ghash: &mut Ghash<'_>, + initial_counter: &[u8; 16], + aad: &[u8], + cipher_inout: &mut [u8], ) { ghash.add(aad); - ghash.add(cipher_inout); - key.ctr(initial_counter, cipher_inout); + + // counter and inc are big endian, so must be vrev32q_u8'd before use + // SAFETY: `initial_counter` is 16 bytes and readable + let counter = unsafe { vld1q_u8(initial_counter.as_ptr().cast()) }; + let mut counter = vreinterpretq_u32_u8(vrev32q_u8(counter)); + + let inc = vsetq_lane_u8(1, vdupq_n_u8(0), 15); + let inc = vreinterpretq_u32_u8(vrev32q_u8(inc)); + + let mut by8 = cipher_inout.chunks_exact_mut(128); + + for cipher8 in by8.by_ref() { + cpu::prefetch_rw(cipher8.as_ptr()); + counter = vaddq_u32(counter, inc); + let b0 = vrev32q_u8(vreinterpretq_u8_u32(counter)); + counter = vaddq_u32(counter, inc); + let b1 = vrev32q_u8(vreinterpretq_u8_u32(counter)); + counter = vaddq_u32(counter, inc); + let b2 = vrev32q_u8(vreinterpretq_u8_u32(counter)); + counter = vaddq_u32(counter, inc); + let b3 = vrev32q_u8(vreinterpretq_u8_u32(counter)); + counter = vaddq_u32(counter, inc); + let b4 = vrev32q_u8(vreinterpretq_u8_u32(counter)); + counter = vaddq_u32(counter, inc); + let b5 = vrev32q_u8(vreinterpretq_u8_u32(counter)); + counter = vaddq_u32(counter, inc); + let b6 = vrev32q_u8(vreinterpretq_u8_u32(counter)); + counter = vaddq_u32(counter, inc); + let b7 = vrev32q_u8(vreinterpretq_u8_u32(counter)); + + let (b0, b1, b2, b3, b4, b5, b6, b7) = match key { + AesKey::Aes128(a128) => crate::low::aarch64::aes::_aes128_8_blocks( + a128.round_keys(), + b0, + b1, + b2, + b3, + b4, + b5, + b6, + b7, + ), + AesKey::Aes256(a256) => crate::low::aarch64::aes::_aes256_8_blocks( + a256.round_keys(), + b0, + b1, + b2, + b3, + b4, + b5, + b6, + b7, + ), + }; + + // SAFETY: cipher8 is 128 bytes long, via `chunks_exact_mut` + unsafe { + let a0 = vld1q_u8(cipher8.as_ptr().add(0).cast()); + let a1 = vld1q_u8(cipher8.as_ptr().add(16).cast()); + let a2 = vld1q_u8(cipher8.as_ptr().add(32).cast()); + let a3 = vld1q_u8(cipher8.as_ptr().add(48).cast()); + let a4 = vld1q_u8(cipher8.as_ptr().add(64).cast()); + let a5 = vld1q_u8(cipher8.as_ptr().add(80).cast()); + let a6 = vld1q_u8(cipher8.as_ptr().add(96).cast()); + let a7 = vld1q_u8(cipher8.as_ptr().add(112).cast()); + + if !ENC { + ghash.add_eight_blocks(a0, a1, a2, a3, a4, a5, a6, a7); + } + + let b0 = veorq_u8(a0, b0); + let b1 = veorq_u8(a1, b1); + let b2 = veorq_u8(a2, b2); + let b3 = veorq_u8(a3, b3); + let b4 = veorq_u8(a4, b4); + let b5 = veorq_u8(a5, b5); + let b6 = veorq_u8(a6, b6); + let b7 = veorq_u8(a7, b7); + + vst1q_u8(cipher8.as_mut_ptr().add(0).cast(), b0); + vst1q_u8(cipher8.as_mut_ptr().add(16).cast(), b1); + vst1q_u8(cipher8.as_mut_ptr().add(32).cast(), b2); + vst1q_u8(cipher8.as_mut_ptr().add(48).cast(), b3); + vst1q_u8(cipher8.as_mut_ptr().add(64).cast(), b4); + vst1q_u8(cipher8.as_mut_ptr().add(80).cast(), b5); + vst1q_u8(cipher8.as_mut_ptr().add(96).cast(), b6); + vst1q_u8(cipher8.as_mut_ptr().add(112).cast(), b7); + + if ENC { + ghash.add_eight_blocks(b0, b1, b2, b3, b4, b5, b6, b7); + } + } + } + + let mut singles = by8.into_remainder().chunks_exact_mut(16); + + for cipher in singles.by_ref() { + // SAFETY: cipher is 16 bytes long, via `chunks_exact_mut`. + let input_block = unsafe { vld1q_u8(cipher.as_ptr().add(0).cast()) }; + if !ENC { + ghash.add_block(input_block); + } + counter = vaddq_u32(counter, inc); + let block = vrev32q_u8(vreinterpretq_u8_u32(counter)); + + let block = match key { + AesKey::Aes128(a128) => { + crate::low::aarch64::aes::_aes128_block(a128.round_keys(), block) + } + AesKey::Aes256(a256) => { + crate::low::aarch64::aes::_aes256_block(a256.round_keys(), block) + } + }; + + // SAFETY: `cipher` is 16 bytes and writable, via `chunks_exact_mut` + unsafe { + let block = veorq_u8(input_block, block); + vst1q_u8(cipher.as_mut_ptr().cast(), block); + if ENC { + ghash.add_block(block); + } + } + } + + { + let cipher_inout = singles.into_remainder(); + if !cipher_inout.is_empty() { + if !ENC { + ghash.add(cipher_inout); + } + let mut cipher = [0u8; 16]; + let len = cipher_inout.len(); + debug_assert!(len < 16); + cipher[..len].copy_from_slice(cipher_inout); + + counter = vaddq_u32(counter, inc); + let block = vrev32q_u8(vreinterpretq_u8_u32(counter)); + + let block = match key { + AesKey::Aes128(a128) => { + crate::low::aarch64::aes::_aes128_block(a128.round_keys(), block) + } + AesKey::Aes256(a256) => { + crate::low::aarch64::aes::_aes256_block(a256.round_keys(), block) + } + }; + + // SAFETY: `cipher` is 16 bytes and writable + unsafe { + let block = veorq_u8(vld1q_u8(cipher.as_ptr().cast()), block); + vst1q_u8(cipher.as_mut_ptr().cast(), block) + }; + + cipher_inout.copy_from_slice(&cipher[..len]); + if ENC { + ghash.add(cipher_inout); + } + } + } } diff --git a/graviola/src/low/aarch64/ghash.rs b/graviola/src/low/aarch64/ghash.rs index 6646da9ff..05f8f876c 100644 --- a/graviola/src/low/aarch64/ghash.rs +++ b/graviola/src/low/aarch64/ghash.rs @@ -54,8 +54,8 @@ impl Drop for GhashTable { } pub(crate) struct Ghash<'a> { - table: &'a GhashTable, - current: uint64x2_t, + pub(crate) table: &'a GhashTable, + pub(crate) current: uint64x2_t, } impl<'a> Ghash<'a> { @@ -112,7 +112,15 @@ impl<'a> Ghash<'a> { self.current = mul(self.current, self.table.powers[0]); } - fn eight_blocks( + #[inline] + pub(crate) fn add_block(&mut self, block: uint8x16_t) { + // SAFETY: This crate requires the `neon` CPU feature. + self.current = unsafe { veorq_u64(self.current, to_uint64x2_be(block)) }; + self.current = mul(self.current, self.table.powers[0]); + } + + #[inline] + pub(crate) fn eight_blocks( &mut self, b1: u128, b2: u128, @@ -137,6 +145,35 @@ impl<'a> Ghash<'a> { from_u128(b8), ); } + + #[inline] + pub(crate) fn add_eight_blocks( + &mut self, + b1: uint8x16_t, + b2: uint8x16_t, + b3: uint8x16_t, + b4: uint8x16_t, + b5: uint8x16_t, + b6: uint8x16_t, + b7: uint8x16_t, + b8: uint8x16_t, + ) { + // SAFETY: this crate requires the `neon` cpu feature + unsafe { + let b1 = veorq_u64(self.current, to_uint64x2_be(b1)); + self.current = mul8( + self.table, + b1, + to_uint64x2_be(b2), + to_uint64x2_be(b3), + to_uint64x2_be(b4), + to_uint64x2_be(b5), + to_uint64x2_be(b6), + to_uint64x2_be(b7), + to_uint64x2_be(b8), + ); + } + } } #[inline] @@ -204,6 +241,7 @@ fn _mul(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { } #[target_feature(enable = "neon,aes")] +#[inline] fn _mul8( table: &GhashTable, a: uint64x2_t, @@ -294,6 +332,21 @@ fn to_u128(u: uint64x2_t) -> u128 { unsafe { mem::transmute(u) } } +#[inline] +#[target_feature(enable = "neon")] +// Make a copy of `u` with the bytes reversed, and cast to `uint64x2_t`. +pub(crate) fn to_uint64x2_be(u: uint8x16_t) -> uint64x2_t { + // Reverse the order of the bytes in each of the two 64-bit lanes in `u`. + let u = vrev64q_u8(u); + let u = vreinterpretq_u64_u8(u); + + // Swap the locations of the two 64-bit lanes to finish reversing the bytes. + let lane0 = vgetq_lane_u64(u, 0); + let lane1 = vgetq_lane_u64(u, 1); + let reversed = vsetq_lane_u64(lane0, u, 1); + vsetq_lane_u64(lane1, reversed, 0) +} + // SAFETY: u128 and uint64x2_t have the same size and meaning of bits const GF128_POLY_HI: uint64x2_t = unsafe { mem::transmute(0xc2000000_00000000_c2000000_00000000u128) };