Skip to content

Commit

Permalink
fix: Use internal tag for SumType enum serialisation (#462)
Browse files Browse the repository at this point in the history
This makes it easier to handle the serialised format with pydantic.

BREAKING CHANGE: Turn `SumType.General` and `SumType.Simple` enum
variants into struct variants
  • Loading branch information
mark-koch authored Aug 29, 2023
1 parent ae81e42 commit 550df7f
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,31 +94,36 @@ pub(crate) fn least_upper_bound(mut tags: impl Iterator<Item = TypeBound>) -> Ty
}

#[derive(Clone, PartialEq, Debug, Eq, derive_more::Display, Serialize, Deserialize)]
#[serde(tag = "s")]
/// Representation of a Sum type.
/// Either store the types of the variants, or in the special (but common) case
/// of a "simple predicate" (sum over empty tuples), store only the size of the predicate.
enum SumType {
#[display(fmt = "SimplePredicate({})", "_0")]
Simple(u8),
General(TypeRow),
#[display(fmt = "SimplePredicate({})", "size")]
Simple {
size: u8,
},
General {
row: TypeRow,
},
}

impl SumType {
fn new(types: impl Into<TypeRow>) -> Self {
let row: TypeRow = types.into();

let len = row.len();
let len: usize = row.len();
if len <= (u8::MAX as usize) && row.iter().all(|t| *t == Type::UNIT) {
Self::Simple(len as u8)
Self::Simple { size: len as u8 }
} else {
Self::General(row)
Self::General { row }
}
}

fn get_variant(&self, tag: usize) -> Option<&Type> {
match self {
SumType::Simple(size) if tag < (*size as usize) => Some(Type::UNIT_REF),
SumType::General(row) => row.get(tag),
SumType::Simple { size } if tag < (*size as usize) => Some(Type::UNIT_REF),
SumType::General { row } => row.get(tag),
_ => None,
}
}
Expand All @@ -127,8 +132,8 @@ impl SumType {
impl From<SumType> for Type {
fn from(sum: SumType) -> Type {
match sum {
SumType::Simple(size) => Type::new_simple_predicate(size),
SumType::General(types) => Type::new_sum(types),
SumType::Simple { size } => Type::new_simple_predicate(size),
SumType::General { row } => Type::new_sum(row),
}
}
}
Expand All @@ -147,9 +152,9 @@ impl TypeEnum {
fn least_upper_bound(&self) -> TypeBound {
match self {
TypeEnum::Prim(p) => p.bound(),
TypeEnum::Sum(SumType::Simple(_)) => TypeBound::Eq,
TypeEnum::Sum(SumType::General(ts)) => {
least_upper_bound(ts.iter().map(Type::least_upper_bound))
TypeEnum::Sum(SumType::Simple { size: _ }) => TypeBound::Eq,
TypeEnum::Sum(SumType::General { row }) => {
least_upper_bound(row.iter().map(Type::least_upper_bound))
}
TypeEnum::Tuple(ts) => least_upper_bound(ts.iter().map(Type::least_upper_bound)),
}
Expand Down Expand Up @@ -237,7 +242,7 @@ impl Type {
/// New simple predicate with empty Tuple variants
pub const fn new_simple_predicate(size: u8) -> Self {
// should be the only way to avoid going through SumType::new
Self(TypeEnum::Sum(SumType::Simple(size)), TypeBound::Eq)
Self(TypeEnum::Sum(SumType::Simple { size }), TypeBound::Eq)
}

/// Report the least upper TypeBound, if there is one.
Expand Down Expand Up @@ -305,7 +310,7 @@ pub(crate) mod test {

assert_eq!(pred1, pred2);

let pred_direct = SumType::Simple(2);
let pred_direct = SumType::Simple { size: 2 };
assert_eq!(pred1, pred_direct.into())
}
}

0 comments on commit 550df7f

Please sign in to comment.