Skip to content

Commit

Permalink
fix serialize impl of SIMD types (#188)
Browse files Browse the repository at this point in the history
* fix serialize impl of SIMD types

change serialization of SIMD types to use serde::Serializer::serialize_tuple

The Serialize implementation of all types used Serializer::serialize_seq
whereas the deserialize implementation delegated to that of the
corresponding [T; LEN] type, which is implemented by serde using
serialize_tuple. Formats like bincode will emit the lenght of the
sequence but not of a tuple, leading to wrong deserialization.

* Add test cases for serde impl
  • Loading branch information
robinhundt authored Jan 8, 2025
1 parent 8fa775a commit fb55237
Show file tree
Hide file tree
Showing 21 changed files with 197 additions and 4 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ serde = ["dep:serde"]
safe_arch = { version = "0.7", features = ["bytemuck"] }
serde = { version = "1", default-features = false, optional = true }
bytemuck = "1"

[dev-dependencies]
bincode = { version = "1.3.3" }
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use safe_arch::*;
use bytemuck::*;

#[cfg(feature = "serde")]
use serde::{ser::SerializeSeq, Deserialize, Serialize};
use serde::{ser::SerializeTuple, Deserialize, Serialize};

#[macro_use]
mod macros;
Expand Down Expand Up @@ -926,7 +926,7 @@ bulk_impl_const_rhs_op!((CmpLe, cmp_le) => [(f64x4, f64), (f64x2, f64), (f32x4,f
bulk_impl_const_rhs_op!((CmpGe, cmp_ge) => [(f64x4, f64), (f64x2, f64), (f32x4,f32), (f32x8,f32),]);

macro_rules! impl_serde {
($i:ident, $t:ty) => {
($i:ident, [$t:ty; $len:expr]) => {
#[cfg(feature = "serde")]
impl Serialize for $i {
#[inline]
Expand All @@ -935,7 +935,7 @@ macro_rules! impl_serde {
S: serde::Serializer,
{
let array = self.as_array_ref();
let mut seq = serializer.serialize_seq(Some(array.len()))?;
let mut seq = serializer.serialize_tuple($len)?;
for e in array {
seq.serialize_element(e)?;
}
Expand All @@ -950,7 +950,7 @@ macro_rules! impl_serde {
where
D: serde::Deserializer<'de>,
{
Ok(<$t>::deserialize(deserializer)?.into())
Ok(<[$t; $len]>::deserialize(deserializer)?.into())
}
}
};
Expand Down
10 changes: 10 additions & 0 deletions tests/all_tests/t_f32x4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,3 +839,13 @@ fn impl_f32x4_from_i32x4() {
let f = f32x4::from([1.0, 2.0, 3.0, 4.0]);
assert_eq!(f32x4::from_i32x4(i), f)
}

#[cfg(feature = "serde")]
#[test]
fn impl_f32x4_ser_de_roundtrip() {
let serialized =
bincode::serialize(&f32x4::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(f32x4::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_f32x8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -956,3 +956,13 @@ fn impl_f32x8_from_i32x8() {
let f = f32x8::from([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
assert_eq!(f32x8::from_i32x8(i), f)
}

#[cfg(feature = "serde")]
#[test]
fn impl_f32x8_ser_de_roundtrip() {
let serialized =
bincode::serialize(&f32x8::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(f32x8::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_f64x2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,13 @@ fn impl_f64x2_from_i32x4() {
let f = f64x2::from([1.0, 2.0]);
assert_eq!(f64x2::from_i32x4_lower2(i), f)
}

#[cfg(feature = "serde")]
#[test]
fn impl_f64x2_ser_de_roundtrip() {
let serialized =
bincode::serialize(&f64x2::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(f64x2::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_f64x4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,13 @@ fn impl_f64x4_from_i32x4() {
assert_eq!(f64x4::from(i), f);
assert_eq!(f64x4::from_i32x4(i), f);
}

#[cfg(feature = "serde")]
#[test]
fn impl_f64x4_ser_de_roundtrip() {
let serialized =
bincode::serialize(&f64x4::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(f64x4::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_i16x16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,13 @@ fn impl_i16x16_reduce_max() {
assert_eq!(p.reduce_min(), i16::MIN);
}
}

#[cfg(feature = "serde")]
#[test]
fn impl_i16x16_ser_de_roundtrip() {
let serialized =
bincode::serialize(&i16x16::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(i16x16::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_i16x8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,13 @@ fn impl_i16x8_mul_widen() {
|a, b| i32::from(a) * i32::from(b),
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_i16x8_ser_de_roundtrip() {
let serialized =
bincode::serialize(&i16x8::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(i16x8::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_i32x4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,13 @@ fn impl_i32x4_mul_widen() {
|a, b| a as i64 * b as i64,
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_i32x4_ser_de_roundtrip() {
let serialized =
bincode::serialize(&i32x4::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(i32x4::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_i32x8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,13 @@ fn impl_i32x8_shl_each() {
|a, b| a.wrapping_shl(b as u32),
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_i32x8_ser_de_roundtrip() {
let serialized =
bincode::serialize(&i32x8::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(i32x8::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_i64x2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,13 @@ fn test_i64x2_move_mask() {
|acc, a, idx| acc | if a < 0 { 1 << idx } else { 0 },
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_i64x2_ser_de_roundtrip() {
let serialized =
bincode::serialize(&i64x2::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(i64x2::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_i64x4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,13 @@ fn test_i32x4_none() {
|acc, a, _idx| acc & !(a < 0),
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_i64x4_ser_de_roundtrip() {
let serialized =
bincode::serialize(&i64x4::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(i64x4::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_i8x16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,13 @@ fn test_i8x16_swizzle_relaxed() {
let actual = a.swizzle_relaxed(b);
assert_eq!(expected, actual);
}

#[cfg(feature = "serde")]
#[test]
fn impl_i8x16_ser_de_roundtrip() {
let serialized =
bincode::serialize(&i8x16::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(i8x16::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_i8x32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,13 @@ fn test_i8x32_swizzle_half() {
let actual = a.swizzle_half(b);
assert_eq!(expected, actual);
}

#[cfg(feature = "serde")]
#[test]
fn impl_i8x32_ser_de_roundtrip() {
let serialized =
bincode::serialize(&i8x32::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(i8x32::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_u16x16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,13 @@ fn impl_mul_for_u16x16() {
|a, b| a.wrapping_mul(b),
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_u16x16_ser_de_roundtrip() {
let serialized =
bincode::serialize(&u16x16::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(u16x16::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_u16x8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,13 @@ fn impl_u16x8_mul_widen() {
|a, b| u32::from(a) * u32::from(b),
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_u16x8_ser_de_roundtrip() {
let serialized =
bincode::serialize(&u16x8::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(u16x8::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_u32x4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,13 @@ fn impl_u32x4_mul_keep_high() {
|a, b| ((u64::from(a) * u64::from(b)) >> 32) as u32,
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_u32x4_ser_de_roundtrip() {
let serialized =
bincode::serialize(&u32x4::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(u32x4::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_u32x8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,13 @@ fn impl_u32x8_mul_keep_high() {
|a, b| ((u64::from(a) * u64::from(b)) >> 32) as u32,
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_u32x8_ser_de_roundtrip() {
let serialized =
bincode::serialize(&u32x8::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(u32x8::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_u64x2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,13 @@ fn impl_u64x2_cmp_lt() {
|a, b| if a < b { u64::MAX } else { 0 },
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_u64x2_ser_de_roundtrip() {
let serialized =
bincode::serialize(&u64x2::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(u64x2::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_u64x4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,13 @@ fn impl_u64x4_cmp_lt() {
|a, b| if a < b { u64::MAX } else { 0 },
);
}

#[cfg(feature = "serde")]
#[test]
fn impl_u64x4_ser_de_roundtrip() {
let serialized =
bincode::serialize(&u64x4::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(u64x4::ZERO, deserialized);
}
10 changes: 10 additions & 0 deletions tests/all_tests/t_u8x16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,13 @@ fn impl_narrow_i16x8() {
let c: [u8; 16] = u8x16::narrow_i16x8(a, b).into();
assert_eq!(c, [0, 2, 0, 4, 0, 6, 0, 8, 9, 10, 11, 12, 13, 0, 15, 0]);
}

#[cfg(feature = "serde")]
#[test]
fn impl_u8x16_ser_de_roundtrip() {
let serialized =
bincode::serialize(&u8x16::ZERO).expect("serialization failed");
let deserialized =
bincode::deserialize(&serialized).expect("deserializaion failed");
assert_eq!(u8x16::ZERO, deserialized);
}

0 comments on commit fb55237

Please sign in to comment.