Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use type schemes in extension definitions wherever possible #678

Merged
merged 15 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,23 @@ impl Extension {
SignatureFunc::TypeScheme(type_scheme),
)
}

/// Create an OpDef with a signature (inputs+outputs) read from e.g.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

/// declarative YAML; and no "misc" or "lowering functions" defined.
pub fn add_op_type_scheme_simple(
&mut self,
name: SmolStr,
description: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, how about impl Into<String> ? Think that would let you avoid the to_owneds

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the other methods take String so I think I'd rather leave that for a
separate issue

type_scheme: PolyFuncType,
) -> Result<&OpDef, ExtensionBuildError> {
self.add_op(
name,
description,
Default::default(),
vec![],
SignatureFunc::TypeScheme(type_scheme),
)
}
}

#[cfg(test)]
Expand Down
5 changes: 5 additions & 0 deletions src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple(
/// The string name of the error type.
pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error");

/// Return a Sum type with the first variant as the given type and the second an Error.
pub fn sum_with_error(ty: Type) -> Type {
Type::new_sum(vec![ty, ERROR_TYPE])
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// Structure for holding constant usize values.
pub struct ConstUsize(u64);
Expand Down
10 changes: 5 additions & 5 deletions src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::macros::const_extension_ids;
use crate::ops::dataflow::IOTrait;
use crate::ops::{self, LeafOp, OpType};
use crate::std_extensions::logic;
use crate::std_extensions::logic::test::{and_op, not_op};
use crate::std_extensions::logic::test::{and_op, not_op, or_op};
use crate::types::type_param::{TypeArg, TypeArgError, TypeParam};
use crate::types::{CustomType, FunctionType, Type, TypeBound, TypeRow};
use crate::{type_row, Direction, IncomingPort, Node};
Expand Down Expand Up @@ -612,12 +612,12 @@ fn dfg_with_cycles() -> Result<(), HugrError> {
type_row![BOOL_T],
));
let [input, output] = h.get_io(h.root()).unwrap();
let and = h.add_node_with_parent(h.root(), and_op())?;
let or = h.add_node_with_parent(h.root(), or_op())?;
let not1 = h.add_node_with_parent(h.root(), not_op())?;
let not2 = h.add_node_with_parent(h.root(), not_op())?;
h.connect(input, 0, and, 0)?;
h.connect(and, 0, not1, 0)?;
h.connect(not1, 0, and, 1)?;
h.connect(input, 0, or, 0)?;
h.connect(or, 0, not1, 0)?;
h.connect(not1, 0, or, 1)?;
h.connect(input, 1, not2, 0)?;
h.connect(not2, 0, output, 0)?;
// The graph contains a cycle:
Expand Down
61 changes: 30 additions & 31 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,34 @@
//! Conversions between integer and floating-point values.

use crate::{
extension::{ExtensionId, ExtensionSet, SignatureError},
extension::{
prelude::sum_with_error, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError,
PRELUDE,
},
type_row,
types::{type_param::TypeArg, FunctionType, Type},
utils::collect_array,
types::{FunctionType, PolyFuncType},
Extension,
};

use super::int_types::int_type;
use super::int_types::int_type_var;
use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};

/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");

fn ftoi_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok(FunctionType::new(
fn ftoi_sig(temp_reg: &ExtensionRegistry) -> Result<PolyFuncType, SignatureError> {
let body = FunctionType::new(
type_row![FLOAT64_TYPE],
vec![Type::new_sum(vec![
int_type(arg.clone()),
crate::extension::prelude::ERROR_TYPE,
])],
))
vec![sum_with_error(int_type_var(0))],
);

PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg)
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
}

fn itof_sig(arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
let [arg] = collect_array(arg_values);
Ok(FunctionType::new(
vec![int_type(arg.clone())],
type_row![FLOAT64_TYPE],
))
fn itof_sig(temp_reg: &ExtensionRegistry) -> Result<PolyFuncType, SignatureError> {
let body = FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]);

PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg)
}

/// Extension for basic arithmetic operations.
Expand All @@ -42,37 +40,38 @@ pub fn extension() -> Extension {
super::float_types::EXTENSION_ID,
]),
);

let temp_reg: ExtensionRegistry = [
super::int_types::EXTENSION.to_owned(),
super::float_types::extension(),
PRELUDE.to_owned(),
]
.into();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"trunc_u".into(),
"float to unsigned int".to_owned(),
vec![LOG_WIDTH_TYPE_PARAM],
ftoi_sig,
ftoi_sig(&temp_reg).unwrap(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"trunc_s".into(),
"float to signed int".to_owned(),
vec![LOG_WIDTH_TYPE_PARAM],
ftoi_sig,
ftoi_sig(&temp_reg).unwrap(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"convert_u".into(),
"unsigned int to float".to_owned(),
vec![LOG_WIDTH_TYPE_PARAM],
itof_sig,
itof_sig(&temp_reg).unwrap(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"convert_s".into(),
"signed int to float".to_owned(),
vec![LOG_WIDTH_TYPE_PARAM],
itof_sig,
itof_sig(&temp_reg).unwrap(),
)
.unwrap();

Expand Down
79 changes: 31 additions & 48 deletions src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! Basic floating-point operations.

use crate::{
extension::{ExtensionId, ExtensionSet, SignatureError},
extension::{ExtensionId, ExtensionSet},
type_row,
types::{type_param::TypeArg, FunctionType},
types::{FunctionType, PolyFuncType},
Extension,
};

Expand All @@ -12,106 +12,89 @@ use super::float_types::FLOAT64_TYPE;
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float");

fn fcmp_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
type_row![FLOAT64_TYPE; 2],
type_row![crate::extension::prelude::BOOL_T],
))
}

fn fbinop_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
type_row![FLOAT64_TYPE; 2],
type_row![FLOAT64_TYPE],
))
}

fn funop_sig(_arg_values: &[TypeArg]) -> Result<FunctionType, SignatureError> {
Ok(FunctionType::new(
type_row![FLOAT64_TYPE],
type_row![FLOAT64_TYPE],
))
}

/// Extension for basic arithmetic operations.
pub fn extension() -> Extension {
let mut extension = Extension::new_with_reqs(
EXTENSION_ID,
ExtensionSet::singleton(&super::float_types::EXTENSION_ID),
);

let fcmp_sig: PolyFuncType = FunctionType::new(
type_row![FLOAT64_TYPE; 2],
type_row![crate::extension::prelude::BOOL_T],
)
.into();
let fbinop_sig: PolyFuncType =
FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]).into();
let funop_sig: PolyFuncType =
FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into();
extension
.add_op_custom_sig_simple("feq".into(), "equality test".to_owned(), vec![], fcmp_sig)
.add_op_type_scheme_simple("feq".into(), "equality test".to_owned(), fcmp_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fne".into(), "inequality test".to_owned(), vec![], fcmp_sig)
.add_op_type_scheme_simple("fne".into(), "inequality test".to_owned(), fcmp_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("flt".into(), "\"less than\"".to_owned(), vec![], fcmp_sig)
.add_op_type_scheme_simple("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"fgt".into(),
"\"greater than\"".to_owned(),
vec![],
fcmp_sig,
fcmp_sig.clone(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"fle".into(),
"\"less than or equal\"".to_owned(),
vec![],
fcmp_sig,
fcmp_sig.clone(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"fge".into(),
"\"greater than or equal\"".to_owned(),
vec![],
fcmp_sig,
)
.unwrap();
extension
.add_op_custom_sig_simple("fmax".into(), "maximum".to_owned(), vec![], fbinop_sig)
.add_op_type_scheme_simple("fmax".into(), "maximum".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fmin".into(), "minimum".to_owned(), vec![], fbinop_sig)
.add_op_type_scheme_simple("fmin".into(), "minimum".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fadd".into(), "addition".to_owned(), vec![], fbinop_sig)
.add_op_type_scheme_simple("fadd".into(), "addition".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fsub".into(), "subtraction".to_owned(), vec![], fbinop_sig)
.add_op_type_scheme_simple("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fneg".into(), "negation".to_owned(), vec![], funop_sig)
.add_op_type_scheme_simple("fneg".into(), "negation".to_owned(), funop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"fabs".into(),
"absolute value".to_owned(),
vec![],
funop_sig,
funop_sig.clone(),
)
.unwrap();
extension
.add_op_custom_sig_simple(
.add_op_type_scheme_simple(
"fmul".into(),
"multiplication".to_owned(),
vec![],
fbinop_sig,
fbinop_sig.clone(),
)
.unwrap();
extension
.add_op_custom_sig_simple("fdiv".into(), "division".to_owned(), vec![], fbinop_sig)
.add_op_type_scheme_simple("fdiv".into(), "division".to_owned(), fbinop_sig)
.unwrap();
extension
.add_op_custom_sig_simple("ffloor".into(), "floor".to_owned(), vec![], funop_sig)
.add_op_type_scheme_simple("ffloor".into(), "floor".to_owned(), funop_sig.clone())
.unwrap();
extension
.add_op_custom_sig_simple("fceil".into(), "ceiling".to_owned(), vec![], funop_sig)
.add_op_type_scheme_simple("fceil".into(), "ceiling".to_owned(), funop_sig)
.unwrap();

extension
Expand Down
5 changes: 5 additions & 0 deletions src/std_extensions/arithmetic/float_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ mod test {
fn test_float_consts() {
let const_f64_1 = ConstF64::new(1.0);
let const_f64_2 = ConstF64::new(2.0);

assert_eq!(const_f64_1.value(), 1.0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated extra test coverage - doesn't need your custom consts PR or anything?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, the commits are separate so i could cherry pick them, but its "related" in the sense that the changes in this PR highlight the remaining missing coverage better

assert_eq!(*const_f64_2, 2.0);
assert_eq!(const_f64_1.name(), "f64(1)");
assert!(const_f64_1.equal_consts(&ConstF64::new(1.0)));
assert_ne!(const_f64_1, const_f64_2);
assert_eq!(const_f64_1, ConstF64::new(1.0));
}
Expand Down
Loading