Skip to content

Commit

Permalink
Merge pull request #218 from arcondello/feature/ArraySymbol.flatten
Browse files Browse the repository at this point in the history
Improve QOL when reshaping array symbols
  • Loading branch information
arcondello authored Jan 31, 2025
2 parents 08e6608 + 6439503 commit a000537
Show file tree
Hide file tree
Showing 14 changed files with 470 additions and 18 deletions.
2 changes: 2 additions & 0 deletions docs/reference/symbols.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ are inherited by the :ref:`model symbols <symbols_model_symbols>`.

~ArraySymbol.all
~ArraySymbol.any
~ArraySymbol.copy
~ArraySymbol.flatten
~ArraySymbol.has_state
~ArraySymbol.max
~ArraySymbol.maybe_equals
Expand Down
2 changes: 2 additions & 0 deletions dwave/optimization/_model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class ArraySymbol(Symbol):
def __truediv__(self, rhs: ArraySymbol) -> Divide: ...
def all(self) -> All: ...
def any(self) -> Any: ...
def copy(self) -> Copy: ...
def flatten(self) -> Reshape: ...
def max(self) -> Max: ...
def min(self) -> Min: ...
def ndim(self) -> int: ...
Expand Down
29 changes: 25 additions & 4 deletions dwave/optimization/_model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,24 @@ cdef class ArraySymbol(Symbol):
from dwave.optimization.symbols import Any # avoid circular import
return Any(self)

def copy(self):
"""Return an array symbol that is a copy of the array.
See Also:
:class:`~dwave.optimization.symbols.Copy` Equivalent class.
.. versionadded:: 0.5.1
"""
from dwave.optimization.symbols import Copy # avoid circular import
return Copy(self)

def flatten(self):
"""Return an array symbol collapsed into one dimension.
Equivalent to ``symbol.reshape(-1)``.
"""
return self.reshape(-1)

def max(self):
"""Create a :class:`~dwave.optimization.symbols.Max` symbol.
Expand Down Expand Up @@ -1350,11 +1368,14 @@ cdef class ArraySymbol(Symbol):
(1, 3)
"""
from dwave.optimization.symbols import Reshape # avoid circular import
if len(shape) > 1:
return Reshape(self, shape)
else:
return Reshape(self, shape[0])
if len(shape) <= 1:
shape = shape[0]

if not self.array_ptr.contiguous():
return Reshape(self.copy(), shape)

return Reshape(self, shape)

