Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Option / Result helpers #1481

Merged
merged 7 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 133 additions & 26 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
//! Prelude extension - available in all contexts, defining common types,
//! operations and constants.
use itertools::Itertools;
use lazy_static::lazy_static;

use crate::extension::simple_op::MakeOpDef;
use crate::ops::constant::{CustomCheckFailure, ValueName};
use crate::ops::{ExtensionOp, OpName};
use crate::types::{FuncValueType, SumType, TypeName, TypeRV};
use crate::{
extension::{ExtensionId, TypeDefBound},
ops::constant::CustomConst,
type_row,
types::{
type_param::{TypeArg, TypeParam},
CustomType, PolyFuncTypeRV, Signature, Type, TypeBound,
},
Extension,
use crate::extension::const_fold::fold_out_row;
use crate::extension::simple_op::{
try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
};
use crate::extension::{
ConstFold, ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDefBound,
};
use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName};
use crate::ops::{ExtensionOp, NamedOp, OpName, Value};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{
CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeBound,
TypeName, TypeRV, TypeRow, TypeRowRV,
};
use crate::utils::sorted_consts;
use crate::{type_row, Extension};

use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::{
extension::{
const_fold::fold_out_row,
simple_op::{try_from_name, MakeExtensionOp, MakeRegisteredOp, OpLoadError},
ConstFold, ExtensionSet, OpDef, SignatureError, SignatureFunc,
},
ops::{NamedOp, Value},
types::{PolyFuncType, TypeRow},
utils::sorted_consts,
};

use super::{ExtensionRegistry, SignatureFromArgs};
struct ArrayOpCustom;

Expand Down Expand Up @@ -255,8 +247,92 @@ pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE);
pub const ERROR_TYPE_NAME: TypeName = TypeName::new_inline("error");

