Skip to content

Commit

Permalink
lehmer: borrow lhs/rhs to be able to claim them back later
Browse files Browse the repository at this point in the history
Removes the need to use raw pointers and fixes undefined behavior detected by miri
  • Loading branch information
eduardosm committed Mar 23, 2024
1 parent ae3b576 commit 2341ab5
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions integer/src/gcd/lehmer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use alloc::alloc::Layout;
use core::{mem, ptr, slice};
use core::borrow::BorrowMut as _;
use core::mem;
use dashu_base::{ExtendedGcd, Gcd};

use crate::{
Expand Down Expand Up @@ -347,12 +348,12 @@ pub fn gcd_ext_in_place(
rhs: &mut [Word],
memory: &mut Memory,
) -> (usize, usize, Sign) {
let (lhs_len, rhs_len) = (lhs.len(), rhs.len());
let (lhs_ptr, rhs_ptr) = (lhs.as_mut_ptr(), rhs.as_mut_ptr());
let lhs_len = lhs.len();

// keep x >= y though the algorithm, and track the source of x and y using the swapped flag
debug_assert!(cmp_in_place(lhs, rhs).is_ge());
let (mut x, mut y) = (lhs, rhs);
// Use `borrow_mut` to be able to claim back `lhs` and `rhs`
let (mut x, mut y) = (lhs.borrow_mut(), rhs.borrow_mut());
let mut swapped = false;

// the normal way is to have four variables s0, s1, t0, t1 and keep gcd(x, y) = gcd(lhs, rhs),
Expand Down Expand Up @@ -449,25 +450,19 @@ pub fn gcd_ext_in_place(
// If y is zero, then the gcd result is in x now.
// Note that y.len() == 0 is equivalent to y == 0, which is guaranteed by trim_leading_zeros.
if y.is_empty() {
// SAFETY: see the comments in the block. The safety here need to be carefully managed.
unsafe {
if !swapped {
// if not swapped, then x is originated from lhs, copy it to rhs
debug_assert!(x.as_ptr() == lhs_ptr);
debug_assert!(x.len() <= rhs_len);

// SAFETY: at this point, x should be from lhs, so it's not overlapping with rhs
ptr::copy_nonoverlapping(x.as_ptr(), rhs_ptr, x.len());
}
// SAFETY: t0 is temporarily allocated space, it won't overlap with lhs or rhs
ptr::copy_nonoverlapping(t0.as_ptr(), lhs_ptr, t0_len);
let x_len = x.len();
// We are not using the borrwed `x` / `y` anymore, so we can
// claim back the original `lhs` / `rhs`.
if !swapped {
rhs[..x_len].copy_from_slice(&lhs[..x_len]);
}
lhs[..t0_len].copy_from_slice(&t0[..t0_len]);
let sign = if swapped {
Sign::Positive
} else {
Sign::Negative
};
return (x.len(), t0_len, sign);
return (x_len, t0_len, sign);
}

// before forwarding to single word gcd, first reduce x by y:
Expand All @@ -491,15 +486,8 @@ pub fn gcd_ext_in_place(
// let lhs stores |b| = |cx| * t0 + |cy| * t1
// by now, number of words in |b| should be close to lhs

// SAFETY: we don't hold any reference to lhs and rhs now, so there will be no
// data racing. The pointer and length are from the original slice, so the slice
// will be valid.
let (lhs, rhs) = unsafe {
(
slice::from_raw_parts_mut(lhs_ptr, lhs_len),
slice::from_raw_parts_mut(rhs_ptr, rhs_len),
)
};
// We are not using the borrwed `x` / `y` anymore, so we can
// claim back the original `lhs` / `rhs`.
*rhs.first_mut().unwrap() = g_word;
lhs.fill(0);

Expand Down

0 comments on commit 2341ab5

Please sign in to comment.