diff --git a/src/aihwkit/simulator/tiles/inference.py b/src/aihwkit/simulator/tiles/inference.py index b16ce862..94865e22 100644 --- a/src/aihwkit/simulator/tiles/inference.py +++ b/src/aihwkit/simulator/tiles/inference.py @@ -98,7 +98,7 @@ def init_mapping_scales(self) -> None: self.set_mapping_scales(mapping_scales) @no_grad() - def _forward_drift_readout_tensor(self, reset_if: bool = False) -> Optional[Tensor]: + def _forward_drift_readout_tensor(self, reset_if: bool = False, is_perfect=False) -> Optional[Tensor]: """Perform a forward pass using the drift read-out tensor. Args: @@ -126,9 +126,16 @@ def _forward_drift_readout_tensor(self, reset_if: bool = False) -> Optional[Tens # We need to take the bias as a common column here, also we do # not want to use indexed. - return self.tile.forward( + is_perfect_state = self.rpu_config.forward.is_perfect + tile_rpu_config = self.rpu_config + tile_rpu_config.forward.is_perfect = True + self.tile.set_config(tile_rpu_config) + output = self.tile.forward( self.drift_readout_tensor, False, self.in_trans, self.out_trans, True, self.non_blocking ) + tile_rpu_config.forward.is_perfect = is_perfect_state + self.tile.set_config(tile_rpu_config) + return output @no_grad() def program_weights( @@ -184,7 +191,7 @@ def program_weights( hasattr(self.rpu_config, "drift_compensation") and self.rpu_config.drift_compensation is not None ): - forward_output = self._forward_drift_readout_tensor(True) + forward_output = self._forward_drift_readout_tensor(True, is_perfect=True) self.drift_baseline = self.rpu_config.drift_compensation.init_baseline(forward_output) @no_grad()