Skip to content

Commit

Permalink
Add random_layout function and Pos2DMapping
Browse files Browse the repository at this point in the history
This commit adds a new function random_layout() which is used to
generate a random layout for a graph that can be used in visualization.
This is necessary for building a matplotlib drawer (issue Qiskit#298 and a
first draft of the implementation Qiskit#304). To make the function more
efficient it also adds a new custom return type Pos2DMapping which is
used to build an imutable readonly dict compatible result container for
the output type from this function.

Related to Qiskit#280
  • Loading branch information
mtreinish committed Apr 7, 2021
1 parent 77ebc40 commit 1375aa2
Show file tree
Hide file tree
Showing 7 changed files with 578 additions and 3 deletions.
3 changes: 3 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ Specific Graph Type Methods
retworkx.digraph_core_number
retworkx.graph_complement
retworkx.digraph_complement
retworkx.graph_random_layout
retworkx.digraph_random_layout

.. _universal-functions:

Expand Down Expand Up @@ -129,6 +131,7 @@ type functions in the algorithms API but can be run with a
retworkx.is_isomorphic_node_match
retworkx.transitivity
retworkx.core_number
retworkx.random_layout

Exceptions
----------
Expand Down
25 changes: 25 additions & 0 deletions retworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,28 @@ def _digraph_complement(graph):
@complement.register(PyGraph)
def _graph_complement(graph):
return graph_complement(graph)


@functools.singledispatch
def random_layout(graph, center=None, seed=None):
"""Generate a random layout
:param PyGraph graph: The graph to generate the layout for
:param tuple center: An optional center position. This is a 2 tuple of two
``float`` values for the center position
:param int seed: An optional seed to set for the random number generator.
:returns: The complement of the graph.
:rtype: Pos2DMapping
"""
raise TypeError("Invalid Input Type %s for graph" % type(graph))


@random_layout.register(PyDiGraph)
def _digraph_random_layout(graph, center=None, seed=None):
return digraph_random_layout(graph, center=center, seed=seed)


@random_layout.register(PyGraph)
def _graph_random_layout(graph, center=None, seed=None):
return graph_random_layout(graph, center=center, seed=seed)
242 changes: 240 additions & 2 deletions src/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ use std::collections::hash_map::DefaultHasher;
use std::convert::TryInto;
use std::hash::Hasher;

use pyo3::class::{PyObjectProtocol, PySequenceProtocol};
use pyo3::exceptions::{PyIndexError, PyNotImplementedError};
use hashbrown::HashMap;

use pyo3::class::iter::{IterNextOutput, PyIterProtocol};
use pyo3::class::{PyMappingProtocol, PyObjectProtocol, PySequenceProtocol};
use pyo3::exceptions::{PyIndexError, PyKeyError, PyNotImplementedError};
use pyo3::gc::{PyGCProtocol, PyVisit};
use pyo3::prelude::*;
use pyo3::types::PySequence;
Expand Down Expand Up @@ -552,3 +555,238 @@ impl PyGCProtocol for WeightedEdgeList {
self.edges = Vec::new();
}
}

/// A class representing a mapping of node indices to 2D positions
///
/// This class is equivalent to having a dict of the form::
///
/// {1: [0, 1], 3: [0.5, 1.2]}
///
/// It is used to efficiently represent a retworkx generated 2D layout for a
/// graph. It behaves as a drop in replacement for a readonly ``dict``.
#[pyclass(module = "retworkx", gc)]
pub struct Pos2DMapping {
pub pos_map: HashMap<usize, [f64; 2]>,
}

#[pymethods]
impl Pos2DMapping {
#[new]
fn new() -> Pos2DMapping {
Pos2DMapping {
pos_map: HashMap::new(),
}
}

fn __getstate__(&self) -> HashMap<usize, [f64; 2]> {
self.pos_map.clone()
}

fn __setstate__(&mut self, state: HashMap<usize, [f64; 2]>) {
self.pos_map = state;
}

fn keys(&self) -> Pos2DMappingKeys {
Pos2DMappingKeys {
pos_keys: self.pos_map.keys().copied().collect(),
iter_pos: 0,
}
}

fn values(&self) -> Pos2DMappingValues {
Pos2DMappingValues {
pos_values: self.pos_map.values().copied().collect(),
iter_pos: 0,
}
}

fn items(&self) -> Pos2DMappingItems {
let items: Vec<(usize, [f64; 2])> =
self.pos_map.iter().map(|(k, v)| (*k, *v)).collect();
Pos2DMappingItems {
pos_items: items,
iter_pos: 0,
}
}
}

#[pyproto]
impl<'p> PyObjectProtocol<'p> for Pos2DMapping {
fn __richcmp__(
&self,
other: PyObject,
op: pyo3::basic::CompareOp,
) -> PyResult<bool> {
let compare = |other: PyObject| -> PyResult<bool> {
let gil = Python::acquire_gil();
let py = gil.python();
let other_ref = other.as_ref(py);
if other_ref.len()? != self.pos_map.len() {
return Ok(false);
}
for (key, value) in &self.pos_map {
match other_ref.get_item(key) {
Ok(other_raw) => {
let other_value: [f64; 2] = other_raw.extract()?;
if other_value != *value {
return Ok(false);
}
}
Err(ref err)
if Python::with_gil(|py| {
err.is_instance::<PyKeyError>(py)
}) =>
{
return Ok(false);
}
Err(err) => return Err(err),
}
}
Ok(true)
};
match op {
pyo3::basic::CompareOp::Eq => compare(other),
pyo3::basic::CompareOp::Ne => match compare(other) {
Ok(res) => Ok(!res),
Err(err) => Err(err),
},
_ => Err(PyNotImplementedError::new_err(
"Comparison not implemented",
)),
}
}

fn __str__(&self) -> PyResult<String> {
let mut str_vec: Vec<String> = Vec::with_capacity(self.pos_map.len());
for path in &self.pos_map {
str_vec.push(format!("{}: ({}, {})", path.0, path.1[0], path.1[1]));
}
Ok(format!("Pos2DMapping{{{}}}", str_vec.join(", ")))
}

fn __hash__(&self) -> PyResult<u64> {
let mut hasher = DefaultHasher::new();
for index in &self.pos_map {
hasher.write_usize(*index.0);
hasher.write(&index.1[0].to_be_bytes());
hasher.write(&index.1[1].to_be_bytes());
}
Ok(hasher.finish())
}
}

#[pyproto]
impl PySequenceProtocol for Pos2DMapping {
fn __len__(&self) -> PyResult<usize> {
Ok(self.pos_map.len())
}

fn __contains__(&self, index: usize) -> PyResult<bool> {
Ok(self.pos_map.contains_key(&index))
}
}

#[pyproto]
impl PyMappingProtocol for Pos2DMapping {
/// Return the number of nodes in the graph
fn __len__(&self) -> PyResult<usize> {
Ok(self.pos_map.len())
}
fn __getitem__(&'p self, idx: usize) -> PyResult<[f64; 2]> {
match self.pos_map.get(&idx) {
Some(data) => Ok(*data),
None => Err(PyIndexError::new_err("No node found for index")),
}
}
}

#[pyproto]
impl PyIterProtocol for Pos2DMapping {
fn __iter__(slf: PyRef<Self>) -> Pos2DMappingKeys {
Pos2DMappingKeys {
pos_keys: slf.pos_map.keys().copied().collect(),
iter_pos: 0,
}
}
}

#[pyproto]
impl PyGCProtocol for Pos2DMapping {
fn __traverse__(&self, _visit: PyVisit) -> Result<(), PyTraverseError> {
Ok(())
}

fn __clear__(&mut self) {}
}

#[pyclass(module = "retworkx")]
pub struct Pos2DMappingKeys {
pub pos_keys: Vec<usize>,
iter_pos: usize,
}

#[pyproto]
impl PyIterProtocol for Pos2DMappingKeys {
fn __iter__(slf: PyRef<Self>) -> Py<Pos2DMappingKeys> {
slf.into()
}
fn __next__(
mut slf: PyRefMut<Self>,
) -> IterNextOutput<usize, &'static str> {
if slf.iter_pos < slf.pos_keys.len() {
let res = IterNextOutput::Yield(slf.pos_keys[slf.iter_pos]);
slf.iter_pos += 1;
res
} else {
IterNextOutput::Return("Ended")
}
}
}

#[pyclass(module = "retworkx")]
pub struct Pos2DMappingValues {
pub pos_values: Vec<[f64; 2]>,
iter_pos: usize,
}

#[pyproto]
impl PyIterProtocol for Pos2DMappingValues {
fn __iter__(slf: PyRef<Self>) -> Py<Pos2DMappingValues> {
slf.into()
}
fn __next__(
mut slf: PyRefMut<Self>,
) -> IterNextOutput<[f64; 2], &'static str> {
if slf.iter_pos < slf.pos_values.len() {
let res = IterNextOutput::Yield(slf.pos_values[slf.iter_pos]);
slf.iter_pos += 1;
res
} else {
IterNextOutput::Return("Ended")
}
}
}

#[pyclass(module = "retworkx")]
pub struct Pos2DMappingItems {
pub pos_items: Vec<(usize, [f64; 2])>,
iter_pos: usize,
}

#[pyproto]
impl PyIterProtocol for Pos2DMappingItems {
fn __iter__(slf: PyRef<Self>) -> Py<Pos2DMappingItems> {
slf.into()
}
fn __next__(
mut slf: PyRefMut<Self>,
) -> IterNextOutput<(usize, [f64; 2]), &'static str> {
if slf.iter_pos < slf.pos_items.len() {
let res = IterNextOutput::Yield(slf.pos_items[slf.iter_pos]);
slf.iter_pos += 1;
res
} else {
IterNextOutput::Return("Ended")
}
}
}
Loading

0 comments on commit 1375aa2

Please sign in to comment.