Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work on BIP340 #420

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions onchain/cairo/afk/src/bip340.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use core::ec::{EcPoint, EcPointTrait, ec_point_unwrap, NonZeroEcPoint, EcState,
use core::poseidon::PoseidonTrait;
use core::hash::{HashStateTrait, HashStateExTrait};
use core::math::u256_mul_mod_n;
use super::request::SocialRequest;
use afk::utils::{shl, shr, compute_sha256_byte_array};
//! bip340 implementation

Expand Down Expand Up @@ -38,6 +39,31 @@ impl PartialEqImpl of PartialEq<EcPoint> {
let (lhs_x, lhs_y): (felt252, felt252) = ec_point_unwrap((*lhs).try_into().unwrap());
let (rhs_x, rhs_y): (felt252, felt252) = ec_point_unwrap((*rhs).try_into().unwrap());

/// Represents a Schnorr signature
struct SchnorrSignature {
s: u256,
R: u256,
}
// pub impl EcPointDisplay of Display<Secp256k1Point> {
// fn fmt(self: @EcPoint, ref f: Formatter) -> Result<(), Error> {
// let non_zero: NonZeroEcPoint = (*self).try_into().unwrap();
// let (x, y): (u256, u256) = ec_point_unwrap(non_zero);
// writeln!(f, "Point ({x}, {y})")
// }
// }

// impl PartialEqImpl of PartialEq<EcPoint> {
// fn eq(lhs: @EcPoint, rhs: @EcPoint) -> bool {
// let (lhs_x, lhs_y): (u256, u256) = ec_point_unwrap((*lhs).try_into().unwrap());
// let (rhs_x, rhs_y): (felt252, felt252) = ec_point_unwrap((*rhs).try_into().unwrap());

// if ((rhs_x == lhs_x) && (rhs_y == lhs_y)) {
// true
// } else {
// false
// }
// }
// }
if ((rhs_x == lhs_x) && (rhs_y == lhs_y)) {
true
} else {
Expand Down Expand Up @@ -219,6 +245,120 @@ fn verify_sig(public_key: EcPoint, message: felt252, signature: SchnorrSignature
(rhs_x == s_Gx) && (s_Gy == rhs_y)
}

fn count_digits(mut num: u256) -> (u32, felt252) {
let mut count: u32 = 0;
while num > 0 {
num = num / BASE;
count = count + 1;
};
let res: felt252 = count.try_into().unwrap();
(count, res)
}

fn encodeSocialRequest<C>(request: SocialRequest<C>) -> ByteArray {
let mut ba: ByteArray = "";

// Encode public_key
let (pk_count, pk_count_felt252) = count_digits(request.public_key);
let pk_felt252: felt252 = request.public_key.try_into().unwrap();
ba.append_word(pk_count_felt252, 1_u32);
ba.append_word(pk_felt252, pk_count);

// Encode created_at
let created_at_u256: u256 = request.created_at.into();
let (created_count, created_count_felt252) = count_digits(created_at_u256);
let created_felt252: felt252 = created_at_u256.try_into().unwrap();
ba.append_word(created_count_felt252, 1_u32);
ba.append_word(created_felt252, created_count);

// Encode kind
let kind_u256: u256 = request.kind.into();
let (kind_count, kind_count_felt252) = count_digits(kind_u256);
let kind_felt252: felt252 = kind_u256.try_into().unwrap();
ba.append_word(kind_count_felt252, 1_u32);
ba.append_word(kind_felt252, kind_count);

// Encode tags directly
ba.append(request.tags);

// Encode content (assuming it can be converted to ByteArray) check needed
let content_bytes = ByteArray::from(request.content);
ba.append(content_bytes);

let (rx, _) = request.sig.R.get_coordinates().unwrap_syscall();
ba.append_word(rx.high.into(), 16);
ba.append_word(rx.low.into(), 16);
ba.append_word(request.sig.s.high.into(), 16);
ba.append_word(request.sig.s.low.into(), 16);

ba
}

/// Generates a key pair (private key, public key) for Schnorr signatures
fn generate_keypair() -> (u256, Secp256k1Point) {
let G = Secp256Trait::<Secp256k1Point>::get_generator_point();
let private_key: u256 = 0x859825214214312162317391210310_u256; // VRF needed
let public_key = G.mul(private_key).unwrap_syscall();

(private_key, public_key)
}

/// Generates a nonce and corresponding R point for signature
fn generate_nonce_point() -> (u256, Secp256k1Point) {
let G = Secp256Trait::<Secp256k1Point>::get_generator_point();
let nonce: u256 = 0x46952909012476409278523962123414653_u256; // VRF needed
let R = G.mul(nonce).unwrap_syscall();

(nonce, R)
}

/// Computes the challenge hash e using Poseidon
fn compute_challenge(R: EcPoint, public_key: EcPoint, message: ByteArray) -> felt252 {
let (rx, _) = R.get_coordinates().unwrap_syscall();
let (px, _) = public_key.get_coordinates().unwrap_syscall();

hash_challenge(rx, px, message)

}

/fn sign(private_key: u256, message: ByteArray) -> SchnorrSignature {
let (nonce, R) = generate_nonce_point();
let G = Secp256Trait::<Secp256k1Point>::get_generator_point();
let public_key = G.mul(private_key).unwrap_syscall();
let (s_G_x, s_G_y) = public_key.get_coordinates().unwrap_syscall();
let s_G_x = r;
let e = compute_challenge(R, public_key, message);
let n = Secp256Trait::<Secp256k1Point>::get_curve_size();

// s = nonce + private_key * e mod n
let s = (nonce + (private_key * e)) % n;

SchnorrSignature { s, r }
}

/// Verifies a Schnorr signature
fn verify_sig(public_key: Secp256k1Point, message: u256, signature: SchnorrSignature) -> bool {
let G = Secp256Trait::<Secp256k1Point>::get_generator_point();
let e = compute_challenge(signature.R, public_key, message);
let n = Secp256Trait::<Secp256k1Point>::get_curve_size();

// Check that s is within valid range
if signature.s >= n {
return false;
}
// Verify s⋅G = R + e⋅P
let s_G = G.mul(signature.s).unwrap_syscall();
let e_P = public_key.mul(e).unwrap_syscall();
let R_plus_eP = signature.R.add(e_P).unwrap_syscall();

// Compare the points
let (s_G_x, s_G_y) = s_G.get_coordinates().unwrap_syscall();
let (rhs_x, rhs_y) = R_plus_eP.get_coordinates().unwrap_syscall();

s_G_x == rhs_x && s_G_y == rhs_y
}


#[cfg(test)]
mod tests {
use core::byte_array::ByteArrayTrait;
Expand Down
Loading