Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RAII issue by introducing wrapper classes for backend plans #208

Merged
merged 66 commits into from
Jan 7, 2025
Merged
Changes from 1 commit
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
8fb15ac
fix: conflicts
Dec 5, 2024
737f3ca
Fix and wrapper for FFTW handle
Dec 3, 2024
95419cb
Wrapper for cufft handle
Dec 3, 2024
9baa59d
fix: conflicts
Dec 5, 2024
8d952e5
Wrapper for rocfft handle
Dec 3, 2024
dd907a8
fix: conflicts
Dec 3, 2024
6b1059b
Cleanup plan class based on the introduction of wrappers
Dec 3, 2024
f46c56d
fix: conflicts
Dec 5, 2024
5116c1c
fix: conflicts
Dec 5, 2024
b93c07c
fix: unused variable
Dec 3, 2024
1b908fe
fix: work buffer allocation
Dec 3, 2024
4010401
remove unused variable
Dec 3, 2024
ac5b34d
remove unused lines
Dec 3, 2024
734a55a
Add missing include header file in KokkosFFT_ROCM_types.hpp
Dec 3, 2024
f161d46
fix: fftwHandle type in SYCL types
Dec 4, 2024
31cd597
Do not return const plan type for fftw
Dec 4, 2024
2030d04
fix: remove const
Dec 4, 2024
a8741bd
fix: fftw plan creation
Dec 4, 2024
a3c94e3
fix: set created
Dec 4, 2024
0d8a616
fix: cleanup
Dec 4, 2024
f1f4f30
fix constructor of fftw wrapper
Dec 4, 2024
2b75678
fix: conflicts
Dec 5, 2024
88310c9
Remove non-default constructors from FFTW wrapper
Dec 5, 2024
995bb4a
Remove non-default constructors from cufft wrapper
Dec 5, 2024
751810c
Remove non-default constructors from hipfft wrapper
Dec 5, 2024
4ba04f0
Remove non-default constructors from rocfft wrapper
Dec 5, 2024
2eac65c
update FFTW wrapper class name
Dec 5, 2024
35afb3d
fix: host plan type
Dec 5, 2024
5c25bce
fix: fftw rapper name in ROCM_types
Dec 5, 2024
c363d3d
update cuda backed based on reviews
Dec 17, 2024
975b2f1
update hip backend based on reviews
Dec 17, 2024
405b36f
update rocm backend based on reviews
Dec 17, 2024
d394672
update host backend based on revies
Dec 17, 2024
7c085eb
fix: Rocm types
Dec 17, 2024
a6ccd56
fix: ROCM types
Dec 17, 2024
701136e
fix: Rocm types
Dec 17, 2024
622ac0e
fix: header files
Dec 17, 2024
1d33d7a
fix: rocm types
Dec 17, 2024
34ed2b3
fix: rocm types
Dec 18, 2024
45c489b
remove unused lines
Dec 18, 2024
3fdd5bf
fix: rocm types
Dec 18, 2024
645696c
Improve the cleanup logic for cufft plan
Dec 19, 2024
a350598
Improve the cleanup logic for hipfft plan
Dec 19, 2024
27501ac
Improve the cleanup logic for rocfft plan
Dec 19, 2024
4bde4d2
simplify fftw plan wrapper
Dec 19, 2024
6f3c535
fix: rocm types
Dec 19, 2024
5e81354
fix: scoped rocfft plan type
Dec 19, 2024
0d9119d
return execution_info by value in scoped rocfft plan
Dec 19, 2024
f181066
Add commit method to scoped cufft plan
Dec 20, 2024
f7944c8
Add commit method to scoped hipfft plan
Dec 20, 2024
0c6b33b
Add commit method to scoped rocfft plan
Dec 20, 2024
944f246
Add const qualifer for host transforms
Dec 20, 2024
eaf5354
fix: ROCM types
Dec 20, 2024
b405320
fix cleanup of ScopedCufft and ScopedHIPfft plan
Jan 2, 2025
28d9891
Add ScopedExecutionInfo for rocm backend
Jan 2, 2025
c7dd94d
fix KokkosFFT_ROCM_types.hpp
Jan 2, 2025
e3e7f0e
fix: KokkosFFT_ROCM_types.hpp
Jan 2, 2025
88d4c43
make commit method const
Jan 6, 2025
46d2faf
call fftw_cleanup_threads only once
Jan 6, 2025
edab676
remove static from init and cleanup methods
Jan 6, 2025
154c12d
use local static object for global initialization and finalization
Jan 6, 2025
e8aa2ec
remove cleanup threads for safety
Jan 7, 2025
c03d035
remove unused header from KokkosFFT_FFTW_Types.hpp
Jan 7, 2025
8df33b0
delete non-default constructors for Rocfft wrappers
Jan 7, 2025
c891b7a
fix: KokkosFFT_ROCM_types.hpp
Jan 7, 2025
0119961
Add Thomas as a co-author
Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix and wrapper for FFTW handle
Yuuichi Asahi committed Dec 17, 2024
commit 737f3ca614b67bbcca90f0f69db4108a61daf6b3
152 changes: 152 additions & 0 deletions fft/src/KokkosFFT_FFTW_Types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// SPDX-FileCopyrightText: (C) The kokkos-fft development team, see COPYRIGHT.md file
//
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception

#ifndef KOKKOSFFT_FFTW_TYPES_HPP
#define KOKKOSFFT_FFTW_TYPES_HPP

#include <fftw3.h>
#include <Kokkos_Core.hpp>
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_utils.hpp"

// Check the size of complex type
static_assert(sizeof(fftwf_complex) == sizeof(Kokkos::complex<float>));
static_assert(alignof(fftwf_complex) <= alignof(Kokkos::complex<float>));

static_assert(sizeof(fftw_complex) == sizeof(Kokkos::complex<double>));
static_assert(alignof(fftw_complex) <= alignof(Kokkos::complex<double>));

namespace KokkosFFT {
namespace Impl {

enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z };

// Define fft transform types
template <typename ExecutionSpace, typename T1, typename T2>
struct fftw_transform_type {
static_assert(std::is_same_v<T1, T2>,
"Real to real transform is unavailable");
};

template <typename ExecutionSpace, typename T1, typename T2>
struct fftw_transform_type<ExecutionSpace, T1, Kokkos::complex<T2>> {
static_assert(std::is_same_v<T1, T2>,
"T1 and T2 should have the same precision");
static constexpr FFTWTransformType m_type = std::is_same_v<T1, float>
? FFTWTransformType::R2C
: FFTWTransformType::D2Z;
static constexpr FFTWTransformType type() { return m_type; };
};

template <typename ExecutionSpace, typename T1, typename T2>
struct fftw_transform_type<ExecutionSpace, Kokkos::complex<T1>, T2> {
static_assert(std::is_same_v<T1, T2>,
"T1 and T2 should have the same precision");
static constexpr FFTWTransformType m_type = std::is_same_v<T2, float>
? FFTWTransformType::C2R
: FFTWTransformType::Z2D;
static constexpr FFTWTransformType type() { return m_type; };
};

template <typename ExecutionSpace, typename T1, typename T2>
struct fftw_transform_type<ExecutionSpace, Kokkos::complex<T1>,
Kokkos::complex<T2>> {
static_assert(std::is_same_v<T1, T2>,
"T1 and T2 should have the same precision");
static constexpr FFTWTransformType m_type = std::is_same_v<T1, float>
? FFTWTransformType::C2C
: FFTWTransformType::Z2Z;
static constexpr FFTWTransformType type() { return m_type; };
};

/// \brief A class that wraps fftw_plan and fftwf_plan for RAII
template <typename ExecutionSpace, typename T1, typename T2>
struct ScopedFFTWPlanType {
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T1>;
using plan_type =
std::conditional_t<std::is_same_v<floating_point_type, float>, fftwf_plan,
fftw_plan>;
plan_type m_plan;
bool m_is_created = false;

ScopedFFTWPlanType() {}
~ScopedFFTWPlanType() {
cleanup_threads<floating_point_type>();
if constexpr (std::is_same_v<floating_point_type, float>) {
if (m_is_created) fftwf_destroy_plan(m_plan);
} else {
if (m_is_created) fftw_destroy_plan(m_plan);
}
}

const plan_type &plan() const { return m_plan; }

template <typename InScalarType, typename OutScalarType>
void create(const ExecutionSpace &exec_space, int rank, const int *n,
int howmany, InScalarType *in, const int *inembed, int istride,
int idist, OutScalarType *out, const int *onembed, int ostride,
int odist, [[maybe_unused]] int sign, unsigned flags) {
init_threads<floating_point_type>(exec_space);

constexpr auto type = fftw_transform_type<ExecutionSpace, T1, T2>::type();

if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) {
m_plan =
fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) {
m_plan =
fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) {
m_plan =
fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) {
m_plan =
fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) {
m_plan =
fftwf_plan_many_dft(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, sign, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) {
m_plan = fftw_plan_many_dft(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, sign, flags);
}
m_is_created = true;
}

private:
template <typename T>
void init_threads([[maybe_unused]] const ExecutionSpace &exec_space) {
tpadioleau marked this conversation as resolved.
Show resolved Hide resolved
#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS)
int nthreads = exec_space.concurrency();

if constexpr (std::is_same_v<T, float>) {
fftwf_init_threads();
fftwf_plan_with_nthreads(nthreads);
} else {
fftw_init_threads();
fftw_plan_with_nthreads(nthreads);
}
#endif
}

template <typename T>
void cleanup_threads() {
tpadioleau marked this conversation as resolved.
Show resolved Hide resolved
#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS)
if constexpr (std::is_same_v<T, float>) {
fftwf_cleanup_threads();
} else {
fftw_cleanup_threads();
}
#endif
}
};

} // namespace Impl
} // namespace KokkosFFT

