Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support itk.Image use in Dask Array map_blocks #1091

Closed
3 of 4 tasks
thewtex opened this issue Jul 14, 2019 · 12 comments · Fixed by #2829
Closed
3 of 4 tasks

Support itk.Image use in Dask Array map_blocks #1091

thewtex opened this issue Jul 14, 2019 · 12 comments · Fixed by #2829
Labels
area:Python wrapping Python bindings for a class type:Enhancement Improvement of existing methods or implementation
Milestone

Comments

@thewtex
Copy link
Member

thewtex commented Jul 14, 2019

To avoid explicit type coersion in Dask Array map_blocks, consider adding

  • dtype
  • shape
  • ndim
  • __array_function__

to itk.Image.

@jakirkham @mrocklin

@thewtex thewtex added type:Enhancement Improvement of existing methods or implementation area:Python wrapping Python bindings for a class labels Jul 14, 2019
@thewtex thewtex added this to the ITK v5.1b01 milestone Aug 9, 2019
@thewtex thewtex modified the milestones: ITK v5.1b01, ITK v5.1b02 Sep 10, 2019
@thewtex thewtex modified the milestones: ITK v5.1rc01, ITK v5.1rc02 Dec 5, 2019
@thewtex thewtex modified the milestones: ITK v5.1rc02, ITK v5.1rc03 Feb 21, 2020
@thewtex thewtex modified the milestones: ITK v5.1rc03, ITK v5.1.0 Apr 9, 2020
@thewtex
Copy link
Member Author

thewtex commented Apr 23, 2020

ndim, dtype, shape added in #1780

@thewtex thewtex modified the milestones: ITK v5.1.0, ITK v5.2.0 Apr 23, 2020
@jakirkham
Copy link
Member

Nice! Thanks for the update Matt 😄

@mrocklin
Copy link
Contributor

mrocklin commented Apr 23, 2020 via email

@thewtex
Copy link
Member Author

thewtex commented Apr 24, 2020

@mrocklin great idea -- I'll see about following up on our initial post with some denoising...

@stale
Copy link

stale bot commented Aug 22, 2020

This issue has been automatically marked as stale because it has not had recent activity. Thank you for your contributions.

@stale stale bot added the status:Use_Milestone_Backlog Use "Backlog" milestone instead of label for issues without a fixed deadline label Aug 22, 2020
@thewtex thewtex modified the milestones: ITK v5.2.0, ITK v5.3.0 Nov 24, 2020
@stale stale bot removed the status:Use_Milestone_Backlog Use "Backlog" milestone instead of label for issues without a fixed deadline label Nov 24, 2020
@stale
Copy link

stale bot commented Jun 11, 2021

This issue has been automatically marked as stale because it has not had recent activity. Thank you for your contributions.

@stale stale bot added the status:Use_Milestone_Backlog Use "Backlog" milestone instead of label for issues without a fixed deadline label Jun 11, 2021
@GenevieveBuckley
Copy link
Contributor

Is this something I can help along?

@stale stale bot removed the status:Use_Milestone_Backlog Use "Backlog" milestone instead of label for issues without a fixed deadline label Jun 15, 2021
@thewtex
Copy link
Member Author

thewtex commented Jun 17, 2021

Hi @GenevieveBuckley,

Help on this would be amazing! 🙏

We need to add pickle support to itk.Image.

The process may be a great example basis for a Dask Blog post on how to add required pickle support for data structures from native Python packages.

We need to get this test to pass:

import itk
import pickle
import numpy as np

# Create an itk.Image
array = np.random.randint(0, 256, (8, 12)).astype(np.uint8)
image = itk.image_from_array(array)
image.SetSpacing([1.0, 2.0])
image.SetOrigin([11.0, 4.0])
theta = np.radians(30)
cosine = np.cos(theta)
sine = np.sin(theta)
rotation = np.array(((cosine, -sine), (sine, cosine)))
image.SetDirection(rotation)

# Verify serialization works with a np.ndarray
serialized = pickle.dumps(array)
deserialized = pickle.loads(serialized)
assert np.array_equal(array, deserialized)

# How to check for image consistency
image_copy = itk.image_duplicator(image)
compared = itk.comparison_image_filter(image, image_copy)
assert np.sum(compared) == 0

# We need to add support for this
serialized = pickle.dumps(image)
deserialized = pickle.loads(serialized)
compared = itk.comparison_image_filter(image, deserialized)
assert np.sum(compared) == 0

