diff --git a/crates/graph_search/src/breadth_first_search.rs b/crates/graph_search/src/breadth_first_search.rs new file mode 100644 index 000000000..34d304f73 --- /dev/null +++ b/crates/graph_search/src/breadth_first_search.rs @@ -0,0 +1,29 @@ +use crate::search_index::SearchIndex; +use crate::search_iterator::SearchIterator; +use std::mem::swap; +use std::vec::IntoIter; + +pub(crate) struct BreadthFirstSearch { + stack_iterator: IntoIter, +} + +impl BreadthFirstSearch { + fn take_stack(stack: &mut Vec) -> Vec { + let mut res = Vec::::new(); + swap(&mut res, stack); + + res + } +} + +impl SearchIterator for BreadthFirstSearch { + fn new(stack: &mut Vec) -> Self { + Self { + stack_iterator: Self::take_stack(stack).into_iter(), + } + } + + fn next(&mut self) -> Option { + self.stack_iterator.next() + } +} diff --git a/crates/graph_search/src/depth_first_search.rs b/crates/graph_search/src/depth_first_search.rs new file mode 100644 index 000000000..787f1528d --- /dev/null +++ b/crates/graph_search/src/depth_first_search.rs @@ -0,0 +1,16 @@ +use crate::search_index::SearchIndex; +use crate::search_iterator::SearchIterator; + +pub(crate) struct DepthFirstSearch { + index: Option, +} + +impl SearchIterator for DepthFirstSearch { + fn new(stack: &mut Vec) -> Self { + Self { index: stack.pop() } + } + + fn next(&mut self) -> Option { + self.index.take() + } +} diff --git a/crates/graph_search/src/graph_search.rs b/crates/graph_search/src/graph_search.rs index 1b456d34e..c2026148c 100644 --- a/crates/graph_search/src/graph_search.rs +++ b/crates/graph_search/src/graph_search.rs @@ -1,11 +1,8 @@ -use std::mem::swap; - -use crate::graph_search_index::GraphSearchIndex; -use crate::search_control::SearchControl; +use crate::breadth_first_search::BreadthFirstSearch; +use crate::depth_first_search::DepthFirstSearch; use crate::search_handler::SearchHandler; -use agdb_bit_set::BitSet; +use crate::search_impl::SearchImpl; use agdb_graph::GraphData; -use agdb_graph::GraphEdgeIterator; use agdb_graph::GraphImpl; use agdb_graph::GraphIndex; @@ -14,9 +11,6 @@ where Data: GraphData, { pub(crate) graph: &'a GraphImpl, - pub(crate) stack: Vec, - pub(crate) visited: BitSet, - pub(crate) result: Vec, } impl<'a, Data> GraphSearch<'a, Data> @@ -28,121 +22,27 @@ where index: &GraphIndex, handler: &Handler, ) -> Vec { - self.clear(); - self.add_index_to_stack(index.clone(), 0); - self.search(handler); - - self.take_result() - } - - fn add_edges_to_stack(&mut self, edges: GraphEdgeIterator, distance: u64) { - for edge in edges { - self.stack.push(GraphSearchIndex { - index: edge.index(), - distance, - }); - } - } - - fn add_index_to_stack(&mut self, index: GraphIndex, distance: u64) { - if self.validate_index(&index) { - self.stack.push(GraphSearchIndex { index, distance }); - } - } - - fn clear(&mut self) { - self.result.clear(); - self.stack.clear(); - self.visited.clear(); - } - - fn expand_index(&mut self, index: &GraphSearchIndex) { - if let Some(node) = self.graph.node(&index.index) { - self.add_edges_to_stack(node.edge_from_iter(), index.distance + 1); - } else if let Some(edge) = self.graph.edge(&index.index) { - self.add_index_to_stack(edge.to_index(), index.distance + 1); - } - } - - fn process_index( - &mut self, - index: GraphSearchIndex, - handler: &Handler, - ) -> bool { - if !self.visit_index(&index) { - self.process_unvisited_index(index, handler) + if self.validate_index(index) { + SearchImpl::<'a, Data, BreadthFirstSearch>::new(self.graph, index.clone()) + .search(handler) } else { - true + vec![] } } - fn process_unvisited_index( + pub fn depth_first_search( &mut self, - index: GraphSearchIndex, + index: &GraphIndex, handler: &Handler, - ) -> bool { - let add_index; - let result; - - match handler.process(&index.index, &index.distance) { - SearchControl::Continue(add) => { - self.expand_index(&index); - add_index = add; - result = true; - } - SearchControl::Finish(add) => { - add_index = add; - result = false; - } - SearchControl::Stop(add) => { - add_index = add; - result = true; - } - } - - if add_index { - self.result.push(index.index); - } - - result - } - - fn process_stack(&mut self, handler: &Handler) -> bool { - for i in self.take_stack() { - if !self.process_index(i, handler) { - return false; - } + ) -> Vec { + if self.validate_index(index) { + SearchImpl::<'a, Data, DepthFirstSearch>::new(self.graph, index.clone()).search(handler) + } else { + vec![] } - - true - } - - fn search(&mut self, handler: &T) { - while !self.stack.is_empty() && self.process_stack(handler) {} - } - - fn take_result(&mut self) -> Vec { - let mut res = Vec::::new(); - swap(&mut res, &mut self.result); - - res - } - - fn take_stack(&mut self) -> Vec { - let mut res = Vec::::new(); - swap(&mut res, &mut self.stack); - - res } fn validate_index(&self, index: &GraphIndex) -> bool { self.graph.node(index).is_some() || self.graph.edge(index).is_some() } - - fn visit_index(&mut self, index: &GraphSearchIndex) -> bool { - let visited = self.visited.value(index.index.as_u64()); - self.visited.insert(index.index.as_u64()); - - visited - } } diff --git a/crates/graph_search/src/graph_search_from.rs b/crates/graph_search/src/graph_search_from.rs index 926fb9152..7569bbc16 100644 --- a/crates/graph_search/src/graph_search_from.rs +++ b/crates/graph_search/src/graph_search_from.rs @@ -1,5 +1,4 @@ use crate::graph_search::GraphSearch; -use agdb_bit_set::BitSet; use agdb_graph::GraphData; use agdb_graph::GraphImpl; @@ -8,11 +7,6 @@ where Data: GraphData, { fn from(graph: &'a GraphImpl) -> Self { - GraphSearch { - graph, - visited: BitSet::new(), - stack: vec![], - result: vec![], - } + GraphSearch { graph } } } diff --git a/crates/graph_search/src/lib.rs b/crates/graph_search/src/lib.rs index ce98169e8..58c2b1efc 100644 --- a/crates/graph_search/src/lib.rs +++ b/crates/graph_search/src/lib.rs @@ -1,8 +1,12 @@ +mod breadth_first_search; +mod depth_first_search; mod graph_search; mod graph_search_from; -mod graph_search_index; mod search_control; mod search_handler; +mod search_impl; +mod search_index; +mod search_iterator; pub use graph_search::GraphSearch; pub use search_control::SearchControl; diff --git a/crates/graph_search/src/search_impl.rs b/crates/graph_search/src/search_impl.rs new file mode 100644 index 000000000..f5baecdae --- /dev/null +++ b/crates/graph_search/src/search_impl.rs @@ -0,0 +1,135 @@ +use crate::search_index::SearchIndex; +use crate::search_iterator::SearchIterator; +use crate::SearchControl; +use crate::SearchHandler; +use agdb_bit_set::BitSet; +use agdb_graph::GraphData; +use agdb_graph::GraphEdgeIterator; +use agdb_graph::GraphImpl; +use agdb_graph::GraphIndex; +use std::marker::PhantomData; +use std::mem::swap; + +pub(crate) struct SearchImpl<'a, Data, SearchIt> +where + Data: GraphData, + SearchIt: SearchIterator, +{ + pub(crate) graph: &'a GraphImpl, + pub(crate) stack: Vec, + pub(crate) visited: BitSet, + pub(crate) result: Vec, + pub(crate) algorithm: PhantomData, +} + +impl<'a, Data, SearchIt> SearchImpl<'a, Data, SearchIt> +where + Data: GraphData, + SearchIt: SearchIterator, +{ + pub(crate) fn new(graph: &'a GraphImpl, index: GraphIndex) -> Self { + Self { + graph, + stack: vec![SearchIndex { index, distance: 0 }], + visited: BitSet::new(), + result: vec![], + algorithm: PhantomData, + } + } + + pub(crate) fn search(&mut self, handler: &Handler) -> Vec { + while !self.stack.is_empty() && self.process_stack(handler) {} + + self.take_result() + } + + fn add_edges_to_stack(&mut self, edges: GraphEdgeIterator, distance: u64) { + for edge in edges { + self.stack.push(SearchIndex { + index: edge.index(), + distance, + }); + } + } + + fn add_index_to_stack(&mut self, index: GraphIndex, distance: u64) { + self.stack.push(SearchIndex { index, distance }); + } + + fn expand_index(&mut self, index: &SearchIndex) { + if let Some(node) = self.graph.node(&index.index) { + self.add_edges_to_stack(node.edge_from_iter(), index.distance + 1); + } else if let Some(edge) = self.graph.edge(&index.index) { + self.add_index_to_stack(edge.to_index(), index.distance + 1); + } + } + + fn process_index( + &mut self, + index: SearchIndex, + handler: &Handler, + ) -> bool { + if !self.visit_index(&index) { + self.process_unvisited_index(index, handler) + } else { + true + } + } + + fn process_stack(&mut self, handler: &Handler) -> bool { + let mut it = SearchIt::new(&mut self.stack); + + while let Some(i) = it.next() { + if !self.process_index(i, handler) { + return false; + } + } + + true + } + + fn process_unvisited_index( + &mut self, + index: SearchIndex, + handler: &Handler, + ) -> bool { + let add_index; + let result; + + match handler.process(&index.index, &index.distance) { + SearchControl::Continue(add) => { + self.expand_index(&index); + add_index = add; + result = true; + } + SearchControl::Finish(add) => { + add_index = add; + result = false; + } + SearchControl::Stop(add) => { + add_index = add; + result = true; + } + } + + if add_index { + self.result.push(index.index); + } + + result + } + + fn take_result(&mut self) -> Vec { + let mut res = Vec::::new(); + swap(&mut res, &mut self.result); + + res + } + + fn visit_index(&mut self, index: &SearchIndex) -> bool { + let visited = self.visited.value(index.index.as_u64()); + self.visited.insert(index.index.as_u64()); + + visited + } +} diff --git a/crates/graph_search/src/graph_search_index.rs b/crates/graph_search/src/search_index.rs similarity index 71% rename from crates/graph_search/src/graph_search_index.rs rename to crates/graph_search/src/search_index.rs index 4d96c1926..c15c12f19 100644 --- a/crates/graph_search/src/graph_search_index.rs +++ b/crates/graph_search/src/search_index.rs @@ -1,6 +1,6 @@ use agdb_graph::GraphIndex; -pub(crate) struct GraphSearchIndex { +pub(crate) struct SearchIndex { pub(crate) index: GraphIndex, pub(crate) distance: u64, } diff --git a/crates/graph_search/src/search_iterator.rs b/crates/graph_search/src/search_iterator.rs new file mode 100644 index 000000000..ec0812fa8 --- /dev/null +++ b/crates/graph_search/src/search_iterator.rs @@ -0,0 +1,6 @@ +use crate::search_index::SearchIndex; + +pub(crate) trait SearchIterator { + fn new(stack: &mut Vec) -> Self; + fn next(&mut self) -> Option; +} diff --git a/crates/graph_search/tests/breadth_first_search_test.rs b/crates/graph_search/tests/breadth_first_search_test.rs index 796cfb392..afc123a80 100644 --- a/crates/graph_search/tests/breadth_first_search_test.rs +++ b/crates/graph_search/tests/breadth_first_search_test.rs @@ -164,8 +164,9 @@ fn stop_at_distance() { let node3 = graph.insert_node().unwrap(); let edge1 = graph.insert_edge(&node1, &node2).unwrap(); - let _edge2 = graph.insert_edge(&node2, &node3).unwrap(); - let _edge3 = graph.insert_edge(&node3, &node1).unwrap(); + let edge2 = graph.insert_edge(&node1, &node2).unwrap(); + let _edge3 = graph.insert_edge(&node2, &node3).unwrap(); + let _edge4 = graph.insert_edge(&node3, &node1).unwrap(); let result = GraphSearch::from(&graph).breadth_first_search( &node1, @@ -180,5 +181,5 @@ fn stop_at_distance() { }, ); - assert_eq!(result, vec![node1, edge1, node2]); + assert_eq!(result, vec![node1, edge2, edge1, node2]); } diff --git a/crates/graph_search/tests/depth_first_search_test.rs b/crates/graph_search/tests/depth_first_search_test.rs new file mode 100644 index 000000000..84b7d0789 --- /dev/null +++ b/crates/graph_search/tests/depth_first_search_test.rs @@ -0,0 +1,185 @@ +use agdb_graph::Graph; +use agdb_graph::GraphIndex; +use agdb_graph_search::GraphSearch; +use agdb_graph_search::SearchControl; +use agdb_graph_search::SearchHandler; + +struct Handler { + pub processor: fn(&GraphIndex, &u64) -> SearchControl, +} + +impl Default for Handler { + fn default() -> Self { + Self { + processor: |_index: &GraphIndex, _distance: &u64| SearchControl::Continue(true), + } + } +} + +impl SearchHandler for Handler { + fn process(&self, index: &GraphIndex, distance: &u64) -> SearchControl { + (self.processor)(index, distance) + } +} + +#[test] +fn empty_graph() { + let graph = Graph::new(); + + let result = + GraphSearch::from(&graph).depth_first_search(&GraphIndex::default(), &Handler::default()); + + assert_eq!(result, vec![]); +} + +#[test] +fn cyclic_graph() { + let mut graph = Graph::new(); + + let node1 = graph.insert_node().unwrap(); + let node2 = graph.insert_node().unwrap(); + let node3 = graph.insert_node().unwrap(); + + let edge1 = graph.insert_edge(&node1, &node2).unwrap(); + let edge2 = graph.insert_edge(&node1, &node2).unwrap(); + let edge3 = graph.insert_edge(&node2, &node3).unwrap(); + let edge4 = graph.insert_edge(&node2, &node3).unwrap(); + let edge5 = graph.insert_edge(&node3, &node1).unwrap(); + let edge6 = graph.insert_edge(&node3, &node1).unwrap(); + + let result = GraphSearch::from(&graph).depth_first_search(&node1, &Handler::default()); + + assert_eq!( + result, + vec![node1, edge1, node2, edge3, node3, edge5, edge6, edge4, edge2] + ); +} + +#[test] +fn full_search() { + let mut graph = Graph::new(); + + let node1 = graph.insert_node().unwrap(); + let node2 = graph.insert_node().unwrap(); + let node3 = graph.insert_node().unwrap(); + let node4 = graph.insert_node().unwrap(); + + let edge1 = graph.insert_edge(&node1, &node2).unwrap(); + let edge2 = graph.insert_edge(&node1, &node3).unwrap(); + let edge3 = graph.insert_edge(&node1, &node4).unwrap(); + + let result = GraphSearch::from(&graph).depth_first_search(&node1, &Handler::default()); + + assert_eq!( + result, + vec![node1, edge1, node2, edge2, node3, edge3, node4] + ); +} + +#[test] +fn filter_edges() { + let mut graph = Graph::new(); + + let node1 = graph.insert_node().unwrap(); + let node2 = graph.insert_node().unwrap(); + let node3 = graph.insert_node().unwrap(); + let node4 = graph.insert_node().unwrap(); + + graph.insert_edge(&node1, &node2).unwrap(); + graph.insert_edge(&node1, &node3).unwrap(); + graph.insert_edge(&node1, &node4).unwrap(); + + let result = GraphSearch::from(&graph).depth_first_search( + &node1, + &Handler { + processor: |index: &GraphIndex, _distance: &u64| { + SearchControl::Continue(index.is_node()) + }, + }, + ); + + assert_eq!(result, vec![node1, node2, node3, node4]); +} + +#[test] +fn finish_search() { + let mut graph = Graph::new(); + + let node1 = graph.insert_node().unwrap(); + let node2 = graph.insert_node().unwrap(); + let node3 = graph.insert_node().unwrap(); + + graph.insert_edge(&node1, &node2).unwrap(); + graph.insert_edge(&node1, &node2).unwrap(); + graph.insert_edge(&node2, &node3).unwrap(); + graph.insert_edge(&node2, &node3).unwrap(); + graph.insert_edge(&node3, &node1).unwrap(); + graph.insert_edge(&node3, &node1).unwrap(); + + let result = GraphSearch::from(&graph).depth_first_search( + &node1, + &Handler { + processor: |index: &GraphIndex, _distance: &u64| { + if index.value() == 2 { + SearchControl::Finish(true) + } else { + SearchControl::Continue(false) + } + }, + }, + ); + + assert_eq!(result, vec![node2]); +} + +#[test] +fn search_twice() { + let mut graph = Graph::new(); + + let node1 = graph.insert_node().unwrap(); + let node2 = graph.insert_node().unwrap(); + let node3 = graph.insert_node().unwrap(); + let node4 = graph.insert_node().unwrap(); + + let edge1 = graph.insert_edge(&node1, &node2).unwrap(); + let edge2 = graph.insert_edge(&node1, &node3).unwrap(); + let edge3 = graph.insert_edge(&node1, &node4).unwrap(); + + let mut result = GraphSearch::from(&graph).depth_first_search(&node1, &Handler::default()); + let expected = vec![node1.clone(), edge1, node2, edge2, node3, edge3, node4]; + + assert_eq!(result, expected); + + result = GraphSearch::from(&graph).depth_first_search(&node1, &Handler::default()); + + assert_eq!(result, expected); +} + +#[test] +fn stop_at_distance() { + let mut graph = Graph::new(); + + let node1 = graph.insert_node().unwrap(); + let node2 = graph.insert_node().unwrap(); + let node3 = graph.insert_node().unwrap(); + + let edge1 = graph.insert_edge(&node1, &node2).unwrap(); + let edge2 = graph.insert_edge(&node1, &node2).unwrap(); + let _edge3 = graph.insert_edge(&node2, &node3).unwrap(); + let _edge4 = graph.insert_edge(&node3, &node1).unwrap(); + + let result = GraphSearch::from(&graph).depth_first_search( + &node1, + &Handler { + processor: |_index: &GraphIndex, distance: &u64| { + if *distance == 2 { + SearchControl::Stop(true) + } else { + SearchControl::Continue(true) + } + }, + }, + ); + + assert_eq!(result, vec![node1, edge1, node2, edge2]); +}