Skip to content

Commit

Permalink
make shl/shr and fastpow generics
Browse files Browse the repository at this point in the history
  • Loading branch information
TAdev0 committed Aug 9, 2024
1 parent 9f1e27c commit 8ecb8cc
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 23 deletions.
94 changes: 77 additions & 17 deletions src/utils.cairo
Original file line number Diff line number Diff line change
@@ -1,34 +1,94 @@
use core::traits::Into;
use core::traits::TryInto;
use core::num::traits::{Zero, One, BitSize};
use core::panic_with_felt252;
use core::starknet::secp256_trait::Secp256PointTrait;

// Bitwise shift left for u256
pub fn shl(value: u256, shift: u32) -> u256 {
value * fast_pow(2.into(), shift.into())

pub trait Bitshift<T, U> {
fn shl(self: T, shift: U) -> T;
fn shr(self: T, shift: U) -> T;
}

// Bitwise shift right for u256
pub fn shr(value: u256, shift: u32) -> u256 {
value / fast_pow(2.into(), shift.into())
pub impl BitshiftImpl<
T,
U,
+Zero<T>,
+Zero<U>,
+One<T>,
+One<U>,
+Add<T>,
+Add<U>,
+Sub<T>,
+Sub<U>,
+Div<T>,
+Mul<T>,
+Rem<U>,
+Div<U>,
+Copy<T>,
+Copy<U>,
+Drop<T>,
+Drop<U>,
+PartialOrd<T>,
+PartialOrd<U>,
+PartialEq<U>,
+BitSize<T>,
+TryInto<usize, T>,
+Into<usize, U>
> of Bitshift<T, U> {
fn shl(self: T, shift: U) -> T {
if shift > BitSize::<T>::bits().into() - One::one() {
panic_with_felt252('mul Overflow');
}
let two = One::one() + One::one();
self * fast_pow(two, shift)
}

fn shr(self: T, shift: U) -> T {
if shift > BitSize::<T>::bits().try_into().unwrap() - One::one() {
panic_with_felt252('mul Overflow');
}
let two = One::one() + One::one();
self / fast_pow(two, shift)
}
}

// Fast exponentiation using the square-and-multiply algorithm
// Reference:
// https://github.com/keep-starknet-strange/alexandria/blob/bcdca70afdf59c9976148e95cebad5cf63d75a7f/packages/math/src/fast_power.cairo#L12
pub fn fast_pow(base: u256, exp: u32) -> u256 {
if exp == 0 {
return 1_u256;
pub fn fast_pow<
T,
U,
+Copy<U>,
+Copy<T>,
+Drop<T>,
+Drop<U>,
+Zero<T>,
+One<T>,
+Zero<U>,
+One<U>,
+PartialEq<U>,
+Add<U>,
+Mul<T>,
+Rem<U>,
+Div<U>
>(
base: T, exp: U
) -> T {
if exp == Zero::zero() {
return Zero::zero();
}

let mut res: u256 = 1_u256;
let mut base: u256 = base;
let mut exp: u32 = exp;
let mut res: T = One::one();
let mut base: T = base;
let mut exp: U = exp;

let two: U = One::one() + One::one();

loop {
if exp % 2 == 1 {
if exp % two == One::one() {
res = res * base;
}
exp = exp / 2;
if exp == 0 {
exp = exp / two;
if exp == Zero::zero() {
break res;
}
base = base * base;
Expand Down
12 changes: 6 additions & 6 deletions src/validation.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::state::{Block, ChainState, Transaction, UtreexoState};
use super::utils::{shl, shr};
use super::utils::Bitshift;

const MAX_TARGET: u256 = 0x00000000FFFF0000000000000000000000000000000000000000000000000000;

Expand Down Expand Up @@ -126,10 +126,10 @@ pub fn bits_to_target(bits: u32) -> Result<u256, felt252> {
} else if exponent <= 3 {
// For exponents 1, 2, and 3, divide by 256^(3 - exponent) i.e right shift
let shift = 8 * (3 - exponent);
target = shr(target, shift);
target = Bitshift::shr(target, shift);
} else {
let shift = 8 * (exponent - 3);
target = shl(target, shift);
target = Bitshift::shl(target, shift);
}

// Ensure the target doesn't exceed the maximum allowed value
Expand All @@ -154,12 +154,12 @@ pub fn target_to_bits(target: u256) -> Result<u32, felt252> {
let mut compact = target;

// Count leading zero bytes by finding the first non-zero byte
while size > 1 && shr(compact, (size - 1) * 8) == 0 {
while size > 1 && Bitshift::shr(compact, (size - 1) * 8) == 0 {
size -= 1;
};

// Extract mantissa (most significant 3 bytes)
let mut mantissa: u32 = shr(compact, (size - 3) * 8).try_into().unwrap();
let mut mantissa: u32 = Bitshift::shr(compact, (size - 3) * 8).try_into().unwrap();

// Normalize
if mantissa > 0x7fffff {
Expand All @@ -179,7 +179,7 @@ pub fn target_to_bits(target: u256) -> Result<u32, felt252> {
let size_u256: u256 = size.into();

// Combine size and mantissa
let result: u32 = (shl(size_u256, 24) + mantissa.into()).try_into().unwrap();
let result: u32 = (Bitshift::shl(size_u256, 24_u32) + mantissa.into()).try_into().unwrap();

Result::Ok(result)
}
Expand Down

0 comments on commit 8ecb8cc

Please sign in to comment.