From cb0fb2f1bf4871c428f72eb2f5025b36dddd4ff7 Mon Sep 17 00:00:00 2001 From: Alexander Condello Date: Wed, 29 Jan 2025 17:41:08 -0800 Subject: [PATCH] Add Copy symbol Closes https://github.com/dwavesystems/dwave-optimization/issues/16 --- docs/reference/symbols.rst | 2 + dwave/optimization/_model.pyi | 1 + dwave/optimization/_model.pyx | 11 ++ .../dwave-optimization/nodes/manipulation.hpp | 49 +++++++++ dwave/optimization/libcpp/nodes.pxd | 3 + dwave/optimization/src/nodes/_state.hpp | 8 ++ dwave/optimization/src/nodes/manipulation.cpp | 44 ++++++++ dwave/optimization/symbols.pyi | 4 + dwave/optimization/symbols.pyx | 35 ++++++ ...-ArraySymbol-flatten-a499f7edf2e28185.yaml | 3 + tests/cpp/nodes/test_manipulation.cpp | 104 ++++++++++++++++++ tests/test_symbols.py | 19 ++++ 12 files changed, 283 insertions(+) diff --git a/docs/reference/symbols.rst b/docs/reference/symbols.rst index 4613df8f..f1dc4efc 100644 --- a/docs/reference/symbols.rst +++ b/docs/reference/symbols.rst @@ -47,6 +47,8 @@ are inherited by the :ref:`model symbols `. ~ArraySymbol.all ~ArraySymbol.any + ~ArraySymbol.copy + ~ArraySymbol.flatten ~ArraySymbol.has_state ~ArraySymbol.max ~ArraySymbol.maybe_equals diff --git a/dwave/optimization/_model.pyi b/dwave/optimization/_model.pyi index bc48f4cb..bb0c4bf3 100644 --- a/dwave/optimization/_model.pyi +++ b/dwave/optimization/_model.pyi @@ -103,6 +103,7 @@ class ArraySymbol(Symbol): def __truediv__(self, rhs: ArraySymbol) -> Divide: ... def all(self) -> All: ... def any(self) -> Any: ... + def copy(self) -> Copy: ... def flatten(self) -> Reshape: ... def max(self) -> Max: ... def min(self) -> Min: ... diff --git a/dwave/optimization/_model.pyx b/dwave/optimization/_model.pyx index ca6437df..849d6f48 100644 --- a/dwave/optimization/_model.pyx +++ b/dwave/optimization/_model.pyx @@ -1280,6 +1280,17 @@ cdef class ArraySymbol(Symbol): from dwave.optimization.symbols import Any # avoid circular import return Any(self) + def copy(self): + """Return an array symbol that is a copy of the array. + + See Also: + :class:`~dwave.optimization.symbols.Copy` Equivalent class. + + .. versionadded:: 0.5.1 + """ + from dwave.optimization.symbols import Copy # avoid circular import + return Copy(self) + def flatten(self): """Return an array symbol collapsed into one dimension. diff --git a/dwave/optimization/include/dwave-optimization/nodes/manipulation.hpp b/dwave/optimization/include/dwave-optimization/nodes/manipulation.hpp index 2b802eaf..5e69c758 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/manipulation.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/manipulation.hpp @@ -46,6 +46,55 @@ class ConcatenateNode : public ArrayOutputMixin { std::vector array_starts_; }; +/// An array node that is a contiguous copy of its predecessor. +class CopyNode : public ArrayOutputMixin { + public: + explicit CopyNode(ArrayNode* array_ptr); + + /// @copydoc Array::buff() + double const* buff(const State& state) const override; + + /// @copydoc Node::commit() + void commit(State& state) const override; + + /// @copydoc Array::diff() + std::span diff(const State& state) const override; + + /// @copydoc Node::initialize_state() + void initialize_state(State& state) const override; + + /// @copydoc Array::integral() + bool integral() const override; + + /// @copydoc Array::max() + double max() const override; + + /// @copydoc Array::min() + double min() const override; + + /// @copydoc Node::propagate() + void propagate(State& state) const override; + + /// @copydoc Node::revert() + void revert(State& state) const override; + + using ArrayOutputMixin::shape; + + /// @copydoc Array::shape() + std::span shape(const State& state) const override; + + using ArrayOutputMixin::size; + + /// @copydoc Array::size() + ssize_t size(const State& state) const override; + + /// @copydoc Array::size_diff() + ssize_t size_diff(const State& state) const override; + + private: + const Array* array_ptr_; +}; + /// Replaces specified elements of an array with the given values. /// /// The indexing works on the flattened array. Translated to NumPy, PutNode is diff --git a/dwave/optimization/libcpp/nodes.pxd b/dwave/optimization/libcpp/nodes.pxd index 41e9e0c9..eaa977e0 100644 --- a/dwave/optimization/libcpp/nodes.pxd +++ b/dwave/optimization/libcpp/nodes.pxd @@ -78,6 +78,9 @@ cdef extern from "dwave-optimization/nodes/manipulation.hpp" namespace "dwave::o cdef cppclass ConcatenateNode(ArrayNode): Py_ssize_t axis() + cdef cppclass CopyNode(ArrayNode): + pass + cdef cppclass PutNode(ArrayNode): pass diff --git a/dwave/optimization/src/nodes/_state.hpp b/dwave/optimization/src/nodes/_state.hpp index fa0e6386..4ea5eef3 100644 --- a/dwave/optimization/src/nodes/_state.hpp +++ b/dwave/optimization/src/nodes/_state.hpp @@ -35,6 +35,10 @@ class ArrayStateData { explicit ArrayStateData(std::vector&& values) noexcept : buffer(std::move(values)), previous_size_(buffer.size()) {} + template + explicit ArrayStateData(Range&& values) noexcept + : ArrayStateData(std::vector(values.begin(), values.end())) {} + // Assign new values to the state, tracking the changes from the previous state to the new // one. Including resizes. bool assign(std::ranges::sized_range auto&& values) { @@ -193,6 +197,10 @@ class ArrayNodeStateData: public ArrayStateData, public NodeStateData { explicit ArrayNodeStateData(std::vector&& values) noexcept : ArrayStateData(std::move(values)), NodeStateData() {} + template + explicit ArrayNodeStateData(Range&& values) noexcept + : ArrayNodeStateData(std::vector(values.begin(), values.end())) {} + std::unique_ptr copy() const override { return std::make_unique(*this); } diff --git a/dwave/optimization/src/nodes/manipulation.cpp b/dwave/optimization/src/nodes/manipulation.cpp index c6c3ae80..2bf45074 100644 --- a/dwave/optimization/src/nodes/manipulation.cpp +++ b/dwave/optimization/src/nodes/manipulation.cpp @@ -156,6 +156,50 @@ void ConcatenateNode::propagate(State& state) const { void ConcatenateNode::revert(State& state) const { data_ptr(state)->revert(); } +CopyNode::CopyNode(ArrayNode* array_ptr) + : ArrayOutputMixin(array_ptr->shape()), array_ptr_(array_ptr) { + this->add_predecessor(array_ptr); +} + +double const* CopyNode::buff(const State& state) const { + return data_ptr(state)->buff(); +} + +void CopyNode::commit(State& state) const { data_ptr(state)->commit(); } + +std::span CopyNode::diff(const State& state) const { + return data_ptr(state)->diff(); +} + +bool CopyNode::integral() const { return array_ptr_->integral(); } + +void CopyNode::initialize_state(State& state) const { + int index = this->topological_index(); + assert(index >= 0 && "must be topologically sorted"); + assert(static_cast(state.size()) > index && "unexpected state length"); + assert(state[index] == nullptr && "already initialized state"); + + state[index] = std::make_unique(array_ptr_->view(state)); +} + +double CopyNode::max() const { return array_ptr_->max(); } + +double CopyNode::min() const { return array_ptr_->min(); } + +void CopyNode::propagate(State& state) const { + data_ptr(state)->update(array_ptr_->diff(state)); +} + +void CopyNode::revert(State& state) const { data_ptr(state)->revert(); } + +std::span CopyNode::shape(const State& state) const { + return array_ptr_->shape(state); +} + +ssize_t CopyNode::size(const State& state) const { return array_ptr_->size(state); } + +ssize_t CopyNode::size_diff(const State& state) const { return array_ptr_->size_diff(state); } + // A PutNode needs to track its buffer as well as a mask of which elements in the // original array are currently overwritten. // We use ArrayStateData for the buffer diff --git a/dwave/optimization/symbols.pyi b/dwave/optimization/symbols.pyi index e15cb608..cf50a07f 100644 --- a/dwave/optimization/symbols.pyi +++ b/dwave/optimization/symbols.pyi @@ -76,6 +76,10 @@ class Constant(ArraySymbol): def __le__(self, rhs: numpy.typing.ArrayLike) -> numpy.typing.NDArray[numpy.bool]: ... +class Copy(ArraySymbol): + ... + + class DisjointBitSets(Symbol): def set_state(self, index: int, state: numpy.typing.ArrayLike): ... diff --git a/dwave/optimization/symbols.pyx b/dwave/optimization/symbols.pyx index 8cfe4e96..2538cb52 100644 --- a/dwave/optimization/symbols.pyx +++ b/dwave/optimization/symbols.pyx @@ -56,6 +56,7 @@ from dwave.optimization.libcpp.nodes cimport ( BinaryNode as cppBinaryNode, ConcatenateNode as cppConcatenateNode, ConstantNode as cppConstantNode, + CopyNode as cppCopyNode, DisjointBitSetNode as cppDisjointBitSetNode, DisjointBitSetsNode as cppDisjointBitSetsNode, DisjointListNode as cppDisjointListNode, @@ -110,6 +111,7 @@ __all__ = [ "BinaryVariable", "Concatenate", "Constant", + "Copy", "DisjointBitSets", "DisjointBitSet", "DisjointLists", @@ -1048,6 +1050,39 @@ cdef class Constant(ArraySymbol): _register(Constant, typeid(cppConstantNode)) +cdef class Copy(ArraySymbol): + """An array symbol that is a copy of another array symbol. + + See Also: + :meth:`ArraySymbol.copy` Equivalent method. + + .. versionadded:: 0.5.1 + """ + def __init__(self, ArraySymbol node): + cdef _Graph model = node.model + + self.ptr = model._graph.emplace_node[cppCopyNode]( + node.array_ptr, + ) + + self.initialize_arraynode(model, self.ptr) + + @staticmethod + def _from_symbol(Symbol symbol): + cdef cppCopyNode* ptr = dynamic_cast_ptr[cppCopyNode](symbol.node_ptr) + if not ptr: + raise TypeError("given symbol cannot be used to construct a Copy") + + cdef Copy m = Copy.__new__(Copy) + m.ptr = ptr + m.initialize_arraynode(symbol.model, ptr) + return m + + cdef cppCopyNode* ptr + +_register(Copy, typeid(cppCopyNode)) + + cdef class DisjointBitSets(Symbol): """Disjoint-sets decision-variable symbol. diff --git a/releasenotes/notes/feature-ArraySymbol-flatten-a499f7edf2e28185.yaml b/releasenotes/notes/feature-ArraySymbol-flatten-a499f7edf2e28185.yaml index e0f4c69f..9a82cd30 100644 --- a/releasenotes/notes/feature-ArraySymbol-flatten-a499f7edf2e28185.yaml +++ b/releasenotes/notes/feature-ArraySymbol-flatten-a499f7edf2e28185.yaml @@ -7,6 +7,9 @@ features: - | Support inferring the shape of one axis when reshaping array symbols by providing ``-1`` for the dimension's shape. + - Support reshaping non-contiguous array symbols. + - Add C++ ``CopyNode``. See `#16 `_. + - Add ``Copy`` symbol. See `#16 `_. fixes: - | Add missing C++ ``ReshapeNode::max()``, ``::min()``, and ``::integral()`` methods. diff --git a/tests/cpp/nodes/test_manipulation.cpp b/tests/cpp/nodes/test_manipulation.cpp index 657b42fe..9c856d2c 100644 --- a/tests/cpp/nodes/test_manipulation.cpp +++ b/tests/cpp/nodes/test_manipulation.cpp @@ -278,6 +278,110 @@ TEST_CASE("ConcatenateNode") { } } +TEST_CASE("CopyNode") { + GIVEN("x = IntegerNode({6}, 0, 10); y = x[::2]; c = CopyNode(y)") { + auto graph = Graph(); + + auto x_ptr = graph.emplace_node(std::initializer_list{6}, 0, 10); + auto y_ptr = graph.emplace_node(x_ptr, Slice(0, 10, 2)); + auto c_ptr = graph.emplace_node(y_ptr); + + graph.emplace_node(c_ptr); + + auto state = graph.empty_state(); + x_ptr->initialize_state(state, {0, 1, 2, 3, 4, 5}); + graph.initialize_state(state); + + THEN("c has the same shape as y and the same values") { + CHECK(std::ranges::equal(c_ptr->shape(state), y_ptr->shape(state))); + CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state))); + + CHECK(!y_ptr->contiguous()); + CHECK(c_ptr->contiguous()); + + CHECK(c_ptr->max() == y_ptr->max()); + CHECK(c_ptr->min() == y_ptr->min()); + CHECK(c_ptr->integral() == y_ptr->integral()); + } + + WHEN("We mutate the state of x") { + x_ptr->set_value(state, 2, 10); + graph.propagate(state, graph.descendants(state, {x_ptr})); + + THEN("c has the same shape as y and the same values") { + CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state))); + } + + AND_WHEN("we commit") { + graph.commit(state, graph.descendants(state, {x_ptr})); + + THEN("c has the same shape as y and the same values") { + CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state))); + } + } + + AND_WHEN("we revert") { + graph.revert(state, graph.descendants(state, {x_ptr})); + + THEN("c has the same shape as y and the same values") { + CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state))); + } + } + } + } + + GIVEN("x = SetNode(6); y = x[::2]; c = CopyNode(y)") { + auto graph = Graph(); + + auto x_ptr = graph.emplace_node(6); + auto y_ptr = graph.emplace_node(x_ptr, Slice(0, 10, 2)); + auto c_ptr = graph.emplace_node(y_ptr); + + graph.emplace_node(c_ptr); + + auto state = graph.empty_state(); + x_ptr->initialize_state(state, {0, 1, 2}); + graph.initialize_state(state); + + THEN("c has the same shape as y and the same values") { + CHECK(std::ranges::equal(c_ptr->shape(state), y_ptr->shape(state))); + CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state))); + + CHECK(!y_ptr->contiguous()); + CHECK(c_ptr->contiguous()); + + CHECK(c_ptr->max() == y_ptr->max()); + CHECK(c_ptr->min() == y_ptr->min()); + CHECK(c_ptr->integral() == y_ptr->integral()); + } + + WHEN("We mutate the state of x") { + x_ptr->grow(state); + graph.propagate(state, graph.descendants(state, {x_ptr})); + + THEN("c has the same shape as y and the same values") { + CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state))); + } + + AND_WHEN("we commit") { + graph.commit(state, graph.descendants(state, {x_ptr})); + + THEN("c has the same shape as y and the same values") { + CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state))); + } + } + + AND_WHEN("we revert") { + graph.revert(state, graph.descendants(state, {x_ptr})); + + THEN("c has the same shape as y and the same values") { + CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state))); + } + } + } + } +} + TEST_CASE("PutNode") { SECTION("a = [0, 1, 2, 3, 4], ind = [0, 2], v = [-44, -55], b = PutNode(a, ind, v)") { auto graph = Graph(); diff --git a/tests/test_symbols.py b/tests/test_symbols.py index d6a3719e..cbd21c02 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -959,6 +959,25 @@ def test_nonfinite_values(self): model.constant(np.array([0, 5, np.nan])) +class TestCopy(utils.SymbolTests): + def generate_symbols(self): + model = Model() + c = model.constant(np.arange(25).reshape(5, 5)) + c_copy = c.copy() + c_indexed_copy = c[::2, 1:4].copy() + with model.lock(): + yield c_copy + yield c_indexed_copy + + def test_simple(self): + model = Model() + c = model.constant(np.arange(25).reshape(5, 5)) + copy = c[::2, 1:4].copy() + model.states.resize(1) + with model.lock(): + np.testing.assert_array_equal(copy.state(), np.arange(25).reshape(5, 5)[::2, 1:4]) + + class TestDisjointBitSetsVariable(utils.SymbolTests): def test_inequality(self): # TODO re-enable this once equality has been fixed