Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static type checks for Jet::Tensor data #4

Merged
merged 10 commits into from
May 14, 2021
6 changes: 5 additions & 1 deletion include/jet/Tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ namespace Jet {
*/
template <class T = std::complex<float>> class Tensor {

static_assert(TensorHelpers::is_supported_data_type_v<T>,
"Tensor data type must be one of std::complex<float>, "
"std::complex<double>");

private:
std::vector<std::string> indices_;
std::vector<size_t> shape_;
Expand Down Expand Up @@ -679,4 +683,4 @@ inline std::ostream &operator<<(std::ostream &out, const Tensor<T> &tensor)
return out;
}

}; // namespace Jet
}; // namespace Jet
31 changes: 21 additions & 10 deletions include/jet/TensorHelpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@
namespace Jet {
namespace TensorHelpers {

/**
* If T is a supported data type for tensors, this struct provides
* the static member constant value equal to true. Otherwise value
* is false.
*
* Supported data types are std::complex<float> and std::complex<double>.
*
* @tparam T candidate data type
*/
template <class T> struct is_supported_data_type {

static const bool value = (std::is_same_v<T, std::complex<float>> ||
std::is_same_v<T, std::complex<double>>);
};

template <class T>
constexpr bool is_supported_data_type_v = is_supported_data_type<T>::value;
brownj85 marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Compile-time binding for BLAS GEMM operation (matrix-matrix product).
*
Expand Down Expand Up @@ -45,9 +63,6 @@ gemmBinding(size_t m, size_t n, size_t k, ComplexPrecision alpha,
cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, &alpha,
A_data, std::max(1ul, k), B_data, std::max(1ul, n), &beta,
C_data, std::max(1ul, n));
else
JET_ABORT(
"Please use complex<float> or complex<double for Tensor data");
};

/**
Expand Down Expand Up @@ -75,9 +90,6 @@ gemvBinding(size_t m, size_t k, ComplexPrecision alpha, ComplexPrecision beta,
else if constexpr (std::is_same_v<ComplexPrecision, std::complex<double>>)
cblas_zgemv(CblasRowMajor, CblasNoTrans, m, k, (&alpha), (A_data),
std::max(1ul, k), (B_data), 1, (&beta), (C_data), 1);
else
JET_ABORT(
"Please use complex<float> or complex<double for Tensor data");
};

/**
Expand All @@ -100,9 +112,6 @@ constexpr void dotuBinding(size_t k, const ComplexPrecision *A_data,
cblas_cdotu_sub(k, (A_data), 1, (B_data), 1, (C_data));
else if constexpr (std::is_same_v<ComplexPrecision, std::complex<double>>)
cblas_zdotu_sub(k, (A_data), 1, (B_data), 1, (C_data));
else
JET_ABORT(
"Please use complex<float> or complex<double for Tensor data");
};

/**
Expand All @@ -120,7 +129,9 @@ constexpr void dotuBinding(size_t k, const ComplexPrecision *A_data,
* @param right_dim Columns in right tensor B and resulting tensor C.
* @param common_dim Rows in left tensor A and columns in right tensor B.
*/
template <typename ComplexPrecision>
template <
typename ComplexPrecision,
std::enable_if_t<is_supported_data_type_v<ComplexPrecision>, bool> = true>
inline void MultiplyTensorData(const std::vector<ComplexPrecision> &A,
const std::vector<ComplexPrecision> &B,
std::vector<ComplexPrecision> &C,
Expand Down