Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Add more list operations #1474

Merged
merged 3 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,30 +458,27 @@ mod test {
use crate::ops::dataflow::DataflowOpTrait;
use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle};
use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG};
use crate::std_extensions::collections;
use crate::types::{Signature, Type, TypeArg, TypeRow};
use crate::std_extensions::collections::{self, list_type, ListOp};
use crate::types::{Signature, Type, TypeRow};
use crate::utils::depth;
use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort};

use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement};

#[test]
#[ignore] // FIXME: This needs a rewrite now that `pop` returns an optional value -.-'
fn cfg() -> Result<(), Box<dyn std::error::Error>> {
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()])
.unwrap();
let listy = Type::new_extension(
collections::EXTENSION
.get_type(&collections::LIST_TYPENAME)
.unwrap()
.instantiate([TypeArg::Type { ty: USIZE_T }])
.unwrap(),
);
let pop: ExtensionOp = collections::EXTENSION
.instantiate_extension_op("pop", [TypeArg::Type { ty: USIZE_T }], &reg)
let listy = list_type(USIZE_T);
let pop: ExtensionOp = ListOp::pop
.with_type(USIZE_T)
.to_extension_op(&reg)
.unwrap();
let push: ExtensionOp = collections::EXTENSION
.instantiate_extension_op("push", [TypeArg::Type { ty: USIZE_T }], &reg)
let push: ExtensionOp = ListOp::push
.with_type(USIZE_T)
.to_extension_op(&reg)
.unwrap();
let just_list = TypeRow::from(vec![listy.clone()]);
let intermed = TypeRow::from(vec![listy.clone(), USIZE_T]);
Expand Down
134 changes: 126 additions & 8 deletions hugr-core/src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::extension::prelude::{either_type, option_type, USIZE_T};
use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE};
use crate::ops::constant::ValueName;
Expand Down Expand Up @@ -110,17 +111,30 @@ impl CustomConst for ListValue {
#[allow(non_camel_case_types)]
#[non_exhaustive]
pub enum ListOp {
/// Pop from end of list
/// Pop from the end of list. Return an optional value.
pop,
/// Push to end of list
/// Push to end of list. Return the new list.
push,
/// Lookup an element in a list by index.
get,
/// Replace the element at index `i` with value `v`, and return the old value.
///
/// If the index is out of bounds, returns the input value as an error.
set,
/// Insert an element at index `i`.
///
/// Elements at higher indices are shifted one position to the right.
/// Returns an Err with the element if the index is out of bounds.
insert,
/// Get the length of a list.
length,
}

impl ListOp {
/// Type parameter used in the list types.
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };

/// Instantiate a list operation with an `element_type`
/// Instantiate a list operation with an `element_type`.
pub fn with_type(self, element_type: Type) -> ListOpInst {
ListOpInst {
elem_type: element_type,
Expand All @@ -135,9 +149,25 @@ impl ListOp {
let l = self.list_type(list_type_def, 0);
match self {
pop => self
.list_polytype(vec![l.clone()], vec![l.clone(), e.clone()])
.list_polytype(vec![l.clone()], vec![l, Type::from(option_type(e))])
.into(),
push => self.list_polytype(vec![l.clone(), e], vec![l]).into(),
get => self
.list_polytype(vec![l, USIZE_T], vec![Type::from(option_type(e))])
.into(),
set => self
.list_polytype(
vec![l.clone(), USIZE_T, e.clone()],
vec![l, Type::from(either_type(e.clone(), e))],
)
.into(),
insert => self
.list_polytype(
vec![l.clone(), USIZE_T, e.clone()],
vec![l, either_type(Type::UNIT, e).into()],
)
.into(),
length => self.list_polytype(vec![l], vec![USIZE_T]).into(),
}
}

Expand Down Expand Up @@ -191,8 +221,12 @@ impl MakeOpDef for ListOp {
use ListOp::*;

match self {
pop => "Pop from back of list",
push => "Push to back of list",
pop => "Pop from the back of list. Returns an optional value.",
push => "Push to the back of list",
get => "Lookup an element in a list by index. Panics if the index is out of bounds.",
set => "Replace the element at index `i` with value `v`.",
insert => "Insert an element at index `i`. Elements at higher indices are shifted one position to the right. Panics if the index is out of bounds.",
length => "Get the length of a list",
}
.into()
}
Expand All @@ -205,7 +239,6 @@ impl MakeOpDef for ListOp {
lazy_static! {
/// Extension for list operations.
pub static ref EXTENSION: Extension = {
println!("creating collections extension");
let mut extension = Extension::new(EXTENSION_ID, VERSION);

// The list type must be defined before the operations are added.
Expand Down Expand Up @@ -323,7 +356,13 @@ impl ListOpInst {

#[cfg(test)]
mod test {
use rstest::rstest;

use crate::extension::prelude::{
const_left_tuple, const_none, const_right_tuple, const_some_tuple,
};
use crate::ops::OpTrait;
use crate::PortIndex;
use crate::{
extension::{
prelude::{ConstUsize, QB_T, USIZE_T},
Expand Down Expand Up @@ -378,7 +417,7 @@ mod test {

let list_t = list_type(QB_T);

let both_row: TypeRow = vec![list_t.clone(), QB_T].into();
let both_row: TypeRow = vec![list_t.clone(), option_type(QB_T).into()].into();
let just_list_row: TypeRow = vec![list_t].into();
assert_eq!(pop_sig.input(), &just_list_row);
assert_eq!(pop_sig.output(), &both_row);
Expand All @@ -396,4 +435,83 @@ mod test {
assert_eq!(push_sig.input(), &both_row);
assert_eq!(push_sig.output(), &just_list_row);
}

/// Values used in the `list_fold` test cases.
#[derive(Debug, Clone, PartialEq, Eq)]
enum TestVal {
Idx(usize),
List(Vec<usize>),
Elem(usize),
Some(Vec<TestVal>),
None(TypeRow),
Ok(Vec<TestVal>, TypeRow),
Err(TypeRow, Vec<TestVal>),
}

impl TestVal {
fn to_value(&self) -> Value {
match self {
TestVal::Idx(i) => Value::extension(ConstUsize::new(*i as u64)),
TestVal::Elem(e) => Value::extension(ConstUsize::new(*e as u64)),
TestVal::List(l) => {
let elems = l
.iter()
.map(|&i| Value::extension(ConstUsize::new(i as u64)))
.collect();
Value::extension(ListValue(elems, USIZE_T))
}
TestVal::Some(l) => {
let elems = l.iter().map(TestVal::to_value);
const_some_tuple(elems)
}
TestVal::None(tr) => const_none(tr.clone()),
TestVal::Ok(l, tr) => {
let elems = l.iter().map(TestVal::to_value);
const_left_tuple(elems, tr.clone())
}
TestVal::Err(tr, l) => {
let elems = l.iter().map(TestVal::to_value);
const_right_tuple(tr.clone(), elems)
}
}
}
}

#[rstest]
#[case::pop(ListOp::pop, &[TestVal::List(vec![77,88, 42])], &[TestVal::List(vec![77,88]), TestVal::Some(vec![TestVal::Elem(42)])])]
#[case::pop_empty(ListOp::pop, &[TestVal::List(vec![])], &[TestVal::List(vec![]), TestVal::None(vec![USIZE_T].into())])]
#[case::push(ListOp::push, &[TestVal::List(vec![77,88]), TestVal::Elem(42)], &[TestVal::List(vec![77,88,42])])]
#[case::set(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,42]), TestVal::Ok(vec![TestVal::Elem(88)], vec![USIZE_T].into())])]
#[case::set_invalid(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(123), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(vec![USIZE_T].into(), vec![TestVal::Elem(99)])])]
#[case::get(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1)], &[TestVal::Some(vec![TestVal::Elem(88)])])]
#[case::get_invalid(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(99)], &[TestVal::None(vec![USIZE_T].into())])]
#[case::insert(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,88,42]), TestVal::Ok(vec![], vec![USIZE_T].into())])]
#[case::insert_invalid(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(52), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(Type::UNIT.into(), vec![TestVal::Elem(99)])])]
#[case::length(ListOp::length, &[TestVal::List(vec![77,88,42])], &[TestVal::Elem(3)])]
fn list_fold(#[case] op: ListOp, #[case] inputs: &[TestVal], #[case] outputs: &[TestVal]) {
let consts: Vec<_> = inputs
.iter()
.enumerate()
.map(|(i, x)| (i.into(), x.to_value()))
.collect();

let res = op
.with_type(USIZE_T)
.to_extension_op(&COLLECTIONS_REGISTRY)
.unwrap()
.constant_fold(&consts)
.unwrap();

for (i, expected) in outputs.iter().enumerate() {
let expected = expected.to_value();
let res_val = res
.iter()
.find(|(port, _)| port.index() == i)
.unwrap()
.1
.clone();

assert_eq!(res_val, expected);
}
}
}
110 changes: 100 additions & 10 deletions hugr-core/src/std_extensions/collections/list_fold.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
//! Folding definitions for list operations.

use crate::extension::{ConstFold, OpDef};
use crate::ops;
use crate::extension::prelude::{
const_left, const_left_tuple, const_none, const_right, const_some, ConstUsize,
};
use crate::extension::{ConstFold, ConstFoldResult, OpDef};
use crate::ops::Value;
use crate::types::type_param::TypeArg;
use crate::types::Type;
use crate::utils::sorted_consts;
use crate::IncomingPort;

use super::{ListOp, ListValue};

pub(super) fn set_fold(op: &ListOp, def: &mut OpDef) {
match op {
ListOp::pop => def.set_constant_folder(PopFold),
ListOp::push => def.set_constant_folder(PushFold),
ListOp::get => def.set_constant_folder(GetFold),
ListOp::set => def.set_constant_folder(SetFold),
ListOp::insert => def.set_constant_folder(InsertFold),
ListOp::length => def.set_constant_folder(LengthFold),
}
}

Expand All @@ -20,14 +29,22 @@ impl ConstFold for PopFold {
fn fold(
&self,
_type_args: &[TypeArg],
consts: &[(crate::IncomingPort, ops::Value)],
) -> crate::extension::ConstFoldResult {
let [list]: [&ops::Value; 1] = sorted_consts(consts).try_into().ok()?;
consts: &[(crate::IncomingPort, Value)],
) -> ConstFoldResult {
let [list]: [&Value; 1] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let mut list = list.clone();
let elem = list.0.pop()?; // empty list fails to evaluate "pop"

Some(vec![(0.into(), list.into()), (1.into(), elem)])
match list.0.pop() {
Some(elem) => Some(vec![(0.into(), list.into()), (1.into(), const_some(elem))]),
None => {
let elem_type = list.1.clone();
Some(vec![
(0.into(), list.into()),
(1.into(), const_none(elem_type)),
])
}
}
}
}

Expand All @@ -37,13 +54,86 @@ impl ConstFold for PushFold {
fn fold(
&self,
_type_args: &[TypeArg],
consts: &[(crate::IncomingPort, ops::Value)],
) -> crate::extension::ConstFoldResult {
let [list, elem]: [&ops::Value; 2] = sorted_consts(consts).try_into().ok()?;
consts: &[(crate::IncomingPort, Value)],
) -> ConstFoldResult {
let [list, elem]: [&Value; 2] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let mut list = list.clone();
list.0.push(elem.clone());

Some(vec![(0.into(), list.into())])
}
}

pub struct GetFold;

impl ConstFold for GetFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list, index]: [&Value; 2] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let index: &ConstUsize = index.get_custom_value().expect("Should be int value.");
let idx = index.value() as usize;

match list.0.get(idx) {
Some(elem) => Some(vec![(0.into(), const_some(elem.clone()))]),
None => Some(vec![(0.into(), const_none(list.1.clone()))]),
}
}
}

pub struct SetFold;

impl ConstFold for SetFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list, idx, elem]: [&Value; 3] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");

let idx: &ConstUsize = idx.get_custom_value().expect("Should be int value.");
let idx = idx.value() as usize;

let mut list = list.clone();
let mut elem = elem.clone();
let res_elem: Value = match list.0.get_mut(idx) {
Some(old_elem) => {
std::mem::swap(old_elem, &mut elem);
const_left(elem, list.1.clone())
}
None => const_right(list.1.clone(), elem),
};
Some(vec![(0.into(), list.into()), (1.into(), res_elem)])
}
}

pub struct InsertFold;

impl ConstFold for InsertFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list, idx, elem]: [&Value; 3] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");

let idx: &ConstUsize = idx.get_custom_value().expect("Should be int value.");
let idx = idx.value() as usize;

let mut list = list.clone();
let elem = elem.clone();
let res_elem: Value = if list.0.len() > idx {
list.0.insert(idx, elem);
const_left_tuple([], list.1.clone())
} else {
const_right(Type::UNIT, elem)
};
Some(vec![(0.into(), list.into()), (1.into(), res_elem)])
}
}

pub struct LengthFold;

impl ConstFold for LengthFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list]: [&Value; 1] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let len = list.0.len();

Some(vec![(0.into(), ConstUsize::new(len as u64).into())])
}
}
Loading
Loading