/// Return a Sum type with the first variant as the given type and the second an Error.
pub fn sum_with_error(ty: Type) -> SumType {
SumType::new([ty, ERROR_TYPE])
pub fn sum_with_error(ty: impl Into<TypeRowRV>) -> SumType {
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
result_type(ty, ERROR_TYPE)
}

/// An optional type, i.e. a Sum type with the first variant as the given type and the second as an empty tuple.
#[inline]
pub fn option_type(ty: impl Into<TypeRowRV>) -> SumType {
result_type(ty, TypeRow::new())
}

/// A result type, i.e. a two-element Sum type where the first variant
/// represents the "Ok" value, and the second is the "Error" value.
#[inline]
pub fn result_type(ty_ok: impl Into<TypeRowRV>, ty_err: impl Into<TypeRowRV>) -> SumType {
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
SumType::new([ty_ok.into(), ty_err.into()])
}

/// A constant optional value with a given value.
///
/// See [option_type].
pub fn const_some(value: Value) -> Value {
const_some_tuple([value])
}

/// A constant optional value with a row of values.
///
/// For single values, use [const_some].
///
/// See [option_type].
pub fn const_some_tuple(values: impl IntoIterator<Item = Value>) -> Value {
const_ok_tuple(values, TypeRow::new())
}

/// A constant optional value with no value.
///
/// See [option_type].
pub fn const_none(ty: impl Into<TypeRowRV>) -> Value {
const_err_tuple(ty, [])
}

/// A constant result value with an Ok value.
///
/// See [result_type].
pub fn const_ok(value: Value, ty_err: impl Into<TypeRowRV>) -> Value {
const_ok_tuple([value], ty_err)
}

/// A constant result value with a row of Ok values.
///
/// See [result_type].
pub fn const_ok_tuple(
values: impl IntoIterator<Item = Value>,
ty_err: impl Into<TypeRowRV>,
) -> Value {
let values = values.into_iter().collect_vec();
let types: TypeRowRV = values
.iter()
.map(|v| TypeRV::from(v.get_type()))
.collect_vec()
.into();
let typ = result_type(types, ty_err);
Value::sum(0, values, typ).unwrap()
}

/// A constant result value with an Err value.
///
/// See [result_type].
pub fn const_err(ty_ok: impl Into<TypeRowRV>, value: Value) -> Value {
const_err_tuple(ty_ok, [value])
}

/// A constant result value with a row of Err values.
///
/// See [result_type].
pub fn const_err_tuple(
ty_ok: impl Into<TypeRowRV>,
values: impl IntoIterator<Item = Value>,
) -> Value {
let values = values.into_iter().collect_vec();
let types: TypeRowRV = values
.iter()
.map(|v| TypeRV::from(v.get_type()))
.collect_vec()
.into();
let typ = result_type(ty_ok, types);
Value::sum(1, values, typ).unwrap()
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
Expand Down Expand Up @@ -806,6 +882,8 @@ impl MakeRegisteredOp for Lift {

#[cfg(test)]
mod test {
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::{
builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr},
utils::test_quantum_extension::cx_gate,
Expand Down Expand Up @@ -897,6 +975,35 @@ mod test {
b.finish_prelude_hugr_with_outputs(out.outputs()).unwrap();
}

#[test]
fn test_option() {
let typ: Type = option_type(BOOL_T).into();
let const_val1 = const_some(Value::true_val());
let const_val2 = const_none(BOOL_T);

let mut b = DFGBuilder::new(inout_sig(type_row![], vec![typ.clone(), typ])).unwrap();

let some = b.add_load_value(const_val1);
let none = b.add_load_value(const_val2);

b.finish_prelude_hugr_with_outputs([some, none]).unwrap();
}

#[test]
fn test_result() {
let typ: Type = result_type(BOOL_T, FLOAT64_TYPE).into();
let const_bool = const_ok(Value::true_val(), FLOAT64_TYPE);
let const_float = const_err(BOOL_T, ConstF64::new(0.5).into());

let mut b = DFGBuilder::new(inout_sig(type_row![], vec![typ.clone(), typ])).unwrap();

let bool = b.add_load_value(const_bool);
let float = b.add_load_value(const_float);

b.finish_hugr_with_outputs([bool, float], &FLOAT_OPS_REGISTRY)
.unwrap();
}

#[test]
/// test the prelude error type and panic op.
fn test_error_type() {
Expand Down
37 changes: 37 additions & 0 deletions hugr-py/src/hugr/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from hugr.utils import ser_it

if TYPE_CHECKING:
from collections.abc import Iterable

from hugr import ext


Expand Down Expand Up @@ -303,6 +305,41 @@ def __repr__(self) -> str:
return f"Tuple{tuple(self.variant_rows[0])}"


@dataclass(eq=False)
class Option(Sum):
"""Optional tuple of elements.

Instances of this type correspond to :class:`Sum` with two variants.
The first variant is the tuple of elements, the second is empty.
"""

def __init__(self, *tys: Type):
self.variant_rows = [list(tys), []]

def __repr__(self) -> str:
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
return f"Option({', '.join(map(repr, self.variant_rows[0]))})"


@dataclass(eq=False)
class Result(Sum):
"""Fallible tuple of elements.

Instances of this type correspond to :class:`Sum` with two variants. The
first variant is a tuple of elements representing the successful state, the
second is a tuple of elements representing failure.
"""

def __init__(self, ok: Iterable[Type], err: Iterable[Type]):
self.variant_rows = [list(ok), list(err)]

def __repr__(self) -> str:
ok = self.variant_rows[0]
err = self.variant_rows[1]
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
ok_str = ok[0] if len(ok) == 1 else tuple(ok)
err_str = err[0] if len(err) == 1 else tuple(err)
return f"Result({ok_str}, {err_str})"


@dataclass(frozen=True)
class Variable(Type):
"""A type variable with a given bound, identified by index."""
Expand Down
110 changes: 110 additions & 0 deletions hugr-py/src/hugr/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from hugr.utils import ser_it

if TYPE_CHECKING:
from collections.abc import Iterable

from hugr.hugr import Hugr


Expand Down Expand Up @@ -149,6 +151,114 @@ def __repr__(self) -> str:
return f"Tuple({', '.join(map(repr, self.vals))})"


@dataclass
class Some(Sum):
"""Optional tuple of value, containing a list of values.
Internally a :class:`Sum` with two variant rows.

Example:
>>> some = Some(TRUE, FALSE)
>>> some
Some(TRUE, FALSE)
>>> some.type_()
Option(Bool, Bool)

"""

#: The values of this tuple.
vals: list[Value]

def __init__(self, *vals: Value):
val_list = list(vals)
super().__init__(
tag=0, typ=tys.Option(*(v.type_() for v in val_list)), vals=val_list
)

def __repr__(self) -> str:
return f"Some({', '.join(map(repr, self.vals))})"


@dataclass
class None_(Sum):
"""Optional tuple of value, containing no values.
Internally a :class:`Sum` with two variant rows.

Example:
>>> none = None_(tys.Bool)
>>> none
None(Bool)
>>> none.type_()
Option(Bool)

"""

def __init__(self, *types: tys.Type):
super().__init__(tag=1, typ=tys.Option(*types), vals=[])

def __repr__(self) -> str:
return f"None({', '.join(map(repr, self.typ.variant_rows[0]))})"


@dataclass
class Ok(Sum):
"""Success variant of a :class:`tys.Result` type, containing a list of values.

Internally a :class:`Sum` with two variant rows.

Example:
>>> ok = Ok([TRUE, FALSE], [tys.Bool])
>>> ok
Ok((TRUE, FALSE), Bool)
>>> ok.type_()
Result((Bool, Bool), Bool)
"""

#: The values of this tuple.
vals: list[Value]

def __init__(self, vals: Iterable[Value], err_typ: Iterable[tys.Type]):
val_list = list(vals)
super().__init__(
tag=0, typ=tys.Result([v.type_() for v in val_list], err_typ), vals=val_list
)

def __repr__(self) -> str:
vals_str = self.vals[0] if len(self.vals) == 1 else tuple(self.vals)
err = self.typ.variant_rows[1]
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
err_str = err[0] if len(err) == 1 else tuple(err)
return f"Ok({vals_str}, {err_str})"
ss2165 marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class Err(Sum):
"""Error variant of a :class:`tys.Result` type, containing a list of values.

Internally a :class:`Sum` with two variant rows.

Example:
>>> err = Err([tys.Bool, tys.Bool], [TRUE, FALSE])
>>> err
Err((Bool, Bool), (TRUE, FALSE))
>>> err.type_()
Result((Bool, Bool), (Bool, Bool))
"""

#: The values of this tuple.
vals: list[Value]

def __init__(self, ok_typ: Iterable[tys.Type], vals: Iterable[Value]):
val_list = list(vals)
super().__init__(
tag=1, typ=tys.Result(ok_typ, [v.type_() for v in val_list]), vals=val_list
)

def __repr__(self) -> str:
ok = self.typ.variant_rows[1]
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
ok_str = ok[0] if len(ok) == 1 else tuple(ok)
vals_str = self.vals[0] if len(self.vals) == 1 else tuple(self.vals)
return f"Err({ok_str}, {vals_str})"


@dataclass
class Function(Value):
"""Higher order function value, defined by a :class:`Hugr <hugr.hugr.HUGR>`."""
Expand Down
Loading