diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index f6524840e1..21c64aa937 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -33,15 +33,16 @@ def __init__( range_end: int = 5, steps: int = 11, num_samples: int = 2, - normalize=False, + normalize: bool = True, ): """ Args: interpolate_epoch_interval: default 20 range_start: default -5 range_end: default 5 + steps: number of step between start and end num_samples: default 2 - normalize: default False + normalize: default True (change image to (0, 1) range) """ super().__init__() self.interpolate_epoch_interval = interpolate_epoch_interval