diff --git a/Cargo.lock b/Cargo.lock index 7b96883a81f6..52dc254280ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2369,6 +2369,7 @@ dependencies = [ "ahash 0.6.3", "criterion", "cs_serde_bytes", + "cs_serde_cbor", "forest_encoding", "rand 0.7.3", "rand_xorshift", diff --git a/utils/bitfield/Cargo.toml b/utils/bitfield/Cargo.toml index 9c0937680818..d86674b111b6 100644 --- a/utils/bitfield/Cargo.toml +++ b/utils/bitfield/Cargo.toml @@ -19,6 +19,10 @@ rand_xorshift = "0.2.0" rand = "0.7.3" criterion = "0.3" serde_json = "1.0" +# TODO remove fork in future (allowing non utf8 strings to be cbor deserialized) +serde_cbor = { package = "cs_serde_cbor", version = "0.12", features = [ + "tags" +] } [features] json = [] diff --git a/utils/bitfield/src/lib.rs b/utils/bitfield/src/lib.rs index e0074d81c02c..0b1922cbc0ce 100644 --- a/utils/bitfield/src/lib.rs +++ b/utils/bitfield/src/lib.rs @@ -16,6 +16,12 @@ use std::{ type Result = std::result::Result; +// MaxEncodedSize is the maximum encoded size of a bitfield. When expanded into +// a slice of runs, a bitfield of this size should not exceed 2MiB of memory. +// +// This bitfield can fit at least 3072 sparse elements. +const MAX_ENCODED_SIZE: usize = 32 << 10; + /// A bit field with buffered insertion/removal that serializes to/from RLE+. Similar to /// `HashSet`, but more memory-efficient when long runs of 1s and 0s are present. #[derive(Debug, Default, Clone)] diff --git a/utils/bitfield/src/rleplus/mod.rs b/utils/bitfield/src/rleplus/mod.rs index 0e67957b88e3..c5d2db85bfe4 100644 --- a/utils/bitfield/src/rleplus/mod.rs +++ b/utils/bitfield/src/rleplus/mod.rs @@ -67,15 +67,11 @@ mod writer; pub use reader::BitReader; pub use writer::BitWriter; -use super::{BitField, Result}; +use super::{BitField, Result, MAX_ENCODED_SIZE}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::borrow::Cow; -// MaxEncodedSize is the maximum encoded size of a bitfield. When expanded into -// a slice of runs, a bitfield of this size should not exceed 2MiB of memory. -// -// This bitfield can fit at least 3072 sparse elements. -const MAX_ENCODED_SIZE: usize = 32 << 10; +pub const VERSION: u8 = 0; impl Serialize for BitField { fn serialize(&self, serializer: S) -> std::result::Result @@ -99,6 +95,12 @@ impl<'de> Deserialize<'de> for BitField { D: Deserializer<'de>, { let bytes: Cow<'de, [u8]> = serde_bytes::deserialize(deserializer)?; + if bytes.len() > MAX_ENCODED_SIZE { + return Err(serde::de::Error::custom(format!( + "decoded bitfield was too large {}", + bytes.len() + ))); + } Self::from_bytes(&bytes).map_err(serde::de::Error::custom) } } @@ -115,7 +117,7 @@ impl BitField { let mut reader = BitReader::new(bytes); let version = reader.read(2); - if version != 0 { + if version != VERSION { return Err("incorrect version"); } diff --git a/utils/bitfield/src/unvalidated.rs b/utils/bitfield/src/unvalidated.rs index 3579066a12d2..5358e1737ee4 100644 --- a/utils/bitfield/src/unvalidated.rs +++ b/utils/bitfield/src/unvalidated.rs @@ -1,7 +1,8 @@ // Copyright 2019-2022 ChainSafe Systems // SPDX-License-Identifier: Apache-2.0, MIT -use super::{BitField, Result}; +use super::rleplus::VERSION; +use super::{BitField, Result, MAX_ENCODED_SIZE}; use encoding::serde_bytes; use serde::{Deserialize, Deserializer, Serialize}; @@ -62,6 +63,15 @@ impl<'de> Deserialize<'de> for UnvalidatedBitField { D: Deserializer<'de>, { let bytes: Vec = serde_bytes::deserialize(deserializer)?; + if bytes.len() > MAX_ENCODED_SIZE { + return Err(serde::de::Error::custom(format!( + "decoded bitfield was too large {}", + bytes.len() + ))); + } + if !bytes.is_empty() && bytes[0] & 3 != VERSION { + return Err(serde::de::Error::custom("invalid RLE+ version".to_string())); + } Ok(Self::Unvalidated(bytes)) } } diff --git a/utils/bitfield/tests/bitfield_tests.rs b/utils/bitfield/tests/bitfield_tests.rs index 50b0e88db929..918438ca7c68 100644 --- a/utils/bitfield/tests/bitfield_tests.rs +++ b/utils/bitfield/tests/bitfield_tests.rs @@ -2,9 +2,10 @@ // SPDX-License-Identifier: Apache-2.0, MIT use ahash::AHashSet; -use forest_bitfield::{bitfield, BitField}; +use forest_bitfield::{bitfield, BitField, UnvalidatedBitField}; use rand::{Rng, SeedableRng}; use rand_xorshift::XorShiftRng; +use serde_cbor::ser::Serializer; use std::iter::FromIterator; fn random_indices(range: usize, seed: u64) -> Vec { @@ -223,3 +224,59 @@ fn padding() { let deserialized: BitField = encoding::from_slice(&cbor).unwrap(); assert_eq!(deserialized, bf); } + +#[test] +fn bitfield_deserialize() { + // Set alternating bits for worst-case size performance + let mut bf = BitField::new(); + let mut i = 0; + while i < 262_142 { + bf.set(i); + i += 2; + } + let bytes = bf.to_bytes(); + let mut cbor = Vec::new(); + serde_bytes::serialize(&bytes, &mut Serializer::new(&mut cbor)).unwrap(); + let res: Result = encoding::from_slice(&cbor); + assert!(res.is_ok()); + + // Set alternating bits for worst-case size performance + let mut bf = BitField::new(); + let mut i = 0; + while i < 262_143 { + bf.set(i); + i += 2; + } + let bytes = bf.to_bytes(); + let mut cbor = Vec::new(); + serde_bytes::serialize(&bytes, &mut Serializer::new(&mut cbor)).unwrap(); + let res: Result = encoding::from_slice(&cbor); + assert!(res.is_err()); +} + +#[test] +fn unvalidated_deserialize() { + // Set alternating bits for worst-case size performance + let mut bf = BitField::new(); + let mut i = 0; + while i < 262_143 { + bf.set(i); + i += 2; + } + let bytes = bf.to_bytes(); + let mut cbor = Vec::new(); + serde_bytes::serialize(&bytes, &mut Serializer::new(&mut cbor)).unwrap(); + let res: Result = encoding::from_slice(&cbor); + assert!(res.is_err()); +} + +#[test] +fn unvalidated_deserialize_version() { + let bf = bitfield![1, 1, 1, 1, 1, 1, 1, 1]; + let mut bytes = bf.to_bytes(); + bytes[0] |= 0x1; // flip bit to corrupt version + let mut cbor = Vec::new(); + serde_bytes::serialize(&bytes, &mut Serializer::new(&mut cbor)).unwrap(); + let res: Result = encoding::from_slice(&cbor); + assert!(res.is_err()); +}