diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index c97427eb7..36b65e064 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -824,10 +824,7 @@ impl<'a> Context<'a> { TypeEnum::Function(func) => self.export_func_type(func), TypeEnum::Variable(index, _) => { let node = self.local_scope.expect("local variable out of scope"); - self.make_term(model::Term::Var { - node, - index: *index as _, - }) + self.make_term(model::Term::Var(model::VarId(node, *index as _))) } TypeEnum::RowVar(rv) => self.export_row_var(rv.as_rv()), TypeEnum::Sum(sum) => self.export_sum_type(sum), @@ -876,18 +873,12 @@ impl<'a> Context<'a> { pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> model::TermId { let node = self.local_scope.expect("local variable out of scope"); - self.make_term(model::Term::Var { - node, - index: var.index() as _, - }) + self.make_term(model::Term::Var(model::VarId(node, var.index() as _))) } pub fn export_row_var(&mut self, t: &RowVariable) -> model::TermId { let node = self.local_scope.expect("local variable out of scope"); - self.make_term(model::Term::Var { - node, - index: t.0 as _, - }) + self.make_term(model::Term::Var(model::VarId(node, t.0 as _))) } pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId { @@ -956,7 +947,7 @@ impl<'a> Context<'a> { match t { TypeParam::Type { b } => { if let (Some((node, index)), TypeBound::Copyable) = (var, b) { - let term = self.make_term(model::Term::Var { node, index }); + let term = self.make_term(model::Term::Var(model::VarId(node, index))); let non_linear = self.make_term(model::Term::NonLinearConstraint { term }); self.local_constraints.push(non_linear); } @@ -999,7 +990,7 @@ impl<'a> Context<'a> { match ext.parse::() { Ok(index) => { let node = self.local_scope.expect("local variable out of scope"); - let term = self.make_term(model::Term::Var { node, index }); + let term = self.make_term(model::Term::Var(model::VarId(node, index))); parts.push(model::ExtSetPart::Splice(term)); } Err(_) => parts.push(model::ExtSetPart::Extension(self.bump.alloc_str(ext))), diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 60a047524..db72664e6 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -859,15 +859,15 @@ impl<'a> Context<'a> { for constraint in decl.constraints { match self.get_term(*constraint)? { model::Term::NonLinearConstraint { term } => { - let model::Term::Var { node, index } = self.get_term(*term)? else { + let model::Term::Var(var) = self.get_term(*term)? else { return Err(error_unsupported!( "constraint on term that is not a variable" )); }; self.local_vars - .get_mut(&model::VarId(*node, *index)) - .ok_or_else(|| model::ModelError::InvalidVar(*node, *index))? + .get_mut(var) + .ok_or_else(|| model::ModelError::InvalidVar(*var))? .bound = TypeBound::Copyable; } _ => return Err(error_unsupported!("constraint other than copy or discard")), @@ -940,13 +940,13 @@ impl<'a> Context<'a> { Err(error_uninferred!("application with implicit parameters")) } - model::Term::Var { node, index } => { - let var = self + model::Term::Var(var) => { + let var_info = self .local_vars - .get(&model::VarId(*node, *index)) - .ok_or(model::ModelError::InvalidVar(*node, *index))?; - let decl = self.import_type_param(var.r#type, var.bound)?; - Ok(TypeArg::new_var_use(*index as _, decl)) + .get(var) + .ok_or(model::ModelError::InvalidVar(*var))?; + let decl = self.import_type_param(var_info.r#type, var_info.bound)?; + Ok(TypeArg::new_var_use(var.1 as _, decl)) } model::Term::List { .. } => { @@ -1001,7 +1001,7 @@ impl<'a> Context<'a> { match self.get_term(term_id)? { model::Term::Wildcard => return Err(error_uninferred!("wildcard")), - model::Term::Var { index, .. } => { + model::Term::Var(model::VarId(_, index)) => { es.insert_type_var(*index as _); } @@ -1067,7 +1067,7 @@ impl<'a> Context<'a> { ))) } - model::Term::Var { index, .. } => { + model::Term::Var(model::VarId(_, index)) => { Ok(TypeBase::new_var_use(*index as _, TypeBound::Copyable)) } @@ -1201,7 +1201,7 @@ impl<'a> Context<'a> { } } } - model::Term::Var { index, .. } => { + model::Term::Var(model::VarId(_, index)) => { let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Any)) .map_err(|_| model::ModelError::TypeError(term_id))?; types.push(TypeBase::new(TypeEnum::RowVar(var))); diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 50191682d..b14ca4482 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -257,7 +257,7 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::Variable(reader) => { let node = model::NodeId(reader.get_variable_node()); let index = reader.get_variable_index(); - model::Term::Var { node, index } + model::Term::Var(model::VarId(node, index)) } Which::Apply(reader) => { diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index d5a9e84f1..ea495db54 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -149,7 +149,7 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { model::Term::Type => builder.set_runtime_type(()), model::Term::StaticType => builder.set_static_type(()), model::Term::Constraint => builder.set_constraint(()), - model::Term::Var { node, index } => { + model::Term::Var(model::VarId(node, index)) => { let mut builder = builder.init_variable(); builder.set_variable_node(node.0); builder.set_variable_index(*index); diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 01e90e0e7..c13d8a589 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -157,11 +157,13 @@ define_index! { } /// The id of a link consisting of its region and the link index. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[display("{_0}#{_1}")] pub struct LinkId(pub RegionId, pub LinkIndex); /// The id of a variable consisting of its node and the variable index. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[display("{_0}#{_1}")] pub struct VarId(pub NodeId, pub VarIndex); /// A module consisting of a hugr graph together with terms. @@ -528,12 +530,7 @@ pub enum Term<'a> { Constraint, /// A local variable. - Var { - /// The node that defines the variable as a parameter. - node: NodeId, - /// The index of the variable in the parameter list of the node. - index: VarIndex, - }, + Var(VarId), /// A symbolic function application. /// @@ -722,8 +719,8 @@ pub enum ModelError { #[error("region not found: {0}")] RegionNotFound(RegionId), /// Invalid variable reference. - #[error("variable {0}#{1} invalid")] - InvalidVar(NodeId, VarIndex), + #[error("variable {0} invalid")] + InvalidVar(VarId), /// Invalid symbol reference. #[error("symbol reference {0} invalid")] InvalidSymbol(NodeId), diff --git a/hugr-model/src/v0/scope/symbol.rs b/hugr-model/src/v0/scope/symbol.rs index bf29c4c07..5c54d930a 100644 --- a/hugr-model/src/v0/scope/symbol.rs +++ b/hugr-model/src/v0/scope/symbol.rs @@ -163,7 +163,7 @@ impl<'a> SymbolTable<'a> { } } -impl<'a> Default for SymbolTable<'a> { +impl Default for SymbolTable<'_> { fn default() -> Self { Self::new() } diff --git a/hugr-model/src/v0/scope/vars.rs b/hugr-model/src/v0/scope/vars.rs index 2ddec349d..b649bc491 100644 --- a/hugr-model/src/v0/scope/vars.rs +++ b/hugr-model/src/v0/scope/vars.rs @@ -3,7 +3,7 @@ use indexmap::IndexSet; use std::hash::BuildHasherDefault; use thiserror::Error; -use crate::v0::{NodeId, VarId, VarIndex}; +use crate::v0::{NodeId, VarId}; type FxIndexSet = IndexSet>; @@ -19,22 +19,22 @@ type FxIndexSet = IndexSet>; /// # Examples /// /// ``` -/// # pub use hugr_model::v0::NodeId; +/// # pub use hugr_model::v0::{NodeId, VarId}; /// # pub use hugr_model::v0::scope::VarTable; /// let mut vars = VarTable::new(); /// vars.enter(NodeId(0)); /// vars.insert("foo").unwrap(); -/// assert_eq!(vars.resolve("foo").unwrap(), (NodeId(0), 0)); -/// assert!(!vars.is_visible(NodeId(0), 1)); +/// assert_eq!(vars.resolve("foo").unwrap(), VarId(NodeId(0), 0)); +/// assert!(!vars.is_visible(VarId(NodeId(0), 1))); /// vars.insert("bar").unwrap(); -/// assert!(vars.is_visible(NodeId(0), 1)); -/// assert_eq!(vars.resolve("bar").unwrap(), (NodeId(0), 1)); +/// assert!(vars.is_visible(VarId(NodeId(0), 1))); +/// assert_eq!(vars.resolve("bar").unwrap(), VarId(NodeId(0), 1)); /// vars.enter(NodeId(1)); /// assert!(vars.resolve("foo").is_err()); -/// assert!(!vars.is_visible(NodeId(0), 0)); +/// assert!(!vars.is_visible(VarId(NodeId(0), 0))); /// vars.exit(); -/// assert_eq!(vars.resolve("foo").unwrap(), (NodeId(0), 0)); -/// assert!(vars.is_visible(NodeId(0), 0)); +/// assert_eq!(vars.resolve("foo").unwrap(), VarId(NodeId(0), 0)); +/// assert!(vars.is_visible(VarId(NodeId(0), 0))); /// ``` #[derive(Debug, Clone)] pub struct VarTable<'a> { @@ -79,14 +79,14 @@ impl<'a> VarTable<'a> { /// # Panics /// /// Panics if there are no open scopes. - pub fn resolve(&self, name: &'a str) -> Result<(NodeId, VarIndex), UnknownVarError<'a>> { + pub fn resolve(&self, name: &'a str) -> Result> { let scope = self.scopes.last().unwrap(); let (set_index, _) = self .vars .get_full(&(scope.node, name)) .ok_or(UnknownVarError(scope.node, name))?; let var_index = (set_index - scope.var_stack) as u16; - Ok((scope.node, var_index)) + Ok(VarId(scope.node, var_index)) } /// Check if a variable is visible in the current scope. @@ -128,7 +128,7 @@ impl<'a> VarTable<'a> { } } -impl<'a> Default for VarTable<'a> { +impl Default for VarTable<'_> { fn default() -> Self { Self::new() } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 4b9ee0c7a..de928affd 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -130,11 +130,11 @@ impl<'a> ParseContext<'a> { let name_token = inner.next().unwrap(); let name = name_token.as_str(); - let (node, index) = self.vars.resolve(name).map_err(|err| { + let var = self.vars.resolve(name).map_err(|err| { ParseError::custom(&err.to_string(), name_token.as_span()) })?; - Term::Var { node, index } + Term::Var(var) } Rule::term_apply => { diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 5da36e80b..5581ff9fd 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use crate::v0::{ ExtSetPart, LinkIndex, ListPart, MetaItem, ModelError, Module, NodeId, Operation, Param, - ParamSort, RegionId, RegionKind, Term, TermId, VarIndex, + ParamSort, RegionId, RegionKind, Term, TermId, VarId, }; type PrintError = ModelError; @@ -495,7 +495,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { self.print_text("constraint"); Ok(()) } - Term::Var { node, index } => self.print_var(*node, *index), + Term::Var(var) => self.print_var(*var), Term::Apply { symbol, args } => { if args.is_empty() { self.print_symbol(*symbol)?; @@ -631,9 +631,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) } - fn print_var(&mut self, node: NodeId, index: VarIndex) -> PrintResult<()> { - let Some(name) = self.locals.get(index as usize) else { - return Err(PrintError::InvalidVar(node, index)); + fn print_var(&mut self, var: VarId) -> PrintResult<()> { + let Some(name) = self.locals.get(var.1 as usize) else { + return Err(PrintError::InvalidVar(var)); }; self.print_text(format!("?{}", name));