As first step, I would create a conda environment, pip install itk and hack on site-packages/itk/itkImagePython.py to work out the details that make the test pass -- this is a generated file. The class to add the required dunder methods in that file for this test is class itkImageUC2. The pickle process in Python is a bit complex. I found that the CPython docs helpful but stepping through the code with pdb was also required for my understanding.

We need to def __getstate__, def __setstate__, and likely __reduce_ex__.

Here is what we did for the itk module:

# For pickle support
def __reduce_ex__(self, proto):
state = self.__getstate__()
return _lazy_itk_module_reconstructor, (self.__name__, state), state
# For pickle support
def __getstate__(self):
state = self.__dict__.copy()
lazy_modules = list()
# import ipdb; ipdb.set_trace()
for key in self.itk_base_global_lazy_attributes:
if isinstance(state[key], LazyITKModule):
lazy_modules.append((key, state[key].itk_base_global_lazy_attributes))
state[key] = not_loaded
state["lazy_modules"] = lazy_modules
return state
# For pickle support
def __setstate__(self, state):
self.__dict__.update(state)
for module_name, lazy_attributes in state["lazy_modules"]:
self.__dict__.update(
{module_name: LazyITKModule(module_name, lazy_attributes)}
)
for module in state["loaded_lazy_modules"]:
namespace = {}
base.itk_load_swig_module(module, namespace)
for k, v in namespace.items():
setattr(self, k, v)

To get the state, here is an example that represents the itk.Image state as a Python Dictionary of fundamental types + a np.ndarray (which has pickle support): https://github.com/InsightSoftwareConsortium/itkwidgets/blob/c4f3b158719bbb2720ef2fadcc3df8990a3feb95/itkwidgets/trait_types.py#L163-L172.

For re-constructing the image, it will require:

image = itk.image_view_from_array(state.pixel_array, ttype=state.imageType)
image.SetOrigin(state.origin)
image.SetSpacing(state.spacing)
image.SetDirection(state.direction)

@jakirkham
Copy link
Member

There's also some discussion about improving pickle in ITK in issue ( #1948 )

@GenevieveBuckley
Copy link
Contributor

So I think I'm almost there.

Approach

Reusing the itkimage_to_json & itkimage_from_json functions, from the link Matt gave here seems like the best approach.

Details:
def _image_to_type(itkimage):  # noqa: C901
    component = itk.template(itkimage)[1][0]
    if component == itk.UL:
        if os.name == 'nt':
            return 'uint32_t', 1
        else:
            return 'uint64_t', 1
    mangle = None
    pixelType = 1
    if component == itk.SL:
        if os.name == 'nt':
            return 'int32_t', 1,
        else:
            return 'int64_t', 1,
    if component in (itk.SC, itk.UC, itk.SS, itk.US, itk.SI, itk.UI, itk.F,
            itk.D, itk.B):
        mangle = component
    elif component in [i[1] for i in itk.Vector.items()]:
        mangle = itk.template(component)[1][0]
        pixelType = 5
    elif component == itk.complex[itk.F]:
        # complex float
        return 'float', 10
    elif component == itk.complex[itk.D]:
        # complex float
        return 'double', 10
    elif component in [i[1] for i in itk.CovariantVector.items()]:
        # CovariantVector
        mangle = itk.template(component)[1][0]
        pixelType = 7
    elif component in [i[1] for i in itk.Offset.items()]:
        # Offset
        return 'int64_t', 4
    elif component in [i[1] for i in itk.FixedArray.items()]:
        # FixedArray
        mangle = itk.template(component)[1][0]
        pixelType = 11
    elif component in [i[1] for i in itk.RGBAPixel.items()]:
        # RGBA
        mangle = itk.template(component)[1][0]
        pixelType = 3
    elif component in [i[1] for i in itk.RGBPixel.items()]:
        # RGB
        mangle = itk.template(component)[1][0]
        pixelType = 2
    elif component in [i[1] for i in itk.SymmetricSecondRankTensor.items()]:
        # SymmetricSecondRankTensor
        mangle = itk.template(component)[1][0]
        pixelType = 8
    else:
        raise RuntimeError('Unrecognized component type: {0}'.format(str(component)))
    _python_to_js = {
        itk.SC: 'int8_t',
        itk.UC: 'uint8_t',
        itk.SS: 'int16_t',
        itk.US: 'uint16_t',
        itk.SI: 'int32_t',
        itk.UI: 'uint32_t',
        itk.F: 'float',
        itk.D: 'double',
        itk.B: 'uint8_t'
    }
    return _python_to_js[mangle], pixelType


