Skip to content

Commit

Permalink
ENH ModulusStable + fix_length (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
xir4n authored May 31, 2024
1 parent 5b363eb commit fbc1445
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 30 deletions.
68 changes: 46 additions & 22 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import torch
import murenn
import math

from .utils import fix_length


class MuReNNDirect(torch.nn.Module):
"""
Args:
J (int): Number of levels of DTCWT decomposition.
Q (int): Number of Conv1D filters at each level.
J (int): Number of levels (octaves) in the DTCWT decomposition.
Q (int): Number of Conv1D filters per octave.
in_channels (int): Number of channels in the input signal.
padding_mode (str): One of 'symmetric' (default), 'zeros', 'replicate',
and 'circular'. Padding scheme for the DTCWT decomposition.
Expand Down Expand Up @@ -40,36 +43,57 @@ def __init__(self, *, J, Q, T, in_channels, padding_mode="symmetric"):
)
torch.nn.init.normal_(conv1d_j.weight)
conv1d.append(conv1d_j)

self.down = torch.nn.ModuleList(down)
self.conv1d = torch.nn.ParameterList(conv1d)


def forward(self, x):
"""
Args:
x (PyTorch tensor): Input data. Should be a tensor of shape
`(B, C, T)` where B is the batch size, C is the number of
channels and T is the number of time samples.
Note that T must be a multiple of 2**J, where J is the number
of wavelet scales (see documentation of MuReNNDirect constructor).
x (PyTorch tensor): A tensor of shape `(B, C, T)`. B is a batch size,
C denotes a number of channels, T is a length of signal sequence.
Returns:
y (PyTorch tensor): A tensor of shape `(B, C, Q, J, T/(2**J))`
y (PyTorch tensor): A tensor of shape `(B, C, Q, J, T_out)`
"""
assert self.C == x.shape[1]
_, bps = self.dtcwt(x)
ys = []

lp, bps = self.dtcwt(x)
output = []
for j in range(self.dtcwt.J):
Wx_r = self.conv1d[j](bps[j].real)
Wx_i = self.conv1d[j](bps[j].imag)
Ux = Wx_r ** 2 + Wx_i ** 2
y_j, _ = self.down[j](Ux)

B, _, N = y_j.shape
Wx_j_r = self.conv1d[j](bps[j].real)
Wx_j_i = self.conv1d[j](bps[j].imag)
UWx_j = ModulusStable.apply(Wx_j_r, Wx_j_i)
UWx_j, _ = self.down[j](UWx_j)
B, _, N = UWx_j.shape
# reshape from (B, C*Q, N) to (B, C, Q, N)
y_j = y_j.view(B, self.C, self.Q, N)
ys.append(y_j)
UWx_j = UWx_j.view(B, self.C, self.Q, N)
output.append(UWx_j)
return torch.stack(output, dim=3)


class ModulusStable(torch.autograd.Function):
"""Stable complex modulus
This class implements a modulus transform for complex numbers which is
stable with respect to very small inputs (z close to 0), avoiding
returning NaN's in all cases.
-------
Adapted from Kymatio
"""
@staticmethod
def forward(ctx, x_r, x_i):
output = (x_r ** 2 + x_i ** 2).sqrt()
ctx.save_for_backward(x_r, x_i, output)
return output

y = torch.stack(ys, dim=3)
return y
@staticmethod
def backward(ctx, grad_output):
x_r, x_i, output = ctx.saved_tensors
dxr, dxi = None, None
if ctx.needs_input_grad[0]:
dxr = x_r.mul(grad_output).div(output)
dxi = x_i.mul(grad_output).div(output)
dxr.masked_fill_(output == 0, 0)
dxi.masked_fill_(output == 0, 0)
return dxr, dxi
5 changes: 5 additions & 0 deletions murenn/dtcwt/transform1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from murenn.dtcwt.lowlevel import prep_filt
from murenn.dtcwt.transform_funcs import FWD_J1, FWD_J2PLUS, INV_J1, INV_J2PLUS
from .utils import fix_length


class DTCWT(torch.nn.Module):
Expand Down Expand Up @@ -226,6 +227,7 @@ def __init__(
alternate_gh=True,
padding_mode="symmetric",
normalize=True,
length=None,
):
if padding_mode != "symmetric":
raise NotImplementedError(
Expand All @@ -241,6 +243,7 @@ def __init__(
padding_mode=padding_mode,
normalize=normalize,
)
self.length = length

def forward(self, yl, yh):
"""
Expand Down Expand Up @@ -301,4 +304,6 @@ def forward(self, yl, yh):
x_phi = INV_J1.apply(
x_phi, x_psi_r, x_psi_i, self.g0o, self.g1o, self.padding_mode
)
if self.length:
x_phi = fix_length(x_phi, size=self.length)
return x_phi
2 changes: 1 addition & 1 deletion murenn/dtcwt/transform_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def forward(ctx, x, h0, h1, skip_hps, padding_mode):
hi = torch.nn.functional.conv1d(
pad_(x, h1, padding_mode), h1_rep, groups=ch)

# Return low-pass (x_phi), real and imaginary part of high-pass (x_psi)
# Return low-pass (x_phi), real and imaginary part of high-pass (x_psi)
return lo, hi[:,:,::2], hi[:,:,1::2]

@staticmethod
Expand Down
42 changes: 41 additions & 1 deletion murenn/dtcwt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,44 @@ def int_to_mode(mode):
elif mode == 3:
return 'circular'
else:
raise ValueError("Unkown pad type: {}".format(mode))
raise ValueError("Unkown pad type: {}".format(mode))

def fix_length(x, *, size, **kwargs):
"""Fix the length an tensor ``x`` to exactly ``size`` along the last dimension.
If ``x.shape[-1] < n``, pad according to the provided kwargs.
By default, ``x`` is padded with trailing zeros.
Parameters
----------
x : torch.Tensor
tensor to be length-adjusted
size : int >= 0 [scalar]
desired length
**kwargs : additional keyword arguments
Parameters to ``torch.nn.functional.pad``
Returns
-------
x_fixed : torch.Tensor [shape=x.shape]
``x`` either trimmed or padded to length ``size``
along the last dimension.
See Also
--------
torch.nn.functional.pad
Adapted from librosa.
"""
kwargs.setdefault("mode", "constant")
n = x.shape[-1]
if n > size:
slices = [slice(None)] * x.ndim
slices[-1] = slice(0, size)
return x[tuple(slices)]

elif n < size:
length = size - n
return torch.nn.functional.pad(x, (0, length), **kwargs)
return x
3 changes: 1 addition & 2 deletions tests/test_dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,9 @@ def test_inv(level1, qshift, J, T, alternate_gh, normalize):
include_scale=False,
padding_mode="symmetric",
normalize=normalize,
length=T,
)
X_rec = inv(lp, bp)
if T % 2 != 0:
X_rec = X_rec[:, :, :-1]
torch.testing.assert_close(Xt, X_rec)


