diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 89fc9de..58b4154 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -9,7 +9,7 @@ class MuReNNDirect(torch.nn.Module): """ Args: J (int): Number of levels (octaves) in the DTCWT decomposition. - Q (int, dict or list): Number of Conv1D filters per octave. + Q (int or list): Number of Conv1D filters per octave. T (int): Conv1D Kernel size multiplier. The Conv1d kernel size at scale j is equal to T * Q[j] where Q[j] is the number of filters. J_phi (int): Number of levels of downsampling. Stride is 2**J_phi. Default is J. @@ -21,11 +21,11 @@ def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric" super().__init__() if isinstance(Q, int): self.Q = [Q for j in range(J)] - elif isinstance(Q, (dict, list)): + elif isinstance(Q, list): assert len(Q) == J self.Q = Q else: - raise TypeError(f"Q must to be int, dict or list, got {type(Q)}") + raise TypeError(f"Q must to be int or list, got {type(Q)}") if J_phi is None: J_phi = J if J_phi < J: