From 3b002fd684179bff2391975c8a935d6e9a6d23ee Mon Sep 17 00:00:00 2001 From: Xiran Date: Tue, 23 Jul 2024 18:09:44 +0200 Subject: [PATCH] Update nn.py --- murenn/dtcwt/nn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/murenn/dtcwt/nn.py b/murenn/dtcwt/nn.py index 89fc9de..58b4154 100644 --- a/murenn/dtcwt/nn.py +++ b/murenn/dtcwt/nn.py @@ -9,7 +9,7 @@ class MuReNNDirect(torch.nn.Module): """ Args: J (int): Number of levels (octaves) in the DTCWT decomposition. - Q (int, dict or list): Number of Conv1D filters per octave. + Q (int or list): Number of Conv1D filters per octave. T (int): Conv1D Kernel size multiplier. The Conv1d kernel size at scale j is equal to T * Q[j] where Q[j] is the number of filters. J_phi (int): Number of levels of downsampling. Stride is 2**J_phi. Default is J. @@ -21,11 +21,11 @@ def __init__(self, *, J, Q, T, in_channels, J_phi=None, padding_mode="symmetric" super().__init__() if isinstance(Q, int): self.Q = [Q for j in range(J)] - elif isinstance(Q, (dict, list)): + elif isinstance(Q, list): assert len(Q) == J self.Q = Q else: - raise TypeError(f"Q must to be int, dict or list, got {type(Q)}") + raise TypeError(f"Q must to be int or list, got {type(Q)}") if J_phi is None: J_phi = J if J_phi < J: