diff --git a/sha1/src/compress.rs b/sha1/src/compress.rs index d1032bbf..72d0a393 100644 --- a/sha1/src/compress.rs +++ b/sha1/src/compress.rs @@ -4,7 +4,11 @@ cfg_if::cfg_if! { if #[cfg(feature = "force-soft")] { mod soft; use soft::compress as compress_inner; - } else if #[cfg(target_arch = "loongarch64")] { + } else if #[cfg(all(target_arch = "aarch64"))] { + mod soft; + mod aarch64; + use aarch64::compress as compress_inner; + } else if #[cfg(target_arch = "loongarch64")] { mod loongarch64_asm; use loongarch64_asm::compress as compress_inner; } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { diff --git a/sha1/src/compress/aarch64.rs b/sha1/src/compress/aarch64.rs index 5952d1f6..f757f7f2 100644 --- a/sha1/src/compress/aarch64.rs +++ b/sha1/src/compress/aarch64.rs @@ -1,5 +1,7 @@ //! SHA-1 `aarch64` backend. +use crate::K; + // Per rustc target feature docs for `aarch64-unknown-linux-gnu` and // `aarch64-apple-darwin` platforms, the `sha2` target feature enables // SHA-1 as well: @@ -7,11 +9,175 @@ // > Enable SHA1 and SHA256 support. cpufeatures::new!(sha1_hwcap, "sha2"); +// note that `sha2` implicitly enables `neon` +#[target_feature(enable = "sha2")] +unsafe fn compress_sha1_neon(state: &mut [u32; 5], blocks: &[[u8; 64]]) { + use core::arch::aarch64::*; + + let mut abcd = vld1q_u32(state.as_ptr()); + let mut e0 = state[4]; + let [k0, k1, k2, k3] = K.map(|k| vdupq_n_u32(k)); + let (mut e1, mut tmp0, mut tmp1); + + for block in blocks { + let abcd_cpy = abcd; + let e0_cpy = e0; + + // Load and reverse byte order + let [mut msg0, mut msg1, mut msg2, mut msg3] = [0, 1, 2, 3].map(|i| { + let p = block.as_ptr().add(16 * i); + vreinterpretq_u32_u8(vrev32q_u8(vld1q_u8(p))) + }); + + tmp0 = vaddq_u32(msg0, k0); + tmp1 = vaddq_u32(msg1, k0); + + // Rounds 0-3 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1cq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg2, k0); + msg0 = vsha1su0q_u32(msg0, msg1, msg2); + + // Rounds 4-7 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1cq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg3, k0); + msg0 = vsha1su1q_u32(msg0, msg3); + msg1 = vsha1su0q_u32(msg1, msg2, msg3); + + // Rounds 8-11 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1cq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg0, k0); + msg1 = vsha1su1q_u32(msg1, msg0); + msg2 = vsha1su0q_u32(msg2, msg3, msg0); + + // Rounds 12-15 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1cq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg1, k1); + msg2 = vsha1su1q_u32(msg2, msg1); + msg3 = vsha1su0q_u32(msg3, msg0, msg1); + + // Rounds 16-19 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1cq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg2, k1); + msg3 = vsha1su1q_u32(msg3, msg2); + msg0 = vsha1su0q_u32(msg0, msg1, msg2); + + // Rounds 20-23 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg3, k1); + msg0 = vsha1su1q_u32(msg0, msg3); + msg1 = vsha1su0q_u32(msg1, msg2, msg3); + + // Rounds 24-27 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg0, k1); + msg1 = vsha1su1q_u32(msg1, msg0); + msg2 = vsha1su0q_u32(msg2, msg3, msg0); + + // Rounds 28-31 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg1, k1); + msg2 = vsha1su1q_u32(msg2, msg1); + msg3 = vsha1su0q_u32(msg3, msg0, msg1); + + // Rounds 32-35 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg2, k2); + msg3 = vsha1su1q_u32(msg3, msg2); + msg0 = vsha1su0q_u32(msg0, msg1, msg2); + + // Rounds 36-39 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg3, k2); + msg0 = vsha1su1q_u32(msg0, msg3); + msg1 = vsha1su0q_u32(msg1, msg2, msg3); + + // Rounds 40-43 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1mq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg0, k2); + msg1 = vsha1su1q_u32(msg1, msg0); + msg2 = vsha1su0q_u32(msg2, msg3, msg0); + + // Rounds 44-47 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1mq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg1, k2); + msg2 = vsha1su1q_u32(msg2, msg1); + msg3 = vsha1su0q_u32(msg3, msg0, msg1); + + // Rounds 48-51 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1mq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg2, k2); + msg3 = vsha1su1q_u32(msg3, msg2); + msg0 = vsha1su0q_u32(msg0, msg1, msg2); + + // Rounds 52-55 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1mq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg3, k3); + msg0 = vsha1su1q_u32(msg0, msg3); + msg1 = vsha1su0q_u32(msg1, msg2, msg3); + + // Rounds 56-59 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1mq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg0, k3); + msg1 = vsha1su1q_u32(msg1, msg0); + msg2 = vsha1su0q_u32(msg2, msg3, msg0); + + // Rounds 60-63 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg1, k3); + msg2 = vsha1su1q_u32(msg2, msg1); + msg3 = vsha1su0q_u32(msg3, msg0, msg1); + + // Rounds 64-67 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e0, tmp0); + tmp0 = vaddq_u32(msg2, k3); + msg3 = vsha1su1q_u32(msg3, msg2); + + // Rounds 68-71 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + tmp1 = vaddq_u32(msg3, k3); + + // Rounds 72-75 + e1 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e0, tmp0); + + // Rounds 76-79 + e0 = vsha1h_u32(vgetq_lane_u32(abcd, 0)); + abcd = vsha1pq_u32(abcd, e1, tmp1); + + // Update state + abcd = vaddq_u32(abcd_cpy, abcd); + e0 = e0.wrapping_add(e0_cpy); + } + + // Save state + vst1q_u32(state.as_mut_ptr(), abcd); + state[4] = e0; +} + pub fn compress(state: &mut [u32; 5], blocks: &[[u8; 64]]) { - // TODO: Replace with https://github.com/rust-lang/rfcs/pull/2725 - // after stabilization + // TODO: Replace with https://github.com/rust-lang/rfcs/pull/2725 after stabilization if sha1_hwcap::get() { - sha1_asm::compress(state, blocks); + unsafe { + compress_sha1_neon(state, blocks); + } } else { super::soft::compress(state, blocks); } diff --git a/sha1/src/compress/loongarch64_asm.rs b/sha1/src/compress/loongarch64_asm.rs index facef1b1..69133e0d 100644 --- a/sha1/src/compress/loongarch64_asm.rs +++ b/sha1/src/compress/loongarch64_asm.rs @@ -1,9 +1,8 @@ //! LoongArch64 assembly backend +use crate::K; use core::arch::asm; -const K: [u32; 4] = [0x5A827999, 0x6ED9EBA1, 0x8F1BBCDC, 0xCA62C1D6]; - macro_rules! c { ($($l:expr)*) => { concat!($($l ,)*) diff --git a/sha1/src/compress/soft.rs b/sha1/src/compress/soft.rs index 0e28e5e2..b2a1be7b 100644 --- a/sha1/src/compress/soft.rs +++ b/sha1/src/compress/soft.rs @@ -1,7 +1,6 @@ #![allow(clippy::many_single_char_names)] use super::BLOCK_SIZE; - -const K: [u32; 4] = [0x5A827999, 0x6ED9EBA1, 0x8F1BBCDC, 0xCA62C1D6]; +use crate::K; #[inline(always)] fn add(a: [u32; 4], b: [u32; 4]) -> [u32; 4] { diff --git a/sha1/src/lib.rs b/sha1/src/lib.rs index e77e9bfd..5361f604 100644 --- a/sha1/src/lib.rs +++ b/sha1/src/lib.rs @@ -33,6 +33,9 @@ pub use compress::compress; const STATE_LEN: usize = 5; const BLOCK_SIZE: usize = ::BlockSize::USIZE; +const H0: [u32; STATE_LEN] = [0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0]; +#[allow(dead_code)] +const K: [u32; 4] = [0x5A827999, 0x6ED9EBA1, 0x8F1BBCDC, 0xCA62C1D6]; /// Core SHA-1 hasher state. #[derive(Clone)] @@ -85,7 +88,7 @@ impl Default for Sha1Core { #[inline] fn default() -> Self { Self { - h: [0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0], + h: H0, block_len: 0, } }