Skip to content

Commit

Permalink
Lints, tests, and more consistent use of LinkId.
Browse files Browse the repository at this point in the history
  • Loading branch information
zrho committed Dec 10, 2024
1 parent e26887e commit 21ee97c
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 58 deletions.
19 changes: 5 additions & 14 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,10 +824,7 @@ impl<'a> Context<'a> {
TypeEnum::Function(func) => self.export_func_type(func),
TypeEnum::Variable(index, _) => {
let node = self.local_scope.expect("local variable out of scope");
self.make_term(model::Term::Var {
node,
index: *index as _,
})
self.make_term(model::Term::Var(model::VarId(node, *index as _)))
}
TypeEnum::RowVar(rv) => self.export_row_var(rv.as_rv()),
TypeEnum::Sum(sum) => self.export_sum_type(sum),
Expand Down Expand Up @@ -876,18 +873,12 @@ impl<'a> Context<'a> {

pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> model::TermId {
let node = self.local_scope.expect("local variable out of scope");
self.make_term(model::Term::Var {
node,
index: var.index() as _,
})
self.make_term(model::Term::Var(model::VarId(node, var.index() as _)))
}

pub fn export_row_var(&mut self, t: &RowVariable) -> model::TermId {
let node = self.local_scope.expect("local variable out of scope");
self.make_term(model::Term::Var {
node,
index: t.0 as _,
})
self.make_term(model::Term::Var(model::VarId(node, t.0 as _)))
}

pub fn export_sum_type(&mut self, t: &SumType) -> model::TermId {
Expand Down Expand Up @@ -956,7 +947,7 @@ impl<'a> Context<'a> {
match t {
TypeParam::Type { b } => {
if let (Some((node, index)), TypeBound::Copyable) = (var, b) {
let term = self.make_term(model::Term::Var { node, index });
let term = self.make_term(model::Term::Var(model::VarId(node, index)));
let non_linear = self.make_term(model::Term::NonLinearConstraint { term });
self.local_constraints.push(non_linear);
}
Expand Down Expand Up @@ -999,7 +990,7 @@ impl<'a> Context<'a> {
match ext.parse::<u16>() {
Ok(index) => {
let node = self.local_scope.expect("local variable out of scope");
let term = self.make_term(model::Term::Var { node, index });
let term = self.make_term(model::Term::Var(model::VarId(node, index)));
parts.push(model::ExtSetPart::Splice(term));
}
Err(_) => parts.push(model::ExtSetPart::Extension(self.bump.alloc_str(ext))),
Expand Down
24 changes: 12 additions & 12 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -859,15 +859,15 @@ impl<'a> Context<'a> {
for constraint in decl.constraints {
match self.get_term(*constraint)? {
model::Term::NonLinearConstraint { term } => {
let model::Term::Var { node, index } = self.get_term(*term)? else {
let model::Term::Var(var) = self.get_term(*term)? else {
return Err(error_unsupported!(
"constraint on term that is not a variable"
));
};

self.local_vars
.get_mut(&model::VarId(*node, *index))
.ok_or_else(|| model::ModelError::InvalidVar(*node, *index))?
.get_mut(var)
.ok_or_else(|| model::ModelError::InvalidVar(*var))?
.bound = TypeBound::Copyable;
}
_ => return Err(error_unsupported!("constraint other than copy or discard")),
Expand Down Expand Up @@ -940,13 +940,13 @@ impl<'a> Context<'a> {
Err(error_uninferred!("application with implicit parameters"))
}

model::Term::Var { node, index } => {
let var = self
model::Term::Var(var) => {
let var_info = self
.local_vars
.get(&model::VarId(*node, *index))
.ok_or(model::ModelError::InvalidVar(*node, *index))?;
let decl = self.import_type_param(var.r#type, var.bound)?;
Ok(TypeArg::new_var_use(*index as _, decl))
.get(var)
.ok_or(model::ModelError::InvalidVar(*var))?;
let decl = self.import_type_param(var_info.r#type, var_info.bound)?;
Ok(TypeArg::new_var_use(var.1 as _, decl))
}

model::Term::List { .. } => {
Expand Down Expand Up @@ -1001,7 +1001,7 @@ impl<'a> Context<'a> {
match self.get_term(term_id)? {
model::Term::Wildcard => return Err(error_uninferred!("wildcard")),

model::Term::Var { index, .. } => {
model::Term::Var(model::VarId(_, index)) => {
es.insert_type_var(*index as _);
}

Expand Down Expand Up @@ -1067,7 +1067,7 @@ impl<'a> Context<'a> {
)))
}

model::Term::Var { index, .. } => {
model::Term::Var(model::VarId(_, index)) => {
Ok(TypeBase::new_var_use(*index as _, TypeBound::Copyable))
}

Expand Down Expand Up @@ -1201,7 +1201,7 @@ impl<'a> Context<'a> {
}
}
}
model::Term::Var { index, .. } => {
model::Term::Var(model::VarId(_, index)) => {
let var = RV::try_from_rv(RowVariable(*index as _, TypeBound::Any))
.map_err(|_| model::ModelError::TypeError(term_id))?;
types.push(TypeBase::new(TypeEnum::RowVar(var)));
Expand Down
2 changes: 1 addition & 1 deletion hugr-model/src/v0/binary/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult
Which::Variable(reader) => {
let node = model::NodeId(reader.get_variable_node());
let index = reader.get_variable_index();
model::Term::Var { node, index }
model::Term::Var(model::VarId(node, index))
}

Which::Apply(reader) => {
Expand Down
2 changes: 1 addition & 1 deletion hugr-model/src/v0/binary/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) {
model::Term::Type => builder.set_runtime_type(()),
model::Term::StaticType => builder.set_static_type(()),
model::Term::Constraint => builder.set_constraint(()),
model::Term::Var { node, index } => {
model::Term::Var(model::VarId(node, index)) => {
let mut builder = builder.init_variable();
builder.set_variable_node(node.0);
builder.set_variable_index(*index);
Expand Down
17 changes: 7 additions & 10 deletions hugr-model/src/v0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,13 @@ define_index! {
}

/// The id of a link consisting of its region and the link index.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[display("{_0}#{_1}")]
pub struct LinkId(pub RegionId, pub LinkIndex);

/// The id of a variable consisting of its node and the variable index.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[display("{_0}#{_1}")]
pub struct VarId(pub NodeId, pub VarIndex);

/// A module consisting of a hugr graph together with terms.
Expand Down Expand Up @@ -528,12 +530,7 @@ pub enum Term<'a> {
Constraint,

/// A local variable.
Var {
/// The node that defines the variable as a parameter.
node: NodeId,
/// The index of the variable in the parameter list of the node.
index: VarIndex,
},
Var(VarId),

/// A symbolic function application.
///
Expand Down Expand Up @@ -722,8 +719,8 @@ pub enum ModelError {
#[error("region not found: {0}")]
RegionNotFound(RegionId),
/// Invalid variable reference.
#[error("variable {0}#{1} invalid")]
InvalidVar(NodeId, VarIndex),
#[error("variable {0} invalid")]
InvalidVar(VarId),
/// Invalid symbol reference.
#[error("symbol reference {0} invalid")]
InvalidSymbol(NodeId),
Expand Down
2 changes: 1 addition & 1 deletion hugr-model/src/v0/scope/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl<'a> SymbolTable<'a> {
}
}

impl<'a> Default for SymbolTable<'a> {
impl Default for SymbolTable<'_> {
fn default() -> Self {
Self::new()
}
Expand Down
24 changes: 12 additions & 12 deletions hugr-model/src/v0/scope/vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use indexmap::IndexSet;
use std::hash::BuildHasherDefault;
use thiserror::Error;

use crate::v0::{NodeId, VarId, VarIndex};
use crate::v0::{NodeId, VarId};

type FxIndexSet<K> = IndexSet<K, BuildHasherDefault<FxHasher>>;

Expand All @@ -19,22 +19,22 @@ type FxIndexSet<K> = IndexSet<K, BuildHasherDefault<FxHasher>>;
/// # Examples
///
/// ```
/// # pub use hugr_model::v0::NodeId;
/// # pub use hugr_model::v0::{NodeId, VarId};
/// # pub use hugr_model::v0::scope::VarTable;
/// let mut vars = VarTable::new();
/// vars.enter(NodeId(0));
/// vars.insert("foo").unwrap();
/// assert_eq!(vars.resolve("foo").unwrap(), (NodeId(0), 0));
/// assert!(!vars.is_visible(NodeId(0), 1));
/// assert_eq!(vars.resolve("foo").unwrap(), VarId(NodeId(0), 0));
/// assert!(!vars.is_visible(VarId(NodeId(0), 1)));
/// vars.insert("bar").unwrap();
/// assert!(vars.is_visible(NodeId(0), 1));
/// assert_eq!(vars.resolve("bar").unwrap(), (NodeId(0), 1));
/// assert!(vars.is_visible(VarId(NodeId(0), 1)));
/// assert_eq!(vars.resolve("bar").unwrap(), VarId(NodeId(0), 1));
/// vars.enter(NodeId(1));
/// assert!(vars.resolve("foo").is_err());
/// assert!(!vars.is_visible(NodeId(0), 0));
/// assert!(!vars.is_visible(VarId(NodeId(0), 0)));
/// vars.exit();
/// assert_eq!(vars.resolve("foo").unwrap(), (NodeId(0), 0));
/// assert!(vars.is_visible(NodeId(0), 0));
/// assert_eq!(vars.resolve("foo").unwrap(), VarId(NodeId(0), 0));
/// assert!(vars.is_visible(VarId(NodeId(0), 0)));
/// ```
#[derive(Debug, Clone)]
pub struct VarTable<'a> {
Expand Down Expand Up @@ -79,14 +79,14 @@ impl<'a> VarTable<'a> {
/// # Panics
///
/// Panics if there are no open scopes.
pub fn resolve(&self, name: &'a str) -> Result<(NodeId, VarIndex), UnknownVarError<'a>> {
pub fn resolve(&self, name: &'a str) -> Result<VarId, UnknownVarError<'a>> {
let scope = self.scopes.last().unwrap();
let (set_index, _) = self
.vars
.get_full(&(scope.node, name))
.ok_or(UnknownVarError(scope.node, name))?;
let var_index = (set_index - scope.var_stack) as u16;
Ok((scope.node, var_index))
Ok(VarId(scope.node, var_index))
}

/// Check if a variable is visible in the current scope.
Expand Down Expand Up @@ -128,7 +128,7 @@ impl<'a> VarTable<'a> {
}
}

impl<'a> Default for VarTable<'a> {
impl Default for VarTable<'_> {
fn default() -> Self {
Self::new()
}
Expand Down
4 changes: 2 additions & 2 deletions hugr-model/src/v0/text/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ impl<'a> ParseContext<'a> {
let name_token = inner.next().unwrap();
let name = name_token.as_str();

let (node, index) = self.vars.resolve(name).map_err(|err| {
let var = self.vars.resolve(name).map_err(|err| {
ParseError::custom(&err.to_string(), name_token.as_span())
})?;

Term::Var { node, index }
Term::Var(var)
}

Rule::term_apply => {
Expand Down
10 changes: 5 additions & 5 deletions hugr-model/src/v0/text/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::borrow::Cow;

use crate::v0::{
ExtSetPart, LinkIndex, ListPart, MetaItem, ModelError, Module, NodeId, Operation, Param,
ParamSort, RegionId, RegionKind, Term, TermId, VarIndex,
ParamSort, RegionId, RegionKind, Term, TermId, VarId,
};

type PrintError = ModelError;
Expand Down Expand Up @@ -495,7 +495,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> {
self.print_text("constraint");
Ok(())
}
Term::Var { node, index } => self.print_var(*node, *index),
Term::Var(var) => self.print_var(*var),
Term::Apply { symbol, args } => {
if args.is_empty() {
self.print_symbol(*symbol)?;
Expand Down Expand Up @@ -631,9 +631,9 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> {
Ok(())
}

fn print_var(&mut self, node: NodeId, index: VarIndex) -> PrintResult<()> {
let Some(name) = self.locals.get(index as usize) else {
return Err(PrintError::InvalidVar(node, index));
fn print_var(&mut self, var: VarId) -> PrintResult<()> {
let Some(name) = self.locals.get(var.1 as usize) else {
return Err(PrintError::InvalidVar(var));
};

self.print_text(format!("?{}", name));
Expand Down

0 comments on commit 21ee97c

Please sign in to comment.