Skip to content

Commit

Permalink
Finish up dask with multi gpu
Browse files Browse the repository at this point in the history
Erase the richardson_lucy_dask_multi_py file and integrate the multi gpu
version with the single gpu version.  Also make sure we call cleanup
to clean up clfft resources when we are finished with the native library.
  • Loading branch information
bnorthan committed Dec 27, 2023
1 parent a32f26e commit acc24fe
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 234 deletions.
16 changes: 14 additions & 2 deletions python/clij2fft/richardson_lucy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,19 @@ def richardson_lucy(img, psf, numiterations, regularizationfactor=0, first_guess
if (lib==None):
print('get lib')
lib = getlib()
cleanup = True
else:
cleanup = False

# deconvolution using clij2fft
if regularizationfactor==0:
lib.deconv3d_32f(numiterations, int(img.shape[2]), int(img.shape[1]), int(img.shape[0]), img, shifted_psf, result, normal, platform, device)
else:
lib.deconv3d_32f_tv(numiterations, regularizationfactor, int(img.shape[2]), int(img.shape[1]), int(img.shape[0]), img, shifted_psf, result, normal, platform, device)


if cleanup==True:
lib.cleanup()

# unpad and return
return unpad(result, original_size)

Expand Down Expand Up @@ -123,6 +129,9 @@ def richardson_lucy_nc(img, psf, numiterations, regularizationfactor=0, lib=None
if (lib==None):
print('get lib')
lib = getlib()
cleanup = True
else:
cleanup = False

print('calling convcorr',platform, device)
# the normalization factor is the valid region correlated with the PSF
Expand All @@ -136,7 +145,10 @@ def richardson_lucy_nc(img, psf, numiterations, regularizationfactor=0, lib=None
lib.deconv3d_32f(numiterations, int(img.shape[2]), int(img.shape[1]), int(img.shape[0]), img, shifted_psf, result, normal, platform, device)
else:
lib.deconv3d_32f_tv(numiterations, regularizationfactor, int(img.shape[2]), int(img.shape[1]), int(img.shape[0]), img, shifted_psf, result, normal, platform, device)


if cleanup==True:
lib.cleanup()

# unpad and return
return unpad(result, original_size)

Expand Down
92 changes: 75 additions & 17 deletions python/clij2fft/richardson_lucy_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from clij2fft.richardson_lucy import richardson_lucy_nc, richardson_lucy
import numpy as np
import pyopencl as cl
from clij2fft.pad import get_next_smooth
from clij2fft.libs import getlib
from multiprocessing import Queue

bytes_per_gb = 1024 * 1024 * 1024

Expand Down Expand Up @@ -97,8 +98,12 @@ def chunk_factor(img, psf, depth, mem_to_use=-1):
cf = np.sqrt(4 ** k)
return cf

def richardson_lucy_dask(img, psf, numiterations, regularizationfactor, non_circulant=True, overlap=10, mem_to_use=-1):
""" perform Richardson-Lucy using dask
def richardson_lucy_dask(img, psf, numiterations, regularizationfactor, non_circulant=True, overlap=10, mem_to_use=-1, debug=False, platform=0, num_devices=1):
""" perform Richardson-Lucy using dask.
If there are multiple devices compute() will be called with num_workers=num_devices and the devices ids (assumed to be 0 to num_devices-1) will be put
in a Queue and each dask task will pop a device id from the queue. When the task is complete the device id will be put back in the queue.
This will allow multiple GPUs to be used without conflicts.
Args:
img (numpy.ndarray): image to be deconvolved
Expand All @@ -108,21 +113,32 @@ def richardson_lucy_dask(img, psf, numiterations, regularizationfactor, non_circ
non_circulant (bool, optional): If True use non-circulant Richardson Lucy. Defaults to True.
overlap (int, optional): Overlap between blocks. Defaults to 10.
mem_to_use (int, optional): GPU memory to use in GB. If -1 use full GPU memory, otherwise limit GPU memory to mem_to_use. Defaults to -1.
debug (bool, optional): If True print debug information. Defaults to False.
platform (int, optional): OpenCL platform to use. Defaults to 0.
num_devices (int, optional): Number of GPUs to use. Defaults to 1.
Returns:
numpy.ndarray: deconvolved image
"""
print('image size',img.shape)
print('psf size', psf.shape)
if debug:
print('richardson_lucy_dask')
print()
print('image size',img.shape)
print('psf size', psf.shape)

gpu_mem_ = gpu_mem()/bytes_per_gb
print('gpu mem is ', gpu_mem_)

if debug:
print('gpu mem is ', gpu_mem_)

rl_mem_ = rl_mem_footprint(img, psf, depth=(0, overlap, overlap))/bytes_per_gb
print('rl mem is ', rl_mem_)

if debug:
print('rl mem is ', rl_mem_)

k = chunk_factor(img, psf, depth=(0, overlap, overlap), mem_to_use=mem_to_use)
print('chunk factor is ', k)

if debug:
print('chunk factor is ', k)

if img.shape[1] % k != 0:
y_chunk_size = img.shape[1] // k + 1
Expand All @@ -135,17 +151,59 @@ def richardson_lucy_dask(img, psf, numiterations, regularizationfactor, non_circ
x_chunk_size = img.shape[2] // k

chunk_size = (img.shape[0], y_chunk_size, x_chunk_size)
print('chunk size is',chunk_size)

if debug:
print('chunk size is',chunk_size)

dimg = da.from_array(img,chunks=(img.shape[0], y_chunk_size, x_chunk_size))

if non_circulant:
rl_func = richardson_lucy_nc
else:
rl_func = richardson_lucy

out = dimg.map_overlap(rl_func, depth={0:0, 1:overlap, 2:overlap}, dtype=np.float32, psf=psf, numiterations=numiterations, regularizationfactor=regularizationfactor)
return out.compute(num_workers=1)
queue = Queue()

for i in range(num_devices):
queue.put(i)

lib = getlib()
import traceback

def rl_dask_task(img, psf, numiterations, regularizationfactor=0, lib=None, block_info=None, block_id=None, thread_id=None):

try:
print()
device_num=queue.get()

if debug:
print('start rlnc')
print('gpu num is', device_num)
print('block id is', block_id)
print('block info is', block_info)
if non_circulant:
result=richardson_lucy_nc(img, psf, numiterations, regularizationfactor=regularizationfactor, lib=lib, platform = platform, device = device_num)
else:
result=richardson_lucy(img, psf, numiterations, regularizationfactor=regularizationfactor, lib=lib, platform = platform, device = device_num)

return result
except Exception as e:
traceback.print_exc()
if debug:
print()
print("EXCEPTION",e)
pass
finally:
if debug:
print('putting gpu num back', device_num)
queue.put(device_num)

import time

start_time = time.time()
out = dimg.map_overlap(rl_dask_task, depth={0:0, 1:overlap, 2:overlap}, dtype=np.float32, psf=psf, numiterations=numiterations, regularizationfactor=regularizationfactor, lib=lib)
out_img = out.compute(num_workers=num_devices)
end_time = time.time()
execution_time = end_time - start_time
if debug:
print(f"Execution time of rl dask multi gpu: {execution_time} seconds")
lib.cleanup()
return out_img



Expand Down
209 changes: 0 additions & 209 deletions python/clij2fft/richardson_lucy_dask_multi_gpu.py

This file was deleted.

Loading

0 comments on commit acc24fe

Please sign in to comment.