diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index c69337b9..b0cd32fb 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -2,10 +2,11 @@ ### New features since last release -* Running CMake with `-DBUILD_PYTHON=ON` now generates Python bindings within a `jet` package. [(#1)](https://github.com/XanaduAI/jet/pull/1) - ### Improvements +* `Tensor` class now checks data type at compile-time. [(#4)](https://github.com/XanaduAI/jet/pull/4) +* Running CMake with `-DBUILD_PYTHON=ON` now generates Python bindings within a `jet` package. [(#1)](https://github.com/XanaduAI/jet/pull/1) + ### Breaking Changes ### Bug Fixes @@ -16,7 +17,7 @@ This release contains contributions from (in alphabetical order): -[Mikhail Andrenkov](https://github.com/Mandrenkov) and [Jack Brown](https://github.com/brownj85). +[Mikhail Andrenkov](https://github.com/Mandrenkov), [Jack Brown](https://github.com/brownj85). ## Release 0.1.0 (current release) diff --git a/include/jet/Tensor.hpp b/include/jet/Tensor.hpp index b58203bc..0c091f1f 100644 --- a/include/jet/Tensor.hpp +++ b/include/jet/Tensor.hpp @@ -37,6 +37,10 @@ namespace Jet { */ template > class Tensor { + static_assert(TensorHelpers::is_supported_data_type, + "Tensor data type must be one of std::complex, " + "std::complex"); + private: std::vector indices_; std::vector shape_; @@ -679,4 +683,4 @@ inline std::ostream &operator<<(std::ostream &out, const Tensor &tensor) return out; } -}; // namespace Jet \ No newline at end of file +}; // namespace Jet diff --git a/include/jet/TensorHelpers.hpp b/include/jet/TensorHelpers.hpp index 19d9947c..99074f1d 100644 --- a/include/jet/TensorHelpers.hpp +++ b/include/jet/TensorHelpers.hpp @@ -17,6 +17,19 @@ namespace Jet { namespace TensorHelpers { +/** + * If T is a supported data type for tensors, this expression will + * evaluate to `true`. Otherwise, it will evaluate to `false`. + * + * Supported data types are std::complex and std::complex. + * + * @tparam T candidate data type + */ +template +constexpr bool is_supported_data_type = + std::is_same_v> || + std::is_same_v>; + /** * @brief Compile-time binding for BLAS GEMM operation (matrix-matrix product). * @@ -45,9 +58,6 @@ gemmBinding(size_t m, size_t n, size_t k, ComplexPrecision alpha, cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, &alpha, A_data, std::max(1ul, k), B_data, std::max(1ul, n), &beta, C_data, std::max(1ul, n)); - else - JET_ABORT( - "Please use complex or complex>) cblas_zgemv(CblasRowMajor, CblasNoTrans, m, k, (&alpha), (A_data), std::max(1ul, k), (B_data), 1, (&beta), (C_data), 1); - else - JET_ABORT( - "Please use complex or complex>) cblas_zdotu_sub(k, (A_data), 1, (B_data), 1, (C_data)); - else - JET_ABORT( - "Please use complex or complex +template < + typename ComplexPrecision, + std::enable_if_t, bool> = true> inline void MultiplyTensorData(const std::vector &A, const std::vector &B, std::vector &C,