def itkimage_to_json(itkimage, manager=None):
    """Serialize a Python itk.Image object.
    Attributes of this dictionary are to be passed to the JavaScript itkimage
    constructor.
    """
    if itkimage is None:
        return None
    else:
        direction = itkimage.GetDirection()
        directionMatrix = direction.GetVnlMatrix()
        directionList = []
        dimension = itkimage.GetImageDimension()
        pixel_arr = itk.array_view_from_image(itkimage)
        componentType, pixelType = _image_to_type(itkimage)
        if 'int64' in componentType:
            # JavaScript does not yet support 64-bit integers well
            if componentType == 'uint64_t':
                pixel_arr = pixel_arr.astype(np.uint32)
                componentType = 'uint32_t'
            else:
                pixel_arr = pixel_arr.astype(np.int32)
                componentType = 'int32_t'
        compressor = zstd.ZstdCompressor(level=3)
        compressed = compressor.compress(pixel_arr.data)
        pixel_arr_compressed = memoryview(compressed)
        for col in range(dimension):
            for row in range(dimension):
                directionList.append(directionMatrix.get(row, col))
        imageType = dict(
            dimension=dimension,
            componentType=componentType,
            pixelType=pixelType,
            components=itkimage.GetNumberOfComponentsPerPixel()
        )
        return dict(
            imageType=imageType,
            origin=tuple(itkimage.GetOrigin()),
            spacing=tuple(itkimage.GetSpacing()),
            size=tuple(itkimage.GetBufferedRegion().GetSize()),
            direction={'data': directionList,
                       'rows': dimension,
                       'columns': dimension},
            compressedData=compressed
        )



def _type_to_image(jstype):
    _pixelType_to_prefix = {
        1: '',
        2: 'RGB',
        3: 'RGBA',
        4: 'O',
        5: 'V',
        7: 'CV',
        8: 'SSRT',
        11: 'FA'
    }
    pixelType = jstype['pixelType']
    dimension = jstype['dimension']
    if pixelType == 10:
        if jstype['componentType'] == 'float':
            return itk.Image[itk.complex, itk.F], np.float32
        else:
            return itk.Image[itk.complex, itk.D], np.float64

    def _long_type():
        if os.name == 'nt':
            return 'LL'
        else:
            return 'L'
    prefix = _pixelType_to_prefix[pixelType]
    _js_to_python = {
        'int8_t': 'SC',
        'uint8_t': 'UC',
        'int16_t': 'SS',
        'uint16_t': 'US',
        'int32_t': 'SI',
        'uint32_t': 'UI',
        'int64_t': 'S' + _long_type(),
        'uint64_t': 'U' + _long_type(),
        'float': 'F',
        'double': 'D'
    }
    _js_to_numpy_dtype = {
        'int8_t': np.int8,
        'uint8_t': np.uint8,
        'int16_t': np.int16,
        'uint16_t': np.uint16,
        'int32_t': np.int32,
        'uint32_t': np.uint32,
        'int64_t': np.int64,
        'uint64_t': np.uint64,
        'float': np.float32,
        'double': np.float64
    }
    dtype = _js_to_numpy_dtype[jstype['componentType']]
    if pixelType != 4:
        prefix += _js_to_python[jstype['componentType']]
    if pixelType not in (1, 2, 3, 10):
        prefix += str(dimension)
    prefix += str(dimension)
    return getattr(itk.Image, prefix), dtype


def itkimage_from_json(js, manager=None):
    """Deserialize a Javascript itk.js Image object."""
    if js is None:
        return None
    else:
        ImageType, dtype = _type_to_image(js['imageType'])
        decompressor = zstd.ZstdDecompressor()
        if six.PY2:
            asBytes = js['compressedData'].tobytes()
            pixelBufferArrayCompressed = np.frombuffer(asBytes, dtype=np.uint8)
        else:
            pixelBufferArrayCompressed = np.frombuffer(js['compressedData'],
                                                    dtype=np.uint8)
        pixelCount = reduce(lambda x, y: x * y, js['size'], 1)
        numberOfBytes = pixelCount * \
            js['imageType']['components'] * np.dtype(dtype).itemsize
        pixelBufferArray = \
            np.frombuffer(decompressor.decompress(pixelBufferArrayCompressed,
                                                numberOfBytes),
                        dtype=dtype)
        pixelBufferArray.shape = js['size'][::-1]
        # Workaround for GetImageFromArray required until 5.0.1
        # and https://github.com/numpy/numpy/pull/11739
        pixelBufferArrayCopyToBeRemoved = pixelBufferArray.copy()
        # image = itk.PyBuffer[ImageType].GetImageFromArray(pixelBufferArray)
        image = itk.PyBuffer[ImageType].GetImageFromArray(
            pixelBufferArrayCopyToBeRemoved)
        Dimension = image.GetImageDimension()
        image.SetOrigin(js['origin'])
        image.SetSpacing(js['spacing'])
        direction = image.GetDirection()
        directionMatrix = direction.GetVnlMatrix()
        directionJs = js['direction']['data']
        for col in range(Dimension):
            for row in range(Dimension):
                directionMatrix.put(
                    row, col, directionJs[col + row * Dimension])
        return image

