Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 12 additions & 115 deletions graviola/src/low/aarch64/aes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
use core::arch::aarch64::*;

use crate::low;
use crate::low::aarch64::cpu;

pub(crate) enum AesKey {
Aes128(AesKey128),
Expand All @@ -37,116 +36,6 @@ impl AesKey {
Self::Aes256(a256) => a256.encrypt_block(inout),
}
}

pub(crate) fn ctr(&self, initial_counter: &[u8; 16], cipher_inout: &mut [u8]) {
// SAFETY: this crate requires the `aes` & `neon` cpu features
unsafe { self._ctr(initial_counter, cipher_inout) }
}

#[target_feature(enable = "aes,neon")]
fn _ctr(&self, initial_counter: &[u8; 16], cipher_inout: &mut [u8]) {
// counter and inc are big endian, so must be vrev32q_u8'd before use
// SAFETY: `initial_counter` is 16 bytes and readable
let counter = unsafe { vld1q_u8(initial_counter.as_ptr().cast()) };
let mut counter = vreinterpretq_u32_u8(vrev32q_u8(counter));

let inc = vsetq_lane_u8(1, vdupq_n_u8(0), 15);
let inc = vreinterpretq_u32_u8(vrev32q_u8(inc));

let mut by8 = cipher_inout.chunks_exact_mut(128);

for cipher8 in by8.by_ref() {
cpu::prefetch_rw(cipher8.as_ptr());
counter = vaddq_u32(counter, inc);
let b0 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b1 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b2 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b3 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b4 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b5 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b6 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b7 = vrev32q_u8(vreinterpretq_u8_u32(counter));

let (b0, b1, b2, b3, b4, b5, b6, b7) = match self {
Self::Aes128(a128) => {
_aes128_8_blocks(&a128.round_keys, b0, b1, b2, b3, b4, b5, b6, b7)
}
Self::Aes256(a256) => {
_aes256_8_blocks(&a256.round_keys, b0, b1, b2, b3, b4, b5, b6, b7)
}
};

// SAFETY: cipher8 is 128 bytes long, via `chunks_exact_mut`
unsafe {
let b0 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(0).cast()), b0);
let b1 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(16).cast()), b1);
let b2 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(32).cast()), b2);
let b3 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(48).cast()), b3);
let b4 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(64).cast()), b4);
let b5 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(80).cast()), b5);
let b6 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(96).cast()), b6);
let b7 = veorq_u8(vld1q_u8(cipher8.as_ptr().add(112).cast()), b7);

vst1q_u8(cipher8.as_mut_ptr().add(0).cast(), b0);
vst1q_u8(cipher8.as_mut_ptr().add(16).cast(), b1);
vst1q_u8(cipher8.as_mut_ptr().add(32).cast(), b2);
vst1q_u8(cipher8.as_mut_ptr().add(48).cast(), b3);
vst1q_u8(cipher8.as_mut_ptr().add(64).cast(), b4);
vst1q_u8(cipher8.as_mut_ptr().add(80).cast(), b5);
vst1q_u8(cipher8.as_mut_ptr().add(96).cast(), b6);
vst1q_u8(cipher8.as_mut_ptr().add(112).cast(), b7);
}
}

let mut singles = by8.into_remainder().chunks_exact_mut(16);

for cipher in singles.by_ref() {
counter = vaddq_u32(counter, inc);
let block = vrev32q_u8(vreinterpretq_u8_u32(counter));

let block = match self {
Self::Aes128(a128) => _aes128_block(&a128.round_keys, block),
Self::Aes256(a256) => _aes256_block(&a256.round_keys, block),
};

// SAFETY: `cipher` is 16 bytes and writable, via `chunks_exact_mut`
unsafe {
let block = veorq_u8(vld1q_u8(cipher.as_ptr().cast()), block);
vst1q_u8(cipher.as_mut_ptr().cast(), block);
}
}

let cipher_inout = singles.into_remainder();
if !cipher_inout.is_empty() {
let mut cipher = [0u8; 16];
let len = cipher_inout.len();
debug_assert!(len < 16);
cipher[..len].copy_from_slice(cipher_inout);

counter = vaddq_u32(counter, inc);
let block = vrev32q_u8(vreinterpretq_u8_u32(counter));

let block = match self {
Self::Aes128(a128) => _aes128_block(&a128.round_keys, block),
Self::Aes256(a256) => _aes256_block(&a256.round_keys, block),
};

// SAFETY: `cipher` is 16 bytes and writable
unsafe {
let block = veorq_u8(vld1q_u8(cipher.as_ptr().cast()), block);
vst1q_u8(cipher.as_mut_ptr().cast(), block)
};

cipher_inout.copy_from_slice(&cipher[..len]);
}
}
}

