From 50726f3e1b2f7b4d50f6755ed30a1ae3108583bf Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Tue, 7 Nov 2023 18:41:14 +0100 Subject: [PATCH] fix(interpreter): Stack `push_slice` fix and dup with pointers (#837) * fix: avoid UB in `Stack::dup` * fix: correctly write full words in `Stack::push_slice` --- crates/interpreter/src/interpreter/stack.rs | 103 +++++++++++++++++--- 1 file changed, 91 insertions(+), 12 deletions(-) diff --git a/crates/interpreter/src/interpreter/stack.rs b/crates/interpreter/src/interpreter/stack.rs index fd5b58d520..f64b8a0f9f 100644 --- a/crates/interpreter/src/interpreter/stack.rs +++ b/crates/interpreter/src/interpreter/stack.rs @@ -204,7 +204,8 @@ impl Stack { } else { // SAFETY: check for out of bounds is done above and it makes this safe to do. unsafe { - *self.data.get_unchecked_mut(len) = *self.data.get_unchecked(len - N); + let data = self.data.as_mut_ptr(); + core::ptr::copy_nonoverlapping(data.add(len - N), data.add(len), 1); self.data.set_len(len + 1); } Ok(()) @@ -240,33 +241,49 @@ impl Stack { // SAFETY: length checked above. unsafe { let dst = self.data.as_mut_ptr().add(self.data.len()).cast::(); + self.data.set_len(new_len); + let mut i = 0; // write full words - let limbs = slice.rchunks_exact(8); - let rem = limbs.remainder(); - for limb in limbs { - *dst.add(i) = u64::from_be_bytes(limb.try_into().unwrap()); + let words = slice.chunks_exact(32); + let partial_last_word = words.remainder(); + for word in words { + // Note: we unroll `U256::from_be_bytes` here to write directly into the buffer, + // instead of creating a 32 byte array on the stack and then copying it over. + for l in word.rchunks_exact(8) { + dst.add(i).write(u64::from_be_bytes(l.try_into().unwrap())); + i += 1; + } + } + + if partial_last_word.is_empty() { + return Ok(()); + } + + // write limbs of partial last word + let limbs = partial_last_word.rchunks_exact(8); + let partial_last_limb = limbs.remainder(); + for l in limbs { + dst.add(i).write(u64::from_be_bytes(l.try_into().unwrap())); i += 1; } - // write remainder by padding with zeros - if !rem.is_empty() { + // write partial last limb by padding with zeros + if !partial_last_limb.is_empty() { let mut tmp = [0u8; 8]; - tmp[8 - rem.len()..].copy_from_slice(rem); - *dst.add(i) = u64::from_be_bytes(tmp); + tmp[8 - partial_last_limb.len()..].copy_from_slice(partial_last_limb); + dst.add(i).write(u64::from_be_bytes(tmp)); i += 1; } - debug_assert_eq!((i + 3) / 4, n_words, "wrote beyond end of stack"); + debug_assert_eq!((i + 3) / 4, n_words, "wrote too much"); // zero out upper bytes of last word let m = i % 4; // 32 / 8 if m != 0 { dst.add(i).write_bytes(0, 4 - m); } - - self.data.set_len(new_len); } Ok(()) @@ -286,3 +303,65 @@ impl Stack { } } } + +#[cfg(test)] +mod tests { + use super::*; + + fn run(f: impl FnOnce(&mut Stack)) { + let mut stack = Stack::new(); + // fill capacity with non-zero values + unsafe { + stack.data.set_len(STACK_LIMIT); + stack.data.fill(U256::MAX); + stack.data.set_len(0); + } + f(&mut stack); + } + + #[test] + fn push_slices() { + // no-op + run(|stack| { + stack.push_slice(b"").unwrap(); + assert_eq!(stack.data, []); + }); + + // one word + run(|stack| { + stack.push_slice(&[42]).unwrap(); + assert_eq!(stack.data, [U256::from(42)]); + }); + + let n = 0x1111_2222_3333_4444_5555_6666_7777_8888_u128; + run(|stack| { + stack.push_slice(&n.to_be_bytes()).unwrap(); + assert_eq!(stack.data, [U256::from(n)]); + }); + + // more than one word + run(|stack| { + let b = [U256::from(n).to_be_bytes::<32>(); 2].concat(); + stack.push_slice(&b).unwrap(); + assert_eq!(stack.data, [U256::from(n); 2]); + }); + + run(|stack| { + let b = [&[0; 32][..], &[42u8]].concat(); + stack.push_slice(&b).unwrap(); + assert_eq!(stack.data, [U256::ZERO, U256::from(42)]); + }); + + run(|stack| { + let b = [&[0; 32][..], &n.to_be_bytes()].concat(); + stack.push_slice(&b).unwrap(); + assert_eq!(stack.data, [U256::ZERO, U256::from(n)]); + }); + + run(|stack| { + let b = [&[0; 64][..], &n.to_be_bytes()].concat(); + stack.push_slice(&b).unwrap(); + assert_eq!(stack.data, [U256::ZERO, U256::ZERO, U256::from(n)]); + }); + } +}