Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Instantiate inferred extensions #461

Merged
merged 7 commits into from
Aug 31, 2023
23 changes: 15 additions & 8 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,21 +881,28 @@ mod test {
let [w] = mult.outputs_arr();

builder.set_outputs([w])?;
let hugr = builder.base;
// TODO: when we put new extensions onto the graph after inference, we
// can call `finish_hugr` and just look at the graph
let (solution, extra) = infer_extensions(&hugr)?;
assert!(extra.is_empty());
let mut hugr = builder.base;
let closure = hugr.infer_extensions()?;
assert!(closure.is_empty());
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(
*solution.get(&(src.node(), Direction::Outgoing)).unwrap(),
hugr.get_nodetype(src.node())
.signature()
.unwrap()
.output_extensions(),
rs
);
assert_eq!(
*solution.get(&(mult.node(), Direction::Incoming)).unwrap(),
hugr.get_nodetype(mult.node())
.signature()
.unwrap()
.input_extensions,
rs
);
assert_eq!(
*solution.get(&(mult.node(), Direction::Outgoing)).unwrap(),
hugr.get_nodetype(mult.node())
.signature()
.unwrap()
.output_extensions(),
rs
);
Ok(())
Expand Down
81 changes: 76 additions & 5 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ impl Hugr {
rw.apply(self)
}

/// Infer extension requirements
/// Infer extension requirements and add new information to `op_types` field
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a doclink to infer::infer_extensions here?
And, should the return type here be ExtensionSolution? (A typedef to the same type as here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be, yes. I was trying to distinguish the genuine solution (as ExtensionSolution) from the closure, but they really should be the same type regardless

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now addressed both of these in 8e55229 and 1802ccd, respectively

pub fn infer_extensions(
&mut self,
) -> Result<HashMap<(Node, Direction), ExtensionSet>, InferExtensionError> {
Expand All @@ -202,9 +202,22 @@ impl Hugr {
Ok(extension_closure)
}

/// TODO: Write this
fn instantiate_extensions(&mut self, _solution: ExtensionSolution) {
//todo!()
/// Add extension requirement information to the hugr in place.
fn instantiate_extensions(&mut self, solution: ExtensionSolution) {
// We only care about inferred _input_ extensions, because `NodeType`
// uses those to infer the output extensions
for ((node, _), input_extensions) in solution
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth separating solution into two parts? Or only returning the input extensions? (Is there any use for the output ones?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth only returning the input extensions. As you say, there's no use for the rest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll open a separate PR for this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case - this PR looks ok as it stands, if that's a followup; or do you want to roll that in here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, I'll do it as a followup, no need for it to block this

.iter()
.filter(|((_, dir), _)| *dir == Direction::Incoming)
{
let nodetype = self.op_types.try_get_mut(node.index).unwrap();
match &nodetype.input_extensions {
None => nodetype.input_extensions = Some(input_extensions.clone()),
Some(existing_ext_reqs) => {
acl-cqc marked this conversation as resolved.
Show resolved Hide resolved
debug_assert_eq!(existing_ext_reqs, input_extensions)
}
}
}
}
}

Expand Down Expand Up @@ -428,7 +441,14 @@ impl From<HugrError> for PyErr {

#[cfg(test)]
mod test {
use super::Hugr;
use super::{Hugr, HugrView, NodeType};
use crate::extension::ExtensionSet;
use crate::hugr::HugrMut;
use crate::ops;
use crate::type_row;
use crate::types::{FunctionType, Type};

use std::error::Error;

#[test]
fn impls_send_and_sync() {
Expand All @@ -447,4 +467,55 @@ mod test {
let hugr = simple_dfg_hugr();
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}

#[test]
fn extension_instantiation() -> Result<(), Box<dyn Error>> {
const BIT: Type = crate::extension::prelude::USIZE_T;
let r = ExtensionSet::singleton(&"R".into());

let root = NodeType::pure(ops::DFG {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driveby nit: NodeType::pure seems a bit misnamed. The so-called "purity" is not about the nodetype - it's about what's already present on the inputs before we even get to that node. A "pure" nodetype would be one with an empty delta; pure should be called new_no_extensions or something...."no extensions" == unextended(??) with synonyms "brief", "truncated", "compact" and "concise". I admit none of these are great. "Basic"? "Retracted" (!)?? Maybe just NodeType::no_extensions paralleling open_extensions ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair, but I'm not sure what the better name is. Maybe open_extensions => open and pure => closed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - let's leave renaming to another PR, but I see two reasonable schemes: open_extensions and no_extensions or open and closed (where currently we have open_extensions and pure)

signature: FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&r),
});
let mut hugr = Hugr::new(root);
let input = hugr.add_node_with_parent(
hugr.root(),
NodeType::pure(ops::Input {
types: type_row![BIT],
}),
)?;
let output = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Output {
types: type_row![BIT],
}),
)?;
let lift = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: type_row![BIT],
new_extension: "R".into(),
}),
)?;
hugr.connect(input, 0, lift, 0)?;
hugr.connect(lift, 0, output, 0)?;
hugr.infer_extensions()?;

assert_eq!(
hugr.op_types
.get(lift.index)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
);
assert_eq!(
hugr.op_types
.get(output.index)
.signature()
.unwrap()
.input_extensions,
r
);
Ok(())
}
}