Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 64 additions & 41 deletions graviola/src/low/posint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,41 +199,32 @@ impl<const N: usize> PosInt<N> {
d.expand(x);

loop {
if u.is_odd() && v.is_odd() {
if v.less_than(&u) {
u = u.sub(&v);
a = a.add_mod(&c, y);
b = b.add_mod(&d, x);
} else {
v = v.sub(&u);
c = c.add_mod(&a, y);
d = d.add_mod(&b, x);
}
}
let u_and_v_are_odd = u64::from(u.is_odd() & v.is_odd());
let v_is_less_than_u = u64::from(v.less_than(&u));
let v_is_not_less_than_u = (!v_is_less_than_u) & 1;

u = u.sub(&v.mask(u_and_v_are_odd & v_is_less_than_u));
a = a.add_mod(&c.mask(u_and_v_are_odd & v_is_less_than_u), y);
b = b.add_mod(&d.mask(u_and_v_are_odd & v_is_less_than_u), x);

v = v.sub(&u.mask(u_and_v_are_odd & v_is_not_less_than_u));
c = c.add_mod(&a.mask(u_and_v_are_odd & v_is_not_less_than_u), y);
d = d.add_mod(&b.mask(u_and_v_are_odd & v_is_not_less_than_u), x);

assert!(u.is_even() || v.is_even());

if u.is_even() {
u = u.shift_right_1();
let u_is_even = u64::from(u.is_even());
let u_is_odd = u64::from(u.is_odd());
let a_or_b_is_odd = u64::from(a.is_odd() | b.is_odd());
let c_or_d_is_odd = u64::from(c.is_odd() | d.is_odd());

if a.is_odd() || b.is_odd() {
a = a.add_shift_right_1(y);
b = b.add_shift_right_1(x);
} else {
a = a.shift_right_1();
b = b.shift_right_1();
}
} else {
v = v.shift_right_1();

if c.is_odd() || d.is_odd() {
c = c.add_shift_right_1(y);
d = d.add_shift_right_1(x);
} else {
c = c.shift_right_1();
d = d.shift_right_1();
}
}
u = u.shift_right_small(u_is_even as _);
a = a.add_shift_right_small(&y.mask(a_or_b_is_odd & u_is_even), u_is_even as _);
b = b.add_shift_right_small(&x.mask(a_or_b_is_odd & u_is_even), u_is_even as _);

v = v.shift_right_small(u_is_odd as _);
c = c.add_shift_right_small(&y.mask(c_or_d_is_odd & u_is_odd), u_is_odd as _);
d = d.add_shift_right_small(&x.mask(c_or_d_is_odd & u_is_odd), u_is_odd as _);

if v.is_zero() {
match u.len_bits() {
Expand All @@ -244,6 +235,21 @@ impl<const N: usize> PosInt<N> {
}
}

// Create a copy of `self` in which every bit of the value has been logically ANDed with
// the low order bit of `mask_bit`.
// Note: The result's `used` field is equal to that of `self`, even if the result
// value contains leading zeros.
fn mask(&self, mask_bit: u64) -> Self {
let mask = mask_bit << 63;
let mask = mask | mask.saturating_sub(1);
Comment on lines +243 to +244
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On x86_64, rustc+llvm generates a cmov operation for the saturating_sub.
On Aarch64, it generates a csel operation. (Compiler Explorer)

As far as I can tell, both those instructions should run in constant time, independent of their inputs. But if there's a better idiom for "sign-extend this bit to fill 64 bits," I'm happy to switch to that.


let mut result = self.clone();
for i in 0..N {
result.words[i] &= mask;
}
result
}

/// Returns `self` >> shift.
///
/// This leaks the value of `shift` / 64, because this affects
Expand All @@ -260,11 +266,13 @@ impl<const N: usize> PosInt<N> {
r
}

/// Returns `self` >> 1.
pub(crate) fn shift_right_1(&self) -> Self {
/// Returns `self` >> `c`.
///
/// `c` must be <= 63.
pub(crate) fn shift_right_small(&self, c: u8) -> Self {
let mut r = Self::zero();
r.used = self.used;
low::bignum_shr_small(r.as_mut_words(), self.as_words(), 1);
low::bignum_shr_small(r.as_mut_words(), self.as_words(), c);
r
}

Expand Down Expand Up @@ -711,19 +719,20 @@ impl<const N: usize> PosInt<N> {
r
}

/// Computes (`self` + `b`) >> 1
/// Computes (`self` + `b`) >> `c`
// / `c` must be <= 63.
#[must_use]
pub(crate) fn add_shift_right_1(&self, b: &Self) -> Self {
pub(crate) fn add_shift_right_small(&self, b: &Self, c: u8) -> Self {
let mut tmp = Self::zero();
let carry = low::bignum_add(&mut tmp.words, self.as_words(), b.as_words());
tmp.used = low::bignum_digitsize(&tmp.words);

let mut r = Self::zero();
low::bignum_shr_small(&mut r.words, tmp.as_words(), 1);
low::bignum_shr_small(&mut r.words, tmp.as_words(), c & 63);
r.used = tmp.used;

// insert carry at top
r.words[r.used - 1] |= carry << 63;
r.words[r.used.saturating_sub(1)] |= carry << 63;

r
}
Expand Down Expand Up @@ -866,6 +875,20 @@ fn trim_leading_zeroes(mut bytes: &[u8]) -> &[u8] {
mod tests {
use super::*;

#[test]
fn test_mask() {
let x = PosInt::<2>::from_bytes(&[
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
0x0f, 0x00,
])
.unwrap();
assert!(x.mask(1).equals(&x));
assert!(x.mask(0).is_zero());
// Only the low order bit is used in masking
assert!(x.mask(0xa5a5a5a5a5a5a5a5).equals(&x));
assert!(x.mask(0xfffffffffffffffe).is_zero());
}

#[test]
fn from_bytes() {
// no bytes -> zero
Expand Down Expand Up @@ -1112,7 +1135,7 @@ mod tests {
}

#[test]
fn test_add_shift_right_1() {
fn test_add_shift_right_small() {
let a = PosInt::<1> {
words: [0x8000_0000_0000_0021; 1],
used: 1,
Expand All @@ -1121,14 +1144,14 @@ mod tests {
words: [0x8421_8421_8421_8421; 1],
used: 1,
};
let c = a.add_shift_right_1(&b);
let c = a.add_shift_right_small(&b, 1);
assert_eq!(c.as_words(), &[0x8210_c210_c210_c221]);

let b = PosInt::<1> {
words: [0x0421_8421_8421_8421; 1],
used: 1,
};
let c = a.add_shift_right_1(&b);
let c = a.add_shift_right_small(&b, 1);
assert_eq!(c.as_words(), &[0x4210_c210_c210_c221]);
}
}
Loading