Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permutation #189

Draft
wants to merge 30 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b17edb8
backup [skip ci]
ryanmrichard Dec 5, 2024
e5cd537
backup [skip ci]
ryanmrichard Dec 6, 2024
c5b1612
downcast works
ryanmrichard Dec 8, 2024
1a923c8
starts dispatcher test
ryanmrichard Dec 8, 2024
0857a7c
Committing clang-format changes
github-actions[bot] Dec 8, 2024
4b06596
backup
ryanmrichard Dec 9, 2024
bd66615
removes old attempt
ryanmrichard Dec 9, 2024
0316e9c
refactor changes
ryanmrichard Dec 9, 2024
3beda27
adds rebind to the allocator
ryanmrichard Dec 10, 2024
ede020c
addition_assignment tested
ryanmrichard Dec 10, 2024
763de80
Committing clang-format changes
github-actions[bot] Dec 10, 2024
195509c
buffer base works for addition
ryanmrichard Dec 10, 2024
8bb09ee
Merge branch 'basic_math' of https://github.com/NWChemEx/TensorWrappe…
ryanmrichard Dec 10, 2024
1a68ba1
permutation almost works...
ryanmrichard Dec 11, 2024
b27c4d1
Committing clang-format changes
github-actions[bot] Dec 11, 2024
4919325
punt on permutation
ryanmrichard Dec 12, 2024
e3d921a
remove unused variable
ryanmrichard Dec 12, 2024
5ab8615
rename parser, try to fix gcc error
ryanmrichard Dec 12, 2024
5fd4cc7
and another conversion...
ryanmrichard Dec 12, 2024
ad900a7
refactor
ryanmrichard Dec 12, 2024
0532864
Merge branch 'master' into permutation
ryanmrichard Dec 12, 2024
31758f3
fix conflict
ryanmrichard Dec 12, 2024
cb9cc4e
backup [skip ci]
ryanmrichard Dec 12, 2024
8ed497a
back up [skip ci]
ryanmrichard Dec 13, 2024
eb6354d
Merge branch 'master' into permutation
ryanmrichard Dec 13, 2024
2382e84
Committing clang-format changes
github-actions[bot] Dec 14, 2024
acf1bf5
backup [skip ci]
ryanmrichard Dec 15, 2024
d0901b6
Merge branch 'permutation' of https://github.com/NWChemEx/TensorWrapp…
ryanmrichard Dec 15, 2024
dcda919
backup [skip ci]
ryanmrichard Dec 31, 2024
2c142f9
backup [skip ci]
ryanmrichard Jan 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/tensorwrapper/allocator/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ class Eigen : public Replicated {
return std::make_unique<my_type>(*this);
}

base_reference assign_(const_base_reference rhs) override {
return my_base_type::assign_impl_<my_type>(rhs);
}

