Skip to content

Commit

Permalink
feat: Add more list operations
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Aug 27, 2024
1 parent 069abd7 commit 43c728f
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 12 deletions.
37 changes: 33 additions & 4 deletions hugr-core/src/std_extensions/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::extension::prelude::USIZE_T;
use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE};
use crate::ops::constant::ValueName;
use crate::ops::{OpName, Value};
use crate::std_extensions::arithmetic::int_types;
use crate::types::{TypeName, TypeRowRV};
use crate::{
extension::{
Expand Down Expand Up @@ -114,6 +116,20 @@ pub enum ListOp {
pop,
/// Push to end of list
push,
/// Concatenate two lists
concat,
/// Lookup an element in a list by index
lookup,
/// Get the length of a list
length,
/// Repeat a list a number of times
repeat,
/// Count the number of times an element appears in a list
count,
/// Reverse the elements in a list
reverse,
// TODO: `find`. It requires an optional result type.
// https://github.com/CQCL/hugr/issues/1473
}

impl ListOp {
Expand All @@ -134,10 +150,16 @@ impl ListOp {
let e = Type::new_var_use(0, TypeBound::Any);
let l = self.list_type(list_type_def, 0);
match self {
pop => self
.list_polytype(vec![l.clone()], vec![l.clone(), e.clone()])
.into(),
pop => self.list_polytype(vec![l.clone()], vec![l, e]).into(),
push => self.list_polytype(vec![l.clone(), e], vec![l]).into(),
concat => self
.list_polytype(vec![l.clone(), l.clone()], vec![l])
.into(),
lookup => self.list_polytype(vec![l, USIZE_T], vec![e]).into(),
length => self.list_polytype(vec![l], vec![USIZE_T]).into(),
repeat => self.list_polytype(vec![l.clone(), USIZE_T], vec![l]).into(),
count => self.list_polytype(vec![l, e], vec![USIZE_T]).into(),
reverse => self.list_polytype(vec![l.clone()], vec![l]).into(),
}
}

Expand Down Expand Up @@ -193,6 +215,12 @@ impl MakeOpDef for ListOp {
match self {
pop => "Pop from back of list",
push => "Push to back of list",
concat => "Concatenate two lists",
lookup => "Lookup an element in a list by index",
length => "Get the length of a list",
repeat => "Repeat a list a number of times",
count => "Count the number of times an element appears in a list",
reverse => "Reverse the elements in a list",
}
.into()
}
Expand All @@ -206,7 +234,8 @@ lazy_static! {
/// Extension for list operations.
pub static ref EXTENSION: Extension = {
println!("creating collections extension");
let mut extension = Extension::new(EXTENSION_ID, VERSION);
let mut extension = Extension::new(EXTENSION_ID, VERSION)
.with_reqs(int_types::EXTENSION_ID);

// The list type must be defined before the operations are added.
extension.add_type(
Expand Down
122 changes: 114 additions & 8 deletions hugr-core/src/std_extensions/collections/list_fold.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
//! Folding definitions for list operations.
use crate::extension::{ConstFold, OpDef};
use crate::ops;
use crate::extension::prelude::ConstUsize;
use crate::extension::{ConstFold, ConstFoldResult, OpDef};
use crate::ops::Value;
use crate::types::type_param::TypeArg;
use crate::utils::sorted_consts;
use crate::IncomingPort;

use super::{ListOp, ListValue};

pub(super) fn set_fold(op: &ListOp, def: &mut OpDef) {
match op {
ListOp::pop => def.set_constant_folder(PopFold),
ListOp::push => def.set_constant_folder(PushFold),
ListOp::concat => def.set_constant_folder(ConcatFold),
ListOp::lookup => def.set_constant_folder(LookupFold),
ListOp::length => def.set_constant_folder(LengthFold),
ListOp::repeat => def.set_constant_folder(RepeatFold),
ListOp::count => def.set_constant_folder(CountFold),
ListOp::reverse => def.set_constant_folder(ReverseFold),
}
}

Expand All @@ -20,9 +28,9 @@ impl ConstFold for PopFold {
fn fold(
&self,
_type_args: &[TypeArg],
consts: &[(crate::IncomingPort, ops::Value)],
) -> crate::extension::ConstFoldResult {
let [list]: [&ops::Value; 1] = sorted_consts(consts).try_into().ok()?;
consts: &[(crate::IncomingPort, Value)],
) -> ConstFoldResult {
let [list]: [&Value; 1] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let mut list = list.clone();
let elem = list.0.pop()?; // empty list fails to evaluate "pop"
Expand All @@ -37,13 +45,111 @@ impl ConstFold for PushFold {
fn fold(
&self,
_type_args: &[TypeArg],
consts: &[(crate::IncomingPort, ops::Value)],
) -> crate::extension::ConstFoldResult {
let [list, elem]: [&ops::Value; 2] = sorted_consts(consts).try_into().ok()?;
consts: &[(crate::IncomingPort, Value)],
) -> ConstFoldResult {
let [list, elem]: [&Value; 2] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let mut list = list.clone();
list.0.push(elem.clone());

Some(vec![(0.into(), list.into())])
}
}

pub struct ConcatFold;

impl ConstFold for ConcatFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list1, list2]: [&Value; 2] = sorted_consts(consts).try_into().ok()?;
let list1: &ListValue = list1.get_custom_value().expect("Should be list value.");
let list2: &ListValue = list2.get_custom_value().expect("Should be list value.");
let mut list1 = list1.clone();
let mut list2 = list2.clone();
list1.0.append(&mut list2.0);

Some(vec![(0.into(), list1.into())])
}
}

pub struct LookupFold;

impl ConstFold for LookupFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list, index]: [&Value; 2] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let index: &ConstUsize = index.get_custom_value().expect("Should be int value.");
let idx = index.value() as usize;
let elem = list
.0
.get(idx)
.unwrap_or_else(|| panic!("Index {idx} out of bounds"));

Some(vec![(0.into(), elem.clone())])
}
}

pub struct LengthFold;

impl ConstFold for LengthFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list]: [&Value; 1] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let len = list.0.len();

Some(vec![(
0.into(),
Value::extension(ConstUsize::new(len as u64)),
)])
}
}

pub struct RepeatFold;

impl ConstFold for RepeatFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list, count]: [&Value; 2] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let count: &ConstUsize = count.get_custom_value().expect("Should be int value.");
let count = count.value() as usize;

let contents = list
.0
.iter()
.cycle()
.take(count * list.0.len())
.cloned()
.collect();
let new_list = ListValue(contents, list.1.clone());

Some(vec![(0.into(), new_list.into())])
}
}

pub struct CountFold;

impl ConstFold for CountFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list, elem]: [&Value; 2] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");
let count = list.0.iter().filter(|e| e == &elem).count();

Some(vec![(
0.into(),
Value::extension(ConstUsize::new(count as u64)),
)])
}
}

pub struct ReverseFold;

impl ConstFold for ReverseFold {
fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [list]: [&Value; 1] = sorted_consts(consts).try_into().ok()?;
let list: &ListValue = list.get_custom_value().expect("Should be list value.");

let contents = list.0.iter().rev().cloned().collect();
let new_list = ListValue(contents, list.1.clone());

Some(vec![(0.into(), new_list.into())])
}
}

0 comments on commit 43c728f

Please sign in to comment.