Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved SSE2 throughput by up to +90% #380

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chacha20/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ work on stable Rust with the following `RUSTFLAGS`:

- `x86` / `x86_64`
- `avx2`: (~1.4cpb) `-Ctarget-cpu=haswell -Ctarget-feature=+avx2`
- `sse2`: (~2.5cpb) `-Ctarget-feature=+sse2` (on by default on x86 CPUs)
- `sse2`: (~1.6cpb) `-Ctarget-feature=+sse2` (on by default on x86 CPUs)
- `aarch64`
- `neon` (~2-3x faster than `soft`) requires Rust 1.61+ and the `neon` feature enabled
- Portable
Expand Down
122 changes: 76 additions & 46 deletions chacha20/src/backends/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{ChaChaCore, Variant};
use crate::{chacha::Block, STATE_WORDS};
#[cfg(feature = "cipher")]
use cipher::{
consts::{U1, U64},
consts::{U4, U64},
BlockSizeUser, ParBlocksSizeUser, StreamCipherBackend, StreamCipherClosure,
};
use core::marker::PhantomData;
Expand All @@ -17,6 +17,8 @@ use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

const PAR_BLOCKS: usize = 4;

#[inline]
#[target_feature(enable = "sse2")]
#[cfg(feature = "cipher")]
Expand Down Expand Up @@ -53,7 +55,7 @@ impl<R: Rounds> BlockSizeUser for Backend<R> {

#[cfg(feature = "cipher")]
impl<R: Rounds> ParBlocksSizeUser for Backend<R> {
type ParBlocksSize = U1;
type ParBlocksSize = U4;
}

#[cfg(feature = "cipher")]
Expand All @@ -66,7 +68,21 @@ impl<R: Rounds> StreamCipherBackend for Backend<R> {

let block_ptr = block.as_mut_ptr() as *mut __m128i;
for i in 0..4 {
_mm_storeu_si128(block_ptr.add(i), res[i]);
_mm_storeu_si128(block_ptr.add(i), res[0][i]);
}
}
}
#[inline(always)]
fn gen_par_ks_blocks(&mut self, blocks: &mut cipher::ParBlocks<Self>) {
unsafe {
let res = rounds::<R>(&self.v);
self.v[3] = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, PAR_BLOCKS as i32));

let blocks_ptr = blocks.as_mut_ptr() as *mut __m128i;
for block in 0..PAR_BLOCKS {
for i in 0..4 {
_mm_storeu_si128(blocks_ptr.add(i + block * PAR_BLOCKS), res[block][i]);
}
}
}
}
Expand All @@ -91,47 +107,55 @@ where
_pd: PhantomData,
};

for i in 0..4 {
backend.gen_ks_block(&mut buffer[i << 4..(i + 1) << 4]);
}
backend.gen_ks_blocks(buffer);

core.state[12] = _mm_cvtsi128_si32(backend.v[3]) as u32;
}

#[cfg(feature = "rng")]
impl<R: Rounds> Backend<R> {
#[inline(always)]
fn gen_ks_block(&mut self, block: &mut [u32]) {
fn gen_ks_blocks(&mut self, block: &mut [u32]) {
unsafe {
let res = rounds::<R>(&self.v);
self.v[3] = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, 1));
self.v[3] = _mm_add_epi32(self.v[3], _mm_set_epi32(0, 0, 0, PAR_BLOCKS as i32));

let block_ptr = block.as_mut_ptr() as *mut __m128i;
for i in 0..4 {
_mm_storeu_si128(block_ptr.add(i), res[i]);
let blocks_ptr = block.as_mut_ptr() as *mut __m128i;
for block in 0..PAR_BLOCKS {
for i in 0..4 {
_mm_storeu_si128(blocks_ptr.add(i + block * PAR_BLOCKS), res[block][i]);
}
}
}
}
}

