Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug making TaskBasedCpuContractor different from TensorNetwork #6

Merged
merged 5 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions include/jet/TensorHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,19 @@ gemmBinding(size_t m, size_t n, size_t k, ComplexPrecision alpha,
* @param A_data Complex data matrix A
* @param B_data Complex data vector B
* @param C_data Output data vector
* @param transpose Transpose flag for matrix A
*/
template <typename ComplexPrecision>
constexpr void
gemvBinding(size_t m, size_t k, ComplexPrecision alpha, ComplexPrecision beta,
const ComplexPrecision *A_data, const ComplexPrecision *B_data,
ComplexPrecision *C_data)
ComplexPrecision *C_data, const CBLAS_TRANSPOSE &transpose)
{
if constexpr (std::is_same_v<ComplexPrecision, std::complex<float>>)
cblas_cgemv(CblasRowMajor, CblasNoTrans, m, k, (&alpha), (A_data),
cblas_cgemv(CblasRowMajor, transpose, m, k, (&alpha), (A_data),
std::max(1ul, k), (B_data), 1, (&beta), (C_data), 1);
else if constexpr (std::is_same_v<ComplexPrecision, std::complex<double>>)
cblas_zgemv(CblasRowMajor, CblasNoTrans, m, k, (&alpha), (A_data),
cblas_zgemv(CblasRowMajor, transpose, m, k, (&alpha), (A_data),
std::max(1ul, k), (B_data), 1, (&beta), (C_data), 1);
};

Expand Down Expand Up @@ -153,12 +154,12 @@ inline void MultiplyTensorData(const std::vector<ComplexPrecision> &A,
else if (left_indices.size() > 0 && right_indices.size() == 0) {
size_t m = left_dim;
size_t k = common_dim;
gemvBinding(m, k, alpha, beta, A_data, B_data, C_data);
gemvBinding(m, k, alpha, beta, A_data, B_data, C_data, CblasNoTrans);
}
else if (left_indices.size() == 0 && right_indices.size() > 0) {
size_t n = right_dim;
size_t k = common_dim;
gemvBinding(k, n, alpha, beta, B_data, A_data, C_data);
gemvBinding(k, n, alpha, beta, B_data, A_data, C_data, CblasTrans);
}
else if (left_indices.size() == 0 && right_indices.size() == 0) {
size_t k = common_dim;
Expand Down
7 changes: 0 additions & 7 deletions include/jet/TensorNetwork.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,6 @@ template <class Tensor> class TensorNetwork {
*/
size_t ContractNodes_(size_t node_id_1, size_t node_id_2) noexcept
{
// Make sure node 1 has at least as many indices as node 2.
const size_t node_1_size = nodes_[node_id_1].tensor.GetIndices().size();
const size_t node_2_size = nodes_[node_id_2].tensor.GetIndices().size();
if (node_1_size <= node_2_size) {
std::swap(node_id_1, node_id_2);
}

auto &node_1 = nodes_[node_id_1];
auto &node_2 = nodes_[node_id_2];
const auto tensor_3 = ContractTensors(node_1.tensor, node_2.tensor);
Expand Down
40 changes: 32 additions & 8 deletions test/Test_Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,22 +460,22 @@ TEMPLATE_TEST_CASE("ContractTensors", "[Tensor]", c_fp32, c_fp64)
Approx(expected_rji_si.GetData()[1].imag()));

CHECK(con_si_rij.GetData()[0].real() ==
Approx(expected_rji_si.GetData()[0].real()));
Approx(expected_rij_si.GetData()[0].real()));
CHECK(con_si_rij.GetData()[0].imag() ==
Approx(expected_rji_si.GetData()[0].imag()));
Approx(expected_rij_si.GetData()[0].imag()));
CHECK(con_si_rij.GetData()[1].real() ==
Approx(expected_rji_si.GetData()[1].real()));
Approx(expected_rij_si.GetData()[1].real()));
CHECK(con_si_rij.GetData()[1].imag() ==
Approx(expected_rji_si.GetData()[1].imag()));
Approx(expected_rij_si.GetData()[1].imag()));

CHECK(con_si_rji.GetData()[0].real() ==
Approx(expected_rij_si.GetData()[0].real()));
Approx(expected_rji_si.GetData()[0].real()));
CHECK(con_si_rji.GetData()[0].imag() ==
Approx(expected_rij_si.GetData()[0].imag()));
Approx(expected_rji_si.GetData()[0].imag()));
CHECK(con_si_rji.GetData()[1].real() ==
Approx(expected_rij_si.GetData()[1].real()));
Approx(expected_rji_si.GetData()[1].real()));
CHECK(con_si_rji.GetData()[1].imag() ==
Approx(expected_rij_si.GetData()[1].imag()));
Approx(expected_rji_si.GetData()[1].imag()));
}

