Skip to content

Commit

Permalink
to_conv1d, figbug
Browse files Browse the repository at this point in the history
  • Loading branch information
xir4n committed Nov 22, 2024
1 parent de41e0a commit 69fb5e1
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 62 deletions.
73 changes: 44 additions & 29 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class MuReNNDirect(torch.nn.Module):
Args:
J (int): Number of levels (octaves) in the DTCWT decomposition.
Q (int or list): Number of Conv1D filters per octave.
T (int): Conv1D Kernel size.
T (int): The Conv1d kernel size.
J_phi (int): Number of levels of downsampling. Stride is 2**J_phi. Default is J.
in_channels (int): Number of channels in the input signal.
padding_mode (str): One of 'symmetric' (default), 'zeros', 'replicate',
Expand All @@ -29,30 +29,29 @@ def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric"
J_phi = J
if J_phi < J:
raise ValueError("J_phi must be greater or equal to J")
self.T = [T for j in range(J)]
self.T = T
self.in_channels = in_channels
self.padding_mode = padding_mode
down = []
conv1d = []
self.dtcwt = murenn.DTCWT(
J=J,
padding_mode=padding_mode,
alternate_gh=False,
)

for j in range(J):
conv1d_j = torch.nn.Conv1d(
in_channels=in_channels,
out_channels=self.Q[j]*in_channels,
kernel_size=self.T[j],
kernel_size=self.T,
bias=False,
groups=in_channels,
padding="same",
)
torch.nn.init.normal_(conv1d_j.weight)
conv1d.append(conv1d_j)

down_j = Downsampling(J_phi - j)
down_j = Downsampling(J_phi - j -1)
down.append(down_j)