/// Implements are_equal, by deferring to the base's operator==
bool are_equal_(const_base_reference rhs) const noexcept override {
return my_base_type::are_equal_impl_<my_type>(rhs);
Expand Down
132 changes: 71 additions & 61 deletions include/tensorwrapper/buffer/buffer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <tensorwrapper/detail_/dsl_base.hpp>
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/dsl/labeled.hpp>
#include <tensorwrapper/layout/layout_base.hpp>
Expand All @@ -25,7 +26,8 @@ namespace tensorwrapper::buffer {
*
* All classes which wrap existing tensor libraries derive from this class.
*/
class BufferBase : public detail_::PolymorphicBase<BufferBase> {
class BufferBase : public detail_::PolymorphicBase<BufferBase>,
public detail_::DSLBase<BufferBase> {
private:
/// Type of *this
using my_type = BufferBase;
Expand Down Expand Up @@ -60,13 +62,16 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
using layout_pointer = typename layout_type::layout_pointer;

/// Type of labels for making a labeled buffer
using label_type = std::string;
using string_type = std::string;

/// Type of a labeled buffer
using labeled_buffer_type = dsl::Labeled<buffer_base_type, label_type>;
using labeled_buffer_type = dsl::Labeled<buffer_base_type, string_type>;

/// Type of a labeled read-only buffer (n.b. labels are mutable)
using labeled_const_buffer_type = dsl::Labeled<const buffer_base_type>;
using labeled_const_buffer_type =
dsl::Labeled<const buffer_base_type, string_type>;

using label_type = typename labeled_buffer_type::label_type;

/// Type of a read-only reference to a labeled_buffer_type object
using const_labeled_buffer_reference = const labeled_const_buffer_type&;
Expand Down Expand Up @@ -116,13 +121,13 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
* @param[in] this_labels The labels to associate with the modes of *this.
* @param[in] rhs The buffer to add into *this.
*
* @throws std::runtimer_error if *this does not have a layout. Strong
* throw guarantee.
* @throws ??? Throws if the derived class's implementation throws. Same
* throw guarantee.
*/
buffer_base_reference addition_assignment(
label_type this_labels, const_labeled_buffer_reference rhs) {
return addition_assignment_(std::move(this_labels), rhs);
}
// buffer_base_reference addition_assignment(
// label_type this_labels, const_labeled_buffer_reference rhs);

/** @brief Returns the result of *this + rhs.
*
Expand All @@ -139,12 +144,12 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
* @throw ??? If addition_assignment throws when adding @p rhs to the
* copy of *this. Same throw guarantee.
*/
buffer_base_pointer addition(label_type this_labels,
const_labeled_buffer_reference rhs) const {
auto pthis = clone();
pthis->addition_assignment(std::move(this_labels), rhs);
return pthis;
}
// buffer_base_pointer addition(label_type this_labels,
// const_labeled_buffer_reference rhs) const {
// auto pthis = clone();
// pthis->addition_assignment(std::move(this_labels), rhs);
// return pthis;
// }

/** @brief Sets *this to a permutation of @p rhs.
*
Expand All @@ -166,10 +171,8 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
* @throw ??? If the derived class's implementation of permute_assignment_
* throws. Same throw guarantee.
*/
buffer_base_reference permute_assignment(
label_type this_labels, const_labeled_buffer_reference rhs) {
return permute_assignment_(std::move(this_labels), rhs);
}
// buffer_base_reference permute_assignment(
// label_type this_labels, const_labeled_buffer_reference rhs);

/** @brief Returns a copy of *this obtained by permuting *this.
*
Expand All @@ -186,44 +189,17 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
* @throw ??? If the derived class's implementation of permute_assignment_
* throws. Same throw guarantee.
*/
buffer_base_pointer permute(label_type this_labels,
label_type out_labels) const {
auto pthis = clone();
pthis->permute_assignment(std::move(out_labels), (*this)(this_labels));
return pthis;
}
// buffer_base_pointer permute(label_type this_labels,
// label_type out_labels) const {
// auto pthis = clone();
// pthis->permute_assignment(std::move(out_labels),
// (*this)(this_labels)); return pthis;
// }

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------

/** @brief Associates labels with the modes of *this.
*
* This method is used to create a labeled buffer object by pairing *this
* with the provided labels. The resulting object is capable of being
* composed via the DSL.
*
* @param[in] labels The indices to associate with the modes of *this.
*
* @return A DSL term pairing *this with @p labels.
*
* @throw None No throw guarantee.
*/
labeled_buffer_type operator()(label_type labels);

/** @brief Associates labels with the modes of *this.
*
* This method is the same as the non-const version except that the result
* contains a read-only reference to *this.
*
* @param[in] labels The labels to associate with *this.
*
* @return A DSL term pairing *this with @p labels.
*
* @throw None No throw guarantee.
*/
labeled_const_buffer_type operator()(label_type labels) const;

/** @brief Is *this value equal to @p rhs?
*
* Two BufferBase objects are value equal if the layouts they contain are
Expand Down Expand Up @@ -321,17 +297,51 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
return *this;
}

/// Derived class should overwrite to implement addition_assignment
virtual buffer_base_reference addition_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) {
throw std::runtime_error("Addition assignment NYI");
}
/** @brief Overridden by derived classes to implement addition_assignment
*
* BufferBase will take care of addition_assignment on the layout member.
* The derived class is responsible for performing the addition_assignment
* on the actual tensor elements in a manner that is consistent with the
* description of addition_assignment.
*
* @param[in] this_labels The dummy indices for *this.
* @param[in] rhs The labeled buffer to add into *this.
*
* @return *this after the operation.
*
* @throw std::runtime_error if the default implementation is called
* because the derived class does not overload it
* (or because the derived class directly called
* it). Strong throw guarantee (in the first
* scenario).
*/
// virtual buffer_base_reference addition_assignment_(
// label_type this_labels, const_labeled_buffer_reference rhs) {
// throw std::runtime_error("Addition assignment NYI");
// }

/// Derived class should overwrite to implement permute_assignment
virtual buffer_base_reference permute_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) {
throw std::runtime_error("Permute assignment NYI");
}
/** @brief Overridden by derived classes to implement permute_assignment
*
* BufferBase will take care of permute_assignment on the layout member.
* The derived class is responsible for performing the permute_assignment
* on the actual tensor elements in a manner that is consistent with the
* description of permute_assignment.
*
* @param[in] this_labels The dummy indices for *this.
* @param[in] rhs The labeled buffer to permute into *this.
*
* @return *this after the operation.
*
* @throw std::runtime_error if the default implementation is called
* because the derived class does not overload it
* (or because the derived class directly called
* it). Strong throw guarantee (in the first
* scenario).
*/
// virtual buffer_base_reference permute_assignment_(
// label_type this_labels, const_labeled_buffer_reference rhs) {
// throw std::runtime_error("Permute assignment NYI");
// }

private:
/// Throws std::runtime_error when there is no layout
Expand Down
15 changes: 10 additions & 5 deletions include/tensorwrapper/buffer/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ class Eigen : public Replicated {
/// Pull in base class's types
using typename my_base_type::buffer_base_pointer;
using typename my_base_type::const_buffer_base_reference;
using typename my_base_type::const_labeled_buffer_reference;
using typename my_base_type::const_labeled_reference;
using typename my_base_type::const_layout_reference;
using typename my_base_type::dsl_reference;
using typename my_base_type::label_type;

/// Type of a rank @p Rank tensor using floats of type @p FloatType
Expand Down Expand Up @@ -177,18 +178,22 @@ class Eigen : public Replicated {
return std::make_unique<my_type>(*this);
}

buffer_base_reference assign_(const_buffer_base_reference rhs) override {
return my_base_type::assign_impl_<my_type>(rhs);
}

/// Implements are_equal by calling are_equal_impl_
bool are_equal_(const_buffer_base_reference rhs) const noexcept override {
return my_base_type::are_equal_impl_<my_type>(rhs);
}

/// Implements addition_assignment by rebinding rhs to an Eigen buffer
buffer_base_reference addition_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) override;
dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

/// Implements permute assignment by deferring to Eigen's shuffle command.
buffer_base_reference permute_assignment_(
label_type this_labels, const_labeled_buffer_reference rhs) override;
dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

/// Implements to_string
typename my_base_type::string_type to_string_() const override;
Expand Down
Loading