Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding rank attr to gaussian noise layer to enable input shape determ… #52

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 50 additions & 40 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,14 @@ def __init__(self, pool_size, strides=None, padding='valid', sigma=1,
def _make_2D_gaussian_kernel(edge_len, sigma=1.):
"""Creates 2D gaussian kernel with side length `edge_len` and a sigma
of `sigma`

Parameters
----------
edge_len : int
Edge size of the kernel
sigma : float
Sigma parameter for gaussian distribution

Returns
-------
kernel : np.ndarray
Expand Down Expand Up @@ -213,10 +215,12 @@ def get_config(self):

def call(self, x):
"""Operates on x with the specified function

Parameters
----------
x : tf.Tensor
Input tensor

Returns
-------
x : tf.Tensor
Expand All @@ -240,23 +244,31 @@ def __init__(self, axis, mean=1, stddev=0.1):
"""
Parameters
----------
axis : int
Axis to apply random noise across. All other axis will have the
same noise. For example, for a 5D spatiotemporal tensor with axis=3
(the time axis), this layer will apply a single random number to
every unique index of axis=3.
axis : int | list | tuple
Axes to apply random noise across. All other axes will have the
same noise. For example, for a 5D spatiotemporal tensor with
axis=(1, 2, 3) (both spatial axes and the temporal axis), this
layer will apply a single random number to every unique index of
axis=(1, 2, 3).
mean : float
The mean of the normal distribution.
stddev : float
The standard deviation of the normal distribution.
"""

super().__init__()
self._axis = axis
self._rand_shape = None
self.rank = None
self._axis = axis if isinstance(axis, (tuple, list)) else [axis]
self._mean = tf.constant(mean, dtype=tf.dtypes.float32)
self._stddev = tf.constant(stddev, dtype=tf.dtypes.float32)

def _get_rand_shape(self, x):
"""Get shape of random noise along the specified axes."""
shape = np.ones(len(x.shape), dtype=np.int32)
for ax in self._axis:
shape[ax] = x.shape[ax]
return tf.constant(shape, dtype=tf.dtypes.int32)

def build(self, input_shape):
"""Custom implementation of the tf layer build method.

Expand All @@ -267,9 +279,7 @@ def build(self, input_shape):
input_shape : tuple
Shape tuple of the input
"""
shape = np.ones(len(input_shape), dtype=np.int32)
shape[self._axis] = input_shape[self._axis]
self._rand_shape = tf.constant(shape, dtype=tf.dtypes.int32)
self.rank = len(input_shape)

def call(self, x):
"""Calls the tile operation
Expand All @@ -285,11 +295,11 @@ def call(self, x):
Output tensor with noise applied to the requested axis.
"""

rand_tensor = tf.random.normal(self._rand_shape,
rand_tensor = tf.random.normal(self._get_rand_shape(x),
mean=self._mean,
stddev=self._stddev,
dtype=tf.dtypes.float32)
return x * rand_tensor
return x + rand_tensor


class FlattenAxis(tf.keras.layers.Layer):
Expand Down Expand Up @@ -351,7 +361,7 @@ def __init__(self, spatial_mult=1):
"""
Parameters
----------
spatial_multiplier : int
spatial_mult : int
Number of times to multiply the spatial dimensions. Note that the
spatial expansion is an un-packing of the feature dimension. For
example, if the input layer has shape (123, 5, 5, 16) with
Expand Down Expand Up @@ -435,14 +445,14 @@ def __init__(self, spatial_mult=1, temporal_mult=1,
"""
Parameters
----------
spatial_multiplier : int
spatial_mult : int
Number of times to multiply the spatial dimensions. Note that the
spatial expansion is an un-packing of the feature dimension. For
example, if the input layer has shape (123, 5, 5, 24, 16) with
multiplier=2 the output shape will be (123, 10, 10, 24, 4). The
input feature dimension must be divisible by the spatial multiplier
squared.
temporal_multiplier : int
temporal_mult : int
Number of times to multiply the temporal dimension. For example,
if the input layer has shape (123, 5, 5, 24, 2) with multiplier=2
the output shape will be (123, 5, 5, 48, 2).
Expand Down Expand Up @@ -603,18 +613,17 @@ def call(self, x):
if self._cache is None:
self._cache = x
return x
try:
out = tf.add(x, self._cache)
except Exception as e:
msg = ('Could not add SkipConnection "{}" data cache of '
'shape {} to input of shape {}.'
.format(self._name, self._cache.shape, x.shape))
logger.error(msg)
raise RuntimeError(msg) from e
else:
try:
out = tf.add(x, self._cache)
except Exception as e:
msg = ('Could not add SkipConnection "{}" data cache of '
'shape {} to input of shape {}.'
.format(self._name, self._cache.shape, x.shape))
logger.error(msg)
raise RuntimeError(msg) from e
else:
self._cache = None
return out
self._cache = None
return out


class SqueezeAndExcitation(tf.keras.layers.Layer):
Expand Down Expand Up @@ -834,7 +843,8 @@ def __init__(self, name=None):
"""
super().__init__(name=name)

def call(self, x, hi_res_adder):
@staticmethod
def call(x, hi_res_adder):
"""Adds hi-resolution data to the input tensor x in the middle of a
sup3r resolution network.

Expand Down Expand Up @@ -869,7 +879,8 @@ def __init__(self, name=None):
"""
super().__init__(name=name)

def call(self, x, hi_res_feature):
@staticmethod
def call(x, hi_res_feature):
"""Concatenates a hi-resolution feature to the input tensor x in the
middle of a sup3r resolution network.

Expand Down Expand Up @@ -940,7 +951,8 @@ class SigLin(tf.keras.layers.Layer):
y = x + 0.5 where x>=0.5
"""

def call(self, x):
@staticmethod
def call(x):
"""Operates on x with SigLin

Parameters
Expand Down Expand Up @@ -1002,8 +1014,7 @@ def build(self, input_shape):
def _logt(self, x):
if not self.inverse:
return tf.math.log(x + self.adder) * self.scalar
else:
return tf.math.exp(x / self.scalar) - self.adder
return tf.math.exp(x / self.scalar) - self.adder

def call(self, x):
"""Operates on x with (inverse) log transform
Expand All @@ -1021,16 +1032,15 @@ def call(self, x):

if self.idf is None:
return self._logt(x)
else:
out = []
for idf in range(x.shape[-1]):
if idf in self.idf:
out.append(self._logt(x[..., idf:idf + 1]))
else:
out.append(x[..., idf:idf + 1])
out = []
for idf in range(x.shape[-1]):
if idf in self.idf:
out.append(self._logt(x[..., idf:idf + 1]))
else:
out.append(x[..., idf:idf + 1])

out = tf.concat(out, -1, name='concat')
return out
out = tf.concat(out, -1, name='concat')
return out


class UnitConversion(tf.keras.layers.Layer):
Expand Down
Loading