Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Option / Result helpers #1481

Merged
merged 7 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 143 additions & 26 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
//! Prelude extension - available in all contexts, defining common types,
//! operations and constants.
use itertools::Itertools;
use lazy_static::lazy_static;

use crate::extension::simple_op::MakeOpDef;
use crate::ops::constant::{CustomCheckFailure, ValueName};
use crate::ops::{ExtensionOp, OpName};
use crate::types::{FuncValueType, SumType, TypeName, TypeRV};
use crate::{
extension::{ExtensionId, TypeDefBound},
ops::constant::CustomConst,
type_row,
types::{
type_param::{TypeArg, TypeParam},
CustomType, PolyFuncTypeRV, Signature, Type, TypeBound,
},
Extension,
use crate::extension::const_fold::fold_out_row;
use crate::extension::simple_op::{
try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
};
use crate::extension::{
ConstFold, ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDefBound,
};
use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName};
use crate::ops::{ExtensionOp, NamedOp, OpName, Value};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{
CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeBound,
TypeName, TypeRV, TypeRow, TypeRowRV,
};
use crate::utils::sorted_consts;
use crate::{type_row, Extension};

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::{
extension::{
const_fold::fold_out_row,
simple_op::{try_from_name, MakeExtensionOp, MakeRegisteredOp, OpLoadError},
ConstFold, ExtensionSet, OpDef, SignatureError, SignatureFunc,
},
ops::{NamedOp, Value},
types::{PolyFuncType, TypeRow},
utils::sorted_consts,
};

use super::{ExtensionRegistry, SignatureFromArgs};
struct ArrayOpCustom;

Expand Down Expand Up @@ -255,8 +247,102 @@ pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE);
pub const ERROR_TYPE_NAME: TypeName = TypeName::new_inline("error");

