diff --git a/rustworkx-core/src/generators/full_rary_tree_graph.rs b/rustworkx-core/src/generators/full_rary_tree_graph.rs new file mode 100644 index 0000000000..181491ddbe --- /dev/null +++ b/rustworkx-core/src/generators/full_rary_tree_graph.rs @@ -0,0 +1,168 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use std::collections::VecDeque; +use std::hash::Hash; + +use petgraph::data::{Build, Create}; +use petgraph::visit::{Data, NodeIndexable}; + +use super::InvalidInputError; + +/// Creates a full r-ary tree of `n` nodes. +/// Sometimes called a k-ary, n-ary, or m-ary tree. +/// +/// * `branching factor` - The number of children at each node. +/// * `num_nodes` - The number of nodes in the graph. +/// * `weights` - A list of node weights. If the number of weights is +/// less than n, extra nodes with with None weight will be appended. +/// * `default_node_weight` - A callable that will return the weight to use +/// for newly created nodes. This is ignored if `weights` is specified, +/// as the weights from that argument will be used instead. +/// * `default_edge_weight` - A callable that will return the weight object +/// to use for newly created edges. +/// * `bidirectional` - Whether edges are added bidirectionally, if set to +/// `true` then for any edge `(u, v)` an edge `(v, u)` will also be added. +/// If the graph is undirected this will result in a pallel edge. +/// +/// # Example +/// ```rust +/// use rustworkx_core::petgraph; +/// use rustworkx_core::generators::full_rary_tree_graph; +/// use rustworkx_core::petgraph::visit::EdgeRef; +/// +/// let expected_edge_list = vec![ +/// (0, 1), +/// (0, 2), +/// (1, 3), +/// (1, 4), +/// (2, 5), +/// (2, 6), +/// (3, 7), +/// (3, 8), +/// (4, 9), +/// ]; +/// let g: petgraph::graph::UnGraph<(), ()> = full_rary_tree_graph( +/// 2, +/// 10, +/// None, +/// || {()}, +/// || {()}, +/// ).unwrap(); +/// assert_eq!( +/// expected_edge_list, +/// g.edge_references() +/// .map(|edge| (edge.source().index(), edge.target().index())) +/// .collect::>(), +/// ) +/// ``` +pub fn full_rary_tree_graph( + branching_factor: usize, + num_nodes: usize, + weights: Option>, + mut default_node_weight: F, + mut default_edge_weight: H, +) -> Result +where + G: Build + Create + Data + NodeIndexable, + F: FnMut() -> T, + H: FnMut() -> M, + G::NodeId: Eq + Hash, +{ + if let Some(wt) = weights.as_ref() { + if wt.len() > num_nodes { + return Err(InvalidInputError {}); + } + } + let mut graph = G::with_capacity(num_nodes, num_nodes * branching_factor); + + let nodes: Vec = match weights { + Some(weights) => { + let mut node_list: Vec = Vec::with_capacity(num_nodes); + let node_count = num_nodes - weights.len(); + for weight in weights { + let index = graph.add_node(weight); + node_list.push(index); + } + for _ in 0..node_count { + let index = graph.add_node(default_node_weight()); + node_list.push(index); + } + node_list + } + None => (0..num_nodes) + .map(|_| graph.add_node(default_node_weight())) + .collect(), + }; + if !nodes.is_empty() { + let mut parents = VecDeque::from(vec![graph.to_index(nodes[0])]); + let mut nod_it: usize = 1; + + while !parents.is_empty() { + let source: usize = parents.pop_front().unwrap(); //If is empty it will never try to pop + for _ in 0..branching_factor { + if nod_it < num_nodes { + let target: usize = graph.to_index(nodes[nod_it]); + parents.push_back(target); + nod_it += 1; + graph.add_edge(nodes[source], nodes[target], default_edge_weight()); + } + } + } + } + Ok(graph) +} + +#[cfg(test)] +mod tests { + use crate::generators::full_rary_tree_graph; + use crate::generators::InvalidInputError; + use crate::petgraph; + use crate::petgraph::visit::EdgeRef; + + #[test] + fn test_full_rary_graph() { + let expected_edge_list = vec![ + (0, 1), + (0, 2), + (1, 3), + (1, 4), + (2, 5), + (2, 6), + (3, 7), + (3, 8), + (4, 9), + ]; + let g: petgraph::graph::UnGraph<(), ()> = + full_rary_tree_graph(2, 10, None, || (), || ()).unwrap(); + assert_eq!( + expected_edge_list, + g.edge_references() + .map(|edge| (edge.source().index(), edge.target().index())) + .collect::>(), + ); + } + + #[test] + fn test_full_rary_error() { + match full_rary_tree_graph::, (), _, _, ()>( + 3, + 2, + Some(vec![(), (), (), ()]), + || (), + || (), + ) { + Ok(_) => panic!("Returned a non-error"), + Err(e) => assert_eq!(e, InvalidInputError), + }; + } +} diff --git a/rustworkx-core/src/generators/mod.rs b/rustworkx-core/src/generators/mod.rs index d6991c22a9..40237685d6 100644 --- a/rustworkx-core/src/generators/mod.rs +++ b/rustworkx-core/src/generators/mod.rs @@ -16,6 +16,7 @@ mod barbell_graph; mod binomial_tree_graph; mod complete_graph; mod cycle_graph; +mod full_rary_tree_graph; mod grid_graph; mod heavy_hex_graph; mod heavy_square_graph; @@ -46,6 +47,7 @@ pub use barbell_graph::barbell_graph; pub use binomial_tree_graph::binomial_tree_graph; pub use complete_graph::complete_graph; pub use cycle_graph::cycle_graph; +pub use full_rary_tree_graph::full_rary_tree_graph; pub use grid_graph::grid_graph; pub use heavy_hex_graph::heavy_hex_graph; pub use heavy_square_graph::heavy_square_graph; diff --git a/src/generators.rs b/src/generators.rs index 944256677f..18e378afbd 100644 --- a/src/generators.rs +++ b/src/generators.rs @@ -10,11 +10,9 @@ // License for the specific language governing permissions and limitations // under the License. -use std::collections::VecDeque; use std::iter; use petgraph::algo; -use petgraph::graph::NodeIndex; use petgraph::prelude::*; use petgraph::Undirected; @@ -719,9 +717,11 @@ pub fn directed_binomial_tree_graph( /// Creates a full r-ary tree of `n` nodes. /// Sometimes called a k-ary, n-ary, or m-ary tree. /// -/// :param int order: Order of the tree. +/// :param int branching factor: The number of children at each node. +/// :param int num_nodes: The number of nodes in the graph. /// :param list weights: A list of node weights. If the number of weights is -/// less than n, extra nodes with with None will be appended. +/// less than num_nodes, extra nodes with with None will be appended. The +/// number of weights cannot exceed num_nodes. /// :param bool multigraph: When set to False the output /// :class:`~rustworkx.PyGraph` object will not be not be a multigraph and /// won't allow parallel edges to be added. Instead @@ -743,50 +743,26 @@ pub fn directed_binomial_tree_graph( #[pyo3(text_signature = "(branching_factor, num_nodes, /, weights=None, multigraph=True)")] pub fn full_rary_tree( py: Python, - branching_factor: u32, + branching_factor: usize, num_nodes: usize, weights: Option>, multigraph: bool, ) -> PyResult { - let mut graph = StablePyGraph::::default(); - - let nodes: Vec = match weights { - Some(weights) => { - let mut node_list: Vec = Vec::with_capacity(num_nodes); - if weights.len() > num_nodes { - return Err(PyIndexError::new_err("weights can't be greater than nodes")); - } - let node_count = num_nodes - weights.len(); - for weight in weights { - let index = graph.add_node(weight); - node_list.push(index); - } - for _ in 0..node_count { - let index = graph.add_node(py.None()); - node_list.push(index); - } - node_list + let default_fn = || py.None(); + let graph: StablePyGraph = match core_generators::full_rary_tree_graph( + branching_factor, + num_nodes, + weights, + default_fn, + default_fn, + ) { + Ok(graph) => graph, + Err(_) => { + return Err(PyIndexError::new_err( + "The number of weights cannot exceed num_nodes.", + )) } - None => (0..num_nodes).map(|_| graph.add_node(py.None())).collect(), }; - - if num_nodes > 0 { - let mut parents = VecDeque::from(vec![nodes[0].index()]); - let mut nod_it: usize = 1; - - while !parents.is_empty() { - let source: usize = parents.pop_front().unwrap(); //If is empty it will never try to pop - for _ in 0..branching_factor { - if nod_it < num_nodes { - let target: usize = nodes[nod_it].index(); - parents.push_back(target); - nod_it += 1; - graph.add_edge(nodes[source], nodes[target], py.None()); - } - } - } - } - Ok(graph::PyGraph { graph, node_removed: false,