-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for spreadinterponly
in finufft
#599
base: master
Are you sure you want to change the base?
Changes from 6 commits
838245b
a033f3b
a67b717
d77cbb1
2094338
d0d60fe
305482b
09a9d0c
90a0675
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -530,6 +530,7 @@ void finufft_default_opts_t(finufft_opts *o) | |
o->spread_kerpad = 1; | ||
o->upsampfac = 0.0; | ||
o->spread_thread = 0; | ||
o->spreadinterponly = 0; | ||
o->maxbatchsize = 0; | ||
o->spread_nthr_atomic = -1; | ||
o->spread_max_sp_size = 0; | ||
|
@@ -560,8 +561,9 @@ FINUFFT_PLAN_T<TF>::FINUFFT_PLAN_T(int type_, int dim_, const BIGINT *n_modes, i | |
printf("[%s] new plan: FINUFFT version " FINUFFT_VER " .................\n", | ||
__func__); | ||
|
||
fftPlan = std::make_unique<Finufft_FFT_plan<TF>>( | ||
opts.fftw_lock_fun, opts.fftw_unlock_fun, opts.fftw_lock_data); | ||
if (!opts.spreadinterponly) // Dont make plans if only spread or interpolate | ||
fftPlan = std::make_unique<Finufft_FFT_plan<TF>>( | ||
opts.fftw_lock_fun, opts.fftw_unlock_fun, opts.fftw_lock_data); | ||
|
||
if ((type != 1) && (type != 2) && (type != 3)) { | ||
fprintf(stderr, "[%s] Invalid type (%d), should be 1, 2 or 3.\n", __func__, type); | ||
|
@@ -668,66 +670,72 @@ FINUFFT_PLAN_T<TF>::FINUFFT_PLAN_T(int type_, int dim_, const BIGINT *n_modes, i | |
__func__, (double)(EPSILON * mu)); | ||
} | ||
|
||
// determine fine grid sizes, sanity check.. | ||
int nfier = set_nf_type12(ms, opts, spopts, &nf1); | ||
if (nfier) throw nfier; // nf too big; we're done | ||
phiHat1.resize(nf1 / 2 + 1); | ||
if (dim > 1) { | ||
nfier = set_nf_type12(mt, opts, spopts, &nf2); | ||
if (nfier) throw nfier; | ||
phiHat2.resize(nf2 / 2 + 1); | ||
} | ||
if (dim > 2) { | ||
nfier = set_nf_type12(mu, opts, spopts, &nf3); | ||
if (nfier) throw nfier; | ||
phiHat3.resize(nf3 / 2 + 1); | ||
} | ||
|
||
if (opts.debug) { // "long long" here is to avoid warnings with printf... | ||
printf("[%s] %dd%d: (ms,mt,mu)=(%lld,%lld,%lld) " | ||
"(nf1,nf2,nf3)=(%lld,%lld,%lld)\n ntrans=%d nthr=%d " | ||
"batchSize=%d ", | ||
__func__, dim, type, (long long)ms, (long long)mt, (long long)mu, | ||
(long long)nf1, (long long)nf2, (long long)nf3, ntrans, nthr, batchSize); | ||
if (batchSize == 1) // spread_thread has no effect in this case | ||
printf("\n"); | ||
else | ||
printf(" spread_thread=%d\n", opts.spread_thread); | ||
} | ||
if(!opts.spreadinterponly) // We dont need fseries if it is spreadinterponly. | ||
{ | ||
// determine fine grid sizes, sanity check.. | ||
phiHat1.resize(nf1 / 2 + 1); | ||
if (dim > 1) { | ||
phiHat2.resize(nf2 / 2 + 1); | ||
} | ||
if (dim > 2) { | ||
phiHat3.resize(nf3 / 2 + 1); | ||
} | ||
|
||
// STEP 0: get Fourier coeffs of spreading kernel along each fine grid dim | ||
CNTime timer; | ||
timer.start(); | ||
onedim_fseries_kernel(nf1, phiHat1, spopts); | ||
if (dim > 1) onedim_fseries_kernel(nf2, phiHat2, spopts); | ||
if (dim > 2) onedim_fseries_kernel(nf3, phiHat3, spopts); | ||
if (opts.debug) | ||
printf("[%s] kernel fser (ns=%d):\t\t%.3g s\n", __func__, spopts.nspread, | ||
timer.elapsedsec()); | ||
if (opts.debug) { // "long long" here is to avoid warnings with printf... | ||
printf("[%s] %dd%d: (ms,mt,mu)=(%lld,%lld,%lld) " | ||
"(nf1,nf2,nf3)=(%lld,%lld,%lld)\n ntrans=%d nthr=%d " | ||
"batchSize=%d ", | ||
__func__, dim, type, (long long)ms, (long long)mt, (long long)mu, | ||
(long long)nf1, (long long)nf2, (long long)nf3, ntrans, nthr, batchSize); | ||
if (batchSize == 1) // spread_thread has no effect in this case | ||
printf("\n"); | ||
else | ||
printf(" spread_thread=%d\n", opts.spread_thread); | ||
} | ||
|
||
nf = nf1 * nf2 * nf3; // fine grid total number of points | ||
if (nf * batchSize > MAX_NF) { | ||
fprintf( | ||
stderr, | ||
"[%s] fwBatch would be bigger than MAX_NF, not attempting memory allocation!\n", | ||
__func__); | ||
throw int(FINUFFT_ERR_MAXNALLOC); | ||
// STEP 0: get Fourier coeffs of spreading kernel along each fine grid dim | ||
CNTime timer; | ||
timer.start(); | ||
onedim_fseries_kernel(nf1, phiHat1, spopts); | ||
if (dim > 1) onedim_fseries_kernel(nf2, phiHat2, spopts); | ||
if (dim > 2) onedim_fseries_kernel(nf3, phiHat3, spopts); | ||
if (opts.debug) | ||
printf("[%s] kernel fser (ns=%d):\t\t%.3g s\n", __func__, spopts.nspread, | ||
timer.elapsedsec()); | ||
|
||
nf = nf1 * nf2 * nf3; // fine grid total number of points | ||
if (nf * batchSize > MAX_NF) { | ||
fprintf( | ||
stderr, | ||
"[%s] fwBatch would be bigger than MAX_NF, not attempting memory allocation!\n", | ||
__func__); | ||
throw int(FINUFFT_ERR_MAXNALLOC); | ||
} | ||
|
||
timer.restart(); | ||
fwBatch.resize(nf * batchSize); // the big workspace | ||
if (opts.debug) | ||
printf("[%s] fwBatch %.2fGB alloc: \t%.3g s\n", __func__, | ||
(double)1E-09 * sizeof(std::complex<TF>) * nf * batchSize, | ||
timer.elapsedsec()); | ||
|
||
timer.restart(); // plan the FFTW | ||
const auto ns = gridsize_for_fft(this); | ||
fftPlan->plan(ns, batchSize, fwBatch.data(), fftSign, opts.fftw, nthr_fft); | ||
if (opts.debug) | ||
printf("[%s] FFT plan (mode %d, nthr=%d):\t%.3g s\n", __func__, opts.fftw, nthr_fft, | ||
timer.elapsedsec()); | ||
} | ||
|
||
timer.restart(); | ||
fwBatch.resize(nf * batchSize); // the big workspace | ||
if (opts.debug) | ||
printf("[%s] fwBatch %.2fGB alloc: \t%.3g s\n", __func__, | ||
(double)1E-09 * sizeof(std::complex<TF>) * nf * batchSize, | ||
timer.elapsedsec()); | ||
|
||
timer.restart(); // plan the FFTW | ||
const auto ns = gridsize_for_fft(this); | ||
fftPlan->plan(ns, batchSize, fwBatch.data(), fftSign, opts.fftw, nthr_fft); | ||
if (opts.debug) | ||
printf("[%s] FFT plan (mode %d, nthr=%d):\t%.3g s\n", __func__, opts.fftw, nthr_fft, | ||
timer.elapsedsec()); | ||
|
||
} else { // -------------------------- type 3 (no planning) ------------ | ||
|
||
if (opts.debug) printf("[%s] %dd%d: ntrans=%d\n", __func__, dim, type, ntrans); | ||
|
@@ -1041,19 +1049,30 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk) { | |
// STEP 1: (varies by type) | ||
timer.restart(); | ||
if (type == 1) { // type 1: spread NU pts X, weights cj, to fw grid | ||
if (opts.spreadinterponly) | ||
wrapArrayInVector(fkb, thisBatchSize*N, this->fwBatch); | ||
spreadinterpSortedBatch<TF>(thisBatchSize, this, cjb); | ||
t_sprint += timer.elapsedsec(); | ||
} else { // type 2: amplify Fourier coeffs fk into 0-padded fw | ||
// Stop here if it is spread interp only. | ||
if (opts.spreadinterponly) | ||
{ | ||
releaseVectorWrapper(this->fwBatch); | ||
continue; | ||
} | ||
} else if(!opts.spreadinterponly) { // type 2: amplify Fourier coeffs fk into 0-padded fw, but dont do it if it is spread interp only. | ||
deconvolveBatch<TF>(thisBatchSize, this, fkb); | ||
t_deconv += timer.elapsedsec(); | ||
} | ||
|
||
// STEP 2: call the FFT on this batch | ||
timer.restart(); | ||
do_fft(this); | ||
t_fft += timer.elapsedsec(); | ||
if (opts.debug > 1) printf("\tFFT exec:\t\t%.3g s\n", timer.elapsedsec()); | ||
|
||
if (!opts.spreadinterponly) // Do FFT only if its not spread interp only. | ||
{ | ||
// STEP 2: call the FFT on this batch | ||
timer.restart(); | ||
do_fft(this); | ||
t_fft += timer.elapsedsec(); | ||
if (opts.debug > 1) printf("\tFFT exec:\t\t%.3g s\n", timer.elapsedsec()); | ||
} | ||
else | ||
wrapArrayInVector(fkb, thisBatchSize*N, this->fwBatch); | ||
// STEP 3: (varies by type) | ||
timer.restart(); | ||
if (type == 1) { // type 1: deconvolve (amplify) fw and shuffle to fk | ||
|
@@ -1063,6 +1082,9 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk) { | |
spreadinterpSortedBatch<TF>(thisBatchSize, this, cjb); | ||
t_sprint += timer.elapsedsec(); | ||
} | ||
// Release the fwBatch vector to prevent double freeing of memory. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain this comment - is there anything to do, given it was not allocated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Testing with matlab interface, it segfaults... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Segfaults fixed - it was the incorrect nf1 = upsampfac * ms, etc. rather nf1 should match user N1 grid size. However, if you can explain the comment? Since fwBatch is a std::vector field of the plan class, there's no freeing to worry about, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh I think this comment was from earlier code and we can remoe this now that I updated the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I agree with you, we dont need to free it manually, as fwBatch is std::vector. |
||
if(opts.spreadinterponly) | ||
releaseVectorWrapper(this->fwBatch); | ||
} // ........end b loop | ||
|
||
if (opts.debug) { // report total times in their natural order... | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to rely on some hacky way to move a piointer data to a std::vector, without any copies. Anyone has any ideas on how to make this better? Also, it seems to fail on MAC, but I dont know if std:: changes in mac, I am not used to it. Can someone help here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to use std::reference_wrapper, but I think that still wont work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to create a vector from a pointer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what i need to do is point incoming memory pointer fk to fwbatch, so that the spread / interpolate directly happens on input/ output. In the cufinufft versions, it was fairly straightforward as the memory for fwbatch was not std::vector. Any ideas why you chose it so here?
I see the underlying kernels use pointers anyway exposed through .data().
Another option can be to overload the function and add support for when we send pointer to memory, although the current spread / interpolate doesn't directly take the fwbatch vector as input, but rather the plan itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either we template more, so that functions can accepts both pointers and vectors as both have [] operator. Or thsi is a use case for a span. Vector is owning so should not be used thsi way.
We might need to write a simple span implementation since it is available in c++20 onwards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We aim to remove all pointers ad c++ data structures are safer. Pointers can cause memory leaks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree about leaks, i have suffered fixing them in gpuNUFFT. Did you happen to use std:: reference_wrapper? I think it's built for this and i somehow am not able to get it work for our use case:
https://en.cppreference.com/w/cpp/utility/functional/reference_wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact reference wrapper is cleaner and way less hacky than what i have now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you remove all of this? I cannot find the hacky code anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup :) No hacks, only clean code :P ... Just changed the function API.