diff --git a/docs/source/using_simulator.rst b/docs/source/using_simulator.rst index 25688ba3..6b257470 100644 --- a/docs/source/using_simulator.rst +++ b/docs/source/using_simulator.rst @@ -399,6 +399,21 @@ instead of manually specifying a ``RPU Configuration``:: tile = AnalogTile(10, 20, rpu_config=TikiTakaEcRamPreset()) +Working with half-precision training +------------------------------------ + +The simulator supports half-precision training. This can be enabled by setting the +``RPUDataType`` to ``HALF`` when creating the configuration:: + + from aihwkit.simulator.configs import InferenceRPUConfig + from aihwkit.simulator.parameters.enums import RPUDataType + + rpu_config = InferenceRPUConfig() # or TorchInferenceRPUConfig(). + rpu_config.runtime.data_type = RPUDataType.HALF + +For more info look into :py:mod:`aihwkit.simulator.parameters.enums.RPUDataType`. + + .. _Gokmen & Haensch 2020: https://www.frontiersin.org/articles/10.3389/fnins.2020.00103/full .. _Example 7: https://github.com/IBM/aihwkit/blob/master/examples/07_simple_layer_with_other_devices.py .. _Example 8: https://github.com/IBM/aihwkit/blob/master/examples/08_simple_layer_with_tiki_taka.py diff --git a/examples/34_half_precision_training.py b/examples/34_half_precision_training.py new file mode 100644 index 00000000..203a8157 --- /dev/null +++ b/examples/34_half_precision_training.py @@ -0,0 +1,81 @@ +# type: ignore +# pylint: disable-all +# -*- coding: utf-8 -*- + +# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved. +# +# Licensed under the MIT license. See LICENSE file in the project root for details. + +"""aihwkit example 31: Using half precision training. + +This example demonstrates how to use half precision training with aihwkit. + +""" +# pylint: disable=invalid-name + +import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import datasets, transforms +from aihwkit.simulator.configs import TorchInferenceRPUConfig +from aihwkit.nn.conversion import convert_to_analog +from aihwkit.optim import AnalogSGD +from aihwkit.simulator.parameters.enums import RPUDataType + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +if __name__ == "__main__": + model = Net() + rpu_config = TorchInferenceRPUConfig() + rpu_config.runtime.data_type = RPUDataType.HALF + model = convert_to_analog(model, rpu_config) + nll_loss = torch.nn.NLLLoss() + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset = datasets.MNIST("data", train=True, download=True, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset, batch_size=32) + + model = model.to(device=device, dtype=torch.bfloat16) + optimizer = AnalogSGD(model.parameters(), lr=0.1) + model = model.train() + + pbar = tqdm.tqdm(enumerate(train_loader)) + for batch_idx, (data, target) in pbar: + data, target = data.to(device=device, dtype=torch.bfloat16), target.to( + device=device + ) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output.float(), target) + loss.backward() + optimizer.step() + pbar.set_description(f"Loss {loss:.4f}")