... then it looks like all we need is to add __getstate__ and setstate` methods like this:

    # For pickle support
    def __getstate__(self):
        state = itkimage_to_json(self)
        return state

    # For pickle support
    def __setstate__(self, state):
        deserialized = itkimage_from_json(state)
        self.__dict__['this'] = deserialized
        self.SetOrigin(state['origin'])
        self.SetSpacing(state['spacing'])

        # FIXME! Something is not quite right about the way I'm setting the direction here
        # ref https://discourse.itk.org/t/set-image-direction-from-numpy-array/844
        direction_data = np.reshape(state['direction']['data'], (state['direction']['rows'], state['direction']['columns']))
        direction_data_vnl = itk.GetVnlMatrixFromArray(direction_data)
        direction = self.GetDirection()
        direction.GetVnlMatrix().copy_in(direction_data_vnl.data_block())
        self.SetDirection(direction)

Results

When I run the example test script the itk.comparison_image_filter function fails, I think because the array direction isn't being set correctly.

Details:
RuntimeError: /work/ITK-source/ITK/Modules/Core/Common/include/itkImageToImageFilter.hxx:219:
ITK ERROR: ComparisonImageFilter(0x55928d0ed180): Inputs do not occupy the same physical space! 
InputImage Direction: 1.0000000e+00 0.0000000e+00
0.0000000e+00 1.0000000e+00
, InputImageValidInput Direction: 8.6602540e-01 -5.0000000e-01
5.0000000e-01 8.6602540e-01

	Tolerance: 1.0000000e-06

The line assert np.sum(compared) == 0 does pass, which I'm hoping means the actual image array values are all correct. I'm not very familiar with itk objects.

Questions

  1. What am I doing wrong with setting the direction? I tried to use @thewtex 's answer to this question as a model, but what I've got isn't working to actually overwrite the direction.
  2. Where should these changes live? I've been hacking on the auto-generated python file as suggested, but I take it there's another process. (I've hacked the new __getstate__ and __setstate__ methods onto the itkImageUC2 class, because that's the specific array type generated in the example test script)

@thewtex
Copy link
Member Author

thewtex commented Oct 20, 2021

@GenevieveBuckley thank you so much for your work on this! 🙏 Sorry for the late reply. It seems we are quite close!

What am I doing wrong with setting the direction?

We have improved interfaces to the direction. Following #2828, we can get the direction for serialization as NumPy array with

direction = np.asarray(self.GetDirection())

ndarray ships with pickling support, so let's work with that.

Recent itk supports setting the direction with a ndarray directly, so we can use

# direction is an NxN ndarray
self.SetDirection(direction)

in __setstate__.

Where should these changes live?

__getstate__ and __setstate__ can go here:

%pythoncode %{
def _SetBase(self, base):
"""Internal method to keep a reference when creating a view of a NumPy array."""
self.base = base
@property
def ndim(self):
"""Equivalant to the np.ndarray ndim attribute when converted
to an image with itk.array_view_from_image."""
spatial_dims = self.GetImageDimension()
if self.GetNumberOfComponentsPerPixel() > 1:
return spatial_dims + 1
else:
return spatial_dims
@property
def shape(self):
"""Equivalant to the np.ndarray shape attribute when converted
to an image with itk.array_view_from_image."""
itksize = self.GetLargestPossibleRegion().GetSize()
dim = len(itksize)
result = [int(itksize[idx]) for idx in range(dim)]
if(self.GetNumberOfComponentsPerPixel() > 1):
result = [self.GetNumberOfComponentsPerPixel(), ] + result
result.reverse()
return tuple(result)
@property
def dtype(self):
"""Equivalant to the np.ndarray dtype attribute when converted
to an image with itk.array_view_from_image."""
import itk
first_template_arg = itk.template(self)[1][0]
if hasattr(first_template_arg, 'dtype'):
return first_template_arg.dtype
else:
# Multi-component pixel types, e.g. Vector,
# CovariantVector, etc.
return itk.template(first_template_arg)[1][0].dtype
def astype(self, pixel_type):
"""Cast the image to the provided itk pixel type or equivalent NumPy dtype."""
import itk
import numpy as np
from itk.support import types
# if both a numpy dtype and a ctype exist, use the latter.
if type(pixel_type) is type:
c_pixel_type = types.itkCType.GetCTypeForDType(pixel_type)
if c_pixel_type is not None:
pixel_type = c_pixel_type
# input_image_template is Image or VectorImage
(input_image_template, (input_pixel_type, input_image_dimension)) = itk.template(self)
if input_pixel_type is pixel_type:
return self
OutputImageType = input_image_template[pixel_type, input_image_dimension]
cast = itk.cast_image_filter(self, ttype=(type(self), OutputImageType))
return cast
def SetDirection(self, direction):
from itk.support import helpers
if helpers.is_arraylike(direction):
import itk
import numpy as np
array = np.asarray(direction).astype(np.float64)
dimension = self.GetImageDimension()
for dim in array.shape:
if dim != dimension:
raise ValueError('Array does not have the expected shape')
matrix = itk.matrix_from_array(array)
self.__SetDirection_orig__(matrix)
else:
self.__SetDirection_orig__(direction)
def keys(self):
"""Return keys related to the image's metadata.
These keys are used in the dictionary resulting from dict(image).
These keys include MetaDataDictionary keys along with
'origin', 'spacing', and 'direction' keys, which
correspond to the image's Origin, Spacing, and Direction. However,
they are in (z, y, x) order as opposed to (x, y, z) order to
correspond to the indexing of the shape of the pixel buffer
array resulting from np.array(image).
"""
meta_keys = self.GetMetaDataDictionary().GetKeys()
# Ignore deprecated, legacy members that cause issues
result = list(filter(lambda k: not k.startswith('ITK_original'), meta_keys))
result.extend(['origin', 'spacing', 'direction'])
return result
def __getitem__(self, key):
"""Access metadata keys, see help(image.keys), for string
keys, otherwise provide NumPy indexing to the pixel buffer
array view. The index order follows NumPy array indexing
order, i.e. [z, y, x] versus [x, y, z]."""
import itk
if isinstance(key, str):
import numpy as np
if key == 'origin':
return np.flip(np.asarray(self.GetOrigin()), axis=None)
elif key == 'spacing':
return np.flip(np.asarray(self.GetSpacing()), axis=None)
elif key == 'direction':
return np.flip(itk.array_from_matrix(self.GetDirection()), axis=None)
else:
return self.GetMetaDataDictionary()[key]
else:
return itk.array_view_from_image(self).__getitem__(key)
def __setitem__(self, key, value):
"""Set metadata keys, see help(image.keys), for string
keys, otherwise provide NumPy indexing to the pixel buffer
array view. The index order follows NumPy array indexing
order, i.e. [z, y, x] versus [x, y, z]."""
if isinstance(key, str):
import numpy as np
if key == 'origin':
self.SetOrigin(np.flip(value, axis=None))
elif key == 'spacing':
self.SetSpacing(np.flip(value, axis=None))
elif key == 'direction':
self.SetDirection(np.flip(value, axis=None))
else:
self.GetMetaDataDictionary()[key] = value
else:
import itk
itk.array_view_from_image(self).__setitem__(key, value)
%}

Feel free to use the CI to test builds / changes :-)

@GenevieveBuckley
Copy link
Contributor

Excellent, I've opened a PR over at #2829, and outlined what the next steps are here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area:Python wrapping Python bindings for a class type:Enhancement Improvement of existing methods or implementation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants