Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve QOL when reshaping array symbols #218

Merged
merged 5 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/reference/symbols.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ are inherited by the :ref:`model symbols <symbols_model_symbols>`.

~ArraySymbol.all
~ArraySymbol.any
~ArraySymbol.copy
~ArraySymbol.flatten
~ArraySymbol.has_state
~ArraySymbol.max
~ArraySymbol.maybe_equals
Expand Down
2 changes: 2 additions & 0 deletions dwave/optimization/_model.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class ArraySymbol(Symbol):
def __truediv__(self, rhs: ArraySymbol) -> Divide: ...
def all(self) -> All: ...
def any(self) -> Any: ...
def copy(self) -> Copy: ...
def flatten(self) -> Reshape: ...
def max(self) -> Max: ...
def min(self) -> Min: ...
def ndim(self) -> int: ...
Expand Down
29 changes: 25 additions & 4 deletions dwave/optimization/_model.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,24 @@ cdef class ArraySymbol(Symbol):
from dwave.optimization.symbols import Any # avoid circular import
return Any(self)

def copy(self):
"""Return an array symbol that is a copy of the array.

See Also:
:class:`~dwave.optimization.symbols.Copy` Equivalent class.

.. versionadded:: 0.5.1
"""
from dwave.optimization.symbols import Copy # avoid circular import
return Copy(self)

def flatten(self):
"""Return an array symbol collapsed into one dimension.

Equivalent to ``symbol.reshape(-1)``.
"""
return self.reshape(-1)

def max(self):
"""Create a :class:`~dwave.optimization.symbols.Max` symbol.

Expand Down Expand Up @@ -1350,11 +1368,14 @@ cdef class ArraySymbol(Symbol):
(1, 3)
"""
from dwave.optimization.symbols import Reshape # avoid circular import
if len(shape) > 1:
return Reshape(self, shape)
else:
return Reshape(self, shape[0])
if len(shape) <= 1:
shape = shape[0]

if not self.array_ptr.contiguous():
return Reshape(self.copy(), shape)

return Reshape(self, shape)

