Skip to content

Commit

Permalink
Merge pull request #195 from NWChemEx/custom_contraction
Browse files Browse the repository at this point in the history
Custom contraction
  • Loading branch information
jwaldrop107 authored Feb 6, 2025
2 parents 3ec992a + c9cbfa3 commit d51768f
Show file tree
Hide file tree
Showing 9 changed files with 653 additions and 164 deletions.
3 changes: 3 additions & 0 deletions include/tensorwrapper/detail_/dsl_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class DSLBase {
/// Type of parsed labels
using label_type = typename labeled_type::label_type;

/// Type of a mutable reference to a labeled_type object
using labeled_reference = labeled_type&;

/// Type of a read-only reference to a labeled_type object
using const_labeled_reference = const labeled_const_type&;

Expand Down
57 changes: 42 additions & 15 deletions include/tensorwrapper/dsl/dummy_indices.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <ostream>
#include <set>
#include <string>
#include <utilities/containers/indexable_container_base.hpp>
Expand Down Expand Up @@ -110,6 +111,15 @@ class DummyIndices
DummyIndices(
utilities::strings::split_string(remove_spaces_(dummy_indices), ",")) {}

/// Main ctor for setting the value, throws if any index is empty
explicit DummyIndices(split_string_type split_dummy_indices) :
m_dummy_indices_(std::move(split_dummy_indices)) {
for(const auto& x : m_dummy_indices_)
if(x.empty())
throw std::runtime_error(
"Dummy index is not allowed to be empty");
}

/** @brief Determines the number of unique indices in *this.
*
* A dummy index can be repeated if it is going to be summed over. This
Expand Down Expand Up @@ -212,12 +222,27 @@ class DummyIndices
*
* Each DummyIndices object is viewed as an ordered set of objects. If
* two DummyIndices objects contain the same objects, but in a different
* order, we can convert either object into the other by permuting it.
* This method computes the permutation needed to change *this into
* @p other. More specifically the result of this method is a vector
* of length `size()` such that the `i`-th element is the offset of
* `(*this)[i]` in @p other, i.e., if `x` is the return then
* `other[x[i]] == (*this)[i]`.
* order, we can convert either object into the other by permuting it. This
* method determines the permutation necessary to convert *this into other;
* the reverse permutation, i.e., converting other to *this is obtained
* by `other.permutation(*this)`.
*
* For concreteness assume we have a set A=(i, j, k) and we want to
* permute it into the set B=(j,k,i), then there are two subtly different
* ways of writing the necessary permutation:
* 1. Write the target set in terms of the current offsets, e.g., A goes
* to B is written as P1=(1, 2, 0)
* 2. Write the current set in terms of the target offsets, e.g., A goes
* to B is written as P2=(2, 0, 1)
*
* Option one maps offsets in B to offsets in A, e.g., A[P1[0]] == B[0],
* whereas option two maps offsets in A to offsets in B, e.g.,
* A[0] == B[P2[0]]. This method follows definition two.
*
* @note The definitions are inverses of each other. So if you don't like
* our definition, and want the other one, flip *this and @p other, e.g.,
* Permuting B to A using definition 1 yields P1'=(2, 0, 1)=P2 and
* permuting B to A using definition 1 yields P2'=(1, 2, 0)=P1.
*
* @param[in] other The order we want to permute *this to.
*
Expand Down Expand Up @@ -437,15 +462,6 @@ class DummyIndices
}

protected:
/// Main ctor for setting the value, throws if any index is empty
explicit DummyIndices(split_string_type split_dummy_indices) :
m_dummy_indices_(std::move(split_dummy_indices)) {
for(const auto& x : m_dummy_indices_)
if(x.empty())
throw std::runtime_error(
"Dummy index is not allowed to be empty");
}

/// Lets the base class get at these implementations
friend base_type;

Expand All @@ -471,6 +487,17 @@ class DummyIndices
split_string_type m_dummy_indices_;
};

template<typename StringType>
std::ostream& operator<<(std::ostream& os, const DummyIndices<StringType>& i) {
if(i.size() == 0) return os;
os << i.at(0);
for(std::size_t j = 1; j < i.size(); ++j) {
os << ",";
os << i.at(j);
}
return os;
}

template<typename StringType>
bool DummyIndices<StringType>::is_hadamard_product(
const DummyIndices& lhs, const DummyIndices& rhs) const noexcept {
Expand Down
130 changes: 130 additions & 0 deletions src/tensorwrapper/buffer/contraction_planner.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright 2025 NWChemEx-Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include <tensorwrapper/dsl/dummy_indices.hpp>

namespace tensorwrapper::buffer {

/** @brief Class for working out details pertaining to a tensor contraction.
*
* N.B. Contraction covers direct product (which is a special case of
* contraction with 0 dummy indices).
*/
class ContractionPlanner {
public:
/// String type users use to label modes
using string_type = std::string;

/// Type of the parsed labels
using label_type = dsl::DummyIndices<string_type>;

ContractionPlanner(string_type result, string_type lhs, string_type rhs) :
ContractionPlanner(label_type(result), label_type(lhs), label_type(rhs)) {
}

ContractionPlanner(label_type result, label_type lhs, label_type rhs) :
m_result_(std::move(result)),
m_lhs_(std::move(lhs)),
m_rhs_(std::move(rhs)) {
assert_no_repeated_indices_();
assert_dummy_indices_are_similar_();
assert_no_shared_free_();
}

/// Labels in LHS that are NOT summed over
label_type lhs_free() const { return m_lhs_.intersection(m_result_); }

/// Labels in RHS that are NOT summed over
label_type rhs_free() const { return m_rhs_.intersection(m_result_); }

/// Labels in LHS that ARE summed over
label_type lhs_dummy() const { return m_lhs_.difference(m_result_); }

/// Labels in RHS that ARE summed over
label_type rhs_dummy() const { return m_rhs_.difference(m_result_); }

/** @brief LHS permuted so free indices are followed by dummy indices. */
label_type lhs_permutation() const {
using split_string_type = typename label_type::split_string_type;
split_string_type rv;
auto lfree = lhs_free();
auto ldummy = lhs_dummy();
for(const auto& freei : m_result_) {
if(!lfree.count(freei)) continue;
rv.push_back(freei);
}
for(const auto& dummyi : ldummy) rv.push_back(dummyi);
return label_type(std::move(rv));
}

/** @brief RHS permuted so dummy indices are followed by free indices. */
label_type rhs_permutation() const {
typename label_type::split_string_type rv;
auto rfree = rhs_free();
auto rdummy = lhs_dummy(); // Use LHS dummy to get the same order!
for(const auto& dummyi : rdummy)
rv.push_back(dummyi); // Know it only appears 1x
for(const auto& freei : m_result_) {
if(!rfree.count(freei)) continue;
rv.push_back(freei); // Know it only appears 1x
}
return label_type(std::move(rv));
}

/** @brief Flattened result labels.
*
* After applying lhs_permutation to LHS to get A, and rhs_permutation to
* RHS to get B, A and B can be multiplied together with a gemm. The
* resulting matrix has indices given by concatenating the free indices of
* A with the free indices of B. This method returns those indices.
*
*/
label_type result_matrix_labels() const {
const auto lhs = lhs_permutation();
const auto rhs = rhs_permutation();
return lhs.concatenation(rhs).difference(lhs_dummy());
}

private:
/// Ensures no tensor contains a repeated label
void assert_no_repeated_indices_() const {
const bool result_good = !m_result_.has_repeated_indices();
const bool lhs_good = !m_lhs_.has_repeated_indices();
const bool rhs_good = !m_rhs_.has_repeated_indices();

if(result_good && lhs_good && rhs_good) return;
throw std::runtime_error("One or more terms contain repeated labels");
}

/// Ensures the dummy indices are permutations of each other
void assert_dummy_indices_are_similar_() const {
if(lhs_dummy().is_permutation(rhs_dummy())) return;
throw std::runtime_error("Dummy indices must appear in all terms");
}

/// Asserts LHS and RHS do not share free indices, which is Hadamard-product
void assert_no_shared_free_() const {
if(!lhs_free().intersection(rhs_free()).size()) return;
throw std::runtime_error("Contraction must sum repeated indices");
}

label_type m_result_;
label_type m_lhs_;
label_type m_rhs_;
};

} // namespace tensorwrapper::buffer
27 changes: 4 additions & 23 deletions src/tensorwrapper/buffer/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "eigen_contraction.hpp"
#include <sstream>
#include <tensorwrapper/allocator/eigen.hpp>
Expand Down Expand Up @@ -139,7 +140,8 @@ typename EIGEN::dsl_reference EIGEN::permute_assignment_(
const auto& rlabels = rhs.labels();

if(this_labels != rlabels) { // We need to permute rhs before assignment
auto r_to_l = rhs.labels().permutation(this_labels);
// Eigen adopts the opposite definition of permutation from us.
auto r_to_l = this_labels.permutation(rlabels);
// Eigen wants int objects
std::vector<int> r_to_l2(r_to_l.begin(), r_to_l.end());
m_tensor_ = rhs_downcasted.value().shuffle(r_to_l2);
Expand Down Expand Up @@ -226,28 +228,7 @@ typename EIGEN::dsl_reference EIGEN::hadamard_(label_type this_labels,
TPARAMS typename EIGEN::dsl_reference EIGEN::contraction_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) {
const auto& llabels = lhs.labels();
const auto& lobject = lhs.object();
const auto& rlabels = rhs.labels();
const auto& robject = rhs.object();

// N.b. is a pure contraction, so common indices are summed over
auto common = llabels.intersection(rlabels);

// -- This block converts string indices to mode offsets
using rank_type = unsigned short;
using pair_type = std::pair<rank_type, rank_type>;
std::vector<pair_type> modes;
auto rank = common.size();
for(decltype(rank) i = 0; i < rank; ++i) {
const auto& index_i = common.at(i);
// N.b., pure contraction so there's no repeats within a tensor's label
auto lindex = llabels.find(index_i)[0];
auto rindex = rlabels.find(index_i)[0];
modes.push_back(pair_type(lindex, rindex));
}

return eigen_contraction<FloatType>(*this, lobject, robject, modes);
return eigen_contraction(*this, this_labels, lhs, rhs);
}

#undef EIGEN
Expand Down
Loading

0 comments on commit d51768f

Please sign in to comment.