Skip to content

Commit

Permalink
refactor: add bitwise_memcpy
Browse files Browse the repository at this point in the history
abstract copy and shift bits logic
  • Loading branch information
gabriele-0201 committed Nov 19, 2024
1 parent 8d54511 commit fdd04bb
Showing 1 changed file with 126 additions and 48 deletions.
174 changes: 126 additions & 48 deletions nomt/src/beatree/ops/bit_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,64 +64,141 @@ pub fn reconstruct_key(maybe_prefix: Option<RawPrefix>, separator: RawSeparator)
let (separator_bytes, separator_bit_init, separator_bit_len) = separator;

// where the separator will start to be stored
let mut key_offset = match prefix_byte_len {
let init_destination = match prefix_byte_len {
0 => 0,
len if prefix_end_bit_offset == 0 => len,
// overlap between the end of the prefix and the beginning of the separator
len => len - 1,
};

// shift the separator and store it in the key
bitwise_memcpy(
&mut key[init_destination..],
prefix_end_bit_offset,
separator_bytes,
separator_bit_init,
separator_bit_len,
);

if prefix_byte_len != 0 {
// UNWRAP: prefix_byte_len can be different than 0 only if maybe_prefix is Some
let prefix = maybe_prefix.unwrap();

// copy the prefix into the key up to the penultimate byte
key[0..prefix_byte_len - 1].copy_from_slice(&prefix.0[0..prefix_byte_len - 1]);

// copy the last byte of the prefix without interfering with the separator
let mask_shift = 8 - (prefix_bit_len % 8) as u32;
let mask = 1u8.checked_shl(mask_shift).map(|m| !(m - 1)).unwrap_or(255);
key[prefix_byte_len - 1] |= prefix.0[prefix_byte_len - 1] & mask;
}

key
}

fn first_chunk_mask(bit_init: usize) -> u64 {
let mask_shift = (7 - bit_init) as u32 + 1 + (8 * 7);
1u64.checked_shl(mask_shift)
.map(|m| m - 1)
.unwrap_or(u64::MAX)
}

fn last_chunk_mask(bit_init: usize, bit_len: usize, n_chunks: usize) -> u64 {
let unused_last_bits = (bit_init + bit_len).saturating_sub((n_chunks - 1) * 64);
// saturating_sub is necessary to prevent overflow in the rare case in which a right
// shift may shift some bits onto the `n_chunks` chunk, leading to more than 64 bits
// on the `n_chunks - 1` chunk.
1u64.checked_shl(64u32.saturating_sub(unused_last_bits as u32))
.map(|m| !(m - 1))
.unwrap_or(0)
}

/// Copy `source_bit_len` from the source bytes, starting from the `source_bit_init`ith bit
/// within the first byte into the destination, starting from the `destination_bit_init`ith bit.
///
/// `destination` must be long enough to store the bits, accounting also possible shifts.
/// `source` must have a length multiple of 8 bytes.
///
/// All the bits in `destination` not involved in the copy will be left unchanged.
pub fn bitwise_memcpy(
destination: &mut [u8],
destination_bit_init: usize,
source: &[u8],
source_bit_init: usize,
source_bit_len: usize,
) {
if source_bit_len == 0 {
return;
}

enum Shift {
Left(usize), // amount
Right(usize, Option<u8>, Option<u8>), // amount, prev_remainder, curr_remainder
}

let mut shift = match prefix_end_bit_offset as isize - separator_bit_init as isize {
let mut shift = match destination_bit_init as isize - source_bit_init as isize {
0 => None,
shift if shift < 0 => Some(Shift::Left(shift.abs() as usize)),
shift => Some(Shift::Right(shift as usize, None, None)),
};

// chunk is an 8-byte slice of the separator which will be cast to a u64 to simplify shifting
let n_chunks = separator_bytes.len() / 8;
let n_chunks = source.len() / 8;
let bytes_to_write = (destination_bit_init + source_bit_len + 7) / 8;

let last_chunk_mask = || -> u64 {
let unused_last_bits = separator_bit_init + separator_bit_len - ((n_chunks - 1) * 64);
1u64.checked_shl(64 - unused_last_bits as u32)
.map(|m| !(m - 1))
.unwrap_or(0)
};
// container of all bits that will not be overwritten
let mut chunk_data: Option<u64> = None;

let mut destination_offset = 0;
for chunk_index in 0..n_chunks {
let chunk_start = chunk_index * 8;

if let Some(Shift::Right(amount, _prev_remainder, curr_remainder)) = &mut shift {
// store bits that will be covered by the right shift
let mask = (1 << *amount) - 1;
let bits = separator_bytes[chunk_start + 7] & mask;
let bits = source[chunk_start + 7] & mask;
*curr_remainder = Some(bits << (8 - *amount));
}

let mut chunk = u64::from_be_bytes(
separator_bytes[chunk_start..chunk_start + 8]
.try_into()
.unwrap(),
);
let mut chunk =
u64::from_be_bytes(source[chunk_start..chunk_start + 8].try_into().unwrap());

// maybe_masks could contain two masks:
// + mask_from, used to extract only the useful bits from the source bytes
// + mask_to, used to keep untouched bits unchanged in the destination
let mut maybe_masks = None;
if chunk_index == 0 {
// first chunk will probably have garbage at the beginning of the first byte
let mask_shift = (7 - separator.1) as u32 + 1 + (8 * 7);
let mask = 1u64
.checked_shl(mask_shift)
.map(|m| m - 1)
.unwrap_or(u64::MAX);

chunk &= mask;
maybe_masks = Some((
first_chunk_mask(source_bit_init),
first_chunk_mask(destination_bit_init),
))
}

if chunk_index == n_chunks - 1 {
// last chunk will probably have garbage at end
chunk &= last_chunk_mask();
let (mask_from, mask_to) = maybe_masks.take().unwrap_or((u64::MAX, u64::MAX));
maybe_masks = Some((
mask_from & last_chunk_mask(source_bit_init, source_bit_len, n_chunks),
mask_to & last_chunk_mask(destination_bit_init, source_bit_len, n_chunks),
));
}

if let Some((mask_from, mask_to)) = maybe_masks {
let destination_chunk = {
// the destination could have less then 8 bytes available
let n_byte = std::cmp::min(8, bytes_to_write - destination_offset);
let destination_bytes = if n_byte < 8 {
let mut buf = [0u8; 8];
buf[..n_byte].copy_from_slice(
&destination[destination_offset..destination_offset + n_byte],
);
buf
} else {
destination[destination_offset..destination_offset + n_byte]
.try_into()
.unwrap()
};
u64::from_be_bytes(destination_bytes)
};
chunk_data = Some(destination_chunk & !mask_to);
chunk &= mask_from;
}

match &mut shift {
Expand All @@ -130,6 +207,10 @@ pub fn reconstruct_key(maybe_prefix: Option<RawPrefix>, separator: RawSeparator)
_ => (),
};

if let Some(data) = chunk_data.take() {
chunk |= data;
}

let mut chunk_shifted = chunk.to_be_bytes();

// move bits remainder between chunk boundaries
Expand All @@ -138,11 +219,11 @@ pub fn reconstruct_key(maybe_prefix: Option<RawPrefix>, separator: RawSeparator)
// this mask removes possible garbage from the last remainder
let mut mask = 255;
if n_chunks > 1 && chunk_index == n_chunks - 2 {
mask = last_chunk_mask().to_be_bytes()[0];
mask =
last_chunk_mask(source_bit_init, source_bit_len, n_chunks).to_be_bytes()[0];
}

let remainder_bits =
(separator_bytes[(chunk_index + 1) * 8] & mask) >> (8 - *amount);
let remainder_bits = (source[(chunk_index + 1) * 8] & mask) >> (8 - *amount);

chunk_shifted[7] |= remainder_bits;
}
Expand All @@ -156,30 +237,27 @@ pub fn reconstruct_key(maybe_prefix: Option<RawPrefix>, separator: RawSeparator)
};

// store the shifted chunk into the key
let n_byte = std::cmp::min(8, 32 - key_offset);
key[key_offset..key_offset + n_byte].copy_from_slice(&chunk_shifted[..n_byte]);
key_offset += n_byte;
let n_byte = std::cmp::min(8, destination.len() - destination_offset);
destination[destination_offset..destination_offset + n_byte]
.copy_from_slice(&chunk_shifted[..n_byte]);
destination_offset += n_byte;

// break if the separtor is already entirely being written
if key_offset == 32 {
// break if the destination is already entirely being written
if destination_offset >= bytes_to_write {
break;
}
}

if prefix_byte_len != 0 {
// UNWRAP: prefix_byte_len can be different than 0 only if maybe_prefix is Some
let prefix = maybe_prefix.unwrap();

// copy the prefix into the key up to the penultimate byte
key[0..prefix_byte_len - 1].copy_from_slice(&prefix.0[0..prefix_byte_len - 1]);

// copy the last byte of the prefix without interfering with the separator
let mask_shift = 8 - (prefix_bit_len % 8) as u32;
let mask = 1u8.checked_shl(mask_shift).map(|m| !(m - 1)).unwrap_or(255);
key[prefix_byte_len - 1] |= prefix.0[prefix_byte_len - 1] & mask;
// handle possible right remainder left unapplied
if destination_offset < bytes_to_write {
// There could be at most one byte which still needs to be written
assert_eq!(bytes_to_write - destination_offset, 1);
if let Some(Shift::Right(amount, Some(prev_remainder), _)) = &mut shift {
let mask = (1 << (8 - *amount)) - 1;
destination[destination_offset] &= mask;
destination[destination_offset] |= *prev_remainder;
}
}

key
}

#[cfg(feature = "benchmarks")]
Expand Down

0 comments on commit fdd04bb

Please sign in to comment.