pub(crate) struct AesKey128 {
Expand Down Expand Up @@ -184,6 +73,10 @@ impl AesKey128 {
// SAFETY: this crate requires the `aes` cpu feature
unsafe { aes128_block(&self.round_keys, inout) }
}

pub(crate) fn round_keys(&self) -> &[uint8x16_t; 11] {
&self.round_keys
}
}

impl Drop for AesKey128 {
Expand Down Expand Up @@ -235,6 +128,10 @@ impl AesKey256 {
// SAFETY: this crate requires the `aes` cpu feature
unsafe { aes256_block(&self.round_keys, inout) }
}

pub(crate) fn round_keys(&self) -> &[uint8x16_t; 15] {
&self.round_keys
}
}

impl Drop for AesKey256 {
Expand Down Expand Up @@ -283,7 +180,7 @@ fn aes128_block(round_keys: &[uint8x16_t; 11], block_inout: &mut [u8]) {

#[target_feature(enable = "aes")]
#[inline]
fn _aes128_block(round_keys: &[uint8x16_t; 11], block: uint8x16_t) -> uint8x16_t {
pub(crate) fn _aes128_block(round_keys: &[uint8x16_t; 11], block: uint8x16_t) -> uint8x16_t {
let block = vaeseq_u8(block, round_keys[0]);
let block = vaesmcq_u8(block);
let block = vaeseq_u8(block, round_keys[1]);
Expand Down Expand Up @@ -330,7 +227,7 @@ macro_rules! round_8 {

#[target_feature(enable = "aes")]
#[inline]
fn _aes128_8_blocks(
pub(crate) fn _aes128_8_blocks(
round_keys: &[uint8x16_t; 11],
mut b0: uint8x16_t,
mut b1: uint8x16_t,
Expand Down Expand Up @@ -391,7 +288,7 @@ fn aes256_block(round_keys: &[uint8x16_t; 15], block_inout: &mut [u8; 16]) {

#[target_feature(enable = "aes")]
#[inline]
fn _aes256_block(round_keys: &[uint8x16_t; 15], block: uint8x16_t) -> uint8x16_t {
pub(crate) fn _aes256_block(round_keys: &[uint8x16_t; 15], block: uint8x16_t) -> uint8x16_t {
let block = vaeseq_u8(block, round_keys[0]);
let block = vaesmcq_u8(block);
let block = vaeseq_u8(block, round_keys[1]);
Expand Down Expand Up @@ -424,7 +321,7 @@ fn _aes256_block(round_keys: &[uint8x16_t; 15], block: uint8x16_t) -> uint8x16_t

#[target_feature(enable = "aes")]
#[inline]
fn _aes256_8_blocks(
pub(crate) fn _aes256_8_blocks(
round_keys: &[uint8x16_t; 15],
mut b0: uint8x16_t,
mut b1: uint8x16_t,
Expand Down
182 changes: 177 additions & 5 deletions graviola/src/low/aarch64/aes_gcm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
// Written for Graviola by Joe Birr-Pixton, 2024.
// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0

use core::arch::aarch64::*;

use crate::low::AesKey;
use crate::low::aarch64::cpu;
use crate::low::ghash::Ghash;

pub(crate) fn encrypt(
Expand All @@ -11,9 +14,8 @@ pub(crate) fn encrypt(
aad: &[u8],
cipher_inout: &mut [u8],
) {
ghash.add(aad);
key.ctr(initial_counter, cipher_inout);
ghash.add(cipher_inout);
// SAFETY: this crate requires the `aes` & `neon` cpu features
unsafe { _cipher::<true>(key, ghash, initial_counter, aad, cipher_inout) }
}

pub(crate) fn decrypt(
Expand All @@ -22,8 +24,178 @@ pub(crate) fn decrypt(
initial_counter: &[u8; 16],
aad: &[u8],
cipher_inout: &mut [u8],
) {
// SAFETY: this crate requires the `aes` & `neon` cpu features
unsafe { _cipher::<false>(key, ghash, initial_counter, aad, cipher_inout) }
}

// AES-GCM encrypt (if `ENC` is `true`) or decrypt.
#[target_feature(enable = "aes,neon")]
fn _cipher<const ENC: bool>(
key: &AesKey,
ghash: &mut Ghash<'_>,
initial_counter: &[u8; 16],
aad: &[u8],
cipher_inout: &mut [u8],
) {
ghash.add(aad);
ghash.add(cipher_inout);
key.ctr(initial_counter, cipher_inout);

// counter and inc are big endian, so must be vrev32q_u8'd before use
// SAFETY: `initial_counter` is 16 bytes and readable
let counter = unsafe { vld1q_u8(initial_counter.as_ptr().cast()) };
let mut counter = vreinterpretq_u32_u8(vrev32q_u8(counter));

let inc = vsetq_lane_u8(1, vdupq_n_u8(0), 15);
let inc = vreinterpretq_u32_u8(vrev32q_u8(inc));

let mut by8 = cipher_inout.chunks_exact_mut(128);

for cipher8 in by8.by_ref() {
cpu::prefetch_rw(cipher8.as_ptr());
counter = vaddq_u32(counter, inc);
let b0 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b1 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b2 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b3 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b4 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b5 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b6 = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let b7 = vrev32q_u8(vreinterpretq_u8_u32(counter));

let (b0, b1, b2, b3, b4, b5, b6, b7) = match key {
AesKey::Aes128(a128) => crate::low::aarch64::aes::_aes128_8_blocks(
a128.round_keys(),
b0,
b1,
b2,
b3,
b4,
b5,
b6,
b7,
),
AesKey::Aes256(a256) => crate::low::aarch64::aes::_aes256_8_blocks(
a256.round_keys(),
b0,
b1,
b2,
b3,
b4,
b5,
b6,
b7,
),
};

// SAFETY: cipher8 is 128 bytes long, via `chunks_exact_mut`
unsafe {
let a0 = vld1q_u8(cipher8.as_ptr().add(0).cast());
let a1 = vld1q_u8(cipher8.as_ptr().add(16).cast());
let a2 = vld1q_u8(cipher8.as_ptr().add(32).cast());
let a3 = vld1q_u8(cipher8.as_ptr().add(48).cast());
let a4 = vld1q_u8(cipher8.as_ptr().add(64).cast());
let a5 = vld1q_u8(cipher8.as_ptr().add(80).cast());
let a6 = vld1q_u8(cipher8.as_ptr().add(96).cast());
let a7 = vld1q_u8(cipher8.as_ptr().add(112).cast());

if !ENC {
ghash.add_eight_blocks(a0, a1, a2, a3, a4, a5, a6, a7);
}

let b0 = veorq_u8(a0, b0);
let b1 = veorq_u8(a1, b1);
let b2 = veorq_u8(a2, b2);
let b3 = veorq_u8(a3, b3);
let b4 = veorq_u8(a4, b4);
let b5 = veorq_u8(a5, b5);
let b6 = veorq_u8(a6, b6);
let b7 = veorq_u8(a7, b7);

vst1q_u8(cipher8.as_mut_ptr().add(0).cast(), b0);
vst1q_u8(cipher8.as_mut_ptr().add(16).cast(), b1);
vst1q_u8(cipher8.as_mut_ptr().add(32).cast(), b2);
vst1q_u8(cipher8.as_mut_ptr().add(48).cast(), b3);
vst1q_u8(cipher8.as_mut_ptr().add(64).cast(), b4);
vst1q_u8(cipher8.as_mut_ptr().add(80).cast(), b5);
vst1q_u8(cipher8.as_mut_ptr().add(96).cast(), b6);
vst1q_u8(cipher8.as_mut_ptr().add(112).cast(), b7);

if ENC {
ghash.add_eight_blocks(b0, b1, b2, b3, b4, b5, b6, b7);
}
}
}

let mut singles = by8.into_remainder().chunks_exact_mut(16);

for cipher in singles.by_ref() {
// SAFETY: cipher is 16 bytes long, via `chunks_exact_mut`.
let input_block = unsafe { vld1q_u8(cipher.as_ptr().add(0).cast()) };
if !ENC {
ghash.add_block(input_block);
}
counter = vaddq_u32(counter, inc);
let block = vrev32q_u8(vreinterpretq_u8_u32(counter));

let block = match key {
AesKey::Aes128(a128) => {
crate::low::aarch64::aes::_aes128_block(a128.round_keys(), block)
}
AesKey::Aes256(a256) => {
crate::low::aarch64::aes::_aes256_block(a256.round_keys(), block)
}
};

// SAFETY: `cipher` is 16 bytes and writable, via `chunks_exact_mut`
unsafe {
let block = veorq_u8(input_block, block);
vst1q_u8(cipher.as_mut_ptr().cast(), block);
if ENC {
ghash.add_block(block);
}
}
}

{
let cipher_inout = singles.into_remainder();
if !cipher_inout.is_empty() {
if !ENC {
ghash.add(cipher_inout);
}
let mut cipher = [0u8; 16];
let len = cipher_inout.len();
debug_assert!(len < 16);
cipher[..len].copy_from_slice(cipher_inout);

counter = vaddq_u32(counter, inc);
let block = vrev32q_u8(vreinterpretq_u8_u32(counter));

let block = match key {
AesKey::Aes128(a128) => {
crate::low::aarch64::aes::_aes128_block(a128.round_keys(), block)
}
AesKey::Aes256(a256) => {
crate::low::aarch64::aes::_aes256_block(a256.round_keys(), block)
}
};

// SAFETY: `cipher` is 16 bytes and writable
unsafe {
let block = veorq_u8(vld1q_u8(cipher.as_ptr().cast()), block);
vst1q_u8(cipher.as_mut_ptr().cast(), block)
};

cipher_inout.copy_from_slice(&cipher[..len]);
if ENC {
ghash.add(cipher_inout);
}
}
}
}
Loading
Loading