diff --git a/reed-solomon-novelpoly/src/field/inc_afft.rs b/reed-solomon-novelpoly/src/field/inc_afft.rs index 5032cd2..7f7c555 100644 --- a/reed-solomon-novelpoly/src/field/inc_afft.rs +++ b/reed-solomon-novelpoly/src/field/inc_afft.rs @@ -14,16 +14,16 @@ pub struct AdditiveFFT { } /// Formal derivative of polynomial in the new?? basis -pub fn formal_derivative(cos: &mut [Additive], size: usize) { - for i in 1..size { +pub fn formal_derivative(cos: &mut [Additive]) { + for i in 1..cos.len() { let length = ((i ^ (i - 1)) + 1) >> 1; for j in (i - length)..i { cos[j] ^= cos.get(j + length).copied().unwrap_or(Additive::ZERO); } } - let mut i = size; + let mut i = cos.len(); while i < FIELD_SIZE && i < cos.len() { - for j in 0..size { + for j in 0..cos.len() { cos[j] ^= cos.get(j + i).copied().unwrap_or(Additive::ZERO); } i <<= 1; @@ -32,9 +32,11 @@ pub fn formal_derivative(cos: &mut [Additive], size: usize) { /// Formal derivative of polynomial in tweaked?? basis #[allow(non_snake_case)] -pub fn tweaked_formal_derivative(codeword: &mut [Additive], n: usize) { +pub fn tweaked_formal_derivative(codeword: &mut [Additive]) { #[cfg(b_is_not_one)] let B = unsafe { &AFFT.B }; + #[cfg(b_is_not_one)] + let n = codeword.len(); // We change nothing when multiplying by b from B. #[cfg(b_is_not_one)] @@ -44,7 +46,7 @@ pub fn tweaked_formal_derivative(codeword: &mut [Additive], n: usize) { codeword[i + 1] = codeword[i + 1].mul(b); } - formal_derivative(codeword, n); + formal_derivative(codeword); // Again changes nothing by multiplying by b although b differs here. #[cfg(b_is_not_one)] @@ -86,7 +88,9 @@ fn b_is_one() { // We're hunting for the differences and trying to undersrtand the algorithm. /// Inverse additive FFT in the "novel polynomial basis" -pub fn inverse_afft(data: &mut [Additive], size: usize, index: usize) { +/// +/// # Safety: See safety section of `AdditiveFFT::inverse_afft`. +pub unsafe fn inverse_afft(data: &mut [Additive], size: usize, index: usize) { unsafe { &AFFT }.inverse_afft(data, size, index) } @@ -96,7 +100,9 @@ pub fn inverse_afft_faster8(data: &mut [Additive], size: usize, index: usize) { } /// Additive FFT in the "novel polynomial basis" -pub fn afft(data: &mut [Additive], size: usize, index: usize) { +/// +/// # Safety: See safety section of `AdditiveFFT::afft`. +pub unsafe fn afft(data: &mut [Additive], size: usize, index: usize) { unsafe { &AFFT }.afft(data, size, index) } @@ -130,7 +136,12 @@ impl AdditiveFFT { } /// Inverse additive FFT in the "novel polynomial basis" - pub fn inverse_afft(&self, data: &mut [Additive], size: usize, index: usize) { + /// + /// # Safety + /// + /// - caller must ensure than `size` is a power of two and that the length of the `data` slice is at least equal to `size`. + /// - caller must ensure that `index + size - 2` is less than or equal to 65534. + pub unsafe fn inverse_afft(&self, data: &mut [Additive], size: usize, index: usize) { // All line references to Algorithm 2 page 6288 of // https://www.citi.sinica.edu.tw/papers/whc/5524-F.pdf @@ -167,20 +178,42 @@ impl AdditiveFFT { // if depart_no >= 8 && false{ // data[i + depart_no] ^= dbg!(data[dbg!(i)]); // } else { - data[i + depart_no] ^= data[i]; + #[cfg(debug)] + { + data[i + depart_no] ^= data[i]; + } + + #[cfg(not(debug))] + { + // SAFETY + // + // j is smaller than size. depart_no is smaller than size. + // depart_no is always doubled, so it's always a power of two smaller than size. + // this means that depart_no is at most half of size, assuming size is a power of two. + // + // i is at most j - 1. j is greater than depart_no but is incremented by double of depart_no. + // for the max depart_no value of size/2, j will only have the one value of size/2, + // so the index will be size/2 - 1 + size/2, which is equal to size - 1, which is safe. + // i will always be smaller than i + depart_no, since they're positive integers. qed. + let local = unsafe { *data.get_unchecked(i) }; + unsafe { *data.get_unchecked_mut(i + depart_no) ^= local }; + } // } } // Algorithm 2 indexs the skew factor in line 5 page 6288 // by i and \omega_{j 2^{i+1}}, but not by r explicitly. // We further explore this confusion below. (TODO) - let skew = - // if depart_no >= 8 && false { - // dbg!(self.skews[j + index - 1]) - // } else { - self.skews[j + index - 1] - // } - ; + #[cfg(debug)] + let skew = self.skews[j + index - 1]; + + #[cfg(not(debug))] + // SAFETY: + // + // Safe because caller ensured that index + size - 2 is less than or equal to 65534 (the skew vector len). + // Since, j is at most size - 1, this is safe. + let skew = unsafe { *self.skews.get_unchecked(j + index - 1) }; + // It's reasonale to skip the loop if skew is zero, but doing so with // all bits set requires justification. (TODO) if skew.0 != ONEMASK { @@ -191,7 +224,17 @@ impl AdditiveFFT { // if depart_no >= 8 && false{ // data[i] ^= dbg!(dbg!(data[dbg!(i + depart_no)]).mul(skew)); // } else { - data[i] ^= data[i + depart_no].mul(skew); + #[cfg(debug)] + { + data[i] ^= data[i + depart_no].mul(skew); + } + + #[cfg(not(debug))] + // Same safety princicples as the first `for i in (j - depart_no)..j` loop. + { + let local = unsafe { *data.get_unchecked(i + depart_no) }; + unsafe { *data.get_unchecked_mut(i) ^= local.mul(skew) }; + } // } } } @@ -259,7 +302,12 @@ impl AdditiveFFT { } /// Additive FFT in the "novel polynomial basis" - pub fn afft(&self, data: &mut [Additive], size: usize, index: usize) { + /// + /// # Safety + /// + /// - caller must ensure than `size` is a power of two and that the length of the `data` slice is at least equal to `size`. + /// - caller must ensure that `index + size - 2` is less than or equal to 65534. + pub unsafe fn afft(&self, data: &mut [Additive], size: usize, index: usize) { // All line references to Algorithm 1 page 6287 of // https://www.citi.sinica.edu.tw/papers/whc/5524-F.pdf @@ -291,8 +339,17 @@ impl AdditiveFFT { // we think r actually appears but the skew factor repeats itself // like in (19) in the proof of Lemma 4. (TODO) // We should understand the rest of this basis story, like (8) too. (TODO) + + #[cfg(debug)] let skew = self.skews[j + index - 1]; + #[cfg(not(debug))] + // SAFETY: + // + // Safe because caller ensured that index + size - 2 is less than or equal to 65534 (the skew vector len). + // Since, j is at most size - 1, this is safe. + let skew = unsafe { *self.skews.get_unchecked(j + index - 1) }; + // It's reasonale to skip the loop if skew is zero, but doing so with // all bits set requires justification. (TODO) if skew.0 != ONEMASK { @@ -300,7 +357,26 @@ impl AdditiveFFT { for i in (j - depart_no)..j { // Line 6, explained by (28) page 6287, but // adding depart_no acts like the r+2^i superscript. - data[i] ^= data[i + depart_no].mul(skew); + + #[cfg(debug)] + { + data[i] ^= data[i + depart_no].mul(skew); + } + + #[cfg(not(debug))] + { + // SAFETY + // + // j is smaller than size. depart_no is smaller than size/2 and it's always halved. + // this means that depart_no is at most half of size, assuming size is a power of two. + // + // i is at most j - 1. j is greater than depart_no but is incremented by double of depart_no. + // for the max depart_no value of size/2, j will only have the one value of size/2, + // so the index will be size/2 - 1 + size/2, which is equal to size - 1, which is safe. + // i will always be smaller than i + depart_no, since they're positive integers. qed. + let local = unsafe { *data.get_unchecked(i + depart_no) }; + unsafe { *data.get_unchecked_mut(i) ^= local.mul(skew) }; + } } } @@ -308,7 +384,17 @@ impl AdditiveFFT { for i in (j - depart_no)..j { // Line 7, explained by (31) page 6287, but // adding depart_no acts like the r+2^i superscript. - data[i + depart_no] ^= data[i]; + #[cfg(debug)] + { + data[i + depart_no] ^= data[i]; + } + + #[cfg(not(debug))] + { + // Same safety princicples as the first `for i in (j - depart_no)..j` loop. + let local = unsafe { *data.get_unchecked(i) }; + unsafe { *data.get_unchecked_mut(i + depart_no) ^= local }; + } } // Increment by double depart_no in agreement with @@ -484,7 +570,7 @@ pub mod test_utils { let data = gen_plain::(size); gen_faster8_from_plain(data) } - + #[cfg(all(target_feature = "avx", feature = "avx"))] pub fn assert_plain_eq_faster8(plain: impl AsRef<[Additive]>, faster8: impl AsRef<[Additive]>) { let plain = plain.as_ref(); @@ -502,7 +588,7 @@ mod afft_tests { use super::super::*; use super::super::test_utils::*; use rand::rngs::SmallRng; - + #[cfg(all(target_feature = "avx", feature = "avx"))] #[test] fn afft_output_plain_eq_faster8_size_16() { @@ -544,7 +630,7 @@ mod afft_tests { println!(">>>>"); assert_plain_eq_faster8(data_plain, data_faster8); } - + #[cfg(all(target_feature = "avx", feature = "avx"))] #[test] fn afft_output_plain_eq_faster8_impulse_data() { diff --git a/reed-solomon-novelpoly/src/field/inc_encode.rs b/reed-solomon-novelpoly/src/field/inc_encode.rs index 6e30bf0..acc7b79 100644 --- a/reed-solomon-novelpoly/src/field/inc_encode.rs +++ b/reed-solomon-novelpoly/src/field/inc_encode.rs @@ -29,20 +29,36 @@ pub fn encode_low_plain(data: &[Additive], k: usize, codeword: &mut [Additive], // split after the first k let (codeword_first_k, codeword_skip_first_k) = codeword.split_at_mut(k); - inverse_afft(codeword_first_k, k, 0); + // - safe because codeword_first_k is exactly k elements and k is a power of two. + // - safe because `index + size - 2` is `k - 2`. k is at most n/2 and n is at most 65536. Therefore, + // k is at most 65536/2-2 = 32766 (smaller than 65535). qed. + unsafe { inverse_afft(codeword_first_k, k, 0) }; // dbg!(&codeword_first_k); // the first codeword is now the basis for the remaining transforms // denoted `M_topdash` for shift in (k..n).step_by(k) { + #[cfg(debug)] let codeword_at_shift = &mut codeword_skip_first_k[(shift - k)..shift]; + + #[cfg(not(debug))] + // SAFETY + // + // n is i*k, with i at least 2. shift is at most (i-1)*k. + // (i-1) * k will always be smaller than i*k for all i greater than 2. + // Similarly, shift - k will always be smaller than shift, since they're positive integers + // and shift is at least equal to k. + let codeword_at_shift = unsafe { codeword_skip_first_k.get_unchecked_mut((shift - k)..shift) }; + // copy `M_topdash` to the position we are currently at, the n transform codeword_at_shift.copy_from_slice(codeword_first_k); - // dbg!(&codeword_at_shift); - afft(codeword_at_shift, k, shift); - // let post = &codeword_at_shift; - // dbg!(post); + // SAFETY + // + // - safe because codeword_first_k is exactly k elements and k is a power of two. + // - k is at most n/2 (32768). `index + size - 2` is therefore equal to 2*k - 2 = 65534 which + // is less than or equal to 65534. qed. + unsafe { afft(codeword_at_shift, k, shift) }; } // restore `M` from the derived ones @@ -78,12 +94,21 @@ pub fn encode_low_faster8(data: &[Additive], k: usize, codeword: &mut [Additive] // denoted `M_topdash` for shift in (k..n).step_by(k) { + #[cfg(debug)] let codeword_at_shift = &mut codeword_skip_first_k[(shift - k)..shift]; + + #[cfg(not(debug))] + // SAFETY + // + // n is i*k, with i at least 2. shift is at most (i-1)*k. + // (i-1) * k will always be smaller than i*k for all i greater than 2. + // Similarly, shift - k will always be smaller than shift, since they're positive integers + // and shift is at least equal to k. + let codeword_at_shift = unsafe { codeword_skip_first_k.get_unchecked_mut((shift - k)..shift) }; + // copy `M_topdash` to the position we are currently at, the n transform codeword_at_shift.copy_from_slice(codeword_first_k); - afft_faster8(codeword_at_shift, k, shift); - // let post = &codeword8x_at_shift; } // restore `M` from the derived ones @@ -93,63 +118,68 @@ pub fn encode_low_faster8(data: &[Additive], k: usize, codeword: &mut [Additive] //data: message array. parity: parity array. mem: buffer(size>= n-k) //Encoding alg for k/n>0.5: parity is a power of two. -#[inline(always)] -pub fn encode_high(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { - #[cfg(all(target_feature = "avx", feature = "avx"))] - if (n - k) % Additive8x::LANE == 0 && n % Additive8x::LANE == 0 && k % Additive8x::LANE == 0 { - encode_high_faster8(data, k, parity, mem, n); - } else { - encode_high_plain(data, k, parity, mem, n); - } - #[cfg(not(target_feature = "avx"))] - encode_high_plain(data, k, parity, mem, n); -} +// Function is not exposed/tested. Consider the safety guidelines of the *afft functions before using. +// #[inline(always)] +// pub fn encode_high(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { +// #[cfg(all(target_feature = "avx", feature = "avx"))] +// if (n - k) % Additive8x::LANE == 0 && n % Additive8x::LANE == 0 && k % Additive8x::LANE == 0 { +// encode_high_faster8(data, k, parity, mem, n); +// } else { +// encode_high_plain(data, k, parity, mem, n); +// } +// #[cfg(not(target_feature = "avx"))] +// encode_high_plain(data, k, parity, mem, n); +// } //data: message array. parity: parity array. mem: buffer(size>= n-k) //Encoding alg for k/n>0.5: parity is a power of two. -pub fn encode_high_plain(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { - let t: usize = n - k; - - // mem_zero(&mut parity[0..t]); - for i in 0..t { - parity[i] = Additive(0); - } - - let mut i = t; - while i < n { - mem[..t].copy_from_slice(&data[(i - t)..t]); - - inverse_afft(mem, t, i); - for j in 0..t { - parity[j] ^= mem[j]; - } - i += t; - } - afft(parity, t, 0); -} - -#[cfg(all(target_feature = "avx", feature = "avx"))] -pub fn encode_high_faster8(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { - let t: usize = n - k; - assert!(t >= 8); - assert_eq!(t % 8, 0); - - for i in 0..t { - parity[i] = Additive::zero(); - } - - let mut i = t; - while i < n { - mem[..t].copy_from_slice(&data[(i - t)..t]); - - inverse_afft_faster8(mem, t, i); - for j in 0..t { - parity[j] ^= mem[j]; - } - i += t; - } - afft_faster8(parity, t, 0); -} +// Function is not exposed/tested. Consider the safety guidelines of the *afft functions before using. +// pub fn encode_high_plain(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { +// assert!(is_power_of_2(n)); + +// let t: usize = n - k; + +// // mem_zero(&mut parity[0..t]); +// for i in 0..t { +// parity[i] = Additive(0); +// } + +// let mut i = t; +// while i < n { +// mem[..t].copy_from_slice(&data[(i - t)..t]); + +// unsafe { inverse_afft(mem, t, i) }; +// for j in 0..t { +// parity[j] ^= mem[j]; +// } +// i += t; +// } +// unsafe { afft(parity, t, 0) }; +// } + +// #[cfg(all(target_feature = "avx", feature = "avx"))] +// Function is not exposed/tested. Consider the safety guidelines of the *afft functions before using. +// pub fn encode_high_faster8(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { +// let t: usize = n - k; +// assert!(t >= 8); +// assert_eq!(t % 8, 0); + +// for i in 0..t { +// parity[i] = Additive::zero(); +// } + +// let mut i = t; +// while i < n { +// mem[..t].copy_from_slice(&data[(i - t)..t]); + +// inverse_afft_faster8(mem, t, i); +// for j in 0..t { +// parity[j] ^= mem[j]; +// } +// i += t; +// } +// afft_faster8(parity, t, 0); +// } pub fn encode_sub(bytes: &[u8], n: usize, k: usize) -> Result> { #[cfg(all(target_feature = "avx", feature = "avx"))] @@ -191,10 +221,26 @@ pub fn encode_sub_plain(bytes: &[u8], n: usize, k: usize) -> Result Result= recover_up_to); assert_eq!(erasure.len(), n); - for i in 0..n { + for i in 0..codeword.len() { codeword[i] = if erasure[i] { Additive(0) } else { codeword[i].mul(log_walsh2[i]) }; } - inverse_afft(codeword, n, 0); + // SAFETY + // + // - safe because we check in `reconstruct_sub` that n is a power of two and we also check that + // codeword.len() is equal to n. + // - n is at most 65536. `index + size - 2` is therefore equal to 65536 - 2 = 65534 which + // is less than or equal to 65534. qed. + unsafe { inverse_afft(codeword, n, 0) }; - tweaked_formal_derivative(codeword, n); + tweaked_formal_derivative(codeword); - afft(codeword, n, 0); + // SAFETY + // + // - safe because we check in `reconstruct_sub` that n is a power of two and we also check that + // codeword.len() is equal to n. + // - n is at most 65536. `index + size - 2` is therefore equal to 65536 - 2 = 65534 which + // is less than or equal to 65534. qed. + unsafe { afft(codeword, n, 0) }; for i in 0..recover_up_to { codeword[i] = if erasure[i] { codeword[i].mul(log_walsh2[i]) } else { Additive(0) }; diff --git a/reed-solomon-novelpoly/src/novel_poly_basis/tests.rs b/reed-solomon-novelpoly/src/novel_poly_basis/tests.rs index 9c3101a..c60d90a 100644 --- a/reed-solomon-novelpoly/src/novel_poly_basis/tests.rs +++ b/reed-solomon-novelpoly/src/novel_poly_basis/tests.rs @@ -69,12 +69,12 @@ fn flt_back_and_forth() { let mut data = (0..N).map(|_x| rand_gf_element()).collect::>(); let expected = data.clone(); - afft(&mut data, N, N / 4); + unsafe { afft(&mut data, N, N / 4) }; // make sure something is done assert!(data.iter().zip(expected.iter()).filter(|(a, b)| { a != b }).count() > 0); - inverse_afft(&mut data, N, N / 4); + unsafe { inverse_afft(&mut data, N, N / 4) }; itertools::assert_equal(data, expected); } @@ -313,7 +313,7 @@ fn flt_roundtrip_small() { let mut data = EXPECTED; - f2e16::afft(&mut data, N, N / 4); + unsafe { f2e16::afft(&mut data, N, N / 4) }; println!("novel basis(rust):"); data.iter().for_each(|sym| { @@ -321,7 +321,7 @@ fn flt_roundtrip_small() { }); println!(); - f2e16::inverse_afft(&mut data, N, N / 4); + unsafe { f2e16::inverse_afft(&mut data, N, N / 4) }; itertools::assert_equal(data.iter(), EXPECTED.iter()); } @@ -351,12 +351,7 @@ fn ported_c_test() { //---------encoding---------- let mut codeword = [Additive(0); N]; - if K + K > N && false { - let (data_till_t, data_skip_t) = data.split_at_mut(N - K); - f2e16::encode_high(data_skip_t, K, data_till_t, &mut codeword[..], N); - } else { - f2e16::encode_low(&data[..], K, &mut codeword[..], N); - } + f2e16::encode_low(&data[..], K, &mut codeword[..], N); // println!("Codeword:"); // for i in K..(K+100) {