diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 5b2dafc9b04..64fad524d94 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4996,8 +4996,14 @@ def __init__( ) kwargs = primers if not isinstance(kwargs, Composite): - kwargs = Composite(**kwargs) - self.primers = kwargs + shape = kwargs.pop("shape", None) + device = kwargs.pop("device", None) + if "batch_size" in kwargs.keys(): + extra_kwargs = {"batch_size": kwargs.pop("batch_size")} + else: + extra_kwargs = {} + primers = Composite(kwargs, device=device, shape=shape, **extra_kwargs) + self.primers = primers self.expand_specs = expand_specs if random and default_value: