Skip to content

Commit

Permalink
feat: Add more list operations
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Aug 29, 2024
1 parent 9766c48 commit 7ba766a
Show file tree
Hide file tree
Showing 7 changed files with 796 additions and 186 deletions.
23 changes: 10 additions & 13 deletions hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,30 +458,27 @@ mod test {
use crate::ops::dataflow::DataflowOpTrait;
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG};
use crate::std_extensions::collections;
use crate::types::{Signature, Type, TypeArg, TypeRow};
use crate::std_extensions::collections::{self, list_type, ListOp};
use crate::types::{Signature, Type, TypeRow};
use crate::utils::depth;
use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort};

use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement};

#[test]
#[ignore] // FIXME: This needs a rewrite now that `pop` returns an optional value -.-'
fn cfg() -> Result<(), Box<dyn std::error::Error>> {
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()])
.unwrap();
let listy = Type::new_extension(
collections::EXTENSION
.get_type(&collections::LIST_TYPENAME)
.unwrap()
.instantiate([TypeArg::Type { ty: USIZE_T }])
.unwrap(),
);
let pop: ExtensionOp = collections::EXTENSION
.instantiate_extension_op("pop", [TypeArg::Type { ty: USIZE_T }], &reg)
let listy = list_type(USIZE_T);
let pop: ExtensionOp = ListOp::pop
.with_type(USIZE_T)
.to_extension_op(&reg)
.unwrap();
let push: ExtensionOp = collections::EXTENSION
.instantiate_extension_op("push", [TypeArg::Type { ty: USIZE_T }], &reg)
let push: ExtensionOp = ListOp::push
.with_type(USIZE_T)
.to_extension_op(&reg)
.unwrap();
let just_list = TypeRow::from(vec![listy.clone()]);
let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]);
Expand Down
132 changes: 124 additions & 8 deletions hugr-core/src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::extension::prelude::{option_type, result_type, USIZE_T};
use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE};
use crate::ops::constant::ValueName;
Expand Down Expand Up @@ -110,17 +111,30 @@ impl CustomConst for ListValue {
#[allow(non_camel_case_types)]
#[non_exhaustive]
pub enum ListOp {
/// Pop from end of list
/// Pop from the end of list. Return an optional value.
pop,
/// Push to end of list
/// Push to end of list. Return the new list.
push,
/// Lookup an element in a list by index.
get,
/// Replace the element at index `i` with value `v`, and return the old value.
///
/// If the index is out of bounds, returns the input value as an error.
set,
/// Insert an element at index `i`.
///
/// Elements at higher indices are shifted one position to the right.
/// Returns an Err with the element if the index is out of bounds.
insert,
/// Get the length of a list.
length,
}

impl ListOp {
/// Type parameter used in the list types.
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };

/// Instantiate a list operation with an `element_type`
/// Instantiate a list operation with an `element_type`.
pub fn with_type(self, element_type: Type) -> ListOpInst {
ListOpInst {
elem_type: element_type,
Expand All @@ -135,9 +149,25 @@ impl ListOp {
let l = self.list_type(list_type_def, 0);
match self {
pop => self
.list_polytype(vec![l.clone()], vec![l.clone(), e.clone()])
.list_polytype(vec![l.clone()], vec![l, Type::from(option_type(e))])
.into(),
push => self.list_polytype(vec![l.clone(), e], vec![l]).into(),
get => self
.list_polytype(vec![l, USIZE_T], vec![Type::from(option_type(e))])
.into(),
set => self
.list_polytype(
vec![l.clone(), USIZE_T, e.clone()],
vec![l, Type::from(result_type(e.clone(), e))],
)
.into(),
insert => self
.list_polytype(
vec![l.clone(), USIZE_T, e.clone()],
vec![l, result_type(Type::UNIT, e).into()],
)
.into(),
length => self.list_polytype(vec![l], vec![USIZE_T]).into(),
}
}

Expand Down Expand Up @@ -191,8 +221,12 @@ impl MakeOpDef for ListOp {
use ListOp::*;

match self {
pop => "Pop from back of list",
push => "Push to back of list",
pop => "Pop from the back of list. Returns an optional value.",
push => "Push to the back of list",
get => "Lookup an element in a list by index. Panics if the index is out of bounds.",
set => "Replace the element at index `i` with value `v`.",
insert => "Insert an element at index `i`. Elements at higher indices are shifted one position to the right. Panics if the index is out of bounds.",
length => "Get the length of a list",
}
.into()
}
Expand All @@ -205,7 +239,6 @@ impl MakeOpDef for ListOp {
lazy_static! {
/// Extension for list operations.
pub static ref EXTENSION: Extension = {
println!("creating collections extension");
let mut extension = Extension::new(EXTENSION_ID, VERSION);

// The list type must be defined before the operations are added.
Expand Down Expand Up @@ -323,7 +356,13 @@ impl ListOpInst {

#[cfg(test)]
mod test {
use rstest::rstest;

use crate::extension::prelude::{
const_err_tuple, const_none, const_ok_tuple, const_some_tuple,
};
use crate::ops::OpTrait;
use crate::PortIndex;
use crate::{
extension::{
prelude::{ConstUsize, QB_T, USIZE_T},
Expand Down Expand Up @@ -378,7 +417,7 @@ mod test {

let list_t = list_type(QB_T);

let both_row: TypeRow = vec![list_t.clone(), QB_T].into();
let both_row: TypeRow = vec![list_t.clone(), option_type(QB_T).into()].into();
let just_list_row: TypeRow = vec![list_t].into();
assert_eq!(pop_sig.input(), &just_list_row);
assert_eq!(pop_sig.output(), &both_row);
Expand All @@ -396,4 +435,81 @@ mod test {
assert_eq!(push_sig.input(), &both_row);
assert_eq!(push_sig.output(), &just_list_row);
}

/// Values used in the `list_fold` test cases.
#[derive(Debug, Clone, PartialEq, Eq)]
enum TestVal {
Idx(usize),
List(Vec<usize>),
Elem(usize),
Some(Vec<TestVal>),
None(TypeRow),
Ok(Vec<TestVal>, TypeRow),
Err(TypeRow, Vec<TestVal>),
}

impl TestVal {
fn to_value(&self) -> Value {
match self {
TestVal::Idx(i) => Value::extension(ConstUsize::new(*i as u64)),
TestVal::Elem(e) => Value::extension(ConstUsize::new(*e as u64)),
TestVal::List(l) => {
let elems = l
.iter()
.map(|&i| Value::extension(ConstUsize::new(i as u64)))
.collect();
Value::extension(ListValue(elems, USIZE_T))
}
TestVal::Some(l) => {
let elems = l.iter().map(TestVal::to_value);
const_some_tuple(elems)
}
TestVal::None(tr) => const_none(tr.clone()),
TestVal::Ok(l, tr) => {
let elems = l.iter().map(TestVal::to_value);
const_ok_tuple(elems, tr.clone())
}
TestVal::Err(tr, l) => {
let elems = l.iter().map(TestVal::to_value);
const_err_tuple(tr.clone(), elems)
}
}
}
}

#[rstest]
#[case::pop(ListOp::pop, &[TestVal::List(vec![77,88, 42])], &[TestVal::List(vec![77,88]), TestVal::Some(vec![TestVal::Elem(42)])])]
#[case::pop_empty(ListOp::pop, &[TestVal::List(vec![])], &[TestVal::List(vec![]), TestVal::None(vec![USIZE_T].into())])]
#[case::push(ListOp::push, &[TestVal::List(vec![77,88]), TestVal::Elem(42)], &[TestVal::List(vec![77,88,42])])]
#[case::get(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1)], &[TestVal::Some(vec![TestVal::Elem(88)])])]
#[case::get_invalid(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(99)], &[TestVal::None(vec![USIZE_T].into())])]
#[case::insert(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,88,42]), TestVal::Ok(vec![], vec![USIZE_T].into())])]
#[case::insert_invalid(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(52), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(Type::UNIT.into(), vec![TestVal::Elem(99)])])]
#[case::length(ListOp::length, &[TestVal::List(vec![77,88,42])], &[TestVal::Elem(3)])]
fn list_fold(#[case] op: ListOp, #[case] inputs: &[TestVal], #[case] outputs: &[TestVal]) {
let consts: Vec<_> = inputs
.iter()
.enumerate()
.map(|(i, x)| (i.into(), x.to_value()))
.collect();

let res = op
.with_type(USIZE_T)
.to_extension_op(&COLLECTIONS_REGISTRY)
.unwrap()
.constant_fold(&consts)
.unwrap();

for (i, expected) in outputs.iter().enumerate() {
let expected = expected.to_value();
let res_val = res
.iter()
.find(|(port, _)| port.index() == i)
.unwrap()
.1
.clone();

assert_eq!(res_val, expected);
}
}
}
110 changes: 100 additions & 10 deletions hugr-core/src/std_extensions/collections/list_fold.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
//! Folding definitions for list operations.
use crate::extension::{ConstFold, OpDef};
use crate::ops;
use crate::extension::prelude::{
const_err, const_none, const_ok, const_ok_tuple, const_some, ConstUsize,
};
use crate::extension::{ConstFold, ConstFoldResult, OpDef};
use crate::ops::Value;
use crate::types::type_param::TypeArg;
use crate::types::Type;
use crate::utils::sorted_consts;
use crate::IncomingPort;

use super::{ListOp, ListValue};

pub(super) fn set_fold(op: &ListOp, def: &mut OpDef) {
match op {
ListOp::pop => def.set_constant_folder(PopFold),
ListOp::push => def.set_constant_folder(PushFold),
ListOp::get => def.set_constant_folder(GetFold),
ListOp::set => def.set_constant_folder(SetFold),
ListOp::insert => def.set_constant_folder(InsertFold),
ListOp::length => def.set_constant_folder(LengthFold),
}
}

Expand All @@ -20,14 +29,22 @@ impl ConstFold for PopFold {
fn fold(
&self,
_type_args: &[TypeArg],
consts: &[(crate::IncomingPort, ops::Value)],
) -> crate::extension::ConstFoldResult {
let [list]: [&ops::Value; 1] = sorted_consts(consts).try_into().ok()?;
consts: &[(crate::IncomingPort, Value)],
) -> ConstFoldResult {
let [list]: [&Value; 1] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let mut list = list.clone();
let elem = list.0.pop()?; // empty list fails to evaluate "pop"

Some(vec![(0.into(), list.into()), (1.into(), elem)])
match list.0.pop() {
Some(elem) => Some(vec![(0.into(), list.into()), (1.into(), const_some(elem))]),
None => {
let elem_type = list.1.clone();
Some(vec![
(0.into(), list.into()),
(1.into(), const_none(elem_type)),
])
}
}
}
}

Expand All @@ -37,13 +54,86 @@ impl ConstFold for PushFold {
fn fold(
&self,
_type_args: &[TypeArg],
consts: &[(crate::IncomingPort, ops::Value)],
) -> crate::extension::ConstFoldResult {
let [list, elem]: [&ops::Value; 2] = sorted_consts(consts).try_into().ok()?;
consts: &[(crate::IncomingPort, Value)],
) -> ConstFoldResult {
let [list, elem]: [&Value; 2] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let mut list = list.clone();
list.0.push(elem.clone());

Some(vec![(0.into(), list.into())])
}
}

pub struct GetFold;

impl ConstFold for GetFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list, index]: [&Value; 2] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let index: &ConstUsize = index.get_custom_value().expect("Should be int value.");
let idx = index.value() as usize;

match list.0.get(idx) {
Some(elem) => Some(vec![(0.into(), const_some(elem.clone()))]),
None => Some(vec![(0.into(), const_none(list.1.clone()))]),
}
}
}

pub struct SetFold;

impl ConstFold for SetFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list, idx, elem]: [&Value; 3] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");

let idx: &ConstUsize = idx.get_custom_value().expect("Should be int value.");
let idx = idx.value() as usize;

let mut list = list.clone();
let mut elem = elem.clone();
let res_elem: Value = match list.0.get_mut(idx) {
Some(old_elem) => {
std::mem::swap(old_elem, &mut elem);
const_ok(elem, list.1.clone())
}
None => const_err(list.1.clone(), elem),
};
Some(vec![(0.into(), list.into()), (1.into(), res_elem)])
}
}

pub struct InsertFold;

impl ConstFold for InsertFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list, idx, elem]: [&Value; 3] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");

let idx: &ConstUsize = idx.get_custom_value().expect("Should be int value.");
let idx = idx.value() as usize;

let mut list = list.clone();
let elem = elem.clone();
let res_elem: Value = if list.0.len() > idx {
list.0.insert(idx, elem);
const_ok_tuple([], list.1.clone())
} else {
const_err(Type::UNIT, elem)
};
Some(vec![(0.into(), list.into()), (1.into(), res_elem)])
}
}

pub struct LengthFold;

impl ConstFold for LengthFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list]: [&Value; 1] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let len = list.0.len();

Some(vec![(0.into(), ConstUsize::new(len as u64).into())])
}
}
Loading

0 comments on commit 7ba766a

Please sign in to comment.