Skip to content

Commit

Permalink
feat: Constant values in hugr-model (#1838)
Browse files Browse the repository at this point in the history
Import and export constant values via `hugr-model`.

Includes terms corresponding to `Value::Function` and allows exporting;
importing these is left for a later PR since sum/extension constants are
the most important for now. In the future, we should move from the
JSON/typetag based model for extension constants to one based on custom
constructors.
  • Loading branch information
zrho authored Jan 7, 2025
1 parent 09cbc6a commit b36d97d
Show file tree
Hide file tree
Showing 33 changed files with 685 additions and 253 deletions.
301 changes: 162 additions & 139 deletions hugr-core/src/export.rs

Large diffs are not rendered by default.

190 changes: 164 additions & 26 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ use crate::{
extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError},
hugr::{HugrMut, IdentList},
ops::{
AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, DataflowBlock, ExitBlock,
FuncDecl, FuncDefn, Input, LoadFunction, Module, OpType, OpaqueOp, Output, Tag, TailLoop,
CFG, DFG,
constant::{CustomConst, CustomSerialized, OpaqueValue},
AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock,
ExitBlock, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, Module, OpType, OpaqueOp,
Output, Tag, TailLoop, Value, CFG, DFG,
},
types::{
type_param::TypeParam, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV,
Expand All @@ -28,6 +29,7 @@ use smol_str::{SmolStr, ToSmolStr};
use thiserror::Error;

const TERM_JSON: &str = "prelude.json";
const TERM_JSON_CONST: &str = "prelude.const-json";

/// Error during import.
#[derive(Debug, Clone, Error)]
Expand Down Expand Up @@ -172,7 +174,7 @@ impl<'a> Context<'a> {
for meta_item in node_data.meta {
// TODO: For now we expect all metadata to be JSON since this is how
// it is handled in `hugr-core`.
let value = self.import_json_value(meta_item.value)?;
let value = self.import_json_meta(meta_item.value)?;
self.hugr.set_metadata(node, meta_item.name, value);
}

Expand Down Expand Up @@ -442,12 +444,6 @@ impl<'a> Context<'a> {

let node = self.make_node(node_id, optype, parent)?;

match node_data.regions {
[] => {}
[region] => self.import_dfg_region(node_id, *region, node)?,
_ => return Err(error_unsupported!("multiple regions in custom operation")),
}

Ok(Some(node))
}

Expand Down Expand Up @@ -508,6 +504,36 @@ impl<'a> Context<'a> {

model::Operation::DeclareConstructor { .. } => Ok(None),
model::Operation::DeclareOperation { .. } => Ok(None),

model::Operation::Const { value } => {
let signature = node_data
.signature
.ok_or_else(|| error_uninferred!("node signature"))?;
let (_, outputs, _) = self.get_func_type(signature)?;
let outputs = self.import_closed_list(outputs)?;
let output = outputs
.first()
.ok_or(model::ModelError::TypeError(signature))?;
let datatype = self.import_type(*output)?;

let imported_value = self.import_value(value, *output)?;

let load_const_node = self.make_node(
node_id,
OpType::LoadConstant(LoadConstant {
datatype: datatype.clone(),
}),
parent,
)?;

let const_node = self
.hugr
.add_node_with_parent(parent, OpType::Const(Const::new(imported_value)));

self.hugr.connect(const_node, 0, load_const_node, 0);

Ok(Some(load_const_node))
}
}
}

Expand Down Expand Up @@ -897,7 +923,7 @@ impl<'a> Context<'a> {
model::Term::Apply { .. } => Err(error_unsupported!("custom type as `TypeParam`")),
model::Term::ApplyFull { .. } => Err(error_unsupported!("custom type as `TypeParam`")),

model::Term::Quote { .. } => Err(error_unsupported!("`(quote ...)` as `TypeParam`")),
model::Term::Const { .. } => Err(error_unsupported!("`(const ...)` as `TypeParam`")),
model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")),

model::Term::ListType { item_type } => {
Expand All @@ -918,9 +944,9 @@ impl<'a> Context<'a> {
| model::Term::ExtSet { .. }
| model::Term::Adt { .. }
| model::Term::Control { .. }
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
| model::Term::NonLinearConstraint { .. }
| model::Term::ConstFunc { .. }
| model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()),

model::Term::ControlType => {
Err(error_unsupported!("type of control types as `TypeParam`"))
Expand Down Expand Up @@ -959,9 +985,6 @@ impl<'a> Context<'a> {
arg: value.to_string(),
}),

model::Term::Quote { .. } => Ok(TypeArg::Type {
ty: self.import_type(term_id)?,
}),
model::Term::Nat(value) => Ok(TypeArg::BoundedNat { n: *value }),
model::Term::ExtSet { .. } => Ok(TypeArg::Extensions {
es: self.import_extension_set(term_id)?,
Expand All @@ -976,6 +999,11 @@ impl<'a> Context<'a> {
model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeArg`")),
model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")),
model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")),
model::Term::Const { .. } => Err(error_unsupported!("`const` as `TypeArg`")),
model::Term::ConstAdt { .. } => Err(error_unsupported!("adt constant as `TypeArg`")),
model::Term::ConstFunc { .. } => {
Err(error_unsupported!("function constant as `TypeArg`"))
}

model::Term::FuncType { .. }
| model::Term::Adt { .. }
Expand Down Expand Up @@ -1045,12 +1073,12 @@ impl<'a> Context<'a> {
let (extension, id) = self.import_custom_name(name)?;

let extension_ref =
self.extensions.get(&extension.to_string()).ok_or_else(|| {
ImportError::Extension {
self.extensions
.get(&extension)
.ok_or_else(|| ImportError::Extension {
missing_ext: extension.clone(),
available: self.extensions.ids().cloned().collect(),
}
})?;
})?;

Ok(TypeBase::new_extension(CustomType::new(
id,
Expand Down Expand Up @@ -1090,16 +1118,16 @@ impl<'a> Context<'a> {
| model::Term::StaticType
| model::Term::Type
| model::Term::Constraint
| model::Term::Quote { .. }
| model::Term::Const { .. }
| model::Term::Str(_)
| model::Term::ExtSet { .. }
| model::Term::List { .. }
| model::Term::Control { .. }
| model::Term::ControlType
| model::Term::Nat(_)
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
| model::Term::NonLinearConstraint { .. }
| model::Term::ConstFunc { .. }
| model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()),
}
}

Expand Down Expand Up @@ -1234,7 +1262,7 @@ impl<'a> Context<'a> {
}
}

fn import_json_value(
fn import_json_meta(
&mut self,
term_id: model::TermId,
) -> Result<serde_json::Value, ImportError> {
Expand Down Expand Up @@ -1263,6 +1291,116 @@ impl<'a> Context<'a> {

Ok(json_value)
}

fn import_value(
&mut self,
term_id: model::TermId,
type_id: model::TermId,
) -> Result<Value, ImportError> {
let term_data = self.get_term(term_id)?;

match term_data {
model::Term::Wildcard => Err(error_uninferred!("wildcard")),
model::Term::Apply { .. } => {
Err(error_uninferred!("application with implicit parameters"))
}
model::Term::Var(_) => Err(error_unsupported!("constant value containing a variable")),

model::Term::ApplyFull { symbol, args } => {
let symbol_name = self.get_symbol_name(*symbol)?;

if symbol_name == TERM_JSON_CONST {
let value = args.get(1).ok_or(model::ModelError::TypeError(term_id))?;

let model::Term::Str(json) = self.get_term(*value)? else {
return Err(model::ModelError::TypeError(term_id).into());
};

// We attempt to deserialize as the custom const directly.
// This might fail due to the custom const struct not being included when
// this code was compiled; in that case, we fall back to the serialized form.
let value: Option<Box<dyn CustomConst>> = serde_json::from_str(json).ok();

if let Some(value) = value {
let opaque_value = OpaqueValue::from(value);
return Ok(Value::Extension { e: opaque_value });
} else {
let runtime_type =
args.first().ok_or(model::ModelError::TypeError(term_id))?;
let runtime_type = self.import_type(*runtime_type)?;

let extensions =
args.get(2).ok_or(model::ModelError::TypeError(term_id))?;
let extensions = self.import_extension_set(*extensions)?;

let value: serde_json::Value = serde_json::from_str(json)
.map_err(|_| model::ModelError::TypeError(term_id))?;
let custom_const = CustomSerialized::new(runtime_type, value, extensions);
let opaque_value = OpaqueValue::new(custom_const);
return Ok(Value::Extension { e: opaque_value });
}
}

Err(error_unsupported!("constant value that is not JSON data"))
// TODO: This should ultimately include the following cases:
// - function definitions
// - custom constructors for values
}

model::Term::StaticType
| model::Term::Constraint
| model::Term::Const { .. }
| model::Term::List { .. }
| model::Term::ListType { .. }
| model::Term::Str(_)
| model::Term::StrType
| model::Term::Nat(_)
| model::Term::NatType
| model::Term::ExtSet { .. }
| model::Term::ExtSetType
| model::Term::Adt { .. }
| model::Term::FuncType { .. }
| model::Term::Control { .. }
| model::Term::ControlType
| model::Term::Type
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}

model::Term::ConstFunc { .. } => Err(error_unsupported!("constant function value")),

model::Term::ConstAdt { tag, values } => {
let model::Term::Adt { variants } = self.get_term(type_id)? else {
return Err(model::ModelError::TypeError(term_id).into());
};

let values = self.import_closed_list(*values)?;
let variants = self.import_closed_list(*variants)?;

let variant = variants
.get(*tag as usize)
.ok_or(model::ModelError::TypeError(term_id))?;
let variant = self.import_closed_list(*variant)?;

let items = values
.iter()
.zip(variant.iter())
.map(|(value, typ)| self.import_value(*value, *typ))
.collect::<Result<Vec<_>, _>>()?;

let typ = {
// TODO: Import as a `SumType` directly and avoid the copy.
let typ: Type = self.import_type(type_id)?;
match typ.as_type_enum() {
TypeEnum::Sum(sum) => sum.clone(),
_ => unreachable!(),
}
};

Ok(Value::sum(*tag as _, items, typ).unwrap())
}
}
}
}

/// Information about a local variable.
Expand Down
6 changes: 6 additions & 0 deletions hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,12 @@ impl<CC: CustomConst> From<CC> for OpaqueValue {
}
}

impl From<Box<dyn CustomConst>> for OpaqueValue {
fn from(value: Box<dyn CustomConst>) -> Self {
Self { v: value }
}
}

impl PartialEq for OpaqueValue {
fn eq(&self, other: &Self) -> bool {
self.value().equal_consts(other.value())
Expand Down
7 changes: 7 additions & 0 deletions hugr-core/tests/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,10 @@ pub fn test_roundtrip_constraints() {
"../../hugr-model/tests/fixtures/model-constraints.edn"
)));
}

#[test]
pub fn test_roundtrip_const() {
insta::assert_snapshot!(roundtrip(include_str!(
"../../hugr-model/tests/fixtures/model-const.edn"
)));
}
4 changes: 2 additions & 2 deletions hugr-core/tests/snapshots/model__roundtrip_add.snap
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-add.
(dfg
[%0 %1] [%2]
(signature
(fn
(->
[(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext)))
((@ arithmetic.int.iadd) [%0 %1] [%2]
(signature
(fn
(->
[(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int))))))
2 changes: 1 addition & 1 deletion hugr-core/tests/snapshots/model__roundtrip_alias.snap
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-alia

(define-alias local.int type (@ arithmetic.int.types.int))

(define-alias local.endo type (fn [] [] (ext)))
(define-alias local.endo type (-> [] [] (ext)))
17 changes: 9 additions & 8 deletions hugr-core/tests/snapshots/model__roundtrip_call.snap
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,39 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call
(dfg
[%0] [%1]
(signature
(fn
(->
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int)))
(call (@ example.callee (ext)) [%0] [%1]
(signature
(fn
(->
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int))))))

(define-func example.load
[]
[(fn
[(->
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int))]
(ext)
(dfg
[] [%0]
(signature
(fn
(->
[]
[(fn
[(->
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int))]
(ext)))
(load-func (@ example.caller)
(load-func (@ example.caller) [] [%0]
(signature
(fn
(->
[]
[(fn
[(->
[(@ arithmetic.int.types.int)]
[(@ arithmetic.int.types.int)]
(ext arithmetic.int))]
Expand Down
Loading

0 comments on commit b36d97d

Please sign in to comment.