/// Return a Sum type with the first variant as the given type and the second an Error.
pub fn sum_with_error(ty: Type) -> SumType {
SumType::new([ty, ERROR_TYPE])
pub fn sum_with_error(ty: impl Into<TypeRowRV>) -> SumType {
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
either_type(ty, ERROR_TYPE)
}

/// An optional type, i.e. a Sum type with the first variant as the given type and the second as an empty tuple.
#[inline]
pub fn option_type(ty: impl Into<TypeRowRV>) -> SumType {
either_type(ty, TypeRow::new())
}

/// An "either" type, i.e. a Sum type with a "left" and a "right" variant.
///
/// When used as a fallible value, the "left" variant represents a successful computation,
/// and the "right" variant represents a failure.
#[inline]
pub fn either_type(ty_ok: impl Into<TypeRowRV>, ty_err: impl Into<TypeRowRV>) -> SumType {
SumType::new([ty_ok.into(), ty_err.into()])
}

/// A constant optional value with a given value.
///
/// See [option_type].
pub fn const_some(value: Value) -> Value {
const_some_tuple([value])
}

/// A constant optional value with a row of values.
///
/// For single values, use [const_some].
///
/// See [option_type].
pub fn const_some_tuple(values: impl IntoIterator<Item = Value>) -> Value {
const_left_tuple(values, TypeRow::new())
}

/// A constant optional value with no value.
///
/// See [option_type].
pub fn const_none(ty: impl Into<TypeRowRV>) -> Value {
const_right_tuple(ty, [])
}

/// A constant Either value with a left variant.
///
/// In fallible computations, this represents a successful result.
///
/// See [either_type].
pub fn const_left(value: Value, ty_right: impl Into<TypeRowRV>) -> Value {
const_left_tuple([value], ty_right)
}

/// A constant Either value with a row of left values.
///
/// In fallible computations, this represents a successful result.
///
/// See [either_type].
pub fn const_left_tuple(
values: impl IntoIterator<Item = Value>,
ty_right: impl Into<TypeRowRV>,
) -> Value {
let values = values.into_iter().collect_vec();
let types: TypeRowRV = values
.iter()
.map(|v| TypeRV::from(v.get_type()))
.collect_vec()
.into();
let typ = either_type(types, ty_right);
Value::sum(0, values, typ).unwrap()
}

/// A constant Either value with a right variant.
///
/// In fallible computations, this represents a failure.
///
/// See [either_type].
pub fn const_right(ty_left: impl Into<TypeRowRV>, value: Value) -> Value {
const_right_tuple(ty_left, [value])
}

/// A constant Either value with a row of right values.
///
/// In fallible computations, this represents a failure.
///
/// See [either_type].
pub fn const_right_tuple(
ty_left: impl Into<TypeRowRV>,
values: impl IntoIterator<Item = Value>,
) -> Value {
let values = values.into_iter().collect_vec();
let types: TypeRowRV = values
.iter()
.map(|v| TypeRV::from(v.get_type()))
.collect_vec()
.into();
let typ = either_type(ty_left, types);
Value::sum(1, values, typ).unwrap()
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -806,6 +892,8 @@ impl MakeRegisteredOp for Lift {

#[cfg(test)]
mod test {
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::{
builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr},
utils::test_quantum_extension::cx_gate,
Expand Down Expand Up @@ -897,6 +985,35 @@ mod test {
b.finish_prelude_hugr_with_outputs(out.outputs()).unwrap();
}

#[test]
fn test_option() {
let typ: Type = option_type(BOOL_T).into();
let const_val1 = const_some(Value::true_val());
let const_val2 = const_none(BOOL_T);

let mut b = DFGBuilder::new(inout_sig(type_row![], vec![typ.clone(), typ])).unwrap();

let some = b.add_load_value(const_val1);
let none = b.add_load_value(const_val2);

b.finish_prelude_hugr_with_outputs([some, none]).unwrap();
}

#[test]
fn test_result() {
let typ: Type = either_type(BOOL_T, FLOAT64_TYPE).into();
let const_bool = const_left(Value::true_val(), FLOAT64_TYPE);
let const_float = const_right(BOOL_T, ConstF64::new(0.5).into());

let mut b = DFGBuilder::new(inout_sig(type_row![], vec![typ.clone(), typ])).unwrap();

let bool = b.add_load_value(const_bool);
let float = b.add_load_value(const_float);

b.finish_hugr_with_outputs([bool, float], &FLOAT_OPS_REGISTRY)
.unwrap();
}

#[test]
/// test the prelude error type and panic op.
fn test_error_type() {
Expand Down
48 changes: 48 additions & 0 deletions hugr-py/src/hugr/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from hugr.utils import ser_it

if TYPE_CHECKING:
from collections.abc import Iterable

from hugr import ext


Expand Down Expand Up @@ -303,6 +305,52 @@ def __repr__(self) -> str:
return f"Tuple{tuple(self.variant_rows[0])}"


@dataclass(eq=False)
class Option(Sum):
"""Optional tuple of elements.

Instances of this type correspond to :class:`Sum` with two variants.
The first variant is the tuple of elements, the second is empty.
"""

def __init__(self, *tys: Type):
self.variant_rows = [list(tys), []]

def __repr__(self) -> str:
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
return f"Option({', '.join(map(repr, self.variant_rows[0]))})"


@dataclass(eq=False)
class Either(Sum):
"""Two-variant tuple of elements.

Instances of this type correspond to :class:`Sum` with a Left and a Right variant.

In fallible contexts, the Left variant is used to represent success, and the
Right variant is used to represent failure.

Example:
>>> either = Either([Bool, Bool], [Bool])
>>> either
Either(left=[Bool, Bool], right=[Bool])
>>> str(either)
'Either((Bool, Bool), Bool)'
"""

def __init__(self, left: Iterable[Type], right: Iterable[Type]):
self.variant_rows = [list(left), list(right)]

def __repr__(self) -> str: # pragma: no cover
left, right = self.variant_rows
return f"Either(left={left}, right={right})"

def __str__(self) -> str:
left, right = self.variant_rows
left_str = left[0] if len(left) == 1 else tuple(left)
right_str = right[0] if len(right) == 1 else tuple(right)
return f"Either({left_str}, {right_str})"


@dataclass(frozen=True)
class Variable(Type):
"""A type variable with a given bound, identified by index."""
Expand Down
Loading
Loading