SECTION("Contract T0(a,b) and T1(b) -> T2(a)")
Expand All @@ -502,6 +502,30 @@ TEMPLATE_TEST_CASE("ContractTensors", "[Tensor]", c_fp32, c_fp64)
CHECK(tensor3 == tensor4);
}

SECTION("Contract T0(a) and T1(a,b) -> T2(b)")
{
std::vector<std::size_t> t_shape1{2};
std::vector<std::size_t> t_shape2{2, 2};
std::vector<std::size_t> t_shape3{2};

std::vector<std::string> t_indices1{"a"};
std::vector<std::string> t_indices2{"a", "b"};

std::vector<TestType> t_data1{TestType(0.0, 0.0), TestType(1.0, 0.0)};
std::vector<TestType> t_data2{TestType(0.0, 0.0), TestType(1.0, 0.0),
TestType(2.0, 0.0), TestType(3.0, 0.0)};
std::vector<TestType> t_data_expect{TestType(2.0, 0.0),
TestType(3.0, 0.0)};

Tensor<TestType> tensor1(t_indices1, t_shape1, t_data1);
Tensor<TestType> tensor2(t_indices2, t_shape2, t_data2);

Tensor<TestType> tensor3 = ContractTensors(tensor1, tensor2);
Tensor<TestType> tensor4({"b"}, {2}, t_data_expect);

CHECK(tensor3 == tensor4);
}

SECTION("Contract T0(a,b,c) and T1(b,c,d) -> T2(a,d)")
{
std::vector<std::size_t> t_shape1{2, 3, 4};
Expand Down
52 changes: 30 additions & 22 deletions test/Test_TensorNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ tensor_t MakeTensor(const indices_t &indices, const shape_t &shape)
if (!shape.empty()) {
for (size_t i = 0; i < tensor.GetSize(); i++) {
const auto index = Jet::Utilities::UnravelIndex(i, shape);
tensor.SetValue(index, i);
tensor.SetValue(index, complex_t{static_cast<float>(i),
static_cast<float>(2 * i)});
}
}
return tensor;
Expand Down Expand Up @@ -208,9 +209,11 @@ TEST_CASE("TensorNetwork::SliceIndices", "[TensorNetwork]")
CHECK(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = node.tensor.GetData();
const data_t want_tensor_data = {0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23};
const data_t want_tensor_data = {
{0, 0}, {1, 2}, {2, 4}, {3, 6}, {4, 8}, {5, 10},
{6, 12}, {7, 14}, {8, 16}, {9, 18}, {10, 20}, {11, 22},
{12, 24}, {13, 26}, {14, 28}, {15, 30}, {16, 32}, {17, 34},
{18, 36}, {19, 38}, {20, 40}, {21, 42}, {22, 44}, {23, 46}};
CHECK(have_tensor_data == want_tensor_data);
}

Expand All @@ -236,7 +239,9 @@ TEST_CASE("TensorNetwork::SliceIndices", "[TensorNetwork]")
CHECK(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = node.tensor.GetData();
const data_t want_tensor_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
const data_t want_tensor_data = {{0, 0}, {1, 2}, {2, 4}, {3, 6},
{4, 8}, {5, 10}, {6, 12}, {7, 14},
{8, 16}, {9, 18}, {10, 20}, {11, 22}};
CHECK(have_tensor_data == want_tensor_data);
}

Expand All @@ -262,8 +267,9 @@ TEST_CASE("TensorNetwork::SliceIndices", "[TensorNetwork]")
CHECK(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = node.tensor.GetData();
const data_t want_tensor_data = {12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23};
const data_t want_tensor_data = {
{12, 24}, {13, 26}, {14, 28}, {15, 30}, {16, 32}, {17, 34},
{18, 36}, {19, 38}, {20, 40}, {21, 42}, {22, 44}, {23, 46}};
CHECK(have_tensor_data == want_tensor_data);
}

Expand All @@ -289,7 +295,7 @@ TEST_CASE("TensorNetwork::SliceIndices", "[TensorNetwork]")
CHECK(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = node.tensor.GetData();
const data_t want_tensor_data = {14, 18, 22};
const data_t want_tensor_data = {{14, 28}, {18, 36}, {22, 44}};
CHECK(have_tensor_data == want_tensor_data);
}

Expand All @@ -315,7 +321,7 @@ TEST_CASE("TensorNetwork::SliceIndices", "[TensorNetwork]")
CHECK(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = node.tensor.GetData();
const data_t want_tensor_data = {23};
const data_t want_tensor_data = {{23, 46}};
CHECK(have_tensor_data == want_tensor_data);
}
}
Expand All @@ -336,7 +342,7 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
CHECK(have_path == want_path);