def shape(self):
"""Return the shape of the symbol.

Expand Down
9 changes: 5 additions & 4 deletions dwave/optimization/include/dwave-optimization/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,11 @@ class ArrayOutputMixin : public Base {
explicit ArrayOutputMixin(ssize_t n) : ArrayOutputMixin({n}) {}

explicit ArrayOutputMixin(std::initializer_list<ssize_t> shape)
: ndim_(shape.size()), shape_(make_shape(shape)) {}
: ArrayOutputMixin(std::span(shape)) {}

explicit ArrayOutputMixin(std::span<const ssize_t> shape)
: ndim_(shape.size()), shape_(make_shape(shape)) {}
template <std::ranges::sized_range Range>
explicit ArrayOutputMixin(Range&& shape)
: ndim_(shape.size()), shape_(make_shape(std::forward<Range>(shape))) {}

ssize_t ndim() const noexcept final { return ndim_; }

Expand All @@ -941,7 +942,7 @@ class ArrayOutputMixin : public Base {
constexpr bool contiguous() const noexcept final { return true; }

private:
template <class Range>
template <std::ranges::sized_range Range>
static std::unique_ptr<ssize_t[]> make_shape(Range&& shape) noexcept {
if (shape.size() == 0) return nullptr;
auto ptr = std::make_unique<ssize_t[]>(shape.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,55 @@ class ConcatenateNode : public ArrayOutputMixin<ArrayNode> {
std::vector<ssize_t> array_starts_;
};

/// An array node that is a contiguous copy of its predecessor.
class CopyNode : public ArrayOutputMixin<ArrayNode> {
public:
explicit CopyNode(ArrayNode* array_ptr);

/// @copydoc Array::buff()
double const* buff(const State& state) const override;

/// @copydoc Node::commit()
void commit(State& state) const override;

/// @copydoc Array::diff()
std::span<const Update> diff(const State& state) const override;

/// @copydoc Node::initialize_state()
void initialize_state(State& state) const override;

/// @copydoc Array::integral()
bool integral() const override;

/// @copydoc Array::max()
double max() const override;

/// @copydoc Array::min()
double min() const override;

/// @copydoc Node::propagate()
void propagate(State& state) const override;

/// @copydoc Node::revert()
void revert(State& state) const override;

using ArrayOutputMixin::shape;

/// @copydoc Array::shape()
std::span<const ssize_t> shape(const State& state) const override;

using ArrayOutputMixin::size;

/// @copydoc Array::size()
ssize_t size(const State& state) const override;

/// @copydoc Array::size_diff()
ssize_t size_diff(const State& state) const override;

private:
const Array* array_ptr_;
};

/// Replaces specified elements of an array with the given values.
///
/// The indexing works on the flattened array. Translated to NumPy, PutNode is
Expand Down Expand Up @@ -104,14 +153,43 @@ class PutNode : public ArrayOutputMixin<ArrayNode> {
const Array* values_ptr_;
};


/// Propagates the values of its predecessor, interpreted into a different shape.
class ReshapeNode : public ArrayOutputMixin<ArrayNode> {
public:
ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape);
/// Constructor for ReshapeNode.
///
/// @param array_ptr The array to be reshaped. May not be dynamic.
/// @param shape The new shape. Must have the same size as the original shape.
ReshapeNode(ArrayNode* array_ptr, std::vector<ssize_t>&& shape);

/// Constructor for ReshapeNode.
///
/// @param array_ptr The array to be reshaped. May not be dynamic.
/// @param shape The new shape. Must have the same size as the original shape.
template <std::ranges::range Range>
ReshapeNode(ArrayNode* node_ptr, Range&& shape)
: ReshapeNode(node_ptr, std::vector<ssize_t>(shape.begin(), shape.end())) {}

/// @copydoc Array::buff()
double const* buff(const State& state) const override;

/// @copydoc Node::commit()
void commit(State& state) const override;

/// @copydoc Array::diff()
std::span<const Update> diff(const State& state) const override;

/// @copydoc Array::integral()
bool integral() const override;

/// @copydoc Array::max()
double max() const override;

/// @copydoc Array::min()
double min() const override;

/// @copydoc Node::revert()
void revert(State& state) const override;

private:
Expand Down
1 change: 1 addition & 0 deletions dwave/optimization/libcpp/array.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ __all__ = ["Array"]
cdef extern from "dwave-optimization/array.hpp" namespace "dwave::optimization" nogil:
cdef cppclass Array:
double* buff(State&)
bint contiguous() const
bint dynamic() const
const string& format() const
Py_ssize_t itemsize() const
Expand Down
3 changes: 3 additions & 0 deletions dwave/optimization/libcpp/nodes.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ cdef extern from "dwave-optimization/nodes/manipulation.hpp" namespace "dwave::o
cdef cppclass ConcatenateNode(ArrayNode):
Py_ssize_t axis()

cdef cppclass CopyNode(ArrayNode):
pass

cdef cppclass PutNode(ArrayNode):
pass

Expand Down
8 changes: 8 additions & 0 deletions dwave/optimization/src/nodes/_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class ArrayStateData {
explicit ArrayStateData(std::vector<double>&& values) noexcept
: buffer(std::move(values)), previous_size_(buffer.size()) {}

template <std::ranges::range Range>
explicit ArrayStateData(Range&& values) noexcept
: ArrayStateData(std::vector<double>(values.begin(), values.end())) {}

// Assign new values to the state, tracking the changes from the previous state to the new
// one. Including resizes.
bool assign(std::ranges::sized_range auto&& values) {
Expand Down Expand Up @@ -193,6 +197,10 @@ class ArrayNodeStateData: public ArrayStateData, public NodeStateData {
explicit ArrayNodeStateData(std::vector<double>&& values) noexcept
: ArrayStateData(std::move(values)), NodeStateData() {}

template <std::ranges::range Range>
explicit ArrayNodeStateData(Range&& values) noexcept
: ArrayNodeStateData(std::vector<double>(values.begin(), values.end())) {}

std::unique_ptr<NodeStateData> copy() const override {
return std::make_unique<ArrayNodeStateData>(*this);
}
Expand Down
89 changes: 84 additions & 5 deletions dwave/optimization/src/nodes/manipulation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,50 @@ void ConcatenateNode::propagate(State& state) const {

void ConcatenateNode::revert(State& state) const { data_ptr<ArrayNodeStateData>(state)->revert(); }

CopyNode::CopyNode(ArrayNode* array_ptr)
: ArrayOutputMixin(array_ptr->shape()), array_ptr_(array_ptr) {
this->add_predecessor(array_ptr);
}

double const* CopyNode::buff(const State& state) const {
return data_ptr<ArrayNodeStateData>(state)->buff();
}

void CopyNode::commit(State& state) const { data_ptr<ArrayNodeStateData>(state)->commit(); }

std::span<const Update> CopyNode::diff(const State& state) const {
return data_ptr<ArrayNodeStateData>(state)->diff();
}

bool CopyNode::integral() const { return array_ptr_->integral(); }

void CopyNode::initialize_state(State& state) const {
int index = this->topological_index();
assert(index >= 0 && "must be topologically sorted");
assert(static_cast<int>(state.size()) > index && "unexpected state length");
assert(state[index] == nullptr && "already initialized state");

state[index] = std::make_unique<ArrayNodeStateData>(array_ptr_->view(state));
}

double CopyNode::max() const { return array_ptr_->max(); }

double CopyNode::min() const { return array_ptr_->min(); }

void CopyNode::propagate(State& state) const {
data_ptr<ArrayNodeStateData>(state)->update(array_ptr_->diff(state));
}

void CopyNode::revert(State& state) const { data_ptr<ArrayNodeStateData>(state)->revert(); }

std::span<const ssize_t> CopyNode::shape(const State& state) const {
return array_ptr_->shape(state);
}

ssize_t CopyNode::size(const State& state) const { return array_ptr_->size(state); }

ssize_t CopyNode::size_diff(const State& state) const { return array_ptr_->size_diff(state); }

// A PutNode needs to track its buffer as well as a mask of which elements in the
// original array are currently overwritten.
// We use ArrayStateData for the buffer
Expand Down Expand Up @@ -442,8 +486,34 @@ void PutNode::propagate(State& state) const {

void PutNode::revert(State& state) const { return data_ptr<PutNodeState>(state)->revert(); }

ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape)
: ArrayOutputMixin(shape), array_ptr_(node_ptr) {
// Reshape allows one shape dimension to be -1. In that case the size is inferred.
// We do that inference here.
std::vector<ssize_t> infer_reshape(Array* array_ptr, std::vector<ssize_t>&& shape) {
// if the base array is dynamic, we might allow the first dimension to be negative
// 1. So let's defer to the various constructors.
if (array_ptr->dynamic()) return shape;

// Check if there are any -1s, and if not fallback to other input checking.
auto it = std::ranges::find(shape, -1);
if (it == shape.end()) return shape;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Ensure there is at most one -1
if (++it != shape.end()) {
// Same error message as NumPy
throw std::invalid_argument("can only specify one unknown dimension");
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually gets checked by the ArrayOutputMixin because that doesn't allow any -1. So I am deferring the error to there. The error message is slightly less explicit, but IMO it's OK.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it checks that it's not dynamic, i.e. the first index is not negative, but that won't catch the case of multiple -1s...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I don't think it gets checked at all. Just returns DYNAMIC_SIZE if the first index is negative. So then the this->dynamic() check below would return true and throw an error if there are an even number of -1s in the shape, but not if there are an odd number

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a test case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4442ba1 adds a test case, it does correctly throw.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because it results in a negative or otherwise non-matching shape. So it fails on the check for consistency between Reshape's shape and its predecessor's shape.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I found an edge case, ReshapeNode(&A, {12, -1, -1}. Will fix.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by 6439503

// Get the product of the shape and negate it (to exclude the -1)
auto prod = -std::reduce(shape.begin(), shape.end(), 1, std::multiplies<ssize_t>());

// If the product is <=0, then we have another negative number or a 0. In which
// case we just fall back to other error checking.
if (prod <= 0) return shape;

// Ok, we can officially overwrite the -1.
// Don't worry about the case that prod doesn't divide array_ptr->size(), other
// error checking will catch that case.
*it = array_ptr->size() / prod;

return shape;
}

ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::vector<ssize_t>&& shape)
: ArrayOutputMixin(infer_reshape(node_ptr, std::move(shape))), array_ptr_(node_ptr) {
// Don't (yet) support non-contiguous predecessors.
// In some cases with non-contiguous predecessors we need to make a copy.
// See https://github.com/dwavesystems/dwave-optimization/issues/200
Expand All @@ -468,6 +538,12 @@ ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape)
throw std::invalid_argument("cannot reshape to a dynamic array");
}

// one -1 was already replaced by infer_shape
if (std::ranges::any_of(this->shape() | std::views::drop(1),
[](const ssize_t& dim) { return dim < 0; })) {
throw std::invalid_argument("can only specify one unknown dimension");
}

if (this->size() != array_ptr_->size()) {
// Use the same error message as NumPy
throw std::invalid_argument("cannot reshape array of size " +
Expand All @@ -478,9 +554,6 @@ ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::span<const ssize_t> shape)
this->add_predecessor(node_ptr);
}

ReshapeNode::ReshapeNode(ArrayNode* node_ptr, std::vector<ssize_t>&& shape)
: ReshapeNode(node_ptr, std::span(shape)) {}

double const* ReshapeNode::buff(const State& state) const { return array_ptr_->buff(state); }

void ReshapeNode::commit(State& state) const {} // stateless node
Expand All @@ -489,6 +562,12 @@ std::span<const Update> ReshapeNode::diff(const State& state) const {
return array_ptr_->diff(state);
}

bool ReshapeNode::integral() const { return array_ptr_->integral(); }

double ReshapeNode::max() const { return array_ptr_->max(); }

double ReshapeNode::min() const { return array_ptr_->min(); }

void ReshapeNode::revert(State& state) const {} // stateless node

class SizeNodeData : public ScalarNodeStateData {
Expand Down
4 changes: 4 additions & 0 deletions dwave/optimization/symbols.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class Constant(ArraySymbol):
def __le__(self, rhs: numpy.typing.ArrayLike) -> numpy.typing.NDArray[numpy.bool]: ...


class Copy(ArraySymbol):
...


class DisjointBitSets(Symbol):
def set_state(self, index: int, state: numpy.typing.ArrayLike): ...

Expand Down
Loading