#[inline]
#[target_feature(enable = "sse2")]
unsafe fn rounds<R: Rounds>(v: &[__m128i; 4]) -> [__m128i; 4] {
let mut res = *v;
unsafe fn rounds<R: Rounds>(v: &[__m128i; 4]) -> [[__m128i; 4]; PAR_BLOCKS] {
let mut res = [*v; 4];
for block in 1..PAR_BLOCKS {
res[block][3] = _mm_add_epi32(res[block][3], _mm_set_epi32(0, 0, 0, block as i32));
}

for _ in 0..R::COUNT {
double_quarter_round(&mut res);
}

for i in 0..4 {
res[i] = _mm_add_epi32(res[i], v[i]);
for block in 0..PAR_BLOCKS {
for i in 0..4 {
res[block][i] = _mm_add_epi32(res[block][i], v[i]);
}
// add the counter since `v` is lacking updated counter values
res[block][3] = _mm_add_epi32(res[block][3], _mm_set_epi32(0, 0, 0, block as i32));
}

res
}

#[inline]
#[target_feature(enable = "sse2")]
unsafe fn double_quarter_round(v: &mut [__m128i; 4]) {
unsafe fn double_quarter_round(v: &mut [[__m128i; 4]; PAR_BLOCKS]) {
add_xor_rot(v);
rows_to_cols(v);
add_xor_rot(v);
Expand Down Expand Up @@ -175,11 +199,13 @@ unsafe fn double_quarter_round(v: &mut [__m128i; 4]) {
/// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643
#[inline]
#[target_feature(enable = "sse2")]
unsafe fn rows_to_cols([a, _, c, d]: &mut [__m128i; 4]) {
// c >>>= 32; d >>>= 64; a >>>= 96;
*c = _mm_shuffle_epi32(*c, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
*d = _mm_shuffle_epi32(*d, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
*a = _mm_shuffle_epi32(*a, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
unsafe fn rows_to_cols(blocks: &mut [[__m128i; 4]; PAR_BLOCKS]) {
for [a, _, c, d] in blocks.iter_mut() {
// c >>>= 32; d >>>= 64; a >>>= 96;
*c = _mm_shuffle_epi32(*c, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
*d = _mm_shuffle_epi32(*d, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
*a = _mm_shuffle_epi32(*a, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
}
}

/// The goal of this function is to transform the state words from:
Expand All @@ -201,33 +227,37 @@ unsafe fn rows_to_cols([a, _, c, d]: &mut [__m128i; 4]) {
/// reversing the transformation of [`rows_to_cols`].
#[inline]
#[target_feature(enable = "sse2")]
unsafe fn cols_to_rows([a, _, c, d]: &mut [__m128i; 4]) {
// c <<<= 32; d <<<= 64; a <<<= 96;
*c = _mm_shuffle_epi32(*c, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
*d = _mm_shuffle_epi32(*d, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
*a = _mm_shuffle_epi32(*a, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
unsafe fn cols_to_rows(blocks: &mut [[__m128i; 4]; PAR_BLOCKS]) {
for [a, _, c, d] in blocks.iter_mut() {
// c <<<= 32; d <<<= 64; a <<<= 96;
*c = _mm_shuffle_epi32(*c, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
*d = _mm_shuffle_epi32(*d, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
*a = _mm_shuffle_epi32(*a, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
}
}

#[inline]
#[target_feature(enable = "sse2")]
unsafe fn add_xor_rot([a, b, c, d]: &mut [__m128i; 4]) {
// a += b; d ^= a; d <<<= (16, 16, 16, 16);
*a = _mm_add_epi32(*a, *b);
*d = _mm_xor_si128(*d, *a);
*d = _mm_xor_si128(_mm_slli_epi32(*d, 16), _mm_srli_epi32(*d, 16));

// c += d; b ^= c; b <<<= (12, 12, 12, 12);
*c = _mm_add_epi32(*c, *d);
*b = _mm_xor_si128(*b, *c);
*b = _mm_xor_si128(_mm_slli_epi32(*b, 12), _mm_srli_epi32(*b, 20));

// a += b; d ^= a; d <<<= (8, 8, 8, 8);
*a = _mm_add_epi32(*a, *b);
*d = _mm_xor_si128(*d, *a);
*d = _mm_xor_si128(_mm_slli_epi32(*d, 8), _mm_srli_epi32(*d, 24));

// c += d; b ^= c; b <<<= (7, 7, 7, 7);
*c = _mm_add_epi32(*c, *d);
*b = _mm_xor_si128(*b, *c);
*b = _mm_xor_si128(_mm_slli_epi32(*b, 7), _mm_srli_epi32(*b, 25));
unsafe fn add_xor_rot(blocks: &mut [[__m128i; 4]; PAR_BLOCKS]) {
for [a, b, c, d] in blocks.iter_mut() {
// a += b; d ^= a; d <<<= (16, 16, 16, 16);
*a = _mm_add_epi32(*a, *b);
*d = _mm_xor_si128(*d, *a);
*d = _mm_xor_si128(_mm_slli_epi32(*d, 16), _mm_srli_epi32(*d, 16));

// c += d; b ^= c; b <<<= (12, 12, 12, 12);
*c = _mm_add_epi32(*c, *d);
*b = _mm_xor_si128(*b, *c);
*b = _mm_xor_si128(_mm_slli_epi32(*b, 12), _mm_srli_epi32(*b, 20));

// a += b; d ^= a; d <<<= (8, 8, 8, 8);
*a = _mm_add_epi32(*a, *b);
*d = _mm_xor_si128(*d, *a);
*d = _mm_xor_si128(_mm_slli_epi32(*d, 8), _mm_srli_epi32(*d, 24));

// c += d; b ^= c; b <<<= (7, 7, 7, 7);
*c = _mm_add_epi32(*c, *d);
*b = _mm_xor_si128(*b, *c);
*b = _mm_xor_si128(_mm_slli_epi32(*b, 7), _mm_srli_epi32(*b, 25));
}
}
12 changes: 8 additions & 4 deletions chacha20/src/rng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use core::fmt::Debug;

use rand_core::{
block::{BlockRng, BlockRngCore, CryptoBlockRng},
impl_try_rng_from_rng_core, CryptoRng, RngCore, SeedableRng,
CryptoRng, RngCore, SeedableRng,
};

#[cfg(feature = "serde1")]
Expand All @@ -32,7 +32,7 @@ const BLOCK_WORDS: u8 = 16;

/// The seed for ChaCha20. Implements ZeroizeOnDrop when the
/// zeroize feature is enabled.
#[derive(PartialEq, Eq, Default)]
#[derive(PartialEq, Eq, Default, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Seed([u8; 32]);

Expand All @@ -42,6 +42,12 @@ impl AsRef<[u8; 32]> for Seed {
}
}

impl AsRef<[u8]> for Seed {
fn as_ref(&self) -> &[u8] {
&self.0
}
}

impl AsMut<[u8]> for Seed {
fn as_mut(&mut self) -> &mut [u8] {
self.0.as_mut()
Expand Down Expand Up @@ -384,8 +390,6 @@ macro_rules! impl_chacha_rng {
}
}

impl_try_rng_from_rng_core!($ChaChaXRng);

impl $ChaChaXRng {
// The buffer is a 4-block window, i.e. it is always at a block-aligned position in the
// stream but if the stream has been sought it may not be self-aligned.
Expand Down
Loading