const data_t have_tensor_data = result.GetData();
const data_t want_tensor_data = {0, 1};
const data_t want_tensor_data = {{0, 0}, {1, 2}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand Down Expand Up @@ -365,7 +371,7 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
CHECK(have_path == want_path);

const data_t have_tensor_data = {result.GetValue({})};
const data_t want_tensor_data = {5};
const data_t want_tensor_data = {{-15, 20}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand Down Expand Up @@ -402,7 +408,7 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
CHECK(have_path == want_path);

const data_t have_tensor_data = result.GetData();
const data_t want_tensor_data = {10, 13};
const data_t want_tensor_data = {{-30, 40}, {-39, 52}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand Down Expand Up @@ -439,7 +445,7 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
CHECK(have_path == want_path);

const data_t have_tensor_data = {result.GetValue({})};
const data_t want_tensor_data = {55};
const data_t want_tensor_data = {{-165, 220}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand Down Expand Up @@ -476,11 +482,12 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
CHECK(have_path == want_path);

const shape_t have_tensor_shape = result.GetShape();
const shape_t want_tensor_shape = {3, 2};
const shape_t want_tensor_shape = {2, 3};
REQUIRE(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = result.GetData();
const data_t want_tensor_data = {5, 14, 14, 50, 23, 86};
const data_t want_tensor_data = {{-15, 20}, {-42, 56}, {-69, 92},
{-42, 56}, {-150, 200}, {-258, 344}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand All @@ -492,8 +499,8 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
{
const auto &node = nodes[2];
CHECK(node.id == 2);
CHECK(node.name == "C2A0");
CHECK(node.indices == indices_t{"C2", "A0"});
CHECK(node.name == "A0C2");
CHECK(node.indices == indices_t{"A0", "C2"});
CHECK(node.contracted == false);
}

Expand Down Expand Up @@ -524,7 +531,8 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
REQUIRE(have_tensor_shape == want_tensor_shape);

const data_t have_tensor_data = result.GetData();
const data_t want_tensor_data = {28, 100, 47, 164};
const data_t want_tensor_data = {
{-308, -56}, {-517, -94}, {-1100, -200}, {-1804, -328}};
CHECK(have_tensor_data == want_tensor_data);

const auto &nodes = tn.GetNodes();
Expand All @@ -537,15 +545,15 @@ TEST_CASE("TensorNetwork::Contract", "[TensorNetwork]")
{
const auto &node = nodes[3];
CHECK(node.id == 3);
CHECK(node.name == "D3B1");
CHECK(node.indices == indices_t{"D3", "B1"});
CHECK(node.name == "B1D3");
CHECK(node.indices == indices_t{"B1", "D3"});
CHECK(node.contracted == true);
}
{
const auto &node = nodes[4];
CHECK(node.id == 4);
CHECK(node.name == "D3A0");
CHECK(node.indices == indices_t{"D3", "A0"});
CHECK(node.name == "A0D3");
CHECK(node.indices == indices_t{"A0", "D3"});
CHECK(node.contracted == false);
}

Expand Down