Skip to content

Commit

Permalink
chacha20: Parallelize AVX2 backend
Browse files Browse the repository at this point in the history
The AVX2 backend was previously computing two ChaCha blocks in parallel,
then throwing one away.

This updates the implementation to always compute two blocks in parallel
when the AVX2 backend is enabled, resulting in a ~2X speedup.

Unfortunately for `cipher.rs`, originally adapted from the `ctr` crate,
I deleted the original parallel computation code, and in lieu of that
the implementation diverges from what was originally in `ctr`. See here
for a reference:

https://github.com/RustCrypto/stream-ciphers/blob/907e94b/ctr/src/lib.rs#L73

Ideally we can come up with some generic counter management and
buffering abstraction in the `ctr` crate which works in all cases.
  • Loading branch information
tarcieri committed Jan 16, 2020
1 parent 4821eed commit 4d98819
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 49 deletions.
6 changes: 5 additions & 1 deletion chacha20/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ xchacha20 = ["stream-cipher"]
rng = ["rand_core"]

[[bench]]
name = "chacha20"
name = "stream_cipher"
harness = false

[[bench]]
name = "rng"
harness = false

[package.metadata.docs.rs]
Expand Down
8 changes: 5 additions & 3 deletions chacha20/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ This crate contains the following implementations of ChaCha20, all of which
work on stable Rust with the following `RUSTFLAGS`:

- `x86` / `x86_64`
- `sse2`: `-Ctarget-feature=+sse2` (on by default on x86 CPUs)
- `avx2`: `-Ctarget-cpu=haswell -Ctarget-feature=+avx2`
- `avx2`: (~1.4cpb) `-Ctarget-cpu=haswell -Ctarget-feature=+avx2`
- `sse2`: (~2.5cpb) `-Ctarget-feature=+sse2` (on by default on x86 CPUs)
- Portable
- `soft`
- `soft`: (~5 cpb on x86/x86_64)

NOTE: cpb = cycles per byte (smaller is better)

## Security Warning

Expand Down
36 changes: 36 additions & 0 deletions chacha20/benches/rng.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//! `ChaCha20Rng` benchmark
#[cfg(not(feature = "rng"))]
compile_error!("run benchmarks with `cargo bench --all-features`");

use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion_cycles_per_byte::CyclesPerByte;

use chacha20::ChaCha20Rng;
use rand_core::{RngCore, SeedableRng};

const KB: usize = 1024;

fn bench(c: &mut Criterion<CyclesPerByte>) {
let mut group = c.benchmark_group("rng");

for size in &[KB, 2 * KB, 4 * KB, 8 * KB, 16 * KB] {
let mut buf = vec![0u8; *size];

group.throughput(Throughput::Bytes(*size as u64));

group.bench_function(BenchmarkId::new("apply_keystream", size), |b| {
let mut rng = ChaCha20Rng::from_seed(Default::default());
b.iter(|| rng.fill_bytes(&mut buf));
});
}

group.finish();
}

criterion_group!(
name = benches;
config = Criterion::default().with_measurement(CyclesPerByte);
targets = bench
);
criterion_main!(benches);
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
//! ChaCha20 `stream-cipher` benchmark
#[cfg(not(feature = "stream-cipher"))]
compile_error!("run benchmarks with `cargo bench --all-features`");

use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion_cycles_per_byte::CyclesPerByte;

Expand All @@ -9,7 +14,7 @@ use chacha20::{
const KB: usize = 1024;

fn bench(c: &mut Criterion<CyclesPerByte>) {
let mut group = c.benchmark_group("chacha20");
let mut group = c.benchmark_group("stream-cipher");

for size in &[KB, 2 * KB, 4 * KB, 8 * KB, 16 * KB] {
let mut buf = vec![0u8; *size];
Expand Down
6 changes: 3 additions & 3 deletions chacha20/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ mod avx2;
any(target_arch = "x86", target_arch = "x86_64"),
any(target_feature = "sse2", target_feature = "avx2")
)))]
pub(crate) use self::soft::Block;
pub(crate) use self::soft::{Block, BUFFER_SIZE};

#[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "sse2",
not(target_feature = "avx2")
))]
pub(crate) use self::sse2::Block;
pub(crate) use self::sse2::{Block, BUFFER_SIZE};

#[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "avx2"
))]
pub(crate) use self::avx2::Block;
pub(crate) use self::avx2::{Block, BUFFER_SIZE};

use core::fmt::{self, Debug};

Expand Down
49 changes: 31 additions & 18 deletions chacha20/src/block/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@
//! Goll, M., and Gueron,S.: Vectorization of ChaCha Stream Cipher. Cryptology ePrint Archive,
//! Report 2013/759, November, 2013, <https://eprint.iacr.org/2013/759.pdf>
use crate::{CONSTANTS, IV_SIZE, KEY_SIZE};
use crate::{BLOCK_SIZE, CONSTANTS, IV_SIZE, KEY_SIZE};
use core::convert::TryInto;

#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

/// Size of buffers passed to `generate` and `apply_keystream` for this
/// backend, which operates on two blocks in parallel for optimal performance.
pub(crate) const BUFFER_SIZE: usize = BLOCK_SIZE * 2;

/// The ChaCha20 block function (AVX2 accelerated implementation for x86/x86_64)
// TODO(tarcieri): zeroize?
#[derive(Clone)]
pub(crate) struct Block {
v0: __m256i,
Expand Down Expand Up @@ -62,16 +68,24 @@ impl Block {
#[inline]
#[allow(clippy::cast_ptr_alignment)] // loadu/storeu support unaligned loads/stores
pub(crate) fn apply_keystream(&self, counter: u64, output: &mut [u8]) {
debug_assert_eq!(output.len(), BUFFER_SIZE);

unsafe {
let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
let mut v3 = iv_setup(self.iv, counter);
self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);

for (chunk, a) in output.chunks_mut(0x10).zip(&[v0, v1, v2, v3]) {
for (chunk, a) in output[..BLOCK_SIZE].chunks_mut(0x10).zip(&[v0, v1, v2, v3]) {
let b = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
let out = _mm_xor_si128(_mm256_castsi256_si128(*a), b);
_mm_storeu_si128(chunk.as_mut_ptr() as *mut __m128i, out);
}

for (chunk, a) in output[BLOCK_SIZE..].chunks_mut(0x10).zip(&[v0, v1, v2, v3]) {
let b = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
let out = _mm_xor_si128(_mm256_extractf128_si256(*a, 1), b);
_mm_storeu_si128(chunk.as_mut_ptr() as *mut __m128i, out);
}
}
}

Expand Down Expand Up @@ -132,22 +146,21 @@ unsafe fn iv_setup(iv: [i32; 2], counter: u64) -> __m256i {
#[target_feature(enable = "avx2")]
#[allow(clippy::cast_ptr_alignment)] // storeu supports unaligned stores
unsafe fn store(v0: __m256i, v1: __m256i, v2: __m256i, v3: __m256i, output: &mut [u8]) {
_mm_storeu_si128(
output.as_mut_ptr().offset(0x00) as *mut __m128i,
_mm256_castsi256_si128(v0),
);
_mm_storeu_si128(
output.as_mut_ptr().offset(0x10) as *mut __m128i,
_mm256_castsi256_si128(v1),
);
_mm_storeu_si128(
output.as_mut_ptr().offset(0x20) as *mut __m128i,
_mm256_castsi256_si128(v2),
);
_mm_storeu_si128(
output.as_mut_ptr().offset(0x30) as *mut __m128i,
_mm256_castsi256_si128(v3),
);
debug_assert_eq!(output.len(), BUFFER_SIZE);

for (chunk, v) in output[..BLOCK_SIZE].chunks_mut(0x10).zip(&[v0, v1, v2, v3]) {
_mm_storeu_si128(
chunk.as_mut_ptr() as *mut __m128i,
_mm256_castsi256_si128(*v),
);
}

for (chunk, v) in output[BLOCK_SIZE..].chunks_mut(0x10).zip(&[v0, v1, v2, v3]) {
_mm_storeu_si128(
chunk.as_mut_ptr() as *mut __m128i,
_mm256_extractf128_si256(*v, 1),
);
}
}

#[inline]
Expand Down
14 changes: 7 additions & 7 deletions chacha20/src/block/soft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
use crate::{BLOCK_SIZE, CONSTANTS, IV_SIZE, KEY_SIZE, STATE_WORDS};
use core::{convert::TryInto, mem};

/// The ChaCha20 block function
///
/// While ChaCha20 is a stream cipher, not a block cipher, its core
/// primitive is a function which acts on a 512-bit block
// TODO(tarcieri): zeroize? need to make sure we're actually copying first
/// Size of buffers passed to `generate` and `apply_keystream` for this backend
pub(crate) const BUFFER_SIZE: usize = BLOCK_SIZE;

/// The ChaCha20 block function (portable software implementation)
// TODO(tarcieri): zeroize?
#[allow(dead_code)]
#[derive(Clone)]
pub(crate) struct Block {
Expand Down Expand Up @@ -49,7 +49,7 @@ impl Block {

/// Generate output, overwriting data already in the buffer
pub(crate) fn generate(&mut self, counter: u64, output: &mut [u8]) {
debug_assert_eq!(output.len(), BLOCK_SIZE);
debug_assert_eq!(output.len(), BUFFER_SIZE);
self.counter_setup(counter);

let mut state = self.state;
Expand All @@ -62,7 +62,7 @@ impl Block {

/// Apply generated keystream to the output buffer
pub(crate) fn apply_keystream(&mut self, counter: u64, output: &mut [u8]) {
debug_assert_eq!(output.len(), BLOCK_SIZE);
debug_assert_eq!(output.len(), BUFFER_SIZE);
self.counter_setup(counter);

let mut state = self.state;
Expand Down
11 changes: 10 additions & 1 deletion chacha20/src/block/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
//!
//! SSE2-optimized implementation for x86/x86-64 CPUs.
use crate::{CONSTANTS, IV_SIZE, KEY_SIZE};
use crate::{BLOCK_SIZE, CONSTANTS, IV_SIZE, KEY_SIZE};
use core::convert::TryInto;

#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

/// Size of buffers passed to `generate` and `apply_keystream` for this backend
pub(crate) const BUFFER_SIZE: usize = BLOCK_SIZE;

/// The ChaCha20 block function (SSE2 accelerated implementation for x86/x86_64)
// TODO(tarcieri): zeroize?
#[derive(Clone)]
pub(crate) struct Block {
v0: __m128i,
Expand Down Expand Up @@ -47,6 +52,8 @@ impl Block {

#[inline]
pub(crate) fn generate(&self, counter: u64, output: &mut [u8]) {
debug_assert_eq!(output.len(), BUFFER_SIZE);

unsafe {
let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
let mut v3 = iv_setup(self.iv, counter);
Expand All @@ -58,6 +65,8 @@ impl Block {
#[inline]
#[allow(clippy::cast_ptr_alignment)] // loadu/storeu support unaligned loads/stores
pub(crate) fn apply_keystream(&self, counter: u64, output: &mut [u8]) {
debug_assert_eq!(output.len(), BUFFER_SIZE);

unsafe {
let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
let mut v3 = iv_setup(self.iv, counter);
Expand Down
33 changes: 21 additions & 12 deletions chacha20/src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,24 @@
// TODO(tarcieri): figure out how to unify this with the `ctr` crate

use crate::{block::Block, BLOCK_SIZE};
use crate::{
block::{Block, BUFFER_SIZE},
BLOCK_SIZE,
};
use core::{
cmp,
fmt::{self, Debug},
};
use stream_cipher::{LoopError, SyncStreamCipher, SyncStreamCipherSeek};

/// Internal buffer
type Buffer = [u8; BLOCK_SIZE];
type Buffer = [u8; BUFFER_SIZE];

/// How much to increment the counter by for each buffer we generate.
/// Normally this is 1 but the AVX2 backend uses double-wide buffers.
// TODO(tarcieri): support a parallel blocks count like the `ctr` crate
// See: <https://github.com/RustCrypto/stream-ciphers/blob/907e94b/ctr/src/lib.rs#L73>
const COUNTER_INCR: u64 = (BUFFER_SIZE as u64) / (BLOCK_SIZE as u64);

/// ChaCha20 as a counter mode stream cipher
pub(crate) struct Cipher {
Expand All @@ -39,7 +48,7 @@ impl Cipher {
pub fn new(block: Block, counter_offset: u64) -> Self {
Self {
block,
buffer: [0u8; BLOCK_SIZE],
buffer: [0u8; BUFFER_SIZE],
buffer_pos: None,
counter: 0,
counter_offset,
Expand All @@ -63,7 +72,7 @@ impl SyncStreamCipher for Cipher {
if let Some(pos) = self.buffer_pos {
let pos = pos as usize;

if data.len() >= BLOCK_SIZE - pos {
if data.len() >= BUFFER_SIZE - pos {
let buf = &self.buffer[pos..];
let (r, l) = data.split_at_mut(buf.len());
data = l;
Expand All @@ -79,20 +88,20 @@ impl SyncStreamCipher for Cipher {

let mut counter = self.counter;

while data.len() >= BLOCK_SIZE {
let (l, r) = { data }.split_at_mut(BLOCK_SIZE);
while data.len() >= BUFFER_SIZE {
let (l, r) = { data }.split_at_mut(BUFFER_SIZE);
data = r;

// TODO(tarcieri): double check this should be checked and not wrapping
let counter_with_offset = self.counter_offset.checked_add(counter).unwrap();
self.block.apply_keystream(counter_with_offset, l);

counter = counter.checked_add(1).unwrap();
counter = counter.checked_add(COUNTER_INCR).unwrap();
}

if !data.is_empty() {
self.generate_block(counter);
counter = counter.checked_add(1).unwrap();
counter = counter.checked_add(COUNTER_INCR).unwrap();
let n = data.len();
xor(data, &self.buffer[..n]);
self.buffer_pos = Some(n as u8);
Expand Down Expand Up @@ -126,7 +135,7 @@ impl SyncStreamCipherSeek for Cipher {
self.buffer_pos = None;
} else {
self.generate_block(self.counter);
self.counter = self.counter.checked_add(1).unwrap();
self.counter = self.counter.checked_add(COUNTER_INCR).unwrap();
self.buffer_pos = Some(rem as u8);
}
}
Expand All @@ -137,12 +146,12 @@ impl Cipher {
let dlen = data.len()
- self
.buffer_pos
.map(|pos| cmp::min(BLOCK_SIZE - pos as usize, data.len()))
.map(|pos| cmp::min(BUFFER_SIZE - pos as usize, data.len()))
.unwrap_or_default();

let data_buffers = dlen / BLOCK_SIZE + if data.len() % BLOCK_SIZE != 0 { 1 } else { 0 };
let data_blocks = dlen / BLOCK_SIZE + if data.len() % BLOCK_SIZE != 0 { 1 } else { 0 };

if self.counter.checked_add(data_buffers as u64).is_some() {
if self.counter.checked_add(data_blocks as u64).is_some() {
Ok(())
} else {
Err(LoopError)
Expand Down
9 changes: 6 additions & 3 deletions chacha20/src/rng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use core::slice;
use rand_core::block::{BlockRng, BlockRngCore};
use rand_core::{Error, RngCore, SeedableRng};

use crate::{block::Block, BLOCK_SIZE, KEY_SIZE, STATE_WORDS};
use crate::{
block::{Block, BUFFER_SIZE},
KEY_SIZE,
};

macro_rules! impl_chacha_rng {
($name:ident, $core:ident, $rounds:expr, $doc:expr) => {
Expand Down Expand Up @@ -63,12 +66,12 @@ macro_rules! impl_chacha_rng {

impl BlockRngCore for $core {
type Item = u32;
type Results = [u32; STATE_WORDS];
type Results = [u32; BUFFER_SIZE / 4];

fn generate(&mut self, results: &mut Self::Results) {
// TODO(tarcieri): eliminate unsafety (replace w\ [u8; BLOCK_SIZE)
self.block.generate(self.counter, unsafe {
slice::from_raw_parts_mut(results.as_mut_ptr() as *mut u8, BLOCK_SIZE)
slice::from_raw_parts_mut(results.as_mut_ptr() as *mut u8, BUFFER_SIZE)
});
self.counter += 1;
}
Expand Down

0 comments on commit 4d98819

Please sign in to comment.