#endif
9 changes: 4 additions & 5 deletions fft/src/KokkosFFT_Host_plans.hpp
Original file line number Diff line number Diff line change
@@ -14,15 +14,14 @@ namespace KokkosFFT {
namespace Impl {
// batched transform, over ND Views
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType, typename BufferViewType, typename InfoType,
std::size_t fft_rank = 1,
typename OutViewType, std::size_t fft_rank = 1,
std::enable_if_t<is_AnyHostSpace_v<ExecutionSpace>, std::nullptr_t> =
nullptr>
auto create_plan(const ExecutionSpace& exec_space,
std::unique_ptr<PlanType>& plan, const InViewType& in,
const OutViewType& out, BufferViewType&, InfoType&,
Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s, bool is_inplace) {
const OutViewType& out, Direction direction,
axis_type<fft_rank> axes, shape_type<fft_rank> s,
bool is_inplace) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
48 changes: 24 additions & 24 deletions fft/src/KokkosFFT_Host_transform.hpp
Original file line number Diff line number Diff line change
@@ -9,40 +9,40 @@

namespace KokkosFFT {
namespace Impl {
template <typename PlanType, typename... Args>
void exec_plan(PlanType& plan, float* idata, fftwf_complex* odata,
int /*direction*/, Args...) {
fftwf_execute_dft_r2c(plan, idata, odata);
template <typename ScopedPlanType>
yasahi-hpc marked this conversation as resolved.
Show resolved Hide resolved
void exec_plan(ScopedPlanType& scoped_plan, float* idata, fftwf_complex* odata,
yasahi-hpc marked this conversation as resolved.
Show resolved Hide resolved
int /*direction*/) {
fftwf_execute_dft_r2c(scoped_plan.plan(), idata, odata);
}

template <typename PlanType, typename... Args>
void exec_plan(PlanType& plan, double* idata, fftw_complex* odata,
int /*direction*/, Args...) {
fftw_execute_dft_r2c(plan, idata, odata);
template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, double* idata, fftw_complex* odata,
int /*direction*/) {
fftw_execute_dft_r2c(scoped_plan.plan(), idata, odata);
}

template <typename PlanType, typename... Args>
void exec_plan(PlanType& plan, fftwf_complex* idata, float* odata,
int /*direction*/, Args...) {
fftwf_execute_dft_c2r(plan, idata, odata);
template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, fftwf_complex* idata, float* odata,
int /*direction*/) {
fftwf_execute_dft_c2r(scoped_plan.plan(), idata, odata);
}

template <typename PlanType, typename... Args>
void exec_plan(PlanType& plan, fftw_complex* idata, double* odata,
int /*direction*/, Args...) {
fftw_execute_dft_c2r(plan, idata, odata);
template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, fftw_complex* idata, double* odata,
int /*direction*/) {
fftw_execute_dft_c2r(scoped_plan.plan(), idata, odata);
}

template <typename PlanType, typename... Args>
void exec_plan(PlanType& plan, fftwf_complex* idata, fftwf_complex* odata,
int /*direction*/, Args...) {
fftwf_execute_dft(plan, idata, odata);
template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, fftwf_complex* idata,
fftwf_complex* odata, int /*direction*/) {
fftwf_execute_dft(scoped_plan.plan(), idata, odata);
}

template <typename PlanType, typename... Args>
void exec_plan(PlanType plan, fftw_complex* idata, fftw_complex* odata,
int /*direction*/, Args...) {
fftw_execute_dft(plan, idata, odata);
template <typename ScopedPlanType>
void exec_plan(ScopedPlanType scoped_plan, fftw_complex* idata,
fftw_complex* odata, int /*direction*/) {
fftw_execute_dft(scoped_plan.plan(), idata, odata);
}
} // namespace Impl
} // namespace KokkosFFT
107 changes: 4 additions & 103 deletions fft/src/KokkosFFT_Host_types.hpp
Original file line number Diff line number Diff line change
@@ -5,29 +5,12 @@
#ifndef KOKKOSFFT_HOST_TYPES_HPP
#define KOKKOSFFT_HOST_TYPES_HPP

#include <fftw3.h>
#include <Kokkos_Core.hpp>
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_utils.hpp"

// Check the size of complex type
static_assert(sizeof(fftwf_complex) == sizeof(Kokkos::complex<float>));
static_assert(alignof(fftwf_complex) <= alignof(Kokkos::complex<float>));

static_assert(sizeof(fftw_complex) == sizeof(Kokkos::complex<double>));
static_assert(alignof(fftw_complex) <= alignof(Kokkos::complex<double>));
#include "KokkosFFT_FFTW_Types.hpp"

namespace KokkosFFT {
namespace Impl {

using FFTDirectionType = int;

// Unused
template <typename ExecutionSpace>
using FFTInfoType = int;

enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z };

template <typename ExecutionSpace>
struct FFTDataType {
using float32 = float;
@@ -39,6 +22,7 @@ struct FFTDataType {
template <typename ExecutionSpace>
using TransformType = FFTWTransformType;

/*
// Define fft transform types
template <typename ExecutionSpace, typename T1, typename T2>
struct transform_type {
@@ -76,93 +60,10 @@ struct transform_type<ExecutionSpace, Kokkos::complex<T1>,
: FFTWTransformType::Z2Z;
static constexpr FFTWTransformType type() { return m_type; };
};
*/

/// \brief A class that wraps fftw_plan and fftwf_plan for RAII
template <typename ExecutionSpace, typename T1, typename T2>
struct ScopedFFTWPlanType {
using floating_point_type = KokkosFFT::Impl::base_floating_point_type<T1>;
using plan_type =
std::conditional_t<std::is_same_v<floating_point_type, float>, fftwf_plan,
fftw_plan>;
plan_type m_plan;
bool m_is_created = false;

ScopedFFTWPlanType() {}
~ScopedFFTWPlanType() {
cleanup_threads<floating_point_type>();
if constexpr (std::is_same_v<floating_point_type, float>) {
if (m_is_created) fftwf_destroy_plan(m_plan);
} else {
if (m_is_created) fftw_destroy_plan(m_plan);
}
}

const plan_type &plan() const { return m_plan; }

template <typename InScalarType, typename OutScalarType>
void create(const ExecutionSpace &exec_space, int rank, const int *n,
int howmany, InScalarType *in, const int *inembed, int istride,
int idist, OutScalarType *out, const int *onembed, int ostride,
int odist, [[maybe_unused]] int sign, unsigned flags) {
init_threads<floating_point_type>(exec_space);

constexpr auto type =
KokkosFFT::Impl::transform_type<ExecutionSpace, T1, T2>::type();

if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) {
m_plan =
fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) {
m_plan =
fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) {
m_plan =
fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) {
m_plan =
fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) {
m_plan =
fftwf_plan_many_dft(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, sign, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) {
m_plan = fftw_plan_many_dft(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, sign, flags);
}
m_is_created = true;
}

private:
template <typename T>
void init_threads([[maybe_unused]] const ExecutionSpace &exec_space) {
#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS)
int nthreads = exec_space.concurrency();

if constexpr (std::is_same_v<T, float>) {
fftwf_init_threads();
fftwf_plan_with_nthreads(nthreads);
} else {
fftw_init_threads();
fftw_plan_with_nthreads(nthreads);
}
#endif
}

template <typename T>
void cleanup_threads() {
#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS)
if constexpr (std::is_same_v<T, float>) {
fftwf_cleanup_threads();
} else {
fftw_cleanup_threads();
}
#endif
}
};
using transform_type = fftw_transform_type<ExecutionSpace, T1, T2>;

template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {