Skip to content

Commit

Permalink
Update rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
berkaysynnada committed Feb 27, 2024
1 parent db9664f commit 56e72d8
Showing 1 changed file with 33 additions and 160 deletions.
193 changes: 33 additions & 160 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,6 @@ macro_rules! handle_transform_recursion_down {
};
}

/// This macro is used to determine continuation during combined transforming traversals.
///
/// After the bottom-up closure returns with [`Transformed`] depending on the returned
/// [`TreeNodeRecursion`], [`Transformed::try_transform_node_with()`] decides about recursion
/// continuation and if [`TreeNodeRecursion`] state propagation is needed.
/// And then after recursing into children returns with [`Transformed`] depending on the
/// returned [`TreeNodeRecursion`], [`Transformed::try_transform_node_with()`] decides about recursion
/// continuation and [`TreeNodeRecursion`] state propagation.
#[macro_export]
macro_rules! handle_transform_recursion {
($F_DOWN:expr, $F_SELF:expr, $F_UP:expr) => {
$F_DOWN?.try_transform_node_with(
|n| {
n.map_children($F_SELF)?
.try_transform_node_with($F_UP, Some(TreeNodeRecursion::Jump))
},
Some(TreeNodeRecursion::Continue),
)
};
}

/// This macro is used to determine continuation during bottom-up transforming traversals.
///
/// After recursing into children returns with [`Transformed`] depending on the returned
Expand Down Expand Up @@ -213,9 +192,34 @@ pub trait TreeNode: Sized {
self,
rewriter: &mut R,
) -> Result<Transformed<Self>> {
handle_transform_recursion!(rewriter.f_down(self), |c| c.rewrite(rewriter), |n| {
rewriter.f_up(n)
})
let pre_visited = rewriter.f_down(self)?;
match pre_visited.tnr {
TreeNodeRecursion::Continue => {
let with_updated_children = pre_visited
.data
.map_children(|c| c.rewrite(rewriter))?
.try_transform_node_with(
|n| rewriter.f_up(n),
Some(TreeNodeRecursion::Jump),
)?;
Ok(Transformed {
transformed: with_updated_children.transformed
|| pre_visited.transformed,
..with_updated_children
})
}
TreeNodeRecursion::Jump => {
let pre_visited_transformed = pre_visited.transformed;
let post_visited = rewriter.f_up(pre_visited.data)?;

Ok(Transformed {
tnr: TreeNodeRecursion::Continue,
transformed: post_visited.transformed || pre_visited_transformed,
data: post_visited.data,
})
}
TreeNodeRecursion::Stop => Ok(pre_visited),
}
}

/// Applies `f` to the node and its children. `f` is applied in a preoder way,
Expand All @@ -232,41 +236,6 @@ pub trait TreeNode: Sized {
self.apply_children(&mut |n| n.apply(f))
}

/// Transforms the tree using `f_down` while traversing the tree top-down
/// (pre-preorder) and using `f_up` while traversing the tree bottom-up (post-order).
///
/// E.g. for an tree such as:
/// ```text
/// ParentNode
/// left: ChildNode1
/// right: ChildNode2
/// ```
///
/// The nodes are visited using the following order:
/// ```text
/// f_down(ParentNode)
/// f_down(ChildNode1)
/// f_up(ChildNode1)
/// f_down(ChildNode2)
/// f_up(ChildNode2)
/// f_up(ParentNode)
/// ```
///
/// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled.
///
/// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately.
fn transform<FD, FU>(
self,
f_down: &mut FD,
f_up: &mut FU,
) -> Result<Transformed<Self>>
where
FD: FnMut(Self) -> Result<Transformed<Self>>,
FU: FnMut(Self) -> Result<Transformed<Self>>,
{
handle_transform_recursion!(f_down(self), |c| c.transform(f_down, f_up), f_up)
}

/// Convenience utils for writing optimizers rule: recursively apply the given 'f' to the node and all of its
/// children(Preorder Traversal).
/// When the `f` does not apply to a given node, it is left unchanged.
Expand Down Expand Up @@ -390,7 +359,9 @@ pub enum TreeNodeRecursion {
/// In bottom-up traversals, bypass calling bottom-up closures till the next leaf node.
///
/// In combined traversals, if it is "f_down" (pre-order) phase, execution "jumps" to
/// next "f_up" (post_order) phase, or vice versa.
/// next "f_up" (post_order) phase by shortcutting its children. If it is "f_up" (pre-order)
/// phase, execution "jumps" to next "f_down" (pre_order) phase by shortcutting its parent
/// nodes until the first parent node having unvisited children path.
Jump,

/// Stop recursion.
Expand Down Expand Up @@ -814,21 +785,6 @@ mod tests {
.collect()
}

fn f_down_jump_on_a_transformed_tree() -> TestTreeNode<String> {
let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string());
let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string());
let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string());
let node_c =
TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string());
let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
let node_f =
TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string());
let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string());
TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string())
}

fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode<String> {
let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string());
let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string());
Expand Down Expand Up @@ -868,7 +824,7 @@ mod tests {
let node_b = TestTreeNode::new(vec![], "b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string());
let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string());
let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string());
let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string());
let node_f =
Expand Down Expand Up @@ -1314,18 +1270,6 @@ mod tests {
};
}

macro_rules! transform_test {
($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => {
#[test]
fn $NAME() -> Result<()> {
let tree = test_tree();
assert_eq!(tree.transform(&mut $F_DOWN, &mut $F_UP,)?, $EXPECTED_TREE);

Ok(())
}
};
}

macro_rules! transform_down_test {
($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => {
#[test]
Expand Down Expand Up @@ -1432,7 +1376,7 @@ mod tests {
test_rewrite_f_down_jump_on_a,
transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
transform_yes("f_up"),
Transformed::yes(f_down_jump_on_a_transformed_tree())
Transformed::yes(transformed_tree())
);
rewrite_test!(
test_rewrite_f_down_jump_on_e,
Expand Down Expand Up @@ -1493,77 +1437,6 @@ mod tests {
)
);

transform_test!(
test_transform,
transform_yes("f_down"),
transform_yes("f_up"),
Transformed::yes(transformed_tree())
);
transform_test!(
test_transform_f_down_jump_on_a,
transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump),
transform_yes("f_up"),
Transformed::yes(f_down_jump_on_a_transformed_tree())
);
transform_test!(
test_transform_f_down_jump_on_e,
transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump),
transform_yes("f_up"),
Transformed::yes(f_down_jump_on_e_transformed_tree())
);
transform_test!(
test_transform_f_up_jump_on_a,
transform_yes("f_down"),
transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump),
Transformed::yes(f_up_jump_on_a_transformed_tree())
);
transform_test!(
test_transform_f_up_jump_on_e,
transform_yes("f_down"),
transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump),
Transformed::yes(f_up_jump_on_e_transformed_tree())
);
transform_test!(
test_transform_f_down_stop_on_a,
transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop),
transform_yes("f_up"),
Transformed::new(
f_down_stop_on_a_transformed_tree(),
true,
TreeNodeRecursion::Stop
)
);
transform_test!(
test_transform_f_down_stop_on_e,
transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop),
transform_yes("f_up"),
Transformed::new(
f_down_stop_on_e_transformed_tree(),
true,
TreeNodeRecursion::Stop
)
);
transform_test!(
test_transform_f_up_stop_on_a,
transform_yes("f_down"),
transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop),
Transformed::new(
f_up_stop_on_a_transformed_tree(),
true,
TreeNodeRecursion::Stop
)
);
transform_test!(
test_transform_f_up_stop_on_e,
transform_yes("f_down"),
transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop),
Transformed::new(
f_up_stop_on_e_transformed_tree(),
true,
TreeNodeRecursion::Stop
)
);

transform_down_test!(
test_transform_down,
transform_yes("f_down"),
Expand Down

0 comments on commit 56e72d8

Please sign in to comment.