Skip to content

Commit

Permalink
Static type checks for Jet::Tensor data (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
brownj85 authored May 14, 2021
1 parent f5debfb commit b9eeabc
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
7 changes: 4 additions & 3 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

### New features since last release

* Running CMake with `-DBUILD_PYTHON=ON` now generates Python bindings within a `jet` package. [(#1)](https://github.com/XanaduAI/jet/pull/1)

### Improvements

* `Tensor` class now checks data type at compile-time. [(#4)](https://github.com/XanaduAI/jet/pull/4)
* Running CMake with `-DBUILD_PYTHON=ON` now generates Python bindings within a `jet` package. [(#1)](https://github.com/XanaduAI/jet/pull/1)

### Breaking Changes

### Bug Fixes
Expand All @@ -16,7 +17,7 @@

This release contains contributions from (in alphabetical order):

[Mikhail Andrenkov](https://github.com/Mandrenkov) and [Jack Brown](https://github.com/brownj85).
[Mikhail Andrenkov](https://github.com/Mandrenkov), [Jack Brown](https://github.com/brownj85).

## Release 0.1.0 (current release)

Expand Down
6 changes: 5 additions & 1 deletion include/jet/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ namespace Jet {
*/
template <class T = std::complex<float>> class Tensor {

static_assert(TensorHelpers::is_supported_data_type<T>,
"Tensor data type must be one of std::complex<float>, "
"std::complex<double>");

private:
std::vector<std::string> indices_;
std::vector<size_t> shape_;
Expand Down Expand Up @@ -679,4 +683,4 @@ inline std::ostream &operator<<(std::ostream &out, const Tensor<T> &tensor)
return out;
}

}; // namespace Jet
}; // namespace Jet
26 changes: 16 additions & 10 deletions include/jet/TensorHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
namespace Jet {
namespace TensorHelpers {

/**
* If T is a supported data type for tensors, this expression will
* evaluate to `true`. Otherwise, it will evaluate to `false`.
*
* Supported data types are std::complex<float> and std::complex<double>.
*
* @tparam T candidate data type
*/
template <class T>
constexpr bool is_supported_data_type =
std::is_same_v<T, std::complex<float>> ||
std::is_same_v<T, std::complex<double>>;

/**
* @brief Compile-time binding for BLAS GEMM operation (matrix-matrix product).
*
Expand Down Expand Up @@ -45,9 +58,6 @@ gemmBinding(size_t m, size_t n, size_t k, ComplexPrecision alpha,
cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, &alpha,
A_data, std::max(1ul, k), B_data, std::max(1ul, n), &beta,
C_data, std::max(1ul, n));
else
JET_ABORT(
"Please use complex<float> or complex<double for Tensor data");
};

/**
Expand Down Expand Up @@ -75,9 +85,6 @@ gemvBinding(size_t m, size_t k, ComplexPrecision alpha, ComplexPrecision beta,
else if constexpr (std::is_same_v<ComplexPrecision, std::complex<double>>)
cblas_zgemv(CblasRowMajor, CblasNoTrans, m, k, (&alpha), (A_data),
std::max(1ul, k), (B_data), 1, (&beta), (C_data), 1);
else
JET_ABORT(
"Please use complex<float> or complex<double for Tensor data");
};

/**
Expand All @@ -100,9 +107,6 @@ constexpr void dotuBinding(size_t k, const ComplexPrecision *A_data,
cblas_cdotu_sub(k, (A_data), 1, (B_data), 1, (C_data));
else if constexpr (std::is_same_v<ComplexPrecision, std::complex<double>>)
cblas_zdotu_sub(k, (A_data), 1, (B_data), 1, (C_data));
else
JET_ABORT(
"Please use complex<float> or complex<double for Tensor data");
};

/**
Expand All @@ -120,7 +124,9 @@ constexpr void dotuBinding(size_t k, const ComplexPrecision *A_data,
* @param right_dim Columns in right tensor B and resulting tensor C.
* @param common_dim Rows in left tensor A and columns in right tensor B.
*/
template <typename ComplexPrecision>
template <
typename ComplexPrecision,
std::enable_if_t<is_supported_data_type<ComplexPrecision>, bool> = true>
inline void MultiplyTensorData(const std::vector<ComplexPrecision> &A,
const std::vector<ComplexPrecision> &B,
std::vector<ComplexPrecision> &C,
Expand Down

0 comments on commit b9eeabc

Please sign in to comment.