Skip to content

Commit

Permalink
optimise bounds checks
Browse files Browse the repository at this point in the history
This brings a performance improvement of 40-100%,
making this implementation as fast as the C++ alternative in kagome.

Where possible, compiler is aided to optimise away the bounds checks without
any unsafe code. However, a fair amount of unsafe code was needed,
but it doesn't lower the security posture as the needed assertions
were already being made.

Signed-off-by: alindima <[email protected]>
  • Loading branch information
alindima committed Dec 18, 2023
1 parent bf96cb6 commit e6827f3
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 107 deletions.
134 changes: 110 additions & 24 deletions reed-solomon-novelpoly/src/field/inc_afft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ pub struct AdditiveFFT {
}

/// Formal derivative of polynomial in the new?? basis
pub fn formal_derivative(cos: &mut [Additive], size: usize) {
for i in 1..size {
pub fn formal_derivative(cos: &mut [Additive]) {
for i in 1..cos.len() {
let length = ((i ^ (i - 1)) + 1) >> 1;
for j in (i - length)..i {
cos[j] ^= cos.get(j + length).copied().unwrap_or(Additive::ZERO);
}
}
let mut i = size;
let mut i = cos.len();
while i < FIELD_SIZE && i < cos.len() {
for j in 0..size {
for j in 0..cos.len() {
cos[j] ^= cos.get(j + i).copied().unwrap_or(Additive::ZERO);
}
i <<= 1;
Expand All @@ -32,9 +32,11 @@ pub fn formal_derivative(cos: &mut [Additive], size: usize) {

/// Formal derivative of polynomial in tweaked?? basis
#[allow(non_snake_case)]
pub fn tweaked_formal_derivative(codeword: &mut [Additive], n: usize) {
pub fn tweaked_formal_derivative(codeword: &mut [Additive]) {
#[cfg(b_is_not_one)]
let B = unsafe { &AFFT.B };
#[cfg(b_is_not_one)]
let n = codeword.len();

// We change nothing when multiplying by b from B.
#[cfg(b_is_not_one)]
Expand All @@ -44,7 +46,7 @@ pub fn tweaked_formal_derivative(codeword: &mut [Additive], n: usize) {
codeword[i + 1] = codeword[i + 1].mul(b);
}

formal_derivative(codeword, n);
formal_derivative(codeword);

// Again changes nothing by multiplying by b although b differs here.
#[cfg(b_is_not_one)]
Expand Down Expand Up @@ -86,7 +88,9 @@ fn b_is_one() {
// We're hunting for the differences and trying to undersrtand the algorithm.

/// Inverse additive FFT in the "novel polynomial basis"
pub fn inverse_afft(data: &mut [Additive], size: usize, index: usize) {
///
/// # Safety: See safety section of `AdditiveFFT::inverse_afft`.
pub unsafe fn inverse_afft(data: &mut [Additive], size: usize, index: usize) {

Check failure on line 93 in reed-solomon-novelpoly/src/field/inc_afft.rs

View workflow job for this annotation

GitHub Actions / Clippy

unsafe function's docs miss `# Safety` section
unsafe { &AFFT }.inverse_afft(data, size, index)
}

Expand All @@ -96,7 +100,9 @@ pub fn inverse_afft_faster8(data: &mut [Additive], size: usize, index: usize) {
}

/// Additive FFT in the "novel polynomial basis"
pub fn afft(data: &mut [Additive], size: usize, index: usize) {
///
/// # Safety: See safety section of `AdditiveFFT::afft`.
pub unsafe fn afft(data: &mut [Additive], size: usize, index: usize) {

Check failure on line 105 in reed-solomon-novelpoly/src/field/inc_afft.rs

View workflow job for this annotation

GitHub Actions / Clippy

unsafe function's docs miss `# Safety` section
unsafe { &AFFT }.afft(data, size, index)
}

Expand Down Expand Up @@ -130,7 +136,12 @@ impl AdditiveFFT {
}

/// Inverse additive FFT in the "novel polynomial basis"
pub fn inverse_afft(&self, data: &mut [Additive], size: usize, index: usize) {
///
/// # Safety
///
/// - caller must ensure than `size` is a power of two and that the length of the `data` slice is at least equal to `size`.
/// - caller must ensure that `index + size - 2` is less than or equal to 65534.
pub unsafe fn inverse_afft(&self, data: &mut [Additive], size: usize, index: usize) {
// All line references to Algorithm 2 page 6288 of
// https://www.citi.sinica.edu.tw/papers/whc/5524-F.pdf

Expand Down Expand Up @@ -167,20 +178,42 @@ impl AdditiveFFT {
// if depart_no >= 8 && false{
// data[i + depart_no] ^= dbg!(data[dbg!(i)]);
// } else {
data[i + depart_no] ^= data[i];
#[cfg(debug)]
{
data[i + depart_no] ^= data[i];
}

#[cfg(not(debug))]
{
// SAFETY
//
// j is smaller than size. depart_no is smaller than size.
// depart_no is always doubled, so it's always a power of two smaller than size.
// this means that depart_no is at most half of size, assuming size is a power of two.
//
// i is at most j - 1. j is greater than depart_no but is incremented by double of depart_no.
// for the max depart_no value of size/2, j will only have the one value of size/2,
// so the index will be size/2 - 1 + size/2, which is equal to size - 1, which is safe.
// i will always be smaller than i + depart_no, since they're positive integers. qed.
let local = unsafe { *data.get_unchecked(i) };
unsafe { *data.get_unchecked_mut(i + depart_no) ^= local };
}
// }
}

// Algorithm 2 indexs the skew factor in line 5 page 6288
// by i and \omega_{j 2^{i+1}}, but not by r explicitly.
// We further explore this confusion below. (TODO)
let skew =
// if depart_no >= 8 && false {
// dbg!(self.skews[j + index - 1])
// } else {
self.skews[j + index - 1]
// }
;
#[cfg(debug)]
let skew = self.skews[j + index - 1];

#[cfg(not(debug))]
// SAFETY:
//
// Safe because caller ensured that index + size - 2 is less than or equal to 65534 (the skew vector len).
// Since, j is at most size - 1, this is safe.
let skew = unsafe { *self.skews.get_unchecked(j + index - 1) };

// It's reasonale to skip the loop if skew is zero, but doing so with
// all bits set requires justification. (TODO)
if skew.0 != ONEMASK {
Expand All @@ -191,7 +224,17 @@ impl AdditiveFFT {
// if depart_no >= 8 && false{
// data[i] ^= dbg!(dbg!(data[dbg!(i + depart_no)]).mul(skew));
// } else {
data[i] ^= data[i + depart_no].mul(skew);
#[cfg(debug)]
{
data[i] ^= data[i + depart_no].mul(skew);
}

#[cfg(not(debug))]
// Same safety princicples as the first `for i in (j - depart_no)..j` loop.
{
let local = unsafe { *data.get_unchecked(i + depart_no) };
unsafe { *data.get_unchecked_mut(i) ^= local.mul(skew) };
}
// }
}
}
Expand Down Expand Up @@ -259,7 +302,12 @@ impl AdditiveFFT {
}

/// Additive FFT in the "novel polynomial basis"
pub fn afft(&self, data: &mut [Additive], size: usize, index: usize) {
///
/// # Safety
///
/// - caller must ensure than `size` is a power of two and that the length of the `data` slice is at least equal to `size`.
/// - caller must ensure that `index + size - 2` is less than or equal to 65534.
pub unsafe fn afft(&self, data: &mut [Additive], size: usize, index: usize) {
// All line references to Algorithm 1 page 6287 of
// https://www.citi.sinica.edu.tw/papers/whc/5524-F.pdf

Expand Down Expand Up @@ -291,24 +339,62 @@ impl AdditiveFFT {
// we think r actually appears but the skew factor repeats itself
// like in (19) in the proof of Lemma 4. (TODO)
// We should understand the rest of this basis story, like (8) too. (TODO)

#[cfg(debug)]
let skew = self.skews[j + index - 1];

#[cfg(not(debug))]
// SAFETY:
//
// Safe because caller ensured that index + size - 2 is less than or equal to 65534 (the skew vector len).
// Since, j is at most size - 1, this is safe.
let skew = unsafe { *self.skews.get_unchecked(j + index - 1) };

// It's reasonale to skip the loop if skew is zero, but doing so with
// all bits set requires justification. (TODO)
if skew.0 != ONEMASK {
// Loop on line 5, except skew should depend upon i aka j in Algorithm 1 (TODO)
for i in (j - depart_no)..j {
// Line 6, explained by (28) page 6287, but
// adding depart_no acts like the r+2^i superscript.
data[i] ^= data[i + depart_no].mul(skew);

#[cfg(debug)]
{
data[i] ^= data[i + depart_no].mul(skew);
}

#[cfg(not(debug))]
{
// SAFETY
//
// j is smaller than size. depart_no is smaller than size/2 and it's always halved.
// this means that depart_no is at most half of size, assuming size is a power of two.
//
// i is at most j - 1. j is greater than depart_no but is incremented by double of depart_no.
// for the max depart_no value of size/2, j will only have the one value of size/2,
// so the index will be size/2 - 1 + size/2, which is equal to size - 1, which is safe.
// i will always be smaller than i + depart_no, since they're positive integers. qed.
let local = unsafe { *data.get_unchecked(i + depart_no) };
unsafe { *data.get_unchecked_mut(i) ^= local.mul(skew) };
}
}
}

// Again loop on line 5, so i corresponds to j in Algorithm 1
for i in (j - depart_no)..j {
// Line 7, explained by (31) page 6287, but
// adding depart_no acts like the r+2^i superscript.
data[i + depart_no] ^= data[i];
#[cfg(debug)]
{
data[i + depart_no] ^= data[i];
}

#[cfg(not(debug))]
{
// Same safety princicples as the first `for i in (j - depart_no)..j` loop.
let local = unsafe { *data.get_unchecked(i) };
unsafe { *data.get_unchecked_mut(i + depart_no) ^= local };
}
}

// Increment by double depart_no in agreement with
Expand Down Expand Up @@ -484,7 +570,7 @@ pub mod test_utils {
let data = gen_plain::<R>(size);
gen_faster8_from_plain(data)
}

#[cfg(all(target_feature = "avx", feature = "avx"))]
pub fn assert_plain_eq_faster8(plain: impl AsRef<[Additive]>, faster8: impl AsRef<[Additive]>) {
let plain = plain.as_ref();
Expand All @@ -502,7 +588,7 @@ mod afft_tests {
use super::super::*;
use super::super::test_utils::*;
use rand::rngs::SmallRng;

#[cfg(all(target_feature = "avx", feature = "avx"))]
#[test]
fn afft_output_plain_eq_faster8_size_16() {
Expand Down Expand Up @@ -544,7 +630,7 @@ mod afft_tests {
println!(">>>>");
assert_plain_eq_faster8(data_plain, data_faster8);
}

#[cfg(all(target_feature = "avx", feature = "avx"))]
#[test]
fn afft_output_plain_eq_faster8_impulse_data() {
Expand Down
Loading

0 comments on commit e6827f3

Please sign in to comment.