Skip to content

Commit

Permalink
Implement csqrt (#619)
Browse files Browse the repository at this point in the history
Adding csqrt implementation with casts as c99 csqrt can not be used with nvcc
  • Loading branch information
tylera-nvidia authored May 13, 2024
1 parent 46cc004 commit 480dab7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
17 changes: 17 additions & 0 deletions include/matx/operators/scalar_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,23 @@ template <typename T> struct SqrtF {

template <typename T> using SqrtOp = UnOp<T, SqrtF<T>>;

template <typename T>
static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto _internal_csqrt(T v1)
{
static_assert(std::is_floating_point_v<T>, "csqrt() only supports non-complex floating point inputs");
return sqrt(static_cast<cuda::std::complex<T>>(v1));
}

template <typename T> struct CSqrtF {
static __MATX_INLINE__ std::string str() { return "csqrt"; }
static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto op(T v1)
{
return _internal_csqrt(v1);
}
};


template <typename T> using CsqrtOp = UnOp<T, CSqrtF<T>>;

template <typename T>
static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto _internal_conj(T v1)
Expand Down
1 change: 1 addition & 0 deletions include/matx/operators/unary_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ namespace matx

#else
DEFINE_UNARY_OP(sqrt, detail::SqrtOp);
DEFINE_UNARY_OP(csqrt, detail::CsqrtOp);
DEFINE_UNARY_OP(exp, detail::ExpOp);
DEFINE_UNARY_OP(expj, detail::ExpjOp);
DEFINE_UNARY_OP(log10, detail::Log10Op);
Expand Down

0 comments on commit 480dab7

Please sign in to comment.