Skip to content

Commit

Permalink
Updated inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylammie committed Aug 7, 2024
1 parent 69e524f commit ecab267
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/aihwkit/simulator/tiles/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def init_mapping_scales(self) -> None:
self.set_mapping_scales(mapping_scales)

@no_grad()
def _forward_drift_readout_tensor(self, reset_if: bool = False) -> Optional[Tensor]:
def _forward_drift_readout_tensor(self, reset_if: bool = False, is_perfect=False) -> Optional[Tensor]:
"""Perform a forward pass using the drift read-out tensor.
Args:
Expand Down Expand Up @@ -126,9 +126,16 @@ def _forward_drift_readout_tensor(self, reset_if: bool = False) -> Optional[Tens

# We need to take the bias as a common column here, also we do
# not want to use indexed.
return self.tile.forward(
is_perfect_state = self.rpu_config.forward.is_perfect
tile_rpu_config = self.rpu_config
tile_rpu_config.forward.is_perfect = True
self.tile.set_config(tile_rpu_config)
output = self.tile.forward(
self.drift_readout_tensor, False, self.in_trans, self.out_trans, True, self.non_blocking
)
tile_rpu_config.forward.is_perfect = is_perfect_state
self.tile.set_config(tile_rpu_config)
return output

@no_grad()
def program_weights(
Expand Down Expand Up @@ -184,7 +191,7 @@ def program_weights(
hasattr(self.rpu_config, "drift_compensation")
and self.rpu_config.drift_compensation is not None
):
forward_output = self._forward_drift_readout_tensor(True)
forward_output = self._forward_drift_readout_tensor(True, is_perfect=True)
self.drift_baseline = self.rpu_config.drift_compensation.init_baseline(forward_output)

@no_grad()
Expand Down

0 comments on commit ecab267

Please sign in to comment.