Skip to content

Commit

Permalink
Python 2 compat
Browse files Browse the repository at this point in the history
  • Loading branch information
jadball committed Aug 23, 2024
1 parent 4138623 commit c3f92e1
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions ImageD11/sinograms/tensor_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def to_h5(self, h5file, h5group='TensorMap'):
# store the step sizes
parent_group.create_dataset("step", data=np.array(self.steps))


def to_paraview(self, h5name, h5group='TensorMap'):
"""Exports to H5, then writes an XDMF file that lets you read the data with ParaView"""
# Write H5 first
Expand Down Expand Up @@ -364,7 +363,7 @@ def from_h5(cls, h5file, h5group='TensorMap'):
return tensor_map

@classmethod
def from_pbpmap(cls, pbpmap, steps=None):
def from_pbpmap(cls, pbpmap, steps=None, phases=None):
"""Create TensorMap from a pbpmap object"""

maps = dict()
Expand All @@ -374,19 +373,22 @@ def from_pbpmap(cls, pbpmap, steps=None):
ubi_map = pbpmap.ubibest
else:
ubi_map = pbpmap.ubi

# create a mask from ubi_map
ubi_mask = np.where(np.isnan(ubi_map[:, :, 0, 0]), 0, 1).astype(bool)

# reshape ubi map and add it to the dict
maps['UBI'] = cls.recon_order_to_map_order(ubi_map)

# add npks to the dict
if hasattr(pbpmap, 'npks'):
maps['npks'] = cls.recon_order_to_map_order(pbpmap.npks)
maps['npks'] = cls.recon_order_to_map_order(np.where(ubi_mask, pbpmap.npks, 0))

# add nuniq to the dict
if hasattr(pbpmap, 'nuniq'):
maps['nuniq'] = cls.recon_order_to_map_order(pbpmap.nuniq)
maps['nuniq'] = cls.recon_order_to_map_order(np.where(ubi_mask, pbpmap.nuniq, 0))

tensor_map = cls(maps=maps, steps=steps)
tensor_map = cls(maps=maps, steps=steps, phases=phases)

return tensor_map

Expand Down Expand Up @@ -569,8 +571,13 @@ def from_combine_phases(cls, tensormaps):
new_arr = tm[map_name]

# selectively overwrite base_arr with new_arr
# the array slicing stuff below is to allow rightward broadcasting
base_arr = np.where((tm['phase_ids'] > -1)[(...,) + (np.newaxis,) * (base_arr.ndim - 3)], new_arr, base_arr)
# the array slicing stuff below is to allow arbitrary rightward broadcasting
# phase_id is (NZ, NY, NX) but base_arr might be UBI for example (NZ, NY, NX, 3, 3)
# Numpy can auto-broadcast leftwards (e.g (NZ, NY, NX) to (3, 3, NZ, NY, NX))
# but not rightwards!
# so we need to slice like this (NZ, NY, NX)[..., np.newaxis, np.newaxis]
# In Python 2, slicing grammar is different, so we can't invoke ... directly inside a tuple
base_arr = np.where((tm['phase_ids'] > -1)[(Ellipsis,) + (np.newaxis,) * (base_arr.ndim - 3)], new_arr, base_arr)

combined_maps[map_name] = base_arr

Expand Down

0 comments on commit c3f92e1

Please sign in to comment.