Expand Down
41 changes: 37 additions & 4 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
import torch
import murenn

from murenn.dtcwt.nn import ModulusStable

@pytest.mark.parametrize("J", list(range(1, 3)))

if torch.cuda.is_available():
dev = torch.device("cuda")
else:
dev = torch.device("cpu")

@pytest.mark.parametrize("J", list(range(2, 4)))
@pytest.mark.parametrize("Q", [3, 4])
@pytest.mark.parametrize("T", [8, 16])
@pytest.mark.parametrize("padding_mode", ["symmetric", "zeros"])
Expand Down Expand Up @@ -38,11 +45,11 @@ def test_direct_diff():
assert conv1d.weight.grad != None


@pytest.mark.parametrize("J", list(range(1, 3)))
@pytest.mark.parametrize("Q", [3, 4])
@pytest.mark.parametrize("T", [8, 16])
@pytest.mark.parametrize("N", list(range(5)))
def test_multi_layers(J, Q, T, N):
def test_multi_layers(Q, T, N):
J = 2
B, C, L = 2, 3, 2**J+N
x = torch.zeros(B, C, L)
for i in range(3):
Expand All @@ -53,4 +60,30 @@ def test_multi_layers(J, Q, T, N):
T=T,
in_channels=x.shape[1],
)
x = layer_i(x)
x = layer_i(x)


def test_modulus():
# check the value
x_r = torch.randn(2, 2, 2**5, device=dev, requires_grad=True)
x_i = torch.randn(2, 2, 2**5, device=dev, requires_grad=True)
Ux = ModulusStable.apply(x_r, x_i)
assert torch.allclose(Ux, torch.sqrt(x_r ** 2 + x_i ** 2))
# check the gradient
loss = torch.sum(Ux)
loss.backward()
Ux2 = Ux.clone()
x_r2 = x_r.clone()
x_i2 = x_i.clone()
xr_grad = x_r2 / Ux2
xi_grad = x_i2 / Ux2
assert torch.allclose(x_r.grad, xr_grad, atol = 1e-4)
assert torch.allclose(x_i.grad, xi_grad, atol = 1e-4)
# Test the differentiation with a vector made of zeros
x0r = torch.zeros(2, 2, 2**5, device=dev, requires_grad=True)
x0i = torch.zeros(2, 2, 2**5, device=dev, requires_grad=True)
Ux0 = ModulusStable.apply(x0r, x0i)
loss0 = torch.sum(Ux0)
loss0.backward()
assert torch.max(torch.abs(x0r.grad)) <= 1e-7
assert torch.max(torch.abs(x0i.grad)) <= 1e-7
15 changes: 15 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest

from murenn.dtcwt.utils import *

@pytest.mark.parametrize("y", [torch.ones((16,)), torch.ones((16, 16))])
@pytest.mark.parametrize("m", [-5, 0, 5])
def test_fix_length(y, m):
n = m + y.shape[-1]
y_out = fix_length(y, size=n)
eq_slice = [slice(None)] * y.ndim
eq_slice[-1] = slice(y.shape[-1])
if n > y.shape[-1]:
assert torch.allclose(y, y_out[tuple(eq_slice)])
else:
assert torch.allclose(y[tuple(eq_slice)], y)

0 comments on commit fbc1445

Please sign in to comment.