self.down = torch.nn.ModuleList(down)
Expand Down Expand Up @@ -95,44 +94,58 @@ def to_conv1d(self):
"""

device = self.conv1d[0].weight.data.device
# T the filter length
T = max(self.T)
# J the number of levels of decompostion
J = self.dtcwt.J
T = self.T # Filter length
J = self.dtcwt.J # Number of levels of decomposition
N = 2 ** J * T # Length of the impulse signal

# Generate the impulse signal
N = 2 ** J * T
x = torch.zeros(1, self.in_channels, N).to(device)
x[:, :, N//2] = 1
inv = murenn.IDTCWT(
J = J,
alternate_gh=False
).to(device)

# Initialize the inverse DTCWT
inv = murenn.IDTCWT(J=J, alternate_gh=False).to(device)

# Get DTCWT impulse reponses
phi, psis = self.dtcwt(x)
# Set phi to a zero valued tensor
zeros_phi = phi.new_zeros(1, 1, phi.shape[-1])
ws = []

ws_r, ws_i, ws = [], [], []
for j in range(J):
Wpsi_jr = self.conv1d[j](psis[j].real).reshape(self.in_channels, self.Q[j], -1)
Wpsi_ji = self.conv1d[j](psis[j].imag).reshape(self.in_channels, self.Q[j], -1)
for q in range(self.Q[j]):
Wpsi_jqr = Wpsi_jr[0, q, :].reshape(1,1,-1)
Wpsi_jqi = Wpsi_ji[0, q, :].reshape(1,1,-1)
Wpsi_jq = torch.complex(Wpsi_jqr, Wpsi_jqi)

Wpsis_r = [Wpsi_jqr * (1+0j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]
Wpsis_i = [Wpsi_jqi * (0+1j) if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]
w_r = inv(zeros_phi, Wpsis_r)
w_i = inv(zeros_phi, Wpsis_i)
ws.append(torch.complex(w_r, w_i))
Wpsis = [Wpsi_jq if k == j else psis[k].new_zeros(1,1,psis[k].shape[-1]) for k in range(J)]

ws_r.append(inv(zeros_phi, Wpsis_r))
ws_i.append(inv(zeros_phi, Wpsis_i))
ws.append(inv(zeros_phi, Wpsis))

ws = torch.cat(ws, dim=0)
conv1d = torch.nn.Conv1d(
in_channels=1,
out_channels=ws.shape[0],
kernel_size=N,
bias=False,
padding="same",
)
conv1d.weight.data = torch.nn.parameter.Parameter(ws)
return conv1d
ws_r = torch.cat(ws_r, dim=0)
ws_i = torch.cat(ws_i, dim=0)

def create_conv1d(weight):
conv1d = torch.nn.Conv1d(
in_channels=1,
out_channels=weight.shape[0],
kernel_size=N,
bias=False,
padding="same",
)
conv1d.weight.data = torch.nn.parameter.Parameter(weight)
return conv1d

return {
"complex": create_conv1d(ws),
"real": create_conv1d(ws_r),
"imag": create_conv1d(ws_i),
}


class ModulusStable(torch.autograd.Function):
Expand Down Expand Up @@ -179,10 +192,12 @@ def __init__(self, J_phi):
level1="near_sym_b",
skip_hps=True,
)
self.relu = torch.nn.ReLU()


def forward(self, x):
for j in range(self.J_phi):
x, _ = self.phi(x)
# Normalize the coefficients
x = x[:,:,::2]
return x
return self.relu(x)
2 changes: 1 addition & 1 deletion murenn/dtcwt/transform1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(
J=8,
skip_hps=False,
include_scale=False,
alternate_gh=True,
alternate_gh=False,
padding_mode="symmetric",
normalize=True,
):
Expand Down
1 change: 1 addition & 0 deletions tests/test_dtcwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def test_avrg_energy(J, alternate_gh):
P_phi = torch.linalg.norm(phi) ** 2 / phi.shape[-1]
P_Ux = P_Ux + P_phi
for psi in psis:
psi = psi / math.sqrt(2)
Ppsi_j = torch.linalg.norm(torch.abs(psi)) ** 2 / psi.shape[-1]
P_Ux = P_Ux + Ppsi_j
ratio = P_Ux / P_x
Expand Down
38 changes: 6 additions & 32 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,42 +91,16 @@ def test_modulus():

@pytest.mark.parametrize("Q", [1, 2])
@pytest.mark.parametrize("T", [1, 2])
def test_toconv1d_shape(Q, T):
def test_toconv1d(Q, T):
J = 4
tfm = murenn.MuReNNDirect(
J=J,
Q=Q,
T=T,
in_channels=2,
)
conv1d = tfm.to_conv1d()
assert isinstance(conv1d, torch.nn.Conv1d)

@pytest.mark.parametrize("J", range(3))
@pytest.mark.parametrize("alternate_gh", [True, False])
def test_avrg_energy(J, alternate_gh):
'''
Test the power of the signals for normalization case.
'''
tfm = murenn.DTCWT(J=J+1, alternate_gh=alternate_gh, normalize=True)
N = 2**15
x = torch.randn(1 ,1, N)
P_x = torch.linalg.norm(x) ** 2 / x.shape[-1]
P_Ux = 0
phi, psis = tfm(x)
P_phi = torch.linalg.norm(phi) ** 2 / phi.shape[-1]
P_Ux = P_Ux + P_phi
for psi in psis:
Ppsi_j = torch.linalg.norm(torch.abs(psi)) ** 2 / psi.shape[-1]
P_Ux = P_Ux + Ppsi_j
ratio = P_Ux / P_x
assert torch.abs(ratio - 1) <= 0.01


@pytest.mark.parametrize("J_phi", range(3))
def test_down(J_phi):
N = 2**15
x = torch.ones(1, 1, N)
down = murenn.dtcwt.nn.Downsampling(J_phi)
x_down = down(x)
assert torch.allclose(x_down, torch.ones(1, 1, N // 2**J_phi))
x = torch.zeros(1, 1, 2**J)
conv1ds = tfm.to_conv1d()
for k, v in conv1ds.items():
assert isinstance(v, torch.nn.Conv1d)
y = v(x)

0 comments on commit 69fb5e1

Please sign in to comment.