Skip to content

Commit

Permalink
Add ArraySymbol.flatten() method
Browse files Browse the repository at this point in the history
  • Loading branch information
arcondello committed Jan 30, 2025
1 parent 08e6608 commit d46e2d6
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 14 deletions.
1 change: 1 addition & 0 deletions dwave/optimization/_model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class ArraySymbol(Symbol):
def __truediv__(self, rhs: ArraySymbol) -> Divide: ...
def all(self) -> All: ...
def any(self) -> Any: ...
def flatten(self) -> Reshape: ...
def max(self) -> Max: ...
def min(self) -> Min: ...
def ndim(self) -> int: ...
Expand Down
7 changes: 7 additions & 0 deletions dwave/optimization/_model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,13 @@ cdef class ArraySymbol(Symbol):
from dwave.optimization.symbols import Any # avoid circular import
return Any(self)

def flatten(self):
"""Return an array symbol collapsed into one dimension.
Equivalent to ``symbol.reshape(-1)``.
"""
return self.reshape(-1)

def max(self):
"""Create a :class:`~dwave.optimization.symbols.Max` symbol.
Expand Down
9 changes: 5 additions & 4 deletions dwave/optimization/include/dwave-optimization/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,11 @@ class ArrayOutputMixin : public Base {
explicit ArrayOutputMixin(ssize_t n) : ArrayOutputMixin({n}) {}

explicit ArrayOutputMixin(std::initializer_list<ssize_t> shape)
: ndim_(shape.size()), shape_(make_shape(shape)) {}
: ArrayOutputMixin(std::span(shape)) {}

explicit ArrayOutputMixin(std::span<const ssize_t> shape)
: ndim_(shape.size()), shape_(make_shape(shape)) {}
template <std::ranges::sized_range Range>
explicit ArrayOutputMixin(Range&& shape)
: ndim_(shape.size()), shape_(make_shape(std::forward<Range>(shape))) {}

ssize_t ndim() const noexcept final { return ndim_; }

Expand All @@ -941,7 +942,7 @@ class ArrayOutputMixin : public Base {
constexpr bool contiguous() const noexcept final { return true; }

private:
template <class Range>
template <std::ranges::sized_range Range>
static std::unique_ptr<ssize_t[]> make_shape(Range&& shape) noexcept {
if (shape.size() == 0) return nullptr;
auto ptr = std::make_unique<ssize_t[]>(shape.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,43 @@ class PutNode : public ArrayOutputMixin<ArrayNode> {
const Array* values_ptr_;
};


/// Propagates the values of its predecessor, interpreted into a different shape.
class ReshapeNode : public ArrayOutputMixin<ArrayNode> {
public:
ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape);
/// Constructor for ReshapeNode.
///
/// @param array_ptr The array to be reshaped. May not be dynamic.
/// @param shape The new shape. Must have the same size as the original shape.
ReshapeNode(ArrayNode* array_ptr, std::vector<ssize_t>&& shape);

/// Constructor for ReshapeNode.
///
/// @param array_ptr The array to be reshaped. May not be dynamic.
/// @param shape The new shape. Must have the same size as the original shape.
template <std::ranges::range Range>
ReshapeNode(ArrayNode* node_ptr, Range&& shape)
: ReshapeNode(node_ptr, std::vector<ssize_t>(shape.begin(), shape.end())) {}

/// @copydoc Array::buff()
double const* buff(const State& state) const override;

/// @copydoc Node::commit()
void commit(State& state) const override;

/// @copydoc Array::diff()
std::span<const Update> diff(const State& state) const override;

/// @copydoc Array::integral()
bool integral() const override;

/// @copydoc Array::max()
double max() const override;

/// @copydoc Array::min()
double min() const override;

/// @copydoc Node::revert()
void revert(State& state) const override;

private:
Expand Down
39 changes: 34 additions & 5 deletions dwave/optimization/src/nodes/manipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,34 @@ void PutNode::propagate(State& state) const {

void PutNode::revert(State& state) const { return data_ptr<PutNodeState>(state)->revert(); }

ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape)
: ArrayOutputMixin(shape), array_ptr_(node_ptr) {
// Reshape allows one shape dimension to be -1. In that case the size is inferred.
// We do that inference here.
std::vector<ssize_t> infer_reshape(Array* array_ptr, std::vector<ssize_t>&& shape) {
// if the base array is dynamic, we might allow the first dimension to be negative
// 1. So let's defer to the various constructors.
if (array_ptr->dynamic()) return shape;

// Check if there are any -1s, and if not fallback to other input checking.
auto it = std::ranges::find(shape, -1);
if (it == shape.end()) return shape;

// Get the product of the shape and negate it (to exclude the -1)
auto prod = -std::reduce(shape.begin(), shape.end(), 1, std::multiplies<ssize_t>());

// If the product is <=0, then we have another negative number or a 0. In which
// case we just fall back to other error checking.
if (prod <= 0) return shape;

// Ok, we can officially overwrite the -1.
// Don't worry about the case that prod doesn't divide array_ptr->size(), other
// error checking will catch that case.
*it = array_ptr->size() / prod;

return shape;
}

ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::vector<ssize_t>&& shape)
: ArrayOutputMixin(infer_reshape(node_ptr, std::move(shape))), array_ptr_(node_ptr) {
// Don't (yet) support non-contiguous predecessors.
// In some cases with non-contiguous predecessors we need to make a copy.
// See https://github.com/dwavesystems/dwave-optimization/issues/200
Expand Down Expand Up @@ -478,9 +504,6 @@ ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape)
this->add_predecessor(node_ptr);
}

ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::vector<ssize_t>&& shape)
: ReshapeNode(node_ptr, std::span(shape)) {}

double const* ReshapeNode::buff(const State& state) const { return array_ptr_->buff(state); }

void ReshapeNode::commit(State& state) const {} // stateless node
Expand All @@ -489,6 +512,12 @@ std::span<const Update> ReshapeNode::diff(const State& state) const {
return array_ptr_->diff(state);
}

bool ReshapeNode::integral() const { return array_ptr_->integral(); }

double ReshapeNode::max() const { return array_ptr_->max(); }

double ReshapeNode::min() const { return array_ptr_->min(); }

void ReshapeNode::revert(State& state) const {} // stateless node

class SizeNodeData : public ScalarNodeStateData {
Expand Down
8 changes: 4 additions & 4 deletions dwave/optimization/symbols.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@ cdef object symbol_from_ptr(_Graph model, cppNode* node_ptr):
return cls._from_symbol(Symbol.from_ptr(model, node_ptr))


cdef vector[Py_ssize_t] _as_cppshape(object shape):
cdef vector[Py_ssize_t] _as_cppshape(object shape, bint nonnegative = True):
"""Convert a shape specified as a python object to a C++ vector."""

# Use the same error messages as NumPy

if isinstance(shape, numbers.Integral):
return _as_cppshape((shape,))
return _as_cppshape((shape,), nonnegative=nonnegative)

if not isinstance(shape, collections.abc.Sequence):
raise TypeError(f"expected a sequence of integers or a single integer, got '{repr(shape)}'")
Expand All @@ -212,7 +212,7 @@ cdef vector[Py_ssize_t] _as_cppshape(object shape):
if not all(isinstance(x, numbers.Integral) for x in shape):
raise ValueError(f"expected a sequence of integers or a single integer, got '{repr(shape)}'")

if any(x < 0 for x in shape):
if nonnegative and any(x < 0 for x in shape):
raise ValueError("negative dimensions are not allowed")

return shape
Expand Down Expand Up @@ -2829,7 +2829,7 @@ cdef class Reshape(ArraySymbol):

self.ptr = model._graph.emplace_node[cppReshapeNode](
node.array_ptr,
_as_cppshape(shape),
_as_cppshape(shape, nonnegative=False),
)

self.initialize_arraynode(model, self.ptr)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
features:
- |
Support both lvalue and rvalue ranges in the constructor of the C++
``ArrayOutputMixin`` class.
- Rework C++ ``ReshapeNode`` constructors to be more general.
- |
Support inferring the shape of one axis when reshaping array symbols by providing
``-1`` for the dimension's shape.
fixes:
- |
Add missing C++ ``ReshapeNode::max()``, ``::min()``, and ``::integral()`` methods.
Thereby allowing reshaped arrays to be used as indices.
43 changes: 43 additions & 0 deletions tests/cpp/nodes/test_manipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,49 @@ TEST_CASE("ReshapeNode") {
THEN("It has the shape/size/etc we expect") {
CHECK(B.ndim() == 1);
CHECK(std::ranges::equal(B.shape(), std::vector{12}));

CHECK(B.max() == A.max());
CHECK(B.min() == A.min());
CHECK(B.integral() == A.integral());
}
}

WHEN("It is reshaped without specifying the size of axis 0") {
auto B = ReshapeNode(&A, {-1});

THEN("It has the shape/size/etc we expect") {
CHECK(B.ndim() == 1);
CHECK(std::ranges::equal(B.shape(), std::vector{12}));

CHECK(B.max() == A.max());
CHECK(B.min() == A.min());
CHECK(B.integral() == A.integral());
}
}

WHEN("We reshape it into a 3x4 array explicitly") {
auto B = ReshapeNode(&A, {3, 4});

THEN("It has the shape/size/etc we expect") {
CHECK(B.ndim() == 2);
CHECK(std::ranges::equal(B.shape(), std::vector{3, 4}));

CHECK(B.max() == A.max());
CHECK(B.min() == A.min());
CHECK(B.integral() == A.integral());
}
}

WHEN("We reshape it into a 3x4 array implicitly") {
auto B = ReshapeNode(&A, {3, -1});

THEN("It has the shape/size/etc we expect") {
CHECK(B.ndim() == 2);
CHECK(std::ranges::equal(B.shape(), std::vector{3, 4}));

CHECK(B.max() == A.max());
CHECK(B.min() == A.min());
CHECK(B.integral() == A.integral());
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions tests/test_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,6 +2232,24 @@ def generate_symbols(self):
model.lock()
yield from syms

def test_implicit_reshape(self):
model = Model()
A = model.constant(np.arange(12).reshape(3, 4))
B = A.reshape(2, -1)
C = A.reshape(-1, 6)
model.states.resize(1)
with model.lock():
np.testing.assert_array_equal(B.state(), np.arange(12).reshape(2, 6))
np.testing.assert_array_equal(C.state(), np.arange(12).reshape(2, 6))

def test_flatten(self):
model = Model()
A = model.constant(np.arange(25).reshape(5, 5))
B = A.flatten()
model.states.resize(1)
with model.lock():
np.testing.assert_array_equal(B.state(), np.arange(25))


class TestRint(utils.SymbolTests):
rng = np.random.default_rng(1)
Expand Down

0 comments on commit d46e2d6

Please sign in to comment.