def shape(self):
"""Return the shape of the symbol.
Expand Down
9 changes: 5 additions & 4 deletions dwave/optimization/include/dwave-optimization/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,11 @@ class ArrayOutputMixin : public Base {
explicit ArrayOutputMixin(ssize_t n) : ArrayOutputMixin({n}) {}

explicit ArrayOutputMixin(std::initializer_list<ssize_t> shape)
: ndim_(shape.size()), shape_(make_shape(shape)) {}
: ArrayOutputMixin(std::span(shape)) {}

explicit ArrayOutputMixin(std::span<const ssize_t> shape)
: ndim_(shape.size()), shape_(make_shape(shape)) {}
template <std::ranges::sized_range Range>
explicit ArrayOutputMixin(Range&& shape)
: ndim_(shape.size()), shape_(make_shape(std::forward<Range>(shape))) {}

ssize_t ndim() const noexcept final { return ndim_; }

Expand All @@ -941,7 +942,7 @@ class ArrayOutputMixin : public Base {
constexpr bool contiguous() const noexcept final { return true; }

private:
template <class Range>
template <std::ranges::sized_range Range>
static std::unique_ptr<ssize_t[]> make_shape(Range&& shape) noexcept {
if (shape.size() == 0) return nullptr;
auto ptr = std::make_unique<ssize_t[]>(shape.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,55 @@ class ConcatenateNode : public ArrayOutputMixin<ArrayNode> {
std::vector<ssize_t> array_starts_;
};

/// An array node that is a contiguous copy of its predecessor.
class CopyNode : public ArrayOutputMixin<ArrayNode> {
public:
explicit CopyNode(ArrayNode* array_ptr);

/// @copydoc Array::buff()
double const* buff(const State& state) const override;

/// @copydoc Node::commit()
void commit(State& state) const override;

/// @copydoc Array::diff()
std::span<const Update> diff(const State& state) const override;

/// @copydoc Node::initialize_state()
void initialize_state(State& state) const override;

/// @copydoc Array::integral()
bool integral() const override;

/// @copydoc Array::max()
double max() const override;

/// @copydoc Array::min()
double min() const override;

/// @copydoc Node::propagate()
void propagate(State& state) const override;

/// @copydoc Node::revert()
void revert(State& state) const override;

using ArrayOutputMixin::shape;

/// @copydoc Array::shape()
std::span<const ssize_t> shape(const State& state) const override;

using ArrayOutputMixin::size;

/// @copydoc Array::size()
ssize_t size(const State& state) const override;

/// @copydoc Array::size_diff()
ssize_t size_diff(const State& state) const override;

private:
const Array* array_ptr_;
};

/// Replaces specified elements of an array with the given values.
///
/// The indexing works on the flattened array. Translated to NumPy, PutNode is
Expand Down Expand Up @@ -104,14 +153,43 @@ class PutNode : public ArrayOutputMixin<ArrayNode> {
const Array* values_ptr_;
};


/// Propagates the values of its predecessor, interpreted into a different shape.
class ReshapeNode : public ArrayOutputMixin<ArrayNode> {
public:
ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape);
/// Constructor for ReshapeNode.
///
/// @param array_ptr The array to be reshaped. May not be dynamic.
/// @param shape The new shape. Must have the same size as the original shape.
ReshapeNode(ArrayNode* array_ptr, std::vector<ssize_t>&& shape);

/// Constructor for ReshapeNode.
///
/// @param array_ptr The array to be reshaped. May not be dynamic.
/// @param shape The new shape. Must have the same size as the original shape.
template <std::ranges::range Range>
ReshapeNode(ArrayNode* node_ptr, Range&& shape)
: ReshapeNode(node_ptr, std::vector<ssize_t>(shape.begin(), shape.end())) {}

/// @copydoc Array::buff()
double const* buff(const State& state) const override;

/// @copydoc Node::commit()
void commit(State& state) const override;

/// @copydoc Array::diff()
std::span<const Update> diff(const State& state) const override;

/// @copydoc Array::integral()
bool integral() const override;

/// @copydoc Array::max()
double max() const override;

/// @copydoc Array::min()
double min() const override;

/// @copydoc Node::revert()
void revert(State& state) const override;

private:
Expand Down
1 change: 1 addition & 0 deletions dwave/optimization/libcpp/array.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ __all__ = ["Array"]
cdef extern from "dwave-optimization/array.hpp" namespace "dwave::optimization" nogil:
cdef cppclass Array:
double* buff(State&)
bint contiguous() const
bint dynamic() const
const string& format() const
Py_ssize_t itemsize() const
Expand Down
3 changes: 3 additions & 0 deletions dwave/optimization/libcpp/nodes.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ cdef extern from "dwave-optimization/nodes/manipulation.hpp" namespace "dwave::o
cdef cppclass ConcatenateNode(ArrayNode):
Py_ssize_t axis()

cdef cppclass CopyNode(ArrayNode):
pass

cdef cppclass PutNode(ArrayNode):
pass

Expand Down
8 changes: 8 additions & 0 deletions dwave/optimization/src/nodes/_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class ArrayStateData {
explicit ArrayStateData(std::vector<double>&& values) noexcept
: buffer(std::move(values)), previous_size_(buffer.size()) {}

template <std::ranges::range Range>
explicit ArrayStateData(Range&& values) noexcept
: ArrayStateData(std::vector<double>(values.begin(), values.end())) {}

// Assign new values to the state, tracking the changes from the previous state to the new
// one. Including resizes.
bool assign(std::ranges::sized_range auto&& values) {
Expand Down Expand Up @@ -193,6 +197,10 @@ class ArrayNodeStateData: public ArrayStateData, public NodeStateData {
explicit ArrayNodeStateData(std::vector<double>&& values) noexcept
: ArrayStateData(std::move(values)), NodeStateData() {}

template <std::ranges::range Range>
explicit ArrayNodeStateData(Range&& values) noexcept
: ArrayNodeStateData(std::vector<double>(values.begin(), values.end())) {}

std::unique_ptr<NodeStateData> copy() const override {
return std::make_unique<ArrayNodeStateData>(*this);
}
Expand Down
89 changes: 84 additions & 5 deletions dwave/optimization/src/nodes/manipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,50 @@ void ConcatenateNode::propagate(State& state) const {

void ConcatenateNode::revert(State& state) const { data_ptr<ArrayNodeStateData>(state)->revert(); }

CopyNode::CopyNode(ArrayNode* array_ptr)
: ArrayOutputMixin(array_ptr->shape()), array_ptr_(array_ptr) {
this->add_predecessor(array_ptr);
}

double const* CopyNode::buff(const State& state) const {
return data_ptr<ArrayNodeStateData>(state)->buff();
}

void CopyNode::commit(State& state) const { data_ptr<ArrayNodeStateData>(state)->commit(); }

std::span<const Update> CopyNode::diff(const State& state) const {
return data_ptr<ArrayNodeStateData>(state)->diff();
}

bool CopyNode::integral() const { return array_ptr_->integral(); }

void CopyNode::initialize_state(State& state) const {
int index = this->topological_index();
assert(index >= 0 && "must be topologically sorted");
assert(static_cast<int>(state.size()) > index && "unexpected state length");
assert(state[index] == nullptr && "already initialized state");

state[index] = std::make_unique<ArrayNodeStateData>(array_ptr_->view(state));
}

double CopyNode::max() const { return array_ptr_->max(); }

double CopyNode::min() const { return array_ptr_->min(); }

void CopyNode::propagate(State& state) const {
data_ptr<ArrayNodeStateData>(state)->update(array_ptr_->diff(state));
}

void CopyNode::revert(State& state) const { data_ptr<ArrayNodeStateData>(state)->revert(); }

std::span<const ssize_t> CopyNode::shape(const State& state) const {
return array_ptr_->shape(state);
}

ssize_t CopyNode::size(const State& state) const { return array_ptr_->size(state); }

ssize_t CopyNode::size_diff(const State& state) const { return array_ptr_->size_diff(state); }

// A PutNode needs to track its buffer as well as a mask of which elements in the
// original array are currently overwritten.
// We use ArrayStateData for the buffer
Expand Down Expand Up @@ -442,8 +486,34 @@ void PutNode::propagate(State& state) const {

void PutNode::revert(State& state) const { return data_ptr<PutNodeState>(state)->revert(); }

ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape)
: ArrayOutputMixin(shape), array_ptr_(node_ptr) {
// Reshape allows one shape dimension to be -1. In that case the size is inferred.
// We do that inference here.
std::vector<ssize_t> infer_reshape(Array* array_ptr, std::vector<ssize_t>&& shape) {
// if the base array is dynamic, we might allow the first dimension to be negative
// 1. So let's defer to the various constructors.
if (array_ptr->dynamic()) return shape;

// Check if there are any -1s, and if not fallback to other input checking.
auto it = std::ranges::find(shape, -1);
if (it == shape.end()) return shape;

// Get the product of the shape and negate it (to exclude the -1)
auto prod = -std::reduce(shape.begin(), shape.end(), 1, std::multiplies<ssize_t>());

// If the product is <=0, then we have another negative number or a 0. In which
// case we just fall back to other error checking.
if (prod <= 0) return shape;

// Ok, we can officially overwrite the -1.
// Don't worry about the case that prod doesn't divide array_ptr->size(), other
// error checking will catch that case.
*it = array_ptr->size() / prod;

return shape;
}

ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::vector<ssize_t>&& shape)
: ArrayOutputMixin(infer_reshape(node_ptr, std::move(shape))), array_ptr_(node_ptr) {
// Don't (yet) support non-contiguous predecessors.
// In some cases with non-contiguous predecessors we need to make a copy.
// See https://github.com/dwavesystems/dwave-optimization/issues/200
Expand All @@ -468,6 +538,12 @@ ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape)
throw std::invalid_argument("cannot reshape to a dynamic array");
}

// one -1 was already replaced by infer_shape
if (std::ranges::any_of(this->shape() | std::views::drop(1),
[](const ssize_t& dim) { return dim < 0; })) {
throw std::invalid_argument("can only specify one unknown dimension");
}

if (this->size() != array_ptr_->size()) {
// Use the same error message as NumPy
throw std::invalid_argument("cannot reshape array of size " +
Expand All @@ -478,9 +554,6 @@ ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape)
this->add_predecessor(node_ptr);
}

ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::vector<ssize_t>&& shape)
: ReshapeNode(node_ptr, std::span(shape)) {}

double const* ReshapeNode::buff(const State& state) const { return array_ptr_->buff(state); }

void ReshapeNode::commit(State& state) const {} // stateless node
Expand All @@ -489,6 +562,12 @@ std::span<const Update> ReshapeNode::diff(const State& state) const {
return array_ptr_->diff(state);
}

bool ReshapeNode::integral() const { return array_ptr_->integral(); }

double ReshapeNode::max() const { return array_ptr_->max(); }

double ReshapeNode::min() const { return array_ptr_->min(); }

void ReshapeNode::revert(State& state) const {} // stateless node

class SizeNodeData : public ScalarNodeStateData {
Expand Down
4 changes: 4 additions & 0 deletions dwave/optimization/symbols.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class Constant(ArraySymbol):
def __le__(self, rhs: numpy.typing.ArrayLike) -> numpy.typing.NDArray[numpy.bool]: ...


class Copy(ArraySymbol):
...


class DisjointBitSets(Symbol):
def set_state(self, index: int, state: numpy.typing.ArrayLike): ...

Expand Down
Loading

0 comments on commit a000537

Please sign in to comment.