Skip to content

Commit

Permalink
Add Copy symbol
Browse files Browse the repository at this point in the history
Closes #16
  • Loading branch information
arcondello committed Jan 30, 2025
1 parent d46e2d6 commit cb0fb2f
Show file tree
Hide file tree
Showing 12 changed files with 283 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/reference/symbols.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ are inherited by the :ref:`model symbols <symbols_model_symbols>`.

~ArraySymbol.all
~ArraySymbol.any
~ArraySymbol.copy
~ArraySymbol.flatten
~ArraySymbol.has_state
~ArraySymbol.max
~ArraySymbol.maybe_equals
Expand Down
1 change: 1 addition & 0 deletions dwave/optimization/_model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class ArraySymbol(Symbol):
def __truediv__(self, rhs: ArraySymbol) -> Divide: ...
def all(self) -> All: ...
def any(self) -> Any: ...
def copy(self) -> Copy: ...
def flatten(self) -> Reshape: ...
def max(self) -> Max: ...
def min(self) -> Min: ...
Expand Down
11 changes: 11 additions & 0 deletions dwave/optimization/_model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,17 @@ cdef class ArraySymbol(Symbol):
from dwave.optimization.symbols import Any # avoid circular import
return Any(self)

def copy(self):
"""Return an array symbol that is a copy of the array.
See Also:
:class:`~dwave.optimization.symbols.Copy` Equivalent class.
.. versionadded:: 0.5.1
"""
from dwave.optimization.symbols import Copy # avoid circular import
return Copy(self)

