Skip to content

Commit

Permalink
Update nn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xir4n committed Jul 23, 2024
1 parent 001ad4b commit 3b002fd
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions murenn/dtcwt/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class MuReNNDirect(torch.nn.Module):
"""
Args:
J (int): Number of levels (octaves) in the DTCWT decomposition.
Q (int, dict or list): Number of Conv1D filters per octave.
Q (int or list): Number of Conv1D filters per octave.
T (int): Conv1D Kernel size multiplier. The Conv1d kernel size at scale j is equal to
T * Q[j] where Q[j] is the number of filters.
J_phi (int): Number of levels of downsampling. Stride is 2**J_phi. Default is J.
Expand All @@ -21,11 +21,11 @@ def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric"
super().__init__()
if isinstance(Q, int):
self.Q = [Q for j in range(J)]
elif isinstance(Q, (dict, list)):
elif isinstance(Q, list):
assert len(Q) == J
self.Q = Q
else:
raise TypeError(f"Q must to be int, dict or list, got {type(Q)}")
raise TypeError(f"Q must to be int or list, got {type(Q)}")
if J_phi is None:
J_phi = J
if J_phi < J:
Expand Down

0 comments on commit 3b002fd

Please sign in to comment.