diff --git a/Cargo.lock b/Cargo.lock
index 959c32034f..25806f0540 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -633,6 +633,7 @@ dependencies = [
"fixedbitset",
"hashbrown 0.14.5",
"indexmap 2.2.6",
+ "ndarray",
"num-traits",
"petgraph",
"priority-queue",
diff --git a/Cargo.toml b/Cargo.toml
index 1e50fea81c..c3bb086937 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -28,6 +28,7 @@ ahash = "0.8.6"
fixedbitset = "0.4.2"
hashbrown = { version = ">=0.13, <0.15", features = ["rayon"] }
indexmap = { version = ">=1.9, <3", features = ["rayon"] }
+ndarray = { version = "0.15.6", features = ["rayon"] }
num-traits = "0.2"
numpy = "0.21.0"
petgraph = "0.6.5"
@@ -44,6 +45,7 @@ ahash.workspace = true
fixedbitset.workspace = true
hashbrown.workspace = true
indexmap.workspace = true
+ndarray.workspace = true
ndarray-stats = "0.5.1"
num-bigint = "0.4"
num-complex = "0.4"
@@ -63,10 +65,6 @@ rustworkx-core = { path = "rustworkx-core", version = "=0.15.0" }
version = "0.21.2"
features = ["abi3-py38", "extension-module", "hashbrown", "num-bigint", "num-complex", "indexmap"]
-[dependencies.ndarray]
-version = "^0.15.6"
-features = ["rayon"]
-
[dependencies.sprs]
version = "^0.11"
features = ["multi_thread"]
diff --git a/docs/source/api/random_graph_generator_functions.rst b/docs/source/api/random_graph_generator_functions.rst
index 4bc52096ec..4c0a33c5f5 100644
--- a/docs/source/api/random_graph_generator_functions.rst
+++ b/docs/source/api/random_graph_generator_functions.rst
@@ -10,6 +10,8 @@ Random Graph Generator Functions
rustworkx.undirected_gnp_random_graph
rustworkx.directed_gnm_random_graph
rustworkx.undirected_gnm_random_graph
+ rustworkx.directed_sbm_random_graph
+ rustworkx.undirected_sbm_random_graph
rustworkx.random_geometric_graph
rustworkx.hyperbolic_random_graph
rustworkx.barabasi_albert_graph
diff --git a/releasenotes/notes/sbm-random-graph-bf7ccd8e938f4218.yaml b/releasenotes/notes/sbm-random-graph-bf7ccd8e938f4218.yaml
new file mode 100644
index 0000000000..8ec9490a47
--- /dev/null
+++ b/releasenotes/notes/sbm-random-graph-bf7ccd8e938f4218.yaml
@@ -0,0 +1,9 @@
+features:
+ - |
+ Adds new random graph generator in rustworkx for the stochastic block model.
+ There is a generator for directed :func:`.directed_sbm_random_graph` and
+ undirected graphs :func:`.undirected_sbm_random_graph`.
+ - |
+ Adds new function ``sbm_random_graph`` to the rustworkx-core module
+ ``rustworkx_core::generators`` that samples a graph from the stochastic
+ block model.
diff --git a/rustworkx-core/Cargo.toml b/rustworkx-core/Cargo.toml
index 781a9fbf5f..c8d292627f 100644
--- a/rustworkx-core/Cargo.toml
+++ b/rustworkx-core/Cargo.toml
@@ -16,6 +16,7 @@ ahash.workspace = true
fixedbitset.workspace = true
hashbrown.workspace = true
indexmap.workspace = true
+ndarray.workspace = true
num-traits.workspace = true
petgraph.workspace = true
priority-queue = "2.0"
diff --git a/rustworkx-core/src/generators/mod.rs b/rustworkx-core/src/generators/mod.rs
index 04d672749f..3034baab13 100644
--- a/rustworkx-core/src/generators/mod.rs
+++ b/rustworkx-core/src/generators/mod.rs
@@ -62,4 +62,5 @@ pub use random_graph::gnp_random_graph;
pub use random_graph::hyperbolic_random_graph;
pub use random_graph::random_bipartite_graph;
pub use random_graph::random_geometric_graph;
+pub use random_graph::sbm_random_graph;
pub use star_graph::star_graph;
diff --git a/rustworkx-core/src/generators/random_graph.rs b/rustworkx-core/src/generators/random_graph.rs
index 1768619d53..edea398fb2 100644
--- a/rustworkx-core/src/generators/random_graph.rs
+++ b/rustworkx-core/src/generators/random_graph.rs
@@ -14,6 +14,7 @@
use std::hash::Hash;
+use ndarray::ArrayView2;
use petgraph::data::{Build, Create};
use petgraph::visit::{
Data, EdgeRef, GraphBase, GraphProp, IntoEdgeReferences, IntoEdgesDirected,
@@ -305,6 +306,131 @@ where
Ok(graph)
}
+/// Generate a graph from the stochastic block model.
+///
+/// The stochastic block model is a generalization of the Gnp random graph
+/// (see [gnp_random_graph] ). The connection probability of
+/// nodes `u` and `v` depends on their block and is given by
+/// `probabilities[blocks[u]][blocks[v]]`, where `blocks[u]` is the block membership
+/// of vertex `u`. The number of nodes and the number of blocks are inferred from
+/// `sizes`.
+///
+/// Arguments:
+///
+/// * `sizes` - Number of nodes in each block.
+/// * `probabilities` - B x B array that contains the connection probability between
+/// nodes of different blocks. Must be symmetric for undirected graphs.
+/// * `loops` - Determines whether the graph can have loops or not.
+/// * `seed` - An optional seed to use for the random number generator.
+/// * `default_node_weight` - A callable that will return the weight to use
+/// for newly created nodes.
+/// * `default_edge_weight` - A callable that will return the weight object
+/// to use for newly created edges.
+///
+/// # Example
+/// ```rust
+/// use ndarray::arr2;
+/// use rustworkx_core::petgraph;
+/// use rustworkx_core::generators::sbm_random_graph;
+///
+/// let g = sbm_random_graph::, (), _, _, ()>(
+/// &vec![1, 2],
+/// &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
+/// true,
+/// Some(10),
+/// || (),
+/// || (),
+/// )
+/// .unwrap();
+/// assert_eq!(g.node_count(), 3);
+/// assert_eq!(g.edge_count(), 6);
+/// ```
+pub fn sbm_random_graph(
+ sizes: &[usize],
+ probabilities: &ndarray::ArrayView2,
+ loops: bool,
+ seed: Option,
+ mut default_node_weight: F,
+ mut default_edge_weight: H,
+) -> Result
+where
+ G: Build + Create + Data + NodeIndexable + GraphProp,
+ F: FnMut() -> T,
+ H: FnMut() -> M,
+ G::NodeId: Eq + Hash,
+{
+ let num_nodes: usize = sizes.iter().sum();
+ if num_nodes == 0 {
+ return Err(InvalidInputError {});
+ }
+ let num_communities = sizes.len();
+ if probabilities.nrows() != num_communities
+ || probabilities.ncols() != num_communities
+ || probabilities.iter().any(|&x| !(0. ..=1.).contains(&x))
+ {
+ return Err(InvalidInputError {});
+ }
+
+ let mut graph = G::with_capacity(num_nodes, num_nodes);
+ let directed = graph.is_directed();
+ if !directed && !symmetric_array(probabilities) {
+ return Err(InvalidInputError {});
+ }
+
+ for _ in 0..num_nodes {
+ graph.add_node(default_node_weight());
+ }
+ let mut rng: Pcg64 = match seed {
+ Some(seed) => Pcg64::seed_from_u64(seed),
+ None => Pcg64::from_entropy(),
+ };
+ let mut blocks = Vec::new();
+ {
+ let mut block = 0;
+ let mut vertices_left = sizes[0];
+ for _ in 0..num_nodes {
+ while vertices_left == 0 {
+ block += 1;
+ vertices_left = sizes[block];
+ }
+ blocks.push(block);
+ vertices_left -= 1;
+ }
+ }
+
+ let between = Uniform::new(0.0, 1.0);
+ for v in 0..(if directed || loops {
+ num_nodes
+ } else {
+ num_nodes - 1
+ }) {
+ for w in ((if directed { 0 } else { v })..num_nodes).filter(|&w| w != v || loops) {
+ if &between.sample(&mut rng)
+ < probabilities.get((blocks[v], blocks[w])).unwrap_or(&0_f64)
+ {
+ graph.add_edge(
+ graph.from_index(v),
+ graph.from_index(w),
+ default_edge_weight(),
+ );
+ }
+ }
+ }
+ Ok(graph)
+}
+
+fn symmetric_array(mat: &ArrayView2) -> bool {
+ let n = mat.nrows();
+ for (i, row) in mat.rows().into_iter().enumerate().take(n - 1) {
+ for (j, m_ij) in row.iter().enumerate().skip(i + 1) {
+ if m_ij != mat.get((j, i)).unwrap() {
+ return false;
+ }
+ }
+ }
+ true
+}
+
#[inline]
fn pnorm(x: f64, p: f64) -> f64 {
if p == 1.0 || p == std::f64::INFINITY {
@@ -749,7 +875,7 @@ mod tests {
use crate::generators::InvalidInputError;
use crate::generators::{
barabasi_albert_graph, gnm_random_graph, gnp_random_graph, hyperbolic_random_graph,
- path_graph, random_bipartite_graph, random_geometric_graph,
+ path_graph, random_bipartite_graph, random_geometric_graph, sbm_random_graph,
};
use crate::petgraph;
@@ -879,6 +1005,165 @@ mod tests {
};
}
+ // Test sbm_random_graph
+ #[test]
+ fn test_sbm_directed_complete_blocks_loops() {
+ let g = sbm_random_graph::, (), _, _, ()>(
+ &vec![1, 2],
+ &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
+ true,
+ Some(10),
+ || (),
+ || (),
+ )
+ .unwrap();
+ assert_eq!(g.node_count(), 3);
+ assert_eq!(g.edge_count(), 6);
+ for (u, v) in [(1, 1), (1, 2), (2, 1), (2, 2), (0, 1), (0, 2)] {
+ assert_eq!(g.contains_edge(u.into(), v.into()), true);
+ }
+ assert_eq!(g.contains_edge(1.into(), 0.into()), false);
+ assert_eq!(g.contains_edge(2.into(), 0.into()), false);
+ }
+
+ #[test]
+ fn test_sbm_undirected_complete_blocks_loops() {
+ let g = sbm_random_graph::, (), _, _, ()>(
+ &vec![1, 2],
+ &ndarray::arr2(&[[0., 1.], [1., 1.]]).view(),
+ true,
+ Some(10),
+ || (),
+ || (),
+ )
+ .unwrap();
+ assert_eq!(g.node_count(), 3);
+ assert_eq!(g.edge_count(), 5);
+ for (u, v) in [(1, 1), (1, 2), (2, 2), (0, 1), (0, 2)] {
+ assert_eq!(g.contains_edge(u.into(), v.into()), true);
+ }
+ assert_eq!(g.contains_edge(0.into(), 0.into()), false);
+ }
+
+ #[test]
+ fn test_sbm_directed_complete_blocks_noloops() {
+ let g = sbm_random_graph::, (), _, _, ()>(
+ &vec![1, 2],
+ &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
+ false,
+ Some(10),
+ || (),
+ || (),
+ )
+ .unwrap();
+ assert_eq!(g.node_count(), 3);
+ assert_eq!(g.edge_count(), 4);
+ for (u, v) in [(1, 2), (2, 1), (0, 1), (0, 2)] {
+ assert_eq!(g.contains_edge(u.into(), v.into()), true);
+ }
+ assert_eq!(g.contains_edge(1.into(), 0.into()), false);
+ assert_eq!(g.contains_edge(2.into(), 0.into()), false);
+ for u in 0..2 {
+ assert_eq!(g.contains_edge(u.into(), u.into()), false);
+ }
+ }
+
+ #[test]
+ fn test_sbm_undirected_complete_blocks_noloops() {
+ let g = sbm_random_graph::, (), _, _, ()>(
+ &vec![1, 2],
+ &ndarray::arr2(&[[0., 1.], [1., 1.]]).view(),
+ false,
+ Some(10),
+ || (),
+ || (),
+ )
+ .unwrap();
+ assert_eq!(g.node_count(), 3);
+ assert_eq!(g.edge_count(), 3);
+ for (u, v) in [(1, 2), (0, 1), (0, 2)] {
+ assert_eq!(g.contains_edge(u.into(), v.into()), true);
+ }
+ for u in 0..2 {
+ assert_eq!(g.contains_edge(u.into(), u.into()), false);
+ }
+ }
+
+ #[test]
+ fn test_sbm_bad_array_rows_error() {
+ match sbm_random_graph::, (), _, _, ()>(
+ &vec![1, 2],
+ &ndarray::arr2(&[[0., 1.], [1., 1.], [1., 1.]]).view(),
+ true,
+ Some(10),
+ || (),
+ || (),
+ ) {
+ Ok(_) => panic!("Returned a non-error"),
+ Err(e) => assert_eq!(e, InvalidInputError),
+ };
+ }
+ #[test]
+
+ fn test_sbm_bad_array_cols_error() {
+ match sbm_random_graph::, (), _, _, ()>(
+ &vec![1, 2],
+ &ndarray::arr2(&[[0., 1., 1.], [1., 1., 1.]]).view(),
+ true,
+ Some(10),
+ || (),
+ || (),
+ ) {
+ Ok(_) => panic!("Returned a non-error"),
+ Err(e) => assert_eq!(e, InvalidInputError),
+ };
+ }
+
+ #[test]
+ fn test_sbm_asymmetric_array_error() {
+ match sbm_random_graph::, (), _, _, ()>(
+ &vec![1, 2],
+ &ndarray::arr2(&[[0., 1.], [0., 1.]]).view(),
+ true,
+ Some(10),
+ || (),
+ || (),
+ ) {
+ Ok(_) => panic!("Returned a non-error"),
+ Err(e) => assert_eq!(e, InvalidInputError),
+ };
+ }
+
+ #[test]
+ fn test_sbm_invalid_probability_error() {
+ match sbm_random_graph::, (), _, _, ()>(
+ &vec![1, 2],
+ &ndarray::arr2(&[[0., 1.], [0., -1.]]).view(),
+ true,
+ Some(10),
+ || (),
+ || (),
+ ) {
+ Ok(_) => panic!("Returned a non-error"),
+ Err(e) => assert_eq!(e, InvalidInputError),
+ };
+ }
+
+ #[test]
+ fn test_sbm_empty_error() {
+ match sbm_random_graph::, (), _, _, ()>(
+ &vec![],
+ &ndarray::arr2(&[[]]).view(),
+ true,
+ Some(10),
+ || (),
+ || (),
+ ) {
+ Ok(_) => panic!("Returned a non-error"),
+ Err(e) => assert_eq!(e, InvalidInputError),
+ };
+ }
+
// Test random_geometric_graph
#[test]
diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi
index f253879754..f499b98130 100644
--- a/rustworkx/__init__.pyi
+++ b/rustworkx/__init__.pyi
@@ -127,6 +127,8 @@ from .rustworkx import directed_gnm_random_graph as directed_gnm_random_graph
from .rustworkx import undirected_gnm_random_graph as undirected_gnm_random_graph
from .rustworkx import directed_gnp_random_graph as directed_gnp_random_graph
from .rustworkx import undirected_gnp_random_graph as undirected_gnp_random_graph
+from .rustworkx import directed_sbm_random_graph as directed_sbm_random_graph
+from .rustworkx import undirected_sbm_random_graph as undirected_sbm_random_graph
from .rustworkx import random_geometric_graph as random_geometric_graph
from .rustworkx import hyperbolic_random_graph as hyperbolic_random_graph
from .rustworkx import barabasi_albert_graph as barabasi_albert_graph
diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi
index bebe0520ee..5229629944 100644
--- a/rustworkx/rustworkx.pyi
+++ b/rustworkx/rustworkx.pyi
@@ -549,6 +549,20 @@ def undirected_gnp_random_graph(
/,
seed: int | None = ...,
) -> PyGraph: ...
+def directed_sbm_random_graph(
+ sizes: list[int],
+ probabilities: np.ndarray,
+ loops: bool,
+ /,
+ seed: int | None = ...,
+) -> PyDiGraph: ...
+def undirected_sbm_random_graph(
+ sizes: list[int],
+ probabilities: np.ndarray,
+ loops: bool,
+ /,
+ seed: int | None = ...,
+) -> PyGraph: ...
def random_geometric_graph(
num_nodes: int,
radius: float,
diff --git a/src/lib.rs b/src/lib.rs
index ce0843b8ea..164b713c5b 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -520,6 +520,8 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(undirected_gnp_random_graph))?;
m.add_wrapped(wrap_pyfunction!(directed_gnm_random_graph))?;
m.add_wrapped(wrap_pyfunction!(undirected_gnm_random_graph))?;
+ m.add_wrapped(wrap_pyfunction!(undirected_sbm_random_graph))?;
+ m.add_wrapped(wrap_pyfunction!(directed_sbm_random_graph))?;
m.add_wrapped(wrap_pyfunction!(random_geometric_graph))?;
m.add_wrapped(wrap_pyfunction!(hyperbolic_random_graph))?;
m.add_wrapped(wrap_pyfunction!(barabasi_albert_graph))?;
diff --git a/src/random_graph.rs b/src/random_graph.rs
index f0e6ee679b..8360c0ff96 100644
--- a/src/random_graph.rs
+++ b/src/random_graph.rs
@@ -23,6 +23,8 @@ use petgraph::algo;
use petgraph::graph::NodeIndex;
use petgraph::prelude::*;
+use numpy::PyReadonlyArray2;
+
use rand::distributions::{Distribution, Uniform};
use rand::prelude::*;
use rand_pcg::Pcg64;
@@ -273,6 +275,116 @@ pub fn undirected_gnm_random_graph(
})
}
+/// Return a directed graph from the stochastic block model.
+///
+/// The stochastic block model is a generalization of the :math:`G(n,p)` random graph
+/// (see :func:`~rustworkx.directed_gnp_random_graph`). The connection probability of
+/// nodes ``u`` and ``v`` depends on their block (or community) and is given by
+/// ``probabilities[blocks[u]][blocks[v]]``, where ``blocks[u]`` is the block
+/// membership of node ``u``. The number of nodes and the number of blocks are
+/// inferred from ``sizes``.
+///
+/// This algorithm has a time complexity of :math:`O(n^2)` for :math:`n` nodes.
+///
+/// Arguments:
+///
+/// :param list[int] sizes: Number of nodes in each block.
+/// :param np.ndarray probabilities: B x B array that contains the connection
+/// probability between nodes of different blocks.
+/// :param bool loops: Determines whether the graph can have loops or not.
+/// :param int seed: An optional seed to use for the random number generator.
+///
+/// :return: A PyDiGraph object
+/// :rtype: PyDiGraph
+#[pyfunction]
+#[pyo3(text_signature = "(sizes, probabilities, loops, /, seed=None)")]
+pub fn directed_sbm_random_graph<'p>(
+ py: Python<'p>,
+ sizes: Vec,
+ probabilities: PyReadonlyArray2<'p, f64>,
+ loops: bool,
+ seed: Option,
+) -> PyResult {
+ let default_fn = || py.None();
+ let graph: StablePyGraph = match core_generators::sbm_random_graph(
+ &sizes,
+ &probabilities.as_array(),
+ loops,
+ seed,
+ default_fn,
+ default_fn,
+ ) {
+ Ok(graph) => graph,
+ Err(_) => {
+ return Err(PyValueError::new_err(
+ "invalid blocks or probabilities input",
+ ))
+ }
+ };
+ Ok(digraph::PyDiGraph {
+ graph,
+ node_removed: false,
+ check_cycle: false,
+ cycle_state: algo::DfsSpace::default(),
+ multigraph: false,
+ attrs: py.None(),
+ })
+}
+
+/// Return an undirected graph from the stochastic block model.
+///
+/// The stochastic block model is a generalization of the :math:`G(n,p)` random graph
+/// (see :func:`~rustworkx.undirected_gnp_random_graph`). The connection probability of
+/// nodes ``u`` and ``v`` depends on their block (or community) and is given by
+/// ``probabilities[blocks[u]][blocks[v]]``, where ``blocks[u]`` is the block membership
+/// of node ``u``. The number of nodes and the number of blocks are inferred from
+/// ``sizes``.
+///
+/// This algorithm has a time complexity of :math:`O(n^2)` for :math:`n` nodes.
+///
+/// Arguments:
+///
+/// :param list[int] sizes: Number of nodes in each block.
+/// :param np.ndarray probabilities: Symmetric B x B array that contains the
+/// connection probability between nodes of different blocks.
+/// :param bool loops: Determines whether the graph can have loops or not.
+/// :param int seed: An optional seed to use for the random number generator.
+///
+/// :return: A PyGraph object
+/// :rtype: PyGraph
+#[pyfunction]
+#[pyo3(text_signature = "(sizes, probabilities, loops, /, seed=None)")]
+pub fn undirected_sbm_random_graph<'p>(
+ py: Python<'p>,
+ sizes: Vec,
+ probabilities: PyReadonlyArray2<'p, f64>,
+ loops: bool,
+ seed: Option,
+) -> PyResult {
+ let default_fn = || py.None();
+ let graph: StablePyGraph = match core_generators::sbm_random_graph(
+ &sizes,
+ &probabilities.as_array(),
+ loops,
+ seed,
+ default_fn,
+ default_fn,
+ ) {
+ Ok(graph) => graph,
+ Err(_) => {
+ return Err(PyValueError::new_err(
+ "invalid blocks or probabilities input",
+ ))
+ }
+ };
+ Ok(graph::PyGraph {
+ graph,
+ node_removed: false,
+ multigraph: false,
+ attrs: py.None(),
+ })
+}
+
#[inline]
fn pnorm(x: f64, p: f64) -> f64 {
if p == 1.0 || p == std::f64::INFINITY {
diff --git a/tests/test_random.py b/tests/test_random.py
index 02cfcd36aa..74f7668bbf 100644
--- a/tests/test_random.py
+++ b/tests/test_random.py
@@ -14,6 +14,7 @@
import random
import math
+import numpy as np
import rustworkx
@@ -177,6 +178,87 @@ def test_random_gnm_undirected_payload(self):
self.assertEqual(graph.nodes(), [0, 1, 2])
+class TestRandomSBM(unittest.TestCase):
+ def test_undirected_sbm_complete_blocks_loops(self):
+ graph = rustworkx.undirected_sbm_random_graph(
+ [2, 1], np.array([[1, 1], [1, 0]], dtype=float), True
+ )
+ self.assertEqual(len(graph), 3)
+ self.assertEqual(len(graph.edges()), 5)
+ for i in range(2):
+ for j in range(i, 2):
+ if (i, j) != (2, 2):
+ self.assertTrue(graph.has_edge(i, j))
+ self.assertFalse(graph.has_edge(2, 2))
+
+ def test_directed_sbm_complete_blocks_loops(self):
+ graph = rustworkx.directed_sbm_random_graph(
+ [2, 1], np.array([[0, 0], [1, 1]], dtype=float), True
+ )
+ self.assertEqual(len(graph), 3)
+ self.assertEqual(len(graph.edges()), 3)
+ self.assertEqual(set(graph.edge_list()), set([(2, 2), (2, 0), (2, 1)]))
+
+ def test_undirected_sbm_complete_blocks_noloops(self):
+ graph = rustworkx.undirected_sbm_random_graph(
+ [2, 1], np.array([[1, 1], [1, 0]], dtype=float), False
+ )
+ self.assertEqual(len(graph), 3)
+ self.assertEqual(len(graph.edges()), 3)
+ for i in range(2):
+ for j in range(i, 2):
+ if i != j:
+ self.assertTrue(graph.has_edge(i, j))
+
+ def test_directed_sbm_complete_blocks_noloops(self):
+ graph = rustworkx.directed_sbm_random_graph(
+ [2, 1], np.array([[0, 0], [1, 1]], dtype=float), False
+ )
+ self.assertEqual(len(graph), 3)
+ self.assertEqual(len(graph.edges()), 2)
+ self.assertEqual(set(graph.edge_list()), set([(2, 0), (2, 1)]))
+
+ def test_undirected_sbm_asymmetric_probabilities_error(self):
+ with self.assertRaises(ValueError):
+ rustworkx.undirected_sbm_random_graph(
+ [2, 1], np.array([[0, 0], [1, 1]], dtype=float), True
+ )
+
+ def test_sbm_invalid_matrix_dim(self):
+ with self.assertRaises(ValueError):
+ rustworkx.undirected_sbm_random_graph(
+ [2, 1], np.array([[1, 0], [0, 1], [0, 1]], dtype=float), True
+ )
+ with self.assertRaises(ValueError):
+ rustworkx.directed_sbm_random_graph(
+ [2, 1], np.array([[1, 0, 1], [0, 1, 0]], dtype=float), True
+ )
+
+ def test_sbm_invalid_probabilities(self):
+ with self.assertRaises(ValueError):
+ rustworkx.undirected_sbm_random_graph(
+ [2, 1], np.array([[1, 0], [0, 1.5]], dtype=float), True
+ )
+ with self.assertRaises(ValueError):
+ rustworkx.undirected_sbm_random_graph(
+ [2, 1], np.array([[-1, 0], [0, 1]], dtype=float), True
+ )
+ with self.assertRaises(ValueError):
+ rustworkx.directed_sbm_random_graph(
+ [2, 1], np.array([[1, 0], [0, 1.5]], dtype=float), True
+ )
+ with self.assertRaises(ValueError):
+ rustworkx.directed_sbm_random_graph(
+ [2, 1], np.array([[-1, 0], [0, 1]], dtype=float), True
+ )
+
+ def test_sbm_empty(self):
+ with self.assertRaises(ValueError):
+ rustworkx.undirected_sbm_random_graph([], np.array([[]]), True)
+ with self.assertRaises(ValueError):
+ rustworkx.directed_sbm_random_graph([], np.array([[]]), True)
+
+
class TestGeometricRandomGraph(unittest.TestCase):
def test_random_geometric_empty(self):
graph = rustworkx.random_geometric_graph(20, 0)