diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index b8d7ffb90..9b98be9d7 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1069,10 +1069,12 @@ def __init__( except Exception: raise ValueError(self._OUT_KEY_ERR) - if type(module) is type or not callable(module): + if type(module) is type or (method is None and not callable(module)): raise ValueError( f"Module {module} if type {type(module)} is not callable. " - f"Typical accepted types are nn.Module or TensorDictModule." + f"Typical accepted types are nn.Module or TensorDictModule. " + f"If you need to call a specific method from your module, pass the " + f"`method` keyword argument to the TensorDictModule constructor." ) self.out_keys = out_keys self.in_keys = in_keys