From 08a2bd6f44548202160a925d28bd0d5df65b256d Mon Sep 17 00:00:00 2001
From: Seyon Sivarajah <seyon.sivarajah@quantinuum.com>
Date: Tue, 9 Jan 2024 12:04:34 +0000
Subject: [PATCH] feat: constant folding for list operations

+ utility enums and structs for dealing with list ops and types
---
 src/algorithm/const_fold.rs       |  43 +++++-
 src/std_extensions/collections.rs | 224 ++++++++++++++++++++++++------
 2 files changed, 225 insertions(+), 42 deletions(-)

diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs
index 50ea430c4..f994a75fb 100644
--- a/src/algorithm/const_fold.rs
+++ b/src/algorithm/const_fold.rs
@@ -220,12 +220,14 @@ mod test {
     use crate::extension::prelude::{sum_with_error, BOOL_T};
     use crate::extension::{ExtensionRegistry, PRELUDE};
     use crate::ops::OpType;
-    use crate::std_extensions::arithmetic;
     use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
     use crate::std_extensions::arithmetic::float_ops::FloatOps;
     use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
     use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES};
+    use crate::std_extensions::collections::{make_list_const, ListOp, ListValue};
     use crate::std_extensions::logic::{self, const_from_bool, NaryLogic};
+    use crate::std_extensions::{arithmetic, collections};
+    use crate::types::TypeArg;
     use rstest::rstest;
 
     /// int to constant
@@ -332,6 +334,45 @@ mod test {
         Ok(())
     }
 
+    #[test]
+    fn test_list_ops() -> Result<(), Box<dyn std::error::Error>> {
+        let reg = ExtensionRegistry::try_new([
+            PRELUDE.to_owned(),
+            logic::EXTENSION.to_owned(),
+            collections::EXTENSION.to_owned(),
+        ])
+        .unwrap();
+        let list = make_list_const(
+            ListValue::new(vec![Value::unit_sum(1)]),
+            &[TypeArg::Type { ty: BOOL_T }],
+        );
+        let mut build = DFGBuilder::new(FunctionType::new(
+            type_row![],
+            vec![list.const_type().clone()],
+        ))
+        .unwrap();
+
+        let list_wire = build.add_load_const(list.clone())?;
+
+        let pop = build.add_dataflow_op(
+            ListOp::Pop.with_type(BOOL_T).to_extension_op(&reg).unwrap(),
+            [list_wire],
+        )?;
+
+        let push = build.add_dataflow_op(
+            ListOp::Push
+                .with_type(BOOL_T)
+                .to_extension_op(&reg)
+                .unwrap(),
+            pop.outputs(),
+        )?;
+        let mut h = build.finish_hugr_with_outputs(push.outputs(), &reg)?;
+        constant_fold_pass(&mut h, &reg);
+
+        assert_fully_folded(&h, &list);
+        Ok(())
+    }
+
     fn assert_fully_folded(h: &Hugr, expected_const: &Const) {
         // check the hugr just loads and returns a single const
         let mut node_count = 0;
diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs
index ebec9bda7..5e2c4e76d 100644
--- a/src/std_extensions/collections.rs
+++ b/src/std_extensions/collections.rs
@@ -5,7 +5,13 @@ use serde::{Deserialize, Serialize};
 use smol_str::SmolStr;
 
 use crate::{
-    extension::{ExtensionId, ExtensionSet, TypeDef, TypeDefBound},
+    algorithm::const_fold::sorted_consts,
+    extension::{
+        simple_op::{MakeExtensionOp, OpLoadError},
+        ConstFold, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef,
+        TypeDefBound,
+    },
+    ops::{self, custom::ExtensionOp, OpName},
     types::{
         type_param::{TypeArg, TypeParam},
         CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound,
@@ -47,7 +53,9 @@ impl CustomConst for ListValue {
             CustomCheckFailure::Message("List type check fail.".to_string())
         };
 
-        get_type(&LIST_TYPENAME)
+        EXTENSION
+            .get_type(&LIST_TYPENAME)
+            .unwrap()
             .check_custom(typ)
             .map_err(|_| error())?;
 
@@ -72,6 +80,55 @@ impl CustomConst for ListValue {
             .union(&ExtensionSet::singleton(&EXTENSION_NAME))
     }
 }
+
+struct PopFold;
+
+impl ConstFold for PopFold {
+    fn fold(
+        &self,
+        type_args: &[TypeArg],
+        consts: &[(crate::IncomingPort, ops::Const)],
+    ) -> crate::extension::ConstFoldResult {
+        let [TypeArg::Type { ty }] = type_args else {
+            return None;
+        };
+        let [list]: [&ops::Const; 1] = sorted_consts(consts).try_into().ok()?;
+        let list: &ListValue = list.get_custom_value().expect("Should be list value.");
+        let mut list = list.clone();
+        let elem = list.0.pop()?; // empty list fails to evaluate "pop"
+        let list = make_list_const(list, type_args);
+        let elem = ops::Const::new(elem, ty.clone()).unwrap();
+
+        Some(vec![(0.into(), list), (1.into(), elem)])
+    }
+}
+
+pub(crate) fn make_list_const(list: ListValue, type_args: &[TypeArg]) -> ops::Const {
+    let list_type_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap();
+    ops::Const::new(
+        list.into(),
+        Type::new_extension(list_type_def.instantiate(type_args).unwrap()),
+    )
+    .unwrap()
+}
+
+struct PushFold;
+
+impl ConstFold for PushFold {
+    fn fold(
+        &self,
+        type_args: &[TypeArg],
+        consts: &[(crate::IncomingPort, ops::Const)],
+    ) -> crate::extension::ConstFoldResult {
+        let [list, elem]: [&ops::Const; 2] = sorted_consts(consts).try_into().ok()?;
+        let list: &ListValue = list.get_custom_value().expect("Should be list value.");
+        let mut list = list.clone();
+        list.0.push(elem.value().clone());
+        let list = make_list_const(list, type_args);
+
+        Some(vec![(0.into(), list)])
+    }
+}
 const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
 
 fn extension() -> Extension {
@@ -87,7 +144,7 @@ fn extension() -> Extension {
         .unwrap();
     let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap();
 
-    let (l, e) = list_and_elem_type(list_type_def);
+    let (l, e) = list_and_elem_type_vars(list_type_def);
     extension
         .add_op(
             POP_NAME,
@@ -97,14 +154,17 @@ fn extension() -> Extension {
                 FunctionType::new(vec![l.clone()], vec![l.clone(), e.clone()]),
             ),
         )
-        .unwrap();
+        .unwrap()
+        .set_constant_folder(PopFold);
     extension
         .add_op(
             PUSH_NAME,
             "Push to back of list".into(),
             PolyFuncType::new(vec![TP], FunctionType::new(vec![l.clone(), e], vec![l])),
         )
-        .unwrap();
+        .unwrap()
+        .set_constant_folder(PushFold);
+
     extension
 }
 
@@ -113,11 +173,18 @@ lazy_static! {
     pub static ref EXTENSION: Extension = extension();
 }
 
-fn get_type(name: &str) -> &TypeDef {
-    EXTENSION.get_type(name).unwrap()
+/// Get the type of a list of `elem_type`
+pub fn list_type(elem_type: Type) -> Type {
+    Type::new_extension(
+        EXTENSION
+            .get_type(&LIST_TYPENAME)
+            .unwrap()
+            .instantiate(vec![TypeArg::Type { ty: elem_type }])
+            .unwrap(),
+    )
 }
 
-fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) {
+fn list_and_elem_type_vars(list_type_def: &TypeDef) -> (Type, Type) {
     let elem_type = Type::new_var_use(0, TypeBound::Any);
     let list_type = Type::new_extension(
         list_type_def
@@ -126,22 +193,107 @@ fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) {
     );
     (list_type, elem_type)
 }
+
+/// A list operation
+#[derive(Debug, Clone, PartialEq)]
+pub enum ListOp {
+    /// Pop from end of list
+    Pop,
+    /// Push to end of list
+    Push,
+}
+
+impl ListOp {
+    /// Instantiate a list operation with an `element_type`
+    pub fn with_type(self, element_type: Type) -> ListOpInst {
+        ListOpInst {
+            elem_type: element_type,
+            op: self,
+        }
+    }
+}
+
+/// A list operation with a concrete element type.
+#[derive(Debug, Clone, PartialEq)]
+pub struct ListOpInst {
+    op: ListOp,
+    elem_type: Type,
+}
+
+impl OpName for ListOpInst {
+    fn name(&self) -> SmolStr {
+        match self.op {
+            ListOp::Pop => POP_NAME,
+            ListOp::Push => PUSH_NAME,
+        }
+    }
+}
+
+impl MakeExtensionOp for ListOpInst {
+    fn from_extension_op(
+        ext_op: &ExtensionOp,
+    ) -> Result<Self, crate::extension::simple_op::OpLoadError> {
+        let [TypeArg::Type { ty }] = ext_op.args() else {
+            return Err(SignatureError::InvalidTypeArgs.into());
+        };
+        let name = ext_op.def().name();
+        let op = match name {
+            // can't use const SmolStr in pattern
+            _ if name == &POP_NAME => ListOp::Pop,
+            _ if name == &PUSH_NAME => ListOp::Push,
+            _ => return Err(OpLoadError::NotMember(name.to_string())),
+        };
+
+        Ok(Self {
+            elem_type: ty.clone(),
+            op,
+        })
+    }
+
+    fn type_args(&self) -> Vec<TypeArg> {
+        vec![TypeArg::Type {
+            ty: self.elem_type.clone(),
+        }]
+    }
+}
+
+impl ListOpInst {
+    /// Convert this list operation to an [`ExtensionOp`] by providing a
+    /// registry to validate the element tyoe against.
+    pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option<ExtensionOp> {
+        let registry = ExtensionRegistry::try_new(
+            elem_type_registry
+                .clone()
+                .into_iter()
+                // ignore self if already in registry
+                .filter_map(|(_, ext)| (ext.name() != EXTENSION.name()).then_some(ext))
+                .chain(std::iter::once(EXTENSION.to_owned())),
+        )
+        .unwrap();
+        ExtensionOp::new(
+            registry.get(&EXTENSION_NAME)?.get_op(&self.name())?.clone(),
+            self.type_args(),
+            &registry,
+        )
+        .ok()
+    }
+}
+
 #[cfg(test)]
 mod test {
     use crate::{
         extension::{
             prelude::{ConstUsize, QB_T, USIZE_T},
-            ExtensionRegistry, OpDef, PRELUDE,
+            ExtensionRegistry, PRELUDE,
         },
+        ops::OpTrait,
         std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE},
-        types::{type_param::TypeArg, Type, TypeRow},
+        types::{type_param::TypeArg, TypeRow},
         Extension,
     };
 
     use super::*;
-    fn get_op(name: &str) -> &OpDef {
-        EXTENSION.get_op(name).unwrap()
-    }
+
     #[test]
     fn test_extension() {
         let r: Extension = extension();
@@ -174,40 +326,30 @@ mod test {
 
     #[test]
     fn test_list_ops() {
-        let reg = ExtensionRegistry::try_new([
-            EXTENSION.to_owned(),
-            PRELUDE.to_owned(),
-            float_types::EXTENSION.to_owned(),
-        ])
-        .unwrap();
-        let pop_sig = get_op(&POP_NAME)
-            .compute_signature(&[TypeArg::Type { ty: QB_T }], &reg)
-            .unwrap();
+        let reg =
+            ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()])
+                .unwrap();
+        let pop_op = ListOp::Pop.with_type(QB_T);
+        let pop_ext = pop_op.clone().to_extension_op(&reg).unwrap();
+        assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op);
+        let pop_sig = pop_ext.dataflow_signature().unwrap();
 
-        let list_type = Type::new_extension(CustomType::new(
-            LIST_TYPENAME,
-            vec![TypeArg::Type { ty: QB_T }],
-            EXTENSION_NAME,
-            TypeBound::Any,
-        ));
+        let list_t = list_type(QB_T);
 
-        let both_row: TypeRow = vec![list_type.clone(), QB_T].into();
-        let just_list_row: TypeRow = vec![list_type].into();
+        let both_row: TypeRow = vec![list_t.clone(), QB_T].into();
+        let just_list_row: TypeRow = vec![list_t].into();
         assert_eq!(pop_sig.input(), &just_list_row);
         assert_eq!(pop_sig.output(), &both_row);
 
-        let push_sig = get_op(&PUSH_NAME)
-            .compute_signature(&[TypeArg::Type { ty: FLOAT64_TYPE }], &reg)
-            .unwrap();
+        let push_op = ListOp::Push.with_type(FLOAT64_TYPE);
+        let push_ext = push_op.clone().to_extension_op(&reg).unwrap();
+        assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op);
+        let push_sig = push_ext.dataflow_signature().unwrap();
 
-        let list_type = Type::new_extension(CustomType::new(
-            LIST_TYPENAME,
-            vec![TypeArg::Type { ty: FLOAT64_TYPE }],
-            EXTENSION_NAME,
-            TypeBound::Copyable,
-        ));
-        let both_row: TypeRow = vec![list_type.clone(), FLOAT64_TYPE].into();
-        let just_list_row: TypeRow = vec![list_type].into();
+        let list_t = list_type(FLOAT64_TYPE);
+
+        let both_row: TypeRow = vec![list_t.clone(), FLOAT64_TYPE].into();
+        let just_list_row: TypeRow = vec![list_t].into();
 
         assert_eq!(push_sig.input(), &both_row);
         assert_eq!(push_sig.output(), &just_list_row);