def flatten(self):
"""Return an array symbol collapsed into one dimension.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,55 @@ class ConcatenateNode : public ArrayOutputMixin<ArrayNode> {
std::vector<ssize_t> array_starts_;
};

/// An array node that is a contiguous copy of its predecessor.
class CopyNode : public ArrayOutputMixin<ArrayNode> {
public:
explicit CopyNode(ArrayNode* array_ptr);

/// @copydoc Array::buff()
double const* buff(const State& state) const override;

/// @copydoc Node::commit()
void commit(State& state) const override;

/// @copydoc Array::diff()
std::span<const Update> diff(const State& state) const override;

/// @copydoc Node::initialize_state()
void initialize_state(State& state) const override;

/// @copydoc Array::integral()
bool integral() const override;

/// @copydoc Array::max()
double max() const override;

/// @copydoc Array::min()
double min() const override;

/// @copydoc Node::propagate()
void propagate(State& state) const override;

/// @copydoc Node::revert()
void revert(State& state) const override;

using ArrayOutputMixin::shape;

/// @copydoc Array::shape()
std::span<const ssize_t> shape(const State& state) const override;

using ArrayOutputMixin::size;

/// @copydoc Array::size()
ssize_t size(const State& state) const override;

/// @copydoc Array::size_diff()
ssize_t size_diff(const State& state) const override;

private:
const Array* array_ptr_;
};

/// Replaces specified elements of an array with the given values.
///
/// The indexing works on the flattened array. Translated to NumPy, PutNode is
Expand Down
3 changes: 3 additions & 0 deletions dwave/optimization/libcpp/nodes.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ cdef extern from "dwave-optimization/nodes/manipulation.hpp" namespace "dwave::o
cdef cppclass ConcatenateNode(ArrayNode):
Py_ssize_t axis()

cdef cppclass CopyNode(ArrayNode):
pass

cdef cppclass PutNode(ArrayNode):
pass

Expand Down
8 changes: 8 additions & 0 deletions dwave/optimization/src/nodes/_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class ArrayStateData {
explicit ArrayStateData(std::vector<double>&& values) noexcept
: buffer(std::move(values)), previous_size_(buffer.size()) {}

template <std::ranges::range Range>
explicit ArrayStateData(Range&& values) noexcept
: ArrayStateData(std::vector<double>(values.begin(), values.end())) {}

// Assign new values to the state, tracking the changes from the previous state to the new
// one. Including resizes.
bool assign(std::ranges::sized_range auto&& values) {
Expand Down Expand Up @@ -193,6 +197,10 @@ class ArrayNodeStateData: public ArrayStateData, public NodeStateData {
explicit ArrayNodeStateData(std::vector<double>&& values) noexcept
: ArrayStateData(std::move(values)), NodeStateData() {}

template <std::ranges::range Range>
explicit ArrayNodeStateData(Range&& values) noexcept
: ArrayNodeStateData(std::vector<double>(values.begin(), values.end())) {}

std::unique_ptr<NodeStateData> copy() const override {
return std::make_unique<ArrayNodeStateData>(*this);
}
Expand Down
44 changes: 44 additions & 0 deletions dwave/optimization/src/nodes/manipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,50 @@ void ConcatenateNode::propagate(State& state) const {

void ConcatenateNode::revert(State& state) const { data_ptr<ArrayNodeStateData>(state)->revert(); }

CopyNode::CopyNode(ArrayNode* array_ptr)
: ArrayOutputMixin(array_ptr->shape()), array_ptr_(array_ptr) {
this->add_predecessor(array_ptr);
}

double const* CopyNode::buff(const State& state) const {
return data_ptr<ArrayNodeStateData>(state)->buff();
}

void CopyNode::commit(State& state) const { data_ptr<ArrayNodeStateData>(state)->commit(); }

std::span<const Update> CopyNode::diff(const State& state) const {
return data_ptr<ArrayNodeStateData>(state)->diff();
}

bool CopyNode::integral() const { return array_ptr_->integral(); }

void CopyNode::initialize_state(State& state) const {
int index = this->topological_index();
assert(index >= 0 && "must be topologically sorted");
assert(static_cast<int>(state.size()) > index && "unexpected state length");
assert(state[index] == nullptr && "already initialized state");

state[index] = std::make_unique<ArrayNodeStateData>(array_ptr_->view(state));
}

double CopyNode::max() const { return array_ptr_->max(); }

double CopyNode::min() const { return array_ptr_->min(); }

void CopyNode::propagate(State& state) const {
data_ptr<ArrayNodeStateData>(state)->update(array_ptr_->diff(state));
}

void CopyNode::revert(State& state) const { data_ptr<ArrayNodeStateData>(state)->revert(); }

std::span<const ssize_t> CopyNode::shape(const State& state) const {
return array_ptr_->shape(state);
}

ssize_t CopyNode::size(const State& state) const { return array_ptr_->size(state); }

ssize_t CopyNode::size_diff(const State& state) const { return array_ptr_->size_diff(state); }

// A PutNode needs to track its buffer as well as a mask of which elements in the
// original array are currently overwritten.
// We use ArrayStateData for the buffer
Expand Down
4 changes: 4 additions & 0 deletions dwave/optimization/symbols.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class Constant(ArraySymbol):
def __le__(self, rhs: numpy.typing.ArrayLike) -> numpy.typing.NDArray[numpy.bool]: ...


class Copy(ArraySymbol):
...


class DisjointBitSets(Symbol):
def set_state(self, index: int, state: numpy.typing.ArrayLike): ...

Expand Down
35 changes: 35 additions & 0 deletions dwave/optimization/symbols.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ from dwave.optimization.libcpp.nodes cimport (
BinaryNode as cppBinaryNode,
ConcatenateNode as cppConcatenateNode,
ConstantNode as cppConstantNode,
CopyNode as cppCopyNode,
DisjointBitSetNode as cppDisjointBitSetNode,
DisjointBitSetsNode as cppDisjointBitSetsNode,
DisjointListNode as cppDisjointListNode,
Expand Down Expand Up @@ -110,6 +111,7 @@ __all__ = [
"BinaryVariable",
"Concatenate",
"Constant",
"Copy",
"DisjointBitSets",
"DisjointBitSet",
"DisjointLists",
Expand Down Expand Up @@ -1048,6 +1050,39 @@ cdef class Constant(ArraySymbol):
_register(Constant, typeid(cppConstantNode))


cdef class Copy(ArraySymbol):
"""An array symbol that is a copy of another array symbol.
See Also:
:meth:`ArraySymbol.copy` Equivalent method.
.. versionadded:: 0.5.1
"""
def __init__(self, ArraySymbol node):
cdef _Graph model = node.model

self.ptr = model._graph.emplace_node[cppCopyNode](
node.array_ptr,
)

self.initialize_arraynode(model, self.ptr)

@staticmethod
def _from_symbol(Symbol symbol):
cdef cppCopyNode* ptr = dynamic_cast_ptr[cppCopyNode](symbol.node_ptr)
if not ptr:
raise TypeError("given symbol cannot be used to construct a Copy")

cdef Copy m = Copy.__new__(Copy)
m.ptr = ptr
m.initialize_arraynode(symbol.model, ptr)
return m

cdef cppCopyNode* ptr

_register(Copy, typeid(cppCopyNode))


cdef class DisjointBitSets(Symbol):
"""Disjoint-sets decision-variable symbol.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ features:
- |
Support inferring the shape of one axis when reshaping array symbols by providing
``-1`` for the dimension's shape.
- Support reshaping non-contiguous array symbols.
- Add C++ ``CopyNode``. See `#16 <https://github.com/dwavesystems/dwave-optimization/issues/16>`_.
- Add ``Copy`` symbol. See `#16 <https://github.com/dwavesystems/dwave-optimization/issues/16>`_.
fixes:
- |
Add missing C++ ``ReshapeNode::max()``, ``::min()``, and ``::integral()`` methods.
Expand Down
104 changes: 104 additions & 0 deletions tests/cpp/nodes/test_manipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,110 @@ TEST_CASE("ConcatenateNode") {
}
}

TEST_CASE("CopyNode") {
GIVEN("x = IntegerNode({6}, 0, 10); y = x[::2]; c = CopyNode(y)") {
auto graph = Graph();

auto x_ptr = graph.emplace_node<IntegerNode>(std::initializer_list<ssize_t>{6}, 0, 10);
auto y_ptr = graph.emplace_node<BasicIndexingNode>(x_ptr, Slice(0, 10, 2));
auto c_ptr = graph.emplace_node<CopyNode>(y_ptr);

graph.emplace_node<ArrayValidationNode>(c_ptr);

auto state = graph.empty_state();
x_ptr->initialize_state(state, {0, 1, 2, 3, 4, 5});
graph.initialize_state(state);

THEN("c has the same shape as y and the same values") {
CHECK(std::ranges::equal(c_ptr->shape(state), y_ptr->shape(state)));
CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state)));

CHECK(!y_ptr->contiguous());
CHECK(c_ptr->contiguous());

CHECK(c_ptr->max() == y_ptr->max());
CHECK(c_ptr->min() == y_ptr->min());
CHECK(c_ptr->integral() == y_ptr->integral());
}

WHEN("We mutate the state of x") {
x_ptr->set_value(state, 2, 10);
graph.propagate(state, graph.descendants(state, {x_ptr}));

THEN("c has the same shape as y and the same values") {
CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state)));
}

AND_WHEN("we commit") {
graph.commit(state, graph.descendants(state, {x_ptr}));

THEN("c has the same shape as y and the same values") {
CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state)));
}
}

AND_WHEN("we revert") {
graph.revert(state, graph.descendants(state, {x_ptr}));

THEN("c has the same shape as y and the same values") {
CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state)));
}
}
}
}

GIVEN("x = SetNode(6); y = x[::2]; c = CopyNode(y)") {
auto graph = Graph();

auto x_ptr = graph.emplace_node<SetNode>(6);
auto y_ptr = graph.emplace_node<BasicIndexingNode>(x_ptr, Slice(0, 10, 2));
auto c_ptr = graph.emplace_node<CopyNode>(y_ptr);

graph.emplace_node<ArrayValidationNode>(c_ptr);

auto state = graph.empty_state();
x_ptr->initialize_state(state, {0, 1, 2});
graph.initialize_state(state);

THEN("c has the same shape as y and the same values") {
CHECK(std::ranges::equal(c_ptr->shape(state), y_ptr->shape(state)));
CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state)));

CHECK(!y_ptr->contiguous());
CHECK(c_ptr->contiguous());

CHECK(c_ptr->max() == y_ptr->max());
CHECK(c_ptr->min() == y_ptr->min());
CHECK(c_ptr->integral() == y_ptr->integral());
}

WHEN("We mutate the state of x") {
x_ptr->grow(state);
graph.propagate(state, graph.descendants(state, {x_ptr}));

THEN("c has the same shape as y and the same values") {
CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state)));
}

AND_WHEN("we commit") {
graph.commit(state, graph.descendants(state, {x_ptr}));

THEN("c has the same shape as y and the same values") {
CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state)));
}
}

AND_WHEN("we revert") {
graph.revert(state, graph.descendants(state, {x_ptr}));

THEN("c has the same shape as y and the same values") {
CHECK(std::ranges::equal(c_ptr->view(state), y_ptr->view(state)));
}
}
}
}
}

TEST_CASE("PutNode") {
SECTION("a = [0, 1, 2, 3, 4], ind = [0, 2], v = [-44, -55], b = PutNode(a, ind, v)") {
auto graph = Graph();
Expand Down
19 changes: 19 additions & 0 deletions tests/test_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,25 @@ def test_nonfinite_values(self):
model.constant(np.array([0, 5, np.nan]))


class TestCopy(utils.SymbolTests):
def generate_symbols(self):
model = Model()
c = model.constant(np.arange(25).reshape(5, 5))
c_copy = c.copy()
c_indexed_copy = c[::2, 1:4].copy()
with model.lock():
yield c_copy
yield c_indexed_copy

def test_simple(self):
model = Model()
c = model.constant(np.arange(25).reshape(5, 5))
copy = c[::2, 1:4].copy()
model.states.resize(1)
with model.lock():
np.testing.assert_array_equal(copy.state(), np.arange(25).reshape(5, 5)[::2, 1:4])


class TestDisjointBitSetsVariable(utils.SymbolTests):
def test_inequality(self):
# TODO re-enable this once equality has been fixed
Expand Down

0 comments on commit cb0fb2f

Please sign in to comment.