From d12675aa7b782e087f8e15a3b042ed31b861569a Mon Sep 17 00:00:00 2001 From: "Ryan M. Richard" Date: Mon, 3 Feb 2025 15:56:18 -0600 Subject: [PATCH 1/5] factor contraction logic --- .../buffer/contraction_planner.hpp | 42 +++++++ .../buffer/contraction_planner.cpp | 108 ++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 src/tensorwrapper/buffer/contraction_planner.hpp create mode 100644 tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp diff --git a/src/tensorwrapper/buffer/contraction_planner.hpp b/src/tensorwrapper/buffer/contraction_planner.hpp new file mode 100644 index 00000000..b9247040 --- /dev/null +++ b/src/tensorwrapper/buffer/contraction_planner.hpp @@ -0,0 +1,42 @@ +#pragma once +#include + +namespace tensorwrapper::buffer { + +/** @brief Class for working out details pertaining to a tensor contraction. + * + * + */ +class ContractionPlanner { +public: + using string_type = std::string; + using label_type = dsl::DummyIndices; + + ContractionPlanner(string_type result, string_type lhs, string_type rhs) : + ContractionPlanner(label_type(result), label_type(lhs), label_type(rhs)) { + } + + ContractionPlanner(label_type result, label_type lhs, label_type rhs) : + m_result_(std::move(result)), + m_lhs_(std::move(lhs)), + m_rhs_(std::move(rhs)) {} + + /// Labels in LHS that are NOT summed over + label_type lhs_free() const { return m_lhs_.intersection(m_result_); } + + /// Labels in RHS that are NOT summed over + label_type rhs_free() const { return m_rhs_.intersection(m_result_); } + + /// Labels in LHS that ARE summed over + label_type lhs_dummy() const { return m_lhs_.difference(m_result_); } + + /// Labels in RHS that ARE summed over + label_type rhs_dummy() const { return m_rhs_.difference(m_result_) } + +private: + label_type m_result_; + label_type m_lhs_; + label_type m_rhs_; +}; + +} // namespace tensorwrapper::buffer \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp new file mode 100644 index 00000000..f42c3c9b --- /dev/null +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp @@ -0,0 +1,108 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../testing/testing.hpp" +#include +#include + +using namespace tensorwrapper; +using namespace buffer; + +TEST_CASE("ContractionPlanner") { + // All scalar + ContractionPlanner cp___("", "", ""); + + // Scalar times vector + ContractionPlanner cp___i("", "", "i"); + ContractionPlanner cp__i_("", "i", ""); + + // Vector times vector + ContractionPlanner cp__i_i("", "i", "i"); + ContractionPlanner cp_i_i_i("i", "i", "i"); + ContractionPlanner cp_i_i_j("i", "i", "j"); + ContractionPlanner cp_ij_i_j("i,j", "i", "j"); + ContractionPlanner cp_ji_i_j("j,i", "i", "j"); + ContractionPlanner cp_i_j_i("i", "j", "i"); + + // Vector times matrix + ContractionPlanner cp_i_i_ij("i", "i", "i,j"); + ContractionPlanner cp_j_i_ij("j", "i", "i,j"); + ContractionPlanner cp_i_i_ji("i", "i", "j,i"); + ContractionPlanner cp_j_i_ji("j", "i", "j,i"); + ContractionPlanner cp_ij_i_ij("i,j", "i", "i,j"); + ContractionPlanner cp_ji_i_ij("j,i", "i", "i,j"); + ContractionPlanner cp_ij_i_ji("i,j", "i", "j,i"); + ContractionPlanner cp_ji_i_ji("j,i", "i", "j,i"); + + // Tensor times tensor + ContractionPlanner cp_ij_ijk_ikj("i,j", "i,j,k", "i,k,j"); + ContractionPlanner cp_iljm_ikj_lmk("i,l,j,m", "i,k,j", "l,m,k"); + + // These are invalid + // ContractionPlanner cpi__("i", "", ""); + + SECTION("lhs_free") { + REQUIRE(cp___.lhs_free() == ""); + + REQUIRE(cp___i.lhs_free() == ""); + REQUIRE(cp__i_.lhs_free() == ""); + + REQUIRE(cp__i_i.lhs_free() == ""); + REQUIRE(cp_i_i_i.lhs_free() == "i"); + REQUIRE(cp_i_i_j.lhs_free() == "i"); + REQUIRE(cp_ij_i_j.lhs_free() == "i"); + REQUIRE(cp_ji_i_j.lhs_free() == "i"); + REQUIRE(cp_i_j_i.lhs_free() == ""); + + REQUIRE(cp_i_i_ij.lhs_free() == "i"); + REQUIRE(cp_j_i_ij.lhs_free() == ""); + REQUIRE(cp_i_i_ji.lhs_free() == "i"); + REQUIRE(cp_j_i_ji.lhs_free() == ""); + REQUIRE(cp_ij_i_ij.lhs_free() == "i"); + REQUIRE(cp_ji_i_ij.lhs_free() == "i"); + REQUIRE(cp_ij_i_ji.lhs_free() == "i"); + REQUIRE(cp_ji_i_ji.lhs_free() == "i"); + + REQUIRE(cp_ij_ijk_ikj.lhs_free() == "i,j"); + REQUIRE(cp_iljm_ikj_lmk.lhs_free() == "i,j"); + } + + SECTION("rhs_free") { + REQUIRE(cp___.rhs_free() == ""); + + REQUIRE(cp___i.rhs_free() == ""); + REQUIRE(cp__i_.rhs_free() == ""); + + REQUIRE(cp__i_i.rhs_free() == ""); + REQUIRE(cp_i_i_i.rhs_free() == "i"); + REQUIRE(cp_i_i_j.rhs_free() == ""); + REQUIRE(cp_ij_i_j.rhs_free() == "j"); + REQUIRE(cp_ji_i_j.rhs_free() == "j"); + REQUIRE(cp_i_j_i.rhs_free() == "i"); + + REQUIRE(cp_i_i_ij.rhs_free() == "i"); + REQUIRE(cp_j_i_ij.rhs_free() == "j"); + REQUIRE(cp_i_i_ji.rhs_free() == "i"); + REQUIRE(cp_j_i_ji.rhs_free() == "j"); + REQUIRE(cp_ij_i_ij.rhs_free() == "i,j"); + REQUIRE(cp_ji_i_ij.rhs_free() == "i,j"); + REQUIRE(cp_ij_i_ji.rhs_free() == "j,i"); + REQUIRE(cp_ji_i_ji.rhs_free() == "j,i"); + + REQUIRE(cp_ij_ijk_ikj.rhs_free() == "i,j"); + REQUIRE(cp_iljm_ikj_lmk.rhs_free() == "l,m"); + } +} From 0d351cfccda60032402553187b5371f42f3b7805 Mon Sep 17 00:00:00 2001 From: "Ryan M. Richard" Date: Tue, 4 Feb 2025 22:16:06 -0600 Subject: [PATCH 2/5] contract works --- include/tensorwrapper/buffer/eigen.hpp | 3 + include/tensorwrapper/dsl/dummy_indices.hpp | 18 +- .../buffer/contraction_planner.hpp | 66 +++++- src/tensorwrapper/buffer/eigen.cpp | 58 ++++++ .../buffer/contraction_planner.cpp | 188 +++++++++++++----- 5 files changed, 273 insertions(+), 60 deletions(-) diff --git a/include/tensorwrapper/buffer/eigen.hpp b/include/tensorwrapper/buffer/eigen.hpp index d9aca697..d57d7587 100644 --- a/include/tensorwrapper/buffer/eigen.hpp +++ b/include/tensorwrapper/buffer/eigen.hpp @@ -169,6 +169,9 @@ class Eigen : public Replicated { */ bool operator!=(const Eigen& rhs) const noexcept { return !(*this == rhs); } + FloatType* data() { return m_tensor_.data(); }; + const FloatType* data() const { return m_tensor_.data(); } + protected: /// Implements clone by calling copy ctor buffer_base_pointer clone_() const override { diff --git a/include/tensorwrapper/dsl/dummy_indices.hpp b/include/tensorwrapper/dsl/dummy_indices.hpp index 1a557dff..b4aac7f6 100644 --- a/include/tensorwrapper/dsl/dummy_indices.hpp +++ b/include/tensorwrapper/dsl/dummy_indices.hpp @@ -110,6 +110,15 @@ class DummyIndices DummyIndices( utilities::strings::split_string(remove_spaces_(dummy_indices), ",")) {} + /// Main ctor for setting the value, throws if any index is empty + explicit DummyIndices(split_string_type split_dummy_indices) : + m_dummy_indices_(std::move(split_dummy_indices)) { + for(const auto& x : m_dummy_indices_) + if(x.empty()) + throw std::runtime_error( + "Dummy index is not allowed to be empty"); + } + /** @brief Determines the number of unique indices in *this. * * A dummy index can be repeated if it is going to be summed over. This @@ -437,15 +446,6 @@ class DummyIndices } protected: - /// Main ctor for setting the value, throws if any index is empty - explicit DummyIndices(split_string_type split_dummy_indices) : - m_dummy_indices_(std::move(split_dummy_indices)) { - for(const auto& x : m_dummy_indices_) - if(x.empty()) - throw std::runtime_error( - "Dummy index is not allowed to be empty"); - } - /// Lets the base class get at these implementations friend base_type; diff --git a/src/tensorwrapper/buffer/contraction_planner.hpp b/src/tensorwrapper/buffer/contraction_planner.hpp index b9247040..d3bb8a11 100644 --- a/src/tensorwrapper/buffer/contraction_planner.hpp +++ b/src/tensorwrapper/buffer/contraction_planner.hpp @@ -5,12 +5,16 @@ namespace tensorwrapper::buffer { /** @brief Class for working out details pertaining to a tensor contraction. * - * + * N.B. Contraction covers direct product (which is a special case of + * contraction with 0 dummy indices). */ class ContractionPlanner { public: + /// String type users use to label modes using string_type = std::string; - using label_type = dsl::DummyIndices; + + /// Type of the parsed labels + using label_type = dsl::DummyIndices; ContractionPlanner(string_type result, string_type lhs, string_type rhs) : ContractionPlanner(label_type(result), label_type(lhs), label_type(rhs)) { @@ -19,7 +23,11 @@ class ContractionPlanner { ContractionPlanner(label_type result, label_type lhs, label_type rhs) : m_result_(std::move(result)), m_lhs_(std::move(lhs)), - m_rhs_(std::move(rhs)) {} + m_rhs_(std::move(rhs)) { + assert_no_repeated_indices_(); + assert_dummy_indices_are_similar_(); + assert_no_shared_free_(); + } /// Labels in LHS that are NOT summed over label_type lhs_free() const { return m_lhs_.intersection(m_result_); } @@ -31,9 +39,59 @@ class ContractionPlanner { label_type lhs_dummy() const { return m_lhs_.difference(m_result_); } /// Labels in RHS that ARE summed over - label_type rhs_dummy() const { return m_rhs_.difference(m_result_) } + label_type rhs_dummy() const { return m_rhs_.difference(m_result_); } + + /** @brief LHS permuted so free indices are followed by dummy indices. */ + label_type lhs_permutation() const { + using split_string_type = typename label_type::split_string_type; + split_string_type rv; + auto lfree = lhs_free(); + auto ldummy = lhs_dummy(); + for(const auto& freei : m_result_) { + if(!lfree.count(freei)) continue; + rv.push_back(freei); + } + for(const auto& dummyi : ldummy) rv.push_back(dummyi); + return label_type(std::move(rv)); + } + + /** @brief RHS permuted so dummy indices are followed by free indices. */ + label_type rhs_permutation() const { + typename label_type::split_string_type rv; + auto rfree = rhs_free(); + auto rdummy = lhs_dummy(); // Use LHS dummy to get the same order! + for(const auto& dummyi : rdummy) + rv.push_back(dummyi); // Know it only appears 1x + for(const auto& freei : m_result_) { + if(!rfree.count(freei)) continue; + rv.push_back(freei); // Know it only appears 1x + } + return label_type(std::move(rv)); + } private: + /// Ensures no tensor contains a repeated label + void assert_no_repeated_indices_() const { + const bool result_good = !m_result_.has_repeated_indices(); + const bool lhs_good = !m_lhs_.has_repeated_indices(); + const bool rhs_good = !m_rhs_.has_repeated_indices(); + + if(result_good && lhs_good && rhs_good) return; + throw std::runtime_error("One or more terms contain repeated labels"); + } + + /// Ensures the dummy indices are permutations of each other + void assert_dummy_indices_are_similar_() const { + if(lhs_dummy().is_permutation(rhs_dummy())) return; + throw std::runtime_error("Dummy indices must appear in all terms"); + } + + /// Asserts LHS and RHS do not share free indices, which is Hadamard-product + void assert_no_shared_free_() const { + if(!lhs_free().intersection(rhs_free()).size()) return; + throw std::runtime_error("Contraction must sum repeated indices"); + } + label_type m_result_; label_type m_lhs_; label_type m_rhs_; diff --git a/src/tensorwrapper/buffer/eigen.cpp b/src/tensorwrapper/buffer/eigen.cpp index 9980ef5b..2abce3b5 100644 --- a/src/tensorwrapper/buffer/eigen.cpp +++ b/src/tensorwrapper/buffer/eigen.cpp @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "contraction_planner.hpp" #include "eigen_contraction.hpp" #include #include @@ -24,6 +25,20 @@ namespace tensorwrapper::buffer { #define TPARAMS template #define EIGEN Eigen +template +FloatType* get_data(TensorType& tensor) { + using allocator_type = allocator::Eigen; + if constexpr(Rank > 10) { + throw std::runtime_error("Tensors with rank > 10 are not supported"); + } else { + if(tensor.layout().rank() == Rank) { + return allocator_type::rebind(tensor).data(); + } else { + return get_data(tensor); + } + } +} + using const_labeled_reference = typename Eigen::const_labeled_reference; using dsl_reference = typename Eigen::dsl_reference; @@ -231,6 +246,48 @@ TPARAMS typename EIGEN::dsl_reference EIGEN::contraction_( const auto& rlabels = rhs.labels(); const auto& robject = rhs.object(); + ContractionPlanner plan(this_labels, llabels, rlabels); + auto lt = lobject.clone(); + auto rt = robject.clone(); + lt->permute_assignment(plan.lhs_permutation(), lhs); + rt->permute_assignment(plan.rhs_permutation(), rhs); + + const auto ndummy = plan.lhs_dummy().size(); + const auto lshape = lt->layout().shape().as_smooth(); + const auto rshape = rt->layout().shape().as_smooth(); + const auto oshape = layout().shape().as_smooth(); + const auto lfree = lshape.rank() - ndummy; + std::size_t lrows = lshape.rank() ? 1 : 0; + std::size_t lcols = lshape.rank() ? 1 : 0; + + for(std::size_t i = 0; i < lfree; ++i) lrows *= lshape.extent(i); + for(std::size_t i = lfree; i < lshape.rank(); ++i) + lcols *= lshape.extent(i); + + std::size_t rrows = rshape.rank() ? 1 : 0; + std::size_t rcols = rshape.rank() ? 1 : 0; + + for(std::size_t i = 0; i < ndummy; ++i) rrows *= rshape.extent(i); + for(std::size_t i = ndummy; i < rshape.rank(); ++i) + rcols *= rshape.extent(i); + + using matrix_t = ::Eigen::Matrix; + using map_t = ::Eigen::Map; + + typename Eigen::data_type buffer(lrows, rcols); + + map_t lmatrix(get_data(*lt), lrows, lcols); + map_t rmatrix(get_data(*rt), rrows, rcols); + map_t omatrix(buffer.data(), lrows, rcols); + omatrix = lmatrix * rmatrix; + + std::array out_size; + for(std::size_t i = 0; i < Rank; ++i) out_size[i] = oshape.extent(i); + m_tensor_ = buffer.reshape(out_size); + return *this; + + /* Doesn't work with Sigma // N.b. is a pure contraction, so common indices are summed over auto common = llabels.intersection(rlabels); @@ -248,6 +305,7 @@ TPARAMS typename EIGEN::dsl_reference EIGEN::contraction_( } return eigen_contraction(*this, lobject, robject, modes); + */ } #undef EIGEN diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp index f42c3c9b..9658a8cd 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp @@ -25,84 +25,178 @@ TEST_CASE("ContractionPlanner") { // All scalar ContractionPlanner cp___("", "", ""); - // Scalar times vector - ContractionPlanner cp___i("", "", "i"); - ContractionPlanner cp__i_("", "i", ""); - // Vector times vector ContractionPlanner cp__i_i("", "i", "i"); - ContractionPlanner cp_i_i_i("i", "i", "i"); - ContractionPlanner cp_i_i_j("i", "i", "j"); ContractionPlanner cp_ij_i_j("i,j", "i", "j"); ContractionPlanner cp_ji_i_j("j,i", "i", "j"); - ContractionPlanner cp_i_j_i("i", "j", "i"); // Vector times matrix - ContractionPlanner cp_i_i_ij("i", "i", "i,j"); ContractionPlanner cp_j_i_ij("j", "i", "i,j"); - ContractionPlanner cp_i_i_ji("i", "i", "j,i"); ContractionPlanner cp_j_i_ji("j", "i", "j,i"); - ContractionPlanner cp_ij_i_ij("i,j", "i", "i,j"); - ContractionPlanner cp_ji_i_ij("j,i", "i", "i,j"); - ContractionPlanner cp_ij_i_ji("i,j", "i", "j,i"); - ContractionPlanner cp_ji_i_ji("j,i", "i", "j,i"); + ContractionPlanner cp_ijk_i_jk("i,j,k", "i", "j,k"); + ContractionPlanner cp_ijk_i_kj("i,j,k", "i", "k,j"); + + // Matrix times matrix + ContractionPlanner cp_ij_ik_kj("i,j", "i,k", "k,j"); + ContractionPlanner cp_ji_ik_kj("j,i", "i,k", "k,j"); + ContractionPlanner cp_ij_ik_jk("i,j", "i,k", "j,k"); + ContractionPlanner cp_ji_ik_jk("j,i", "i,k", "j,k"); + + // 3 times 3 + ContractionPlanner cp__ijk_ijk("", "i,j,k", "i,j,k"); + ContractionPlanner cp__ijk_jik("", "i,j,k", "j,i,k"); + ContractionPlanner cp_il_ijk_jkl("i,l", "i,j,k", "j,k,l"); + ContractionPlanner cp_il_ijk_klj("i,l", "i,j,k", "k,l,j"); + + SECTION("Ctors") { + using error_t = std::runtime_error; - // Tensor times tensor - ContractionPlanner cp_ij_ijk_ikj("i,j", "i,j,k", "i,k,j"); - ContractionPlanner cp_iljm_ikj_lmk("i,l,j,m", "i,k,j", "l,m,k"); + // Can't contain repeated indices + REQUIRE_THROWS_AS(ContractionPlanner("j", "i,i", "j"), error_t); + REQUIRE_THROWS_AS(ContractionPlanner("j", "j", "i,i"), error_t); - // These are invalid - // ContractionPlanner cpi__("i", "", ""); + // Can't require trace of a tensor + REQUIRE_THROWS_AS(ContractionPlanner("", "", "i"), error_t); + + // Can't contain Hadamard + REQUIRE_THROWS_AS(ContractionPlanner("i", "i", "i"), error_t); + } SECTION("lhs_free") { REQUIRE(cp___.lhs_free() == ""); - REQUIRE(cp___i.lhs_free() == ""); - REQUIRE(cp__i_.lhs_free() == ""); - REQUIRE(cp__i_i.lhs_free() == ""); - REQUIRE(cp_i_i_i.lhs_free() == "i"); - REQUIRE(cp_i_i_j.lhs_free() == "i"); REQUIRE(cp_ij_i_j.lhs_free() == "i"); REQUIRE(cp_ji_i_j.lhs_free() == "i"); - REQUIRE(cp_i_j_i.lhs_free() == ""); - REQUIRE(cp_i_i_ij.lhs_free() == "i"); REQUIRE(cp_j_i_ij.lhs_free() == ""); - REQUIRE(cp_i_i_ji.lhs_free() == "i"); REQUIRE(cp_j_i_ji.lhs_free() == ""); - REQUIRE(cp_ij_i_ij.lhs_free() == "i"); - REQUIRE(cp_ji_i_ij.lhs_free() == "i"); - REQUIRE(cp_ij_i_ji.lhs_free() == "i"); - REQUIRE(cp_ji_i_ji.lhs_free() == "i"); - - REQUIRE(cp_ij_ijk_ikj.lhs_free() == "i,j"); - REQUIRE(cp_iljm_ikj_lmk.lhs_free() == "i,j"); + REQUIRE(cp_ijk_i_jk.lhs_free() == "i"); + REQUIRE(cp_ijk_i_kj.lhs_free() == "i"); + + REQUIRE(cp_ij_ik_kj.lhs_free() == "i"); + REQUIRE(cp_ji_ik_kj.lhs_free() == "i"); + REQUIRE(cp_ij_ik_jk.lhs_free() == "i"); + REQUIRE(cp_ji_ik_jk.lhs_free() == "i"); + + REQUIRE(cp__ijk_ijk.lhs_free() == ""); + REQUIRE(cp__ijk_jik.lhs_free() == ""); + REQUIRE(cp_il_ijk_jkl.lhs_free() == "i"); + REQUIRE(cp_il_ijk_klj.lhs_free() == "i"); } SECTION("rhs_free") { REQUIRE(cp___.rhs_free() == ""); - REQUIRE(cp___i.rhs_free() == ""); - REQUIRE(cp__i_.rhs_free() == ""); - REQUIRE(cp__i_i.rhs_free() == ""); - REQUIRE(cp_i_i_i.rhs_free() == "i"); - REQUIRE(cp_i_i_j.rhs_free() == ""); REQUIRE(cp_ij_i_j.rhs_free() == "j"); REQUIRE(cp_ji_i_j.rhs_free() == "j"); - REQUIRE(cp_i_j_i.rhs_free() == "i"); - REQUIRE(cp_i_i_ij.rhs_free() == "i"); REQUIRE(cp_j_i_ij.rhs_free() == "j"); - REQUIRE(cp_i_i_ji.rhs_free() == "i"); REQUIRE(cp_j_i_ji.rhs_free() == "j"); - REQUIRE(cp_ij_i_ij.rhs_free() == "i,j"); - REQUIRE(cp_ji_i_ij.rhs_free() == "i,j"); - REQUIRE(cp_ij_i_ji.rhs_free() == "j,i"); - REQUIRE(cp_ji_i_ji.rhs_free() == "j,i"); + REQUIRE(cp_ijk_i_jk.rhs_free() == "j,k"); + REQUIRE(cp_ijk_i_kj.rhs_free() == "k,j"); + + REQUIRE(cp_ij_ik_kj.rhs_free() == "j"); + REQUIRE(cp_ji_ik_kj.rhs_free() == "j"); + REQUIRE(cp_ij_ik_jk.rhs_free() == "j"); + REQUIRE(cp_ji_ik_jk.rhs_free() == "j"); + + REQUIRE(cp__ijk_ijk.rhs_free() == ""); + REQUIRE(cp__ijk_jik.rhs_free() == ""); + REQUIRE(cp_il_ijk_jkl.rhs_free() == "l"); + REQUIRE(cp_il_ijk_klj.rhs_free() == "l"); + } + + SECTION("lhs_dummy") { + REQUIRE(cp___.lhs_dummy() == ""); + + REQUIRE(cp__i_i.lhs_dummy() == "i"); + REQUIRE(cp_ij_i_j.lhs_dummy() == ""); + REQUIRE(cp_ji_i_j.lhs_dummy() == ""); + + REQUIRE(cp_j_i_ij.lhs_dummy() == "i"); + REQUIRE(cp_j_i_ji.lhs_dummy() == "i"); + REQUIRE(cp_ijk_i_jk.lhs_dummy() == ""); + REQUIRE(cp_ijk_i_kj.lhs_dummy() == ""); + + REQUIRE(cp_ij_ik_kj.lhs_dummy() == "k"); + REQUIRE(cp_ji_ik_kj.lhs_dummy() == "k"); + REQUIRE(cp_ij_ik_jk.lhs_dummy() == "k"); + REQUIRE(cp_ji_ik_jk.lhs_dummy() == "k"); + + REQUIRE(cp__ijk_ijk.lhs_dummy() == "i,j,k"); + REQUIRE(cp__ijk_jik.lhs_dummy() == "i,j,k"); + REQUIRE(cp_il_ijk_jkl.lhs_dummy() == "j,k"); + REQUIRE(cp_il_ijk_klj.lhs_dummy() == "j,k"); + } + + SECTION("rhs_dummy") { + REQUIRE(cp___.rhs_dummy() == ""); + + REQUIRE(cp__i_i.rhs_dummy() == "i"); + REQUIRE(cp_ij_i_j.rhs_dummy() == ""); + REQUIRE(cp_ji_i_j.rhs_dummy() == ""); + + REQUIRE(cp_j_i_ij.rhs_dummy() == "i"); + REQUIRE(cp_j_i_ji.rhs_dummy() == "i"); + REQUIRE(cp_ijk_i_jk.rhs_dummy() == ""); + REQUIRE(cp_ijk_i_kj.rhs_dummy() == ""); + + REQUIRE(cp_ij_ik_kj.rhs_dummy() == "k"); + REQUIRE(cp_ji_ik_kj.rhs_dummy() == "k"); + REQUIRE(cp_ij_ik_jk.rhs_dummy() == "k"); + REQUIRE(cp_ji_ik_jk.rhs_dummy() == "k"); + + REQUIRE(cp__ijk_ijk.rhs_dummy() == "i,j,k"); + REQUIRE(cp__ijk_jik.rhs_dummy() == "j,i,k"); + REQUIRE(cp_il_ijk_jkl.rhs_dummy() == "j,k"); + REQUIRE(cp_il_ijk_klj.rhs_dummy() == "k,j"); + } + + SECTION("lhs_permutation") { + REQUIRE(cp___.lhs_permutation() == ""); + + REQUIRE(cp__i_i.lhs_permutation() == "i"); + REQUIRE(cp_ij_i_j.lhs_permutation() == "i"); + REQUIRE(cp_ji_i_j.lhs_permutation() == "i"); + + REQUIRE(cp_j_i_ij.lhs_permutation() == "i"); + REQUIRE(cp_j_i_ji.lhs_permutation() == "i"); + REQUIRE(cp_ijk_i_jk.lhs_permutation() == "i"); + REQUIRE(cp_ijk_i_kj.lhs_permutation() == "i"); + + REQUIRE(cp_ij_ik_kj.lhs_permutation() == "i,k"); + REQUIRE(cp_ji_ik_kj.lhs_permutation() == "i,k"); + REQUIRE(cp_ij_ik_jk.lhs_permutation() == "i,k"); + REQUIRE(cp_ji_ik_jk.lhs_permutation() == "i,k"); + + REQUIRE(cp__ijk_ijk.lhs_permutation() == "i,j,k"); + REQUIRE(cp__ijk_jik.lhs_permutation() == "i,j,k"); + REQUIRE(cp_il_ijk_jkl.lhs_permutation() == "i,j,k"); + REQUIRE(cp_il_ijk_klj.lhs_permutation() == "i,j,k"); + } + + SECTION("rhs_permutation") { + REQUIRE(cp___.rhs_permutation() == ""); + + REQUIRE(cp__i_i.rhs_permutation() == "i"); + REQUIRE(cp_ij_i_j.rhs_permutation() == "j"); + REQUIRE(cp_ji_i_j.rhs_permutation() == "j"); + + REQUIRE(cp_j_i_ij.rhs_permutation() == "i,j"); + REQUIRE(cp_j_i_ji.rhs_permutation() == "i,j"); + REQUIRE(cp_ijk_i_jk.rhs_permutation() == "j,k"); + REQUIRE(cp_ijk_i_kj.rhs_permutation() == "j,k"); + + REQUIRE(cp_ij_ik_kj.rhs_permutation() == "k,j"); + REQUIRE(cp_ji_ik_kj.rhs_permutation() == "k,j"); + REQUIRE(cp_ij_ik_jk.rhs_permutation() == "k,j"); + REQUIRE(cp_ji_ik_jk.rhs_permutation() == "k,j"); - REQUIRE(cp_ij_ijk_ikj.rhs_free() == "i,j"); - REQUIRE(cp_iljm_ikj_lmk.rhs_free() == "l,m"); + REQUIRE(cp__ijk_ijk.rhs_permutation() == "i,j,k"); + REQUIRE(cp__ijk_jik.rhs_permutation() == "i,j,k"); + REQUIRE(cp_il_ijk_jkl.rhs_permutation() == "j,k,l"); + REQUIRE(cp_il_ijk_klj.rhs_permutation() == "j,k,l"); } } From c794a0da72e04e3d70152c6a35e73642d7a84a4f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 5 Feb 2025 04:24:35 +0000 Subject: [PATCH 3/5] Committing clang-format changes --- src/tensorwrapper/buffer/contraction_planner.hpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/tensorwrapper/buffer/contraction_planner.hpp b/src/tensorwrapper/buffer/contraction_planner.hpp index d3bb8a11..20308f5f 100644 --- a/src/tensorwrapper/buffer/contraction_planner.hpp +++ b/src/tensorwrapper/buffer/contraction_planner.hpp @@ -1,3 +1,19 @@ +/* + * Copyright 2025 NWChemEx-Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #pragma once #include From 7018e25ca32968d29127c9ebba7d5beb88ae9f24 Mon Sep 17 00:00:00 2001 From: "Ryan M. Richard" Date: Thu, 6 Feb 2025 09:52:41 -0600 Subject: [PATCH 4/5] works (I think) --- include/tensorwrapper/buffer/eigen.hpp | 3 - include/tensorwrapper/detail_/dsl_base.hpp | 3 + include/tensorwrapper/dsl/dummy_indices.hpp | 39 +++- .../buffer/contraction_planner.hpp | 14 ++ src/tensorwrapper/buffer/eigen.cpp | 85 +------- .../buffer/eigen_contraction.cpp | 187 +++++++++--------- .../buffer/eigen_contraction.hpp | 14 +- .../buffer/contraction_planner.cpp | 48 +++++ .../buffer/eigen_contraction.cpp | 99 ++++++++-- .../tensorwrapper/testing/eigen_buffers.hpp | 45 +++-- 10 files changed, 321 insertions(+), 216 deletions(-) diff --git a/include/tensorwrapper/buffer/eigen.hpp b/include/tensorwrapper/buffer/eigen.hpp index d57d7587..d9aca697 100644 --- a/include/tensorwrapper/buffer/eigen.hpp +++ b/include/tensorwrapper/buffer/eigen.hpp @@ -169,9 +169,6 @@ class Eigen : public Replicated { */ bool operator!=(const Eigen& rhs) const noexcept { return !(*this == rhs); } - FloatType* data() { return m_tensor_.data(); }; - const FloatType* data() const { return m_tensor_.data(); } - protected: /// Implements clone by calling copy ctor buffer_base_pointer clone_() const override { diff --git a/include/tensorwrapper/detail_/dsl_base.hpp b/include/tensorwrapper/detail_/dsl_base.hpp index ea6e8e1a..fe044881 100644 --- a/include/tensorwrapper/detail_/dsl_base.hpp +++ b/include/tensorwrapper/detail_/dsl_base.hpp @@ -61,6 +61,9 @@ class DSLBase { /// Type of parsed labels using label_type = typename labeled_type::label_type; + /// Type of a mutable reference to a labeled_type object + using labeled_reference = labeled_type&; + /// Type of a read-only reference to a labeled_type object using const_labeled_reference = const labeled_const_type&; diff --git a/include/tensorwrapper/dsl/dummy_indices.hpp b/include/tensorwrapper/dsl/dummy_indices.hpp index b4aac7f6..dc64fe30 100644 --- a/include/tensorwrapper/dsl/dummy_indices.hpp +++ b/include/tensorwrapper/dsl/dummy_indices.hpp @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include #include @@ -221,12 +222,27 @@ class DummyIndices * * Each DummyIndices object is viewed as an ordered set of objects. If * two DummyIndices objects contain the same objects, but in a different - * order, we can convert either object into the other by permuting it. - * This method computes the permutation needed to change *this into - * @p other. More specifically the result of this method is a vector - * of length `size()` such that the `i`-th element is the offset of - * `(*this)[i]` in @p other, i.e., if `x` is the return then - * `other[x[i]] == (*this)[i]`. + * order, we can convert either object into the other by permuting it. This + * method determines the permutation necessary to convert *this into other; + * the reverse permutation, i.e., converting other to *this is obtained + * by `other.permutation(*this)`. + * + * For concreteness assume we have a set A=(i, j, k) and we want to + * permute it into the set B=(j,k,i), then there are two subtly different + * ways of writing the necessary permutation: + * 1. Write the target set in terms of the current offsets, e.g., A goes + * to B is written as P1=(1, 2, 0) + * 2. Write the current set in terms of the target offsets, e.g., A goes + * to B is written as P2=(2, 0, 1) + * + * Option one maps offsets in B to offsets in A, e.g., A[P1[0]] == B[0], + * whereas option two maps offsets in A to offsets in B, e.g., + * A[0] == B[P2[0]]. This method follows definition two. + * + * @note The definitions are inverses of each other. So if you don't like + * our definition, and want the other one, flip *this and @p other, e.g., + * Permuting B to A using definition 1 yields P1'=(2, 0, 1)=P2 and + * permuting B to A using definition 1 yields P2'=(1, 2, 0)=P1. * * @param[in] other The order we want to permute *this to. * @@ -471,6 +487,17 @@ class DummyIndices split_string_type m_dummy_indices_; }; +template +std::ostream& operator<<(std::ostream& os, const DummyIndices& i) { + if(i.size() == 0) return os; + os << i.at(0); + for(std::size_t j = 1; j < i.size(); ++j) { + os << ","; + os << i.at(j); + } + return os; +} + template bool DummyIndices::is_hadamard_product( const DummyIndices& lhs, const DummyIndices& rhs) const noexcept { diff --git a/src/tensorwrapper/buffer/contraction_planner.hpp b/src/tensorwrapper/buffer/contraction_planner.hpp index d3bb8a11..db0c19f1 100644 --- a/src/tensorwrapper/buffer/contraction_planner.hpp +++ b/src/tensorwrapper/buffer/contraction_planner.hpp @@ -69,6 +69,20 @@ class ContractionPlanner { return label_type(std::move(rv)); } + /** @brief Flattened result labels. + * + * After applying lhs_permutation to LHS to get A, and rhs_permutation to + * RHS to get B, A and B can be multiplied together with a gemm. The + * resulting matrix has indices given by concatenating the free indices of + * A with the free indices of B. This method returns those indices. + * + */ + label_type result_matrix_labels() const { + const auto lhs = lhs_permutation(); + const auto rhs = rhs_permutation(); + return lhs.concatenation(rhs).difference(lhs_dummy()); + } + private: /// Ensures no tensor contains a repeated label void assert_no_repeated_indices_() const { diff --git a/src/tensorwrapper/buffer/eigen.cpp b/src/tensorwrapper/buffer/eigen.cpp index 2abce3b5..9f778512 100644 --- a/src/tensorwrapper/buffer/eigen.cpp +++ b/src/tensorwrapper/buffer/eigen.cpp @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "contraction_planner.hpp" + #include "eigen_contraction.hpp" #include #include @@ -25,20 +25,6 @@ namespace tensorwrapper::buffer { #define TPARAMS template #define EIGEN Eigen -template -FloatType* get_data(TensorType& tensor) { - using allocator_type = allocator::Eigen; - if constexpr(Rank > 10) { - throw std::runtime_error("Tensors with rank > 10 are not supported"); - } else { - if(tensor.layout().rank() == Rank) { - return allocator_type::rebind(tensor).data(); - } else { - return get_data(tensor); - } - } -} - using const_labeled_reference = typename Eigen::const_labeled_reference; using dsl_reference = typename Eigen::dsl_reference; @@ -154,7 +140,8 @@ typename EIGEN::dsl_reference EIGEN::permute_assignment_( const auto& rlabels = rhs.labels(); if(this_labels != rlabels) { // We need to permute rhs before assignment - auto r_to_l = rhs.labels().permutation(this_labels); + // Eigen adopts the opposite definition of permutation from us. + auto r_to_l = this_labels.permutation(rlabels); // Eigen wants int objects std::vector r_to_l2(r_to_l.begin(), r_to_l.end()); m_tensor_ = rhs_downcasted.value().shuffle(r_to_l2); @@ -241,71 +228,7 @@ typename EIGEN::dsl_reference EIGEN::hadamard_(label_type this_labels, TPARAMS typename EIGEN::dsl_reference EIGEN::contraction_( label_type this_labels, const_labeled_reference lhs, const_labeled_reference rhs) { - const auto& llabels = lhs.labels(); - const auto& lobject = lhs.object(); - const auto& rlabels = rhs.labels(); - const auto& robject = rhs.object(); - - ContractionPlanner plan(this_labels, llabels, rlabels); - auto lt = lobject.clone(); - auto rt = robject.clone(); - lt->permute_assignment(plan.lhs_permutation(), lhs); - rt->permute_assignment(plan.rhs_permutation(), rhs); - - const auto ndummy = plan.lhs_dummy().size(); - const auto lshape = lt->layout().shape().as_smooth(); - const auto rshape = rt->layout().shape().as_smooth(); - const auto oshape = layout().shape().as_smooth(); - const auto lfree = lshape.rank() - ndummy; - std::size_t lrows = lshape.rank() ? 1 : 0; - std::size_t lcols = lshape.rank() ? 1 : 0; - - for(std::size_t i = 0; i < lfree; ++i) lrows *= lshape.extent(i); - for(std::size_t i = lfree; i < lshape.rank(); ++i) - lcols *= lshape.extent(i); - - std::size_t rrows = rshape.rank() ? 1 : 0; - std::size_t rcols = rshape.rank() ? 1 : 0; - - for(std::size_t i = 0; i < ndummy; ++i) rrows *= rshape.extent(i); - for(std::size_t i = ndummy; i < rshape.rank(); ++i) - rcols *= rshape.extent(i); - - using matrix_t = ::Eigen::Matrix; - using map_t = ::Eigen::Map; - - typename Eigen::data_type buffer(lrows, rcols); - - map_t lmatrix(get_data(*lt), lrows, lcols); - map_t rmatrix(get_data(*rt), rrows, rcols); - map_t omatrix(buffer.data(), lrows, rcols); - omatrix = lmatrix * rmatrix; - - std::array out_size; - for(std::size_t i = 0; i < Rank; ++i) out_size[i] = oshape.extent(i); - m_tensor_ = buffer.reshape(out_size); - return *this; - - /* Doesn't work with Sigma - // N.b. is a pure contraction, so common indices are summed over - auto common = llabels.intersection(rlabels); - - // -- This block converts string indices to mode offsets - using rank_type = unsigned short; - using pair_type = std::pair; - std::vector modes; - auto rank = common.size(); - for(decltype(rank) i = 0; i < rank; ++i) { - const auto& index_i = common.at(i); - // N.b., pure contraction so there's no repeats within a tensor's label - auto lindex = llabels.find(index_i)[0]; - auto rindex = rlabels.find(index_i)[0]; - modes.push_back(pair_type(lindex, rindex)); - } - - return eigen_contraction(*this, lobject, robject, modes); - */ + return eigen_contraction(*this, this_labels, lhs, rhs); } #undef EIGEN diff --git a/src/tensorwrapper/buffer/eigen_contraction.cpp b/src/tensorwrapper/buffer/eigen_contraction.cpp index 23b729cf..295567ba 100644 --- a/src/tensorwrapper/buffer/eigen_contraction.cpp +++ b/src/tensorwrapper/buffer/eigen_contraction.cpp @@ -14,124 +14,125 @@ * limitations under the License. */ +#include "contraction_planner.hpp" #include "eigen_contraction.hpp" #include #include namespace tensorwrapper::buffer { -using rank_type = unsigned short; -using base_reference = BufferBase::base_reference; -using const_base_reference = BufferBase::const_base_reference; -using return_type = BufferBase::dsl_reference; -using pair_type = std::pair; -using vector_type = std::vector; - -// N.b. will create about max_rank**3 instantiations of eigen_contraction -static constexpr unsigned int max_rank = 6; - -/// Wraps the contraction once we've worked out all of the template params. -template -return_type eigen_contraction(RVType&& rv, LHSType&& lhs, RHSType&& rhs, - ModesType&& sum_modes) { - rv.value() = lhs.value().contract(rhs.value(), sum_modes); - return rv; -} +using rank_type = unsigned short; +using return_type = BufferBase::dsl_reference; +using label_type = BufferBase::label_type; +using const_labeled_reference = BufferBase::const_labeled_reference; + +static constexpr unsigned int max_rank = 10; -/// This function converts @p sum_modes to a statically sized array -template -return_type n_contraction_modes(buffer::Eigen& rv, - const buffer::Eigen& lhs, - const buffer::Eigen& rhs, - const vector_type& sum_modes) { - // Can't contract more modes than a tensor has (this is recursion end point) - constexpr auto max_n = std::min({LHSRank, RHSRank}); - if constexpr(N == max_n + 1) { - throw std::runtime_error("Contracted more modes than a tensor has!!?"); +template +FloatType* get_data(TensorType& tensor) { + using allocator_type = allocator::Eigen; + if constexpr(Rank > max_rank) { + const auto sr = std::to_string(max_rank); + const auto msg = "Tensors with rank > " + sr + " are not supported"; + throw std::runtime_error(msg); } else { - if(N == sum_modes.size()) { - std::array temp; - for(std::size_t i = 0; i < temp.size(); ++i) temp[i] = sum_modes[i]; - return eigen_contraction(rv, lhs, rhs, std::move(temp)); + if(tensor.layout().rank() == Rank) { + return allocator_type::rebind(tensor).value().data(); } else { - return n_contraction_modes(rv, lhs, rhs, sum_modes); + return get_data(tensor); } } } -/// This function works out the rank of RHS -template -return_type rhs_rank(buffer::Eigen& rv, - const buffer::Eigen& lhs, - const_base_reference rhs, const vector_type& sum_modes) { - if constexpr(RHSRank == max_rank + 1) { - throw std::runtime_error("RHS has rank > max_rank"); - } else { - if(RHSRank == rhs.rank()) { - using allocator_type = allocator::Eigen; - const auto& rhs_eigen = allocator_type::rebind(rhs); - return n_contraction_modes(rv, lhs, rhs_eigen, sum_modes); - } else { - return rhs_rank(rv, lhs, rhs, sum_modes); - } - } +template +auto matrix_size(TensorType&& t, std::size_t row_ranks) { + const auto shape = t.layout().shape().as_smooth(); + std::size_t nrows = 1; + for(std::size_t i = 0; i < row_ranks; ++i) nrows *= shape.extent(i); + + std::size_t ncols = 1; + const auto rank = shape.rank(); + for(std::size_t i = row_ranks; i < rank; ++i) ncols *= shape.extent(i); + return std::make_pair(nrows, ncols); } -/// This function works out the rank of LHS -template -return_type lhs_rank(buffer::Eigen& rv, - const_base_reference lhs, const_base_reference rhs, - const vector_type& sum_modes) { - if constexpr(LHSRank == max_rank + 1) { - throw std::runtime_error("LHS has rank > max_rank"); - } else { - if(LHSRank == lhs.rank()) { - using allocator_type = allocator::Eigen; - const auto& lhs_eigen = allocator_type::rebind(lhs); - return rhs_rank(rv, lhs_eigen, rhs, sum_modes); - } else { - return lhs_rank(rv, lhs, rhs, sum_modes); - } +template +return_type eigen_contraction(Eigen& result, + label_type olabels, const_labeled_reference lhs, + const_labeled_reference rhs) { + const auto& llabels = lhs.labels(); + const auto& lobject = lhs.object(); + const auto& rlabels = rhs.labels(); + const auto& robject = rhs.object(); + + ContractionPlanner plan(olabels, llabels, rlabels); + auto lt = lobject.clone(); + auto rt = robject.clone(); + lt->permute_assignment(plan.lhs_permutation(), lhs); + rt->permute_assignment(plan.rhs_permutation(), rhs); + const auto [lrows, lcols] = matrix_size(*lt, plan.lhs_free().size()); + const auto [rrows, rcols] = matrix_size(*rt, plan.rhs_dummy().size()); + + // Work out the types of the matrix amd a map + constexpr auto e_dyn = ::Eigen::Dynamic; + constexpr auto e_row_major = ::Eigen::RowMajor; + using matrix_t = ::Eigen::Matrix; + using map_t = ::Eigen::Map; + + typename Eigen::data_type buffer(lrows, rcols); + + map_t lmatrix(get_data(*lt), lrows, lcols); + map_t rmatrix(get_data(*rt), rrows, rcols); + map_t omatrix(buffer.data(), lrows, rcols); + omatrix = lmatrix * rmatrix; + + auto mlabels = plan.result_matrix_labels(); + auto oshape = result.layout().shape()(olabels); + + // oshapes is the final shape, permute it to shape omatrix is currently in + auto temp_shape = result.layout().shape().clone(); + temp_shape->permute_assignment(mlabels, oshape); + auto mshape = temp_shape->as_smooth(); + + auto m_to_o = olabels.permutation(mlabels); // N.b. Eigen def is inverse us + + std::array out_size; + std::array m_to_o_array; + for(std::size_t i = 0; i < Rank; ++i) { + out_size[i] = mshape.extent(i); + m_to_o_array[i] = m_to_o[i]; } -} -/// This function works out the rank of rv -template -return_type eigen_contraction_(base_reference rv, const_base_reference lhs, - const_base_reference rhs, - const vector_type& sum_modes) { - if constexpr(RVRank == max_rank + 1) { - throw std::runtime_error("Return has rank > max_rank"); + auto tensor = buffer.reshape(out_size); + if constexpr(Rank > 0) { + result.value() = tensor.shuffle(m_to_o_array); } else { - if(RVRank == rv.rank()) { - using allocator_type = allocator::Eigen; - auto& rv_eigen = allocator_type::rebind(rv); - return lhs_rank(rv_eigen, lhs, rhs, sum_modes); - } else { - constexpr auto RVp1 = RVRank + 1; - return eigen_contraction_(rv, lhs, rhs, sum_modes); - } + result.value() = tensor; } + return result; } -template -return_type eigen_contraction(base_reference rv, const_base_reference lhs, - const_base_reference rhs, - const vector_type& sum_modes) { - return eigen_contraction_(rv, lhs, rhs, sum_modes); -} +#define EIGEN_CONTRACTION_(FLOAT_TYPE, RANK) \ + template return_type eigen_contraction( \ + Eigen& result, label_type olabels, \ + const_labeled_reference lhs, const_labeled_reference rhs) -#define EIGEN_CONTRACTION(FLOAT_TYPE) \ - template return_type eigen_contraction( \ - base_reference, const_base_reference, const_base_reference, \ - const vector_type&) +#define EIGEN_CONTRACTION(FLOAT_TYPE) \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 0); \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 1); \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 2); \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 3); \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 4); \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 5); \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 6); \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 7); \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 8); \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 9); \ + EIGEN_CONTRACTION_(FLOAT_TYPE, 10) EIGEN_CONTRACTION(float); EIGEN_CONTRACTION(double); #undef EIGEN_CONTRACTION - +#undef EIGEN_CONTRACTION_ } // namespace tensorwrapper::buffer \ No newline at end of file diff --git a/src/tensorwrapper/buffer/eigen_contraction.hpp b/src/tensorwrapper/buffer/eigen_contraction.hpp index 5160cab3..5b4c066b 100644 --- a/src/tensorwrapper/buffer/eigen_contraction.hpp +++ b/src/tensorwrapper/buffer/eigen_contraction.hpp @@ -15,7 +15,7 @@ */ #pragma once -#include +#include namespace tensorwrapper::buffer { @@ -27,15 +27,11 @@ namespace tensorwrapper::buffer { * instantiations for every combination of template parameters that Eigen may * end up seeing, that's what the functions in this header do. * -// * The entry point into this infrastructure is currently the return_rank - * method, which kicks the process off by working out the rank of the tensor - * which will - * */ -template +template BufferBase::dsl_reference eigen_contraction( - BufferBase::base_reference rv, BufferBase::const_base_reference lhs, - BufferBase::const_base_reference rhs, - const std::vector>& sum_modes); + Eigen& result, BufferBase::label_type olabels, + BufferBase::const_labeled_reference lhs, + BufferBase::const_labeled_reference rhs); } // namespace tensorwrapper::buffer \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp index 9658a8cd..47a26d2f 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/contraction_planner.cpp @@ -42,6 +42,10 @@ TEST_CASE("ContractionPlanner") { ContractionPlanner cp_ij_ik_jk("i,j", "i,k", "j,k"); ContractionPlanner cp_ji_ik_jk("j,i", "i,k", "j,k"); + // Matrix times rank 3 + ContractionPlanner cp_j_ki_jki("j", "k,i", "j,k,i"); + ContractionPlanner cp_jil_ki_jkl("j,i,l", "k,i", "j,k,l"); + // 3 times 3 ContractionPlanner cp__ijk_ijk("", "i,j,k", "i,j,k"); ContractionPlanner cp__ijk_jik("", "i,j,k", "j,i,k"); @@ -79,6 +83,9 @@ TEST_CASE("ContractionPlanner") { REQUIRE(cp_ij_ik_jk.lhs_free() == "i"); REQUIRE(cp_ji_ik_jk.lhs_free() == "i"); + REQUIRE(cp_j_ki_jki.lhs_free() == ""); + REQUIRE(cp_jil_ki_jkl.lhs_free() == "i"); + REQUIRE(cp__ijk_ijk.lhs_free() == ""); REQUIRE(cp__ijk_jik.lhs_free() == ""); REQUIRE(cp_il_ijk_jkl.lhs_free() == "i"); @@ -102,6 +109,9 @@ TEST_CASE("ContractionPlanner") { REQUIRE(cp_ij_ik_jk.rhs_free() == "j"); REQUIRE(cp_ji_ik_jk.rhs_free() == "j"); + REQUIRE(cp_j_ki_jki.rhs_free() == "j"); + REQUIRE(cp_jil_ki_jkl.rhs_free() == "j,l"); + REQUIRE(cp__ijk_ijk.rhs_free() == ""); REQUIRE(cp__ijk_jik.rhs_free() == ""); REQUIRE(cp_il_ijk_jkl.rhs_free() == "l"); @@ -125,6 +135,9 @@ TEST_CASE("ContractionPlanner") { REQUIRE(cp_ij_ik_jk.lhs_dummy() == "k"); REQUIRE(cp_ji_ik_jk.lhs_dummy() == "k"); + REQUIRE(cp_j_ki_jki.lhs_dummy() == "k,i"); + REQUIRE(cp_jil_ki_jkl.lhs_dummy() == "k"); + REQUIRE(cp__ijk_ijk.lhs_dummy() == "i,j,k"); REQUIRE(cp__ijk_jik.lhs_dummy() == "i,j,k"); REQUIRE(cp_il_ijk_jkl.lhs_dummy() == "j,k"); @@ -148,6 +161,9 @@ TEST_CASE("ContractionPlanner") { REQUIRE(cp_ij_ik_jk.rhs_dummy() == "k"); REQUIRE(cp_ji_ik_jk.rhs_dummy() == "k"); + REQUIRE(cp_j_ki_jki.rhs_dummy() == "k,i"); + REQUIRE(cp_jil_ki_jkl.rhs_dummy() == "k"); + REQUIRE(cp__ijk_ijk.rhs_dummy() == "i,j,k"); REQUIRE(cp__ijk_jik.rhs_dummy() == "j,i,k"); REQUIRE(cp_il_ijk_jkl.rhs_dummy() == "j,k"); @@ -171,6 +187,9 @@ TEST_CASE("ContractionPlanner") { REQUIRE(cp_ij_ik_jk.lhs_permutation() == "i,k"); REQUIRE(cp_ji_ik_jk.lhs_permutation() == "i,k"); + REQUIRE(cp_j_ki_jki.lhs_permutation() == "k,i"); + REQUIRE(cp_jil_ki_jkl.lhs_permutation() == "i,k"); + REQUIRE(cp__ijk_ijk.lhs_permutation() == "i,j,k"); REQUIRE(cp__ijk_jik.lhs_permutation() == "i,j,k"); REQUIRE(cp_il_ijk_jkl.lhs_permutation() == "i,j,k"); @@ -194,9 +213,38 @@ TEST_CASE("ContractionPlanner") { REQUIRE(cp_ij_ik_jk.rhs_permutation() == "k,j"); REQUIRE(cp_ji_ik_jk.rhs_permutation() == "k,j"); + REQUIRE(cp_j_ki_jki.rhs_permutation() == "k,i,j"); + REQUIRE(cp_jil_ki_jkl.rhs_permutation() == "k,j,l"); + REQUIRE(cp__ijk_ijk.rhs_permutation() == "i,j,k"); REQUIRE(cp__ijk_jik.rhs_permutation() == "i,j,k"); REQUIRE(cp_il_ijk_jkl.rhs_permutation() == "j,k,l"); REQUIRE(cp_il_ijk_klj.rhs_permutation() == "j,k,l"); } + + SECTION("result_matrix_labels") { + REQUIRE(cp___.result_matrix_labels() == ""); + + REQUIRE(cp__i_i.result_matrix_labels() == ""); + REQUIRE(cp_ij_i_j.result_matrix_labels() == "i,j"); + REQUIRE(cp_ji_i_j.result_matrix_labels() == "i,j"); + + REQUIRE(cp_j_i_ij.result_matrix_labels() == "j"); + REQUIRE(cp_j_i_ji.result_matrix_labels() == "j"); + REQUIRE(cp_ijk_i_jk.result_matrix_labels() == "i,j,k"); + REQUIRE(cp_ijk_i_kj.result_matrix_labels() == "i,j,k"); + + REQUIRE(cp_ij_ik_kj.result_matrix_labels() == "i,j"); + REQUIRE(cp_ji_ik_kj.result_matrix_labels() == "i,j"); + REQUIRE(cp_ij_ik_jk.result_matrix_labels() == "i,j"); + REQUIRE(cp_ji_ik_jk.result_matrix_labels() == "i,j"); + + REQUIRE(cp_j_ki_jki.result_matrix_labels() == "j"); + REQUIRE(cp_jil_ki_jkl.result_matrix_labels() == "i,j,l"); + + REQUIRE(cp__ijk_ijk.result_matrix_labels() == ""); + REQUIRE(cp__ijk_jik.result_matrix_labels() == ""); + REQUIRE(cp_il_ijk_jkl.result_matrix_labels() == "i,l"); + REQUIRE(cp_il_ijk_klj.result_matrix_labels() == "i,l"); + } } diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/eigen_contraction.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/eigen_contraction.cpp index 15ce98ca..c75fe2be 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/eigen_contraction.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/eigen_contraction.cpp @@ -23,9 +23,7 @@ using namespace buffer; TEMPLATE_TEST_CASE("eigen_contraction", "", float, double) { using float_t = TestType; - using mode_type = unsigned short; - using pair_type = std::pair; - using mode_array = std::vector; + using label_type = typename BufferBase::label_type; // Inputs auto scalar = testing::eigen_scalar(); @@ -33,10 +31,6 @@ TEMPLATE_TEST_CASE("eigen_contraction", "", float, double) { auto vector2 = testing::eigen_vector(2); auto matrix = testing::eigen_matrix(); - mode_array m00{pair_type{0, 0}}; - mode_array m11{pair_type{1, 1}}; - mode_array m00_11{pair_type{0, 0}, pair_type{1, 1}}; - auto scalar_corr = testing::eigen_scalar(); scalar_corr.value()() = 30.0; @@ -50,29 +44,112 @@ TEMPLATE_TEST_CASE("eigen_contraction", "", float, double) { matrix_corr.value()(1, 0) = 14.0; matrix_corr.value()(1, 1) = 20.0; + label_type l(""); + label_type j("j"); + label_type ij("i,j"); + + auto mij = matrix("i,j"); SECTION("vector with vector") { - auto& rv = eigen_contraction(scalar, vector, vector, m00); + auto vi = vector("i"); + auto& rv = eigen_contraction(scalar, l, vi, vi); REQUIRE(&rv == static_cast(&scalar)); REQUIRE(scalar_corr.are_equal(scalar)); } SECTION("ij,ij->") { - auto& rv = eigen_contraction(scalar, matrix, matrix, m00_11); + auto& rv = eigen_contraction(scalar, l, mij, mij); REQUIRE(&rv == static_cast(&scalar)); REQUIRE(scalar_corr.are_equal(scalar)); } SECTION("ki,kj->ij") { + auto mki = matrix("k,i"); + auto mkj = matrix("k,j"); auto buffer = testing::eigen_matrix(); - auto& rv = eigen_contraction(buffer, matrix, matrix, m00); + auto& rv = eigen_contraction(buffer, ij, mki, mkj); REQUIRE(&rv == static_cast(&buffer)); REQUIRE(matrix_corr.are_equal(buffer)); } SECTION("ij,i->j") { + auto vi = vector2("i"); auto buffer = testing::eigen_vector(2); - auto& rv = eigen_contraction(buffer, matrix, vector2, m00); + auto& rv = eigen_contraction(buffer, j, mij, vi); REQUIRE(&rv == static_cast(&buffer)); REQUIRE(vector_corr.are_equal(rv)); } + + SECTION("ki,jki->j") { + auto tensor = testing::eigen_tensor3(2); + auto matrix2 = testing::eigen_matrix(2); + auto buffer = testing::eigen_vector(2); + auto corr = testing::eigen_vector(2); + corr.value()(0) = 30; + corr.value()(1) = 70; + + auto tjki = tensor("j,k,i"); + auto mki = matrix2("k,i"); + auto& rv = eigen_contraction(buffer, j, mki, tjki); + REQUIRE(&rv == static_cast(&buffer)); + REQUIRE(corr.are_equal(rv)); + } + + SECTION("ki,jkl->jil") { + auto tensor = testing::eigen_tensor3(2); + auto matrix2 = testing::eigen_matrix(2); + auto buffer = testing::eigen_tensor3(2); + auto corr = testing::eigen_tensor3(); + corr.value()(0, 0, 0) = 10; + corr.value()(0, 0, 1) = 14; + corr.value()(0, 1, 0) = 14; + corr.value()(0, 1, 1) = 20; + + corr.value()(1, 0, 0) = 26; + corr.value()(1, 0, 1) = 30; + corr.value()(1, 1, 0) = 38; + corr.value()(1, 1, 1) = 44; + + auto tjki = tensor("j,k,l"); + auto mki = matrix2("k,i"); + label_type jil("j,i,l"); + auto& rv = eigen_contraction(buffer, jil, mki, tjki); + REQUIRE(&rv == static_cast(&buffer)); + REQUIRE(corr.are_equal(rv)); + } + + SECTION("kl,ijkl->ij") { + auto tensor = testing::eigen_tensor4(); + auto matrix2 = testing::eigen_matrix(2); + auto buffer = testing::eigen_matrix(2); + auto corr = testing::eigen_matrix(); + corr.value()(0, 0) = 30; + corr.value()(0, 1) = 70; + corr.value()(1, 0) = 110; + corr.value()(1, 1) = 150; + + auto lt = tensor("i,j,k,l"); + auto lm = matrix2("k,l"); + label_type jil("i,j"); + auto& rv = eigen_contraction(buffer, ij, lm, lt); + REQUIRE(&rv == static_cast(&buffer)); + REQUIRE(corr.are_equal(rv)); + } + + SECTION("kl,ilkj->ij") { + auto tensor = testing::eigen_tensor4(); + auto matrix2 = testing::eigen_matrix(2); + auto buffer = testing::eigen_matrix(2); + auto corr = testing::eigen_matrix(); + corr.value()(0, 0) = 48; + corr.value()(0, 1) = 58; + corr.value()(1, 0) = 128; + corr.value()(1, 1) = 138; + + auto lt = tensor("i,l,k,j"); + auto lm = matrix2("k,l"); + label_type jil("i,j"); + auto& rv = eigen_contraction(buffer, ij, lm, lt); + REQUIRE(&rv == static_cast(&buffer)); + REQUIRE(corr.are_equal(rv)); + } } \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/testing/eigen_buffers.hpp b/tests/cxx/unit_tests/tensorwrapper/testing/eigen_buffers.hpp index 97b5811c..8f6f2977 100644 --- a/tests/cxx/unit_tests/tensorwrapper/testing/eigen_buffers.hpp +++ b/tests/cxx/unit_tests/tensorwrapper/testing/eigen_buffers.hpp @@ -73,21 +73,40 @@ auto eigen_matrix(std::size_t n = 2, std::size_t m = 2) { } template -auto eigen_tensor3() { +auto eigen_tensor3(std::size_t n = 2, std::size_t m = 2, std::size_t l = 2) { using buffer_type = buffer::Eigen; using data_type = typename buffer_type::data_type; - shape::Smooth shape{2, 2, 2}; - layout::Physical l(shape); - data_type tensor(2, 2, 2); - tensor(0, 0, 0) = 1.0; - tensor(0, 0, 1) = 2.0; - tensor(0, 1, 0) = 3.0; - tensor(0, 1, 1) = 4.0; - tensor(1, 0, 0) = 5.0; - tensor(1, 0, 1) = 6.0; - tensor(1, 1, 0) = 7.0; - tensor(1, 1, 1) = 8.0; - return buffer_type(tensor, l); + int in = static_cast(n); + int im = static_cast(m); + int il = static_cast(l); + shape::Smooth shape{n, m, l}; + layout::Physical layout(shape); + data_type e_tensor(in, im, il); + double counter = 1.0; + for(decltype(in) i = 0; i < in; ++i) + for(decltype(im) j = 0; j < im; ++j) + for(decltype(il) k = 0; k < il; ++k) e_tensor(i, j, k) = counter++; + return buffer_type(e_tensor, layout); +} + +template +auto eigen_tensor4(std::array extents = {2, 2, 2, 2}) { + auto constexpr Rank = 4; + using buffer_type = buffer::Eigen; + using data_type = typename buffer_type::data_type; + std::array iextents; + for(std::size_t i = 0; i < Rank; ++i) iextents[i] = extents[i]; + shape::Smooth shape{extents[0], extents[1], extents[2], extents[3]}; + layout::Physical layout(shape); + data_type e_tensor(iextents[0], iextents[1], iextents[2], iextents[3]); + double counter = 1.0; + std::array i; + for(i[0] = 0; i[0] < iextents[0]; ++i[0]) + for(i[1] = 0; i[1] < iextents[1]; ++i[1]) + for(i[2] = 0; i[2] < iextents[2]; ++i[2]) + for(i[3] = 0; i[3] < iextents[3]; ++i[3]) + e_tensor(i[0], i[1], i[2], i[3]) = counter++; + return buffer_type(e_tensor, layout); } } // namespace tensorwrapper::testing \ No newline at end of file From c9cbfa37861580b271650e4ac6f72a87508594d2 Mon Sep 17 00:00:00 2001 From: "Ryan M. Richard" Date: Thu, 6 Feb 2025 09:58:26 -0600 Subject: [PATCH 5/5] use int for Eigen --- src/tensorwrapper/buffer/eigen_contraction.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/tensorwrapper/buffer/eigen_contraction.cpp b/src/tensorwrapper/buffer/eigen_contraction.cpp index 295567ba..353d8868 100644 --- a/src/tensorwrapper/buffer/eigen_contraction.cpp +++ b/src/tensorwrapper/buffer/eigen_contraction.cpp @@ -96,8 +96,8 @@ return_type eigen_contraction(Eigen& result, auto m_to_o = olabels.permutation(mlabels); // N.b. Eigen def is inverse us - std::array out_size; - std::array m_to_o_array; + std::array out_size; + std::array m_to_o_array; for(std::size_t i = 0; i < Rank; ++i) { out_size[i] = mshape.extent(i); m_to_o_array[i] = m_to_o[i]; @@ -133,6 +133,11 @@ return_type eigen_contraction(Eigen& result, EIGEN_CONTRACTION(float); EIGEN_CONTRACTION(double); +#ifdef ENABLE_SIGMA +EIGEN_CONTRACTION(sigma::UFloat); +EIGEN_CONTRACTION(sigma::UDouble); +#endif + #undef EIGEN_CONTRACTION #undef EIGEN_CONTRACTION_ } // namespace tensorwrapper::buffer \ No newline at end of file