Skip to content

Commit

Permalink
revert RandShiftIntensity
Browse files Browse the repository at this point in the history
Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu committed Jul 28, 2023
1 parent c976790 commit bea398c
Showing 1 changed file with 4 additions and 17 deletions.
21 changes: 4 additions & 17 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,13 @@ class RandShiftIntensity(RandomizableTransform):

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self, offsets: tuple[float, float] | float, safe: bool = False, channel_wise: bool = False, prob: float = 0.1
) -> None:
def __init__(self, offsets: tuple[float, float] | float, safe: bool = False, prob: float = 0.1) -> None:
"""
Args:
offsets: offset range to randomly shift.
if single number, offset value is picked from (-offsets, offsets).
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
channel_wise: if True, calculate on each channel separately. Please ensure
that the first dimension represents the channel of the image if True.
prob: probability of shift.
"""
RandomizableTransform.__init__(self, prob)
Expand All @@ -276,17 +272,13 @@ def __init__(
else:
self.offsets = (min(offsets), max(offsets))
self._offset = self.offsets[0]
self.channel_wise = channel_wise
self._shifter = ShiftIntensity(self._offset, safe)

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
if not self._do_transform:
return None
if self.channel_wise:
self._offset = [self.R.uniform(low=self.offsets[0], high=self.offsets[1]) for _ in range(data.shape[0])]
else:
self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1])
self._offset = self.R.uniform(low=self.offsets[0], high=self.offsets[1])

def __call__(self, img: NdarrayOrTensor, factor: float | None = None, randomize: bool = True) -> NdarrayOrTensor:
"""
Expand All @@ -300,17 +292,12 @@ def __call__(self, img: NdarrayOrTensor, factor: float | None = None, randomize:
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
if randomize:
self.randomize(img)
self.randomize()

if not self._do_transform:
return img

if self.channel_wise:
for i, d in enumerate(img):
img[i] = self._shifter(d, self._offset[i] if factor is None else self._offset * factor)
else:
img = self._shifter(img, self._offset if factor is None else self._offset * factor)
return img
return self._shifter(img, self._offset if factor is None else self._offset * factor)


class StdShiftIntensity(Transform):
Expand Down

0 comments on commit bea398c

Please sign in to comment.