diff --git a/examples/pytorch/image-classification/mnist.ipynb b/examples/pytorch/image-classification/mnist.ipynb new file mode 100644 index 0000000000..c01197cbff --- /dev/null +++ b/examples/pytorch/image-classification/mnist.ipynb @@ -0,0 +1,779 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# PyTorch DDP Fashion MNIST Training Example" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "This example demonstrates how to train a convolutional neural network to classify images using the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset and [PyTorch Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html).\n", + "\n", + "This notebook walks you through running that example locally, and how to easily scale PyTorch DDP across multiple nodes with Kubeflow TrainJob." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Install the Kubeflow SDK" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You need to install the Kubeflow SDK to interact with Kubeflow APIs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO (astefanutti): Change to the Kubeflow SDK when it's available.\n", + "!pip install git+https://github.com/kubeflow/training-operator.git@master#subdirectory=sdk_v2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install the PyTorch dependencies" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You also need to install PyTorch and Torchvision to be able to run the example locally:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install torch==2.5.1\n", + "!pip install torchvision==0.20.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the training function" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def train_fashion_mnist():\n", + " import os\n", + "\n", + " import torch\n", + " import torch.distributed as dist\n", + " import torch.nn.functional as F\n", + " from torch import nn\n", + " from torch.utils.data import DataLoader, DistributedSampler\n", + " from torchvision import datasets, transforms\n", + "\n", + " # Define the PyTorch CNN model to be trained\n", + " class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 20, 5, 1)\n", + " self.conv2 = nn.Conv2d(20, 50, 5, 1)\n", + " self.fc1 = nn.Linear(4 * 4 * 50, 500)\n", + " self.fc2 = nn.Linear(500, 10)\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(self.conv1(x))\n", + " x = F.max_pool2d(x, 2, 2)\n", + " x = F.relu(self.conv2(x))\n", + " x = F.max_pool2d(x, 2, 2)\n", + " x = x.view(-1, 4 * 4 * 50)\n", + " x = F.relu(self.fc1(x))\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " # Use NCCL is a GPU is available, otherwise use Gloo as communication backend\n", + " device, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n", + "\n", + " print(f\"Using Device: {device}, Backend: {backend}\")\n", + "\n", + " # Setup PyTorch Distributed\n", + " local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n", + " dist.init_process_group(backend=backend)\n", + "\n", + " print(\n", + " \"Distributed Training for WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}\".format(\n", + " dist.get_world_size(),\n", + " dist.get_rank(),\n", + " local_rank,\n", + " )\n", + " )\n", + "\n", + " # Create the model and load it into the device\n", + " device = torch.device(f\"{device}:{local_rank}\")\n", + " model = nn.parallel.DistributedDataParallel(Net().to(device))\n", + "\n", + " # Retrieve the Fashion-MNIST dataset\n", + " if local_rank == 0:\n", + " # Only download the dataset from local rank 0\n", + " dataset = datasets.FashionMNIST(\n", + " \"./data\",\n", + " train=True,\n", + " download=True,\n", + " transform=transforms.Compose([transforms.ToTensor()]),\n", + " )\n", + " dist.barrier()\n", + " else:\n", + " # Wait for local rank 0 to complete downloading the dataset and load it\n", + " dist.barrier()\n", + " dataset = datasets.FashionMNIST(\n", + " \"./data\",\n", + " train=True,\n", + " download=False,\n", + " transform=transforms.Compose([transforms.ToTensor()]),\n", + " )\n", + "\n", + " # Shard the dataset accross workers\n", + " train_loader = DataLoader(\n", + " dataset,\n", + " batch_size=100,\n", + " sampler=DistributedSampler(dataset),\n", + " pin_memory=torch.cuda.is_available(),\n", + " )\n", + "\n", + " # Setup the optimization loop\n", + " optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)\n", + "\n", + " # TODO(astefanutti): add parameters to the training function\n", + " for epoch in range(1, 5):\n", + " model.train()\n", + "\n", + " # Iterate over mini-batches from the training set\n", + " for batch_idx, (inputs, labels) in enumerate(train_loader):\n", + " # Copy the data to the GPU device if available\n", + " inputs, labels = inputs.to(device), labels.to(device)\n", + " # Forward pass\n", + " outputs = model(inputs)\n", + " loss = F.nll_loss(outputs, labels)\n", + " # Backward pass\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " if batch_idx % 10 == 0 and dist.get_rank() == 0:\n", + " print(\n", + " \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n", + " epoch,\n", + " batch_idx * len(inputs),\n", + " len(train_loader.dataset),\n", + " 100.0 * batch_idx / len(train_loader),\n", + " loss.item(),\n", + " )\n", + " )\n", + "\n", + " # Wait for the distributed training to complete\n", + " dist.barrier()\n", + " if dist.get_rank() == 0:\n", + " print(\"Training is finished\")\n", + "\n", + " # Finally clean up PyTorch distributed\n", + " dist.destroy_process_group()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dry-run the training locally" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using Device: cpu, Backend: gloo\n", + "Distributed Training for WORLD_SIZE: 1, RANK: 0, LOCAL_RANK: 0\n", + "Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.316241\n", + "Train Epoch: 1 [1000/60000 (2%)]\tLoss: 1.992070\n", + "Train Epoch: 1 [2000/60000 (3%)]\tLoss: 1.172966\n", + "Train Epoch: 1 [3000/60000 (5%)]\tLoss: 1.099289\n", + "Train Epoch: 1 [4000/60000 (7%)]\tLoss: 0.905155\n", + "Train Epoch: 1 [5000/60000 (8%)]\tLoss: 0.924361\n", + "Train Epoch: 1 [6000/60000 (10%)]\tLoss: 0.826742\n", + "Train Epoch: 1 [7000/60000 (12%)]\tLoss: 0.711170\n", + "Train Epoch: 1 [8000/60000 (13%)]\tLoss: 0.624539\n", + "Train Epoch: 1 [9000/60000 (15%)]\tLoss: 0.634144\n", + "Train Epoch: 1 [10000/60000 (17%)]\tLoss: 0.564117\n", + "Train Epoch: 1 [11000/60000 (18%)]\tLoss: 0.678713\n", + "Train Epoch: 1 [12000/60000 (20%)]\tLoss: 0.661701\n", + "Train Epoch: 1 [13000/60000 (22%)]\tLoss: 0.610292\n", + "Train Epoch: 1 [14000/60000 (23%)]\tLoss: 0.597939\n", + "Train Epoch: 1 [15000/60000 (25%)]\tLoss: 0.649877\n", + "Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.479480\n", + "Train Epoch: 1 [17000/60000 (28%)]\tLoss: 0.511836\n", + "Train Epoch: 1 [18000/60000 (30%)]\tLoss: 0.409321\n", + "Train Epoch: 1 [19000/60000 (32%)]\tLoss: 0.504975\n", + "Train Epoch: 1 [20000/60000 (33%)]\tLoss: 0.509290\n", + "Train Epoch: 1 [21000/60000 (35%)]\tLoss: 0.444297\n", + "Train Epoch: 1 [22000/60000 (37%)]\tLoss: 0.403273\n", + "Train Epoch: 1 [23000/60000 (38%)]\tLoss: 0.651570\n", + "Train Epoch: 1 [24000/60000 (40%)]\tLoss: 0.450566\n", + "Train Epoch: 1 [25000/60000 (42%)]\tLoss: 0.318915\n", + "Train Epoch: 1 [26000/60000 (43%)]\tLoss: 0.535619\n", + "Train Epoch: 1 [27000/60000 (45%)]\tLoss: 0.380061\n", + "Train Epoch: 1 [28000/60000 (47%)]\tLoss: 0.372009\n", + "Train Epoch: 1 [29000/60000 (48%)]\tLoss: 0.517455\n", + "Train Epoch: 1 [30000/60000 (50%)]\tLoss: 0.493992\n", + "Train Epoch: 1 [31000/60000 (52%)]\tLoss: 0.442904\n", + "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.378228\n", + "Train Epoch: 1 [33000/60000 (55%)]\tLoss: 0.324654\n", + "Train Epoch: 1 [34000/60000 (57%)]\tLoss: 0.502618\n", + "Train Epoch: 1 [35000/60000 (58%)]\tLoss: 0.583709\n", + "Train Epoch: 1 [36000/60000 (60%)]\tLoss: 0.523756\n", + "Train Epoch: 1 [37000/60000 (62%)]\tLoss: 0.493982\n", + "Train Epoch: 1 [38000/60000 (63%)]\tLoss: 0.496740\n", + "Train Epoch: 1 [39000/60000 (65%)]\tLoss: 0.401850\n", + "Train Epoch: 1 [40000/60000 (67%)]\tLoss: 0.274524\n", + "Train Epoch: 1 [41000/60000 (68%)]\tLoss: 0.493281\n", + "Train Epoch: 1 [42000/60000 (70%)]\tLoss: 0.408177\n", + "Train Epoch: 1 [43000/60000 (72%)]\tLoss: 0.353222\n", + "Train Epoch: 1 [44000/60000 (73%)]\tLoss: 0.436191\n", + "Train Epoch: 1 [45000/60000 (75%)]\tLoss: 0.256211\n", + "Train Epoch: 1 [46000/60000 (77%)]\tLoss: 0.482434\n", + "Train Epoch: 1 [47000/60000 (78%)]\tLoss: 0.406141\n", + "Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.530664\n", + "Train Epoch: 1 [49000/60000 (82%)]\tLoss: 0.433257\n", + "Train Epoch: 1 [50000/60000 (83%)]\tLoss: 0.399889\n", + "Train Epoch: 1 [51000/60000 (85%)]\tLoss: 0.504735\n", + "Train Epoch: 1 [52000/60000 (87%)]\tLoss: 0.383247\n", + "Train Epoch: 1 [53000/60000 (88%)]\tLoss: 0.379662\n", + "Train Epoch: 1 [54000/60000 (90%)]\tLoss: 0.344262\n", + "Train Epoch: 1 [55000/60000 (92%)]\tLoss: 0.372070\n", + "Train Epoch: 1 [56000/60000 (93%)]\tLoss: 0.409867\n", + "Train Epoch: 1 [57000/60000 (95%)]\tLoss: 0.451042\n", + "Train Epoch: 1 [58000/60000 (97%)]\tLoss: 0.325030\n", + "Train Epoch: 1 [59000/60000 (98%)]\tLoss: 0.255318\n", + "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.452428\n", + "Train Epoch: 2 [1000/60000 (2%)]\tLoss: 0.421046\n", + "Train Epoch: 2 [2000/60000 (3%)]\tLoss: 0.241406\n", + "Train Epoch: 2 [3000/60000 (5%)]\tLoss: 0.295971\n", + "Train Epoch: 2 [4000/60000 (7%)]\tLoss: 0.426782\n", + "Train Epoch: 2 [5000/60000 (8%)]\tLoss: 0.373231\n", + "Train Epoch: 2 [6000/60000 (10%)]\tLoss: 0.427039\n", + "Train Epoch: 2 [7000/60000 (12%)]\tLoss: 0.362823\n", + "Train Epoch: 2 [8000/60000 (13%)]\tLoss: 0.427223\n", + "Train Epoch: 2 [9000/60000 (15%)]\tLoss: 0.503178\n", + "Train Epoch: 2 [10000/60000 (17%)]\tLoss: 0.359559\n", + "Train Epoch: 2 [11000/60000 (18%)]\tLoss: 0.349066\n", + "Train Epoch: 2 [12000/60000 (20%)]\tLoss: 0.329017\n", + "Train Epoch: 2 [13000/60000 (22%)]\tLoss: 0.405490\n", + "Train Epoch: 2 [14000/60000 (23%)]\tLoss: 0.517647\n", + "Train Epoch: 2 [15000/60000 (25%)]\tLoss: 0.360733\n", + "Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.336746\n", + "Train Epoch: 2 [17000/60000 (28%)]\tLoss: 0.306503\n", + "Train Epoch: 2 [18000/60000 (30%)]\tLoss: 0.246597\n", + "Train Epoch: 2 [19000/60000 (32%)]\tLoss: 0.389185\n", + "Train Epoch: 2 [20000/60000 (33%)]\tLoss: 0.341517\n", + "Train Epoch: 2 [21000/60000 (35%)]\tLoss: 0.244054\n", + "Train Epoch: 2 [22000/60000 (37%)]\tLoss: 0.314584\n", + "Train Epoch: 2 [23000/60000 (38%)]\tLoss: 0.470367\n", + "Train Epoch: 2 [24000/60000 (40%)]\tLoss: 0.310524\n", + "Train Epoch: 2 [25000/60000 (42%)]\tLoss: 0.255482\n", + "Train Epoch: 2 [26000/60000 (43%)]\tLoss: 0.389009\n", + "Train Epoch: 2 [27000/60000 (45%)]\tLoss: 0.298264\n", + "Train Epoch: 2 [28000/60000 (47%)]\tLoss: 0.263448\n", + "Train Epoch: 2 [29000/60000 (48%)]\tLoss: 0.264673\n", + "Train Epoch: 2 [30000/60000 (50%)]\tLoss: 0.409005\n", + "Train Epoch: 2 [31000/60000 (52%)]\tLoss: 0.317575\n", + "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.340074\n", + "Train Epoch: 2 [33000/60000 (55%)]\tLoss: 0.213986\n", + "Train Epoch: 2 [34000/60000 (57%)]\tLoss: 0.429456\n", + "Train Epoch: 2 [35000/60000 (58%)]\tLoss: 0.464485\n", + "Train Epoch: 2 [36000/60000 (60%)]\tLoss: 0.371509\n", + "Train Epoch: 2 [37000/60000 (62%)]\tLoss: 0.307564\n", + "Train Epoch: 2 [38000/60000 (63%)]\tLoss: 0.409460\n", + "Train Epoch: 2 [39000/60000 (65%)]\tLoss: 0.351242\n", + "Train Epoch: 2 [40000/60000 (67%)]\tLoss: 0.234072\n", + "Train Epoch: 2 [41000/60000 (68%)]\tLoss: 0.443169\n", + "Train Epoch: 2 [42000/60000 (70%)]\tLoss: 0.349876\n", + "Train Epoch: 2 [43000/60000 (72%)]\tLoss: 0.272572\n", + "Train Epoch: 2 [44000/60000 (73%)]\tLoss: 0.345878\n", + "Train Epoch: 2 [45000/60000 (75%)]\tLoss: 0.246817\n", + "Train Epoch: 2 [46000/60000 (77%)]\tLoss: 0.420070\n", + "Train Epoch: 2 [47000/60000 (78%)]\tLoss: 0.313249\n", + "Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.440062\n", + "Train Epoch: 2 [49000/60000 (82%)]\tLoss: 0.320057\n", + "Train Epoch: 2 [50000/60000 (83%)]\tLoss: 0.253204\n", + "Train Epoch: 2 [51000/60000 (85%)]\tLoss: 0.400022\n", + "Train Epoch: 2 [52000/60000 (87%)]\tLoss: 0.271899\n", + "Train Epoch: 2 [53000/60000 (88%)]\tLoss: 0.333966\n", + "Train Epoch: 2 [54000/60000 (90%)]\tLoss: 0.305328\n", + "Train Epoch: 2 [55000/60000 (92%)]\tLoss: 0.333588\n", + "Train Epoch: 2 [56000/60000 (93%)]\tLoss: 0.359272\n", + "Train Epoch: 2 [57000/60000 (95%)]\tLoss: 0.458515\n", + "Train Epoch: 2 [58000/60000 (97%)]\tLoss: 0.303549\n", + "Train Epoch: 2 [59000/60000 (98%)]\tLoss: 0.250087\n", + "Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.304483\n", + "Train Epoch: 3 [1000/60000 (2%)]\tLoss: 0.365609\n", + "Train Epoch: 3 [2000/60000 (3%)]\tLoss: 0.194465\n", + "Train Epoch: 3 [3000/60000 (5%)]\tLoss: 0.229713\n", + "Train Epoch: 3 [4000/60000 (7%)]\tLoss: 0.378436\n", + "Train Epoch: 3 [5000/60000 (8%)]\tLoss: 0.331190\n", + "Train Epoch: 3 [6000/60000 (10%)]\tLoss: 0.352867\n", + "Train Epoch: 3 [7000/60000 (12%)]\tLoss: 0.319846\n", + "Train Epoch: 3 [8000/60000 (13%)]\tLoss: 0.323834\n", + "Train Epoch: 3 [9000/60000 (15%)]\tLoss: 0.438132\n", + "Train Epoch: 3 [10000/60000 (17%)]\tLoss: 0.310757\n", + "Train Epoch: 3 [11000/60000 (18%)]\tLoss: 0.336065\n", + "Train Epoch: 3 [12000/60000 (20%)]\tLoss: 0.273282\n", + "Train Epoch: 3 [13000/60000 (22%)]\tLoss: 0.341864\n", + "Train Epoch: 3 [14000/60000 (23%)]\tLoss: 0.376471\n", + "Train Epoch: 3 [15000/60000 (25%)]\tLoss: 0.282714\n", + "Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.316221\n", + "Train Epoch: 3 [17000/60000 (28%)]\tLoss: 0.291025\n", + "Train Epoch: 3 [18000/60000 (30%)]\tLoss: 0.228876\n", + "Train Epoch: 3 [19000/60000 (32%)]\tLoss: 0.341082\n", + "Train Epoch: 3 [20000/60000 (33%)]\tLoss: 0.310737\n", + "Train Epoch: 3 [21000/60000 (35%)]\tLoss: 0.219182\n", + "Train Epoch: 3 [22000/60000 (37%)]\tLoss: 0.295618\n", + "Train Epoch: 3 [23000/60000 (38%)]\tLoss: 0.452618\n", + "Train Epoch: 3 [24000/60000 (40%)]\tLoss: 0.277424\n", + "Train Epoch: 3 [25000/60000 (42%)]\tLoss: 0.294226\n", + "Train Epoch: 3 [26000/60000 (43%)]\tLoss: 0.343396\n", + "Train Epoch: 3 [27000/60000 (45%)]\tLoss: 0.263564\n", + "Train Epoch: 3 [28000/60000 (47%)]\tLoss: 0.267713\n", + "Train Epoch: 3 [29000/60000 (48%)]\tLoss: 0.265478\n", + "Train Epoch: 3 [30000/60000 (50%)]\tLoss: 0.387819\n", + "Train Epoch: 3 [31000/60000 (52%)]\tLoss: 0.281409\n", + "Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.290360\n", + "Train Epoch: 3 [33000/60000 (55%)]\tLoss: 0.193808\n", + "Train Epoch: 3 [34000/60000 (57%)]\tLoss: 0.424527\n", + "Train Epoch: 3 [35000/60000 (58%)]\tLoss: 0.446750\n", + "Train Epoch: 3 [36000/60000 (60%)]\tLoss: 0.378367\n", + "Train Epoch: 3 [37000/60000 (62%)]\tLoss: 0.242524\n", + "Train Epoch: 3 [38000/60000 (63%)]\tLoss: 0.307956\n", + "Train Epoch: 3 [39000/60000 (65%)]\tLoss: 0.236581\n", + "Train Epoch: 3 [40000/60000 (67%)]\tLoss: 0.202924\n", + "Train Epoch: 3 [41000/60000 (68%)]\tLoss: 0.347900\n", + "Train Epoch: 3 [42000/60000 (70%)]\tLoss: 0.349966\n", + "Train Epoch: 3 [43000/60000 (72%)]\tLoss: 0.264343\n", + "Train Epoch: 3 [44000/60000 (73%)]\tLoss: 0.257069\n", + "Train Epoch: 3 [45000/60000 (75%)]\tLoss: 0.237410\n", + "Train Epoch: 3 [46000/60000 (77%)]\tLoss: 0.433061\n", + "Train Epoch: 3 [47000/60000 (78%)]\tLoss: 0.262928\n", + "Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.431371\n", + "Train Epoch: 3 [49000/60000 (82%)]\tLoss: 0.319234\n", + "Train Epoch: 3 [50000/60000 (83%)]\tLoss: 0.205424\n", + "Train Epoch: 3 [51000/60000 (85%)]\tLoss: 0.428523\n", + "Train Epoch: 3 [52000/60000 (87%)]\tLoss: 0.279334\n", + "Train Epoch: 3 [53000/60000 (88%)]\tLoss: 0.224919\n", + "Train Epoch: 3 [54000/60000 (90%)]\tLoss: 0.259237\n", + "Train Epoch: 3 [55000/60000 (92%)]\tLoss: 0.307156\n", + "Train Epoch: 3 [56000/60000 (93%)]\tLoss: 0.354070\n", + "Train Epoch: 3 [57000/60000 (95%)]\tLoss: 0.423041\n", + "Train Epoch: 3 [58000/60000 (97%)]\tLoss: 0.309722\n", + "Train Epoch: 3 [59000/60000 (98%)]\tLoss: 0.222640\n", + "Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.300421\n", + "Train Epoch: 4 [1000/60000 (2%)]\tLoss: 0.345936\n", + "Train Epoch: 4 [2000/60000 (3%)]\tLoss: 0.171191\n", + "Train Epoch: 4 [3000/60000 (5%)]\tLoss: 0.259701\n", + "Train Epoch: 4 [4000/60000 (7%)]\tLoss: 0.318586\n", + "Train Epoch: 4 [5000/60000 (8%)]\tLoss: 0.291326\n", + "Train Epoch: 4 [6000/60000 (10%)]\tLoss: 0.326561\n", + "Train Epoch: 4 [7000/60000 (12%)]\tLoss: 0.272973\n", + "Train Epoch: 4 [8000/60000 (13%)]\tLoss: 0.284000\n", + "Train Epoch: 4 [9000/60000 (15%)]\tLoss: 0.408247\n", + "Train Epoch: 4 [10000/60000 (17%)]\tLoss: 0.258859\n", + "Train Epoch: 4 [11000/60000 (18%)]\tLoss: 0.316111\n", + "Train Epoch: 4 [12000/60000 (20%)]\tLoss: 0.223681\n", + "Train Epoch: 4 [13000/60000 (22%)]\tLoss: 0.339430\n", + "Train Epoch: 4 [14000/60000 (23%)]\tLoss: 0.370453\n", + "Train Epoch: 4 [15000/60000 (25%)]\tLoss: 0.262537\n", + "Train Epoch: 4 [16000/60000 (27%)]\tLoss: 0.288620\n", + "Train Epoch: 4 [17000/60000 (28%)]\tLoss: 0.260116\n", + "Train Epoch: 4 [18000/60000 (30%)]\tLoss: 0.198731\n", + "Train Epoch: 4 [19000/60000 (32%)]\tLoss: 0.278724\n", + "Train Epoch: 4 [20000/60000 (33%)]\tLoss: 0.346948\n", + "Train Epoch: 4 [21000/60000 (35%)]\tLoss: 0.311753\n", + "Train Epoch: 4 [22000/60000 (37%)]\tLoss: 0.250816\n", + "Train Epoch: 4 [23000/60000 (38%)]\tLoss: 0.387214\n", + "Train Epoch: 4 [24000/60000 (40%)]\tLoss: 0.284455\n", + "Train Epoch: 4 [25000/60000 (42%)]\tLoss: 0.217047\n", + "Train Epoch: 4 [26000/60000 (43%)]\tLoss: 0.263308\n", + "Train Epoch: 4 [27000/60000 (45%)]\tLoss: 0.252495\n", + "Train Epoch: 4 [28000/60000 (47%)]\tLoss: 0.254124\n", + "Train Epoch: 4 [29000/60000 (48%)]\tLoss: 0.244650\n", + "Train Epoch: 4 [30000/60000 (50%)]\tLoss: 0.343729\n", + "Train Epoch: 4 [31000/60000 (52%)]\tLoss: 0.256499\n", + "Train Epoch: 4 [32000/60000 (53%)]\tLoss: 0.260895\n", + "Train Epoch: 4 [33000/60000 (55%)]\tLoss: 0.189408\n", + "Train Epoch: 4 [34000/60000 (57%)]\tLoss: 0.414647\n", + "Train Epoch: 4 [35000/60000 (58%)]\tLoss: 0.422450\n", + "Train Epoch: 4 [36000/60000 (60%)]\tLoss: 0.330982\n", + "Train Epoch: 4 [37000/60000 (62%)]\tLoss: 0.237692\n", + "Train Epoch: 4 [38000/60000 (63%)]\tLoss: 0.273783\n", + "Train Epoch: 4 [39000/60000 (65%)]\tLoss: 0.238216\n", + "Train Epoch: 4 [40000/60000 (67%)]\tLoss: 0.216738\n", + "Train Epoch: 4 [41000/60000 (68%)]\tLoss: 0.351224\n", + "Train Epoch: 4 [42000/60000 (70%)]\tLoss: 0.323398\n", + "Train Epoch: 4 [43000/60000 (72%)]\tLoss: 0.253854\n", + "Train Epoch: 4 [44000/60000 (73%)]\tLoss: 0.248770\n", + "Train Epoch: 4 [45000/60000 (75%)]\tLoss: 0.241756\n", + "Train Epoch: 4 [46000/60000 (77%)]\tLoss: 0.391494\n", + "Train Epoch: 4 [47000/60000 (78%)]\tLoss: 0.256354\n", + "Train Epoch: 4 [48000/60000 (80%)]\tLoss: 0.308836\n", + "Train Epoch: 4 [49000/60000 (82%)]\tLoss: 0.249096\n", + "Train Epoch: 4 [50000/60000 (83%)]\tLoss: 0.216556\n", + "Train Epoch: 4 [51000/60000 (85%)]\tLoss: 0.304842\n", + "Train Epoch: 4 [52000/60000 (87%)]\tLoss: 0.245515\n", + "Train Epoch: 4 [53000/60000 (88%)]\tLoss: 0.177025\n", + "Train Epoch: 4 [54000/60000 (90%)]\tLoss: 0.219287\n", + "Train Epoch: 4 [55000/60000 (92%)]\tLoss: 0.282225\n", + "Train Epoch: 4 [56000/60000 (93%)]\tLoss: 0.310966\n", + "Train Epoch: 4 [57000/60000 (95%)]\tLoss: 0.338995\n", + "Train Epoch: 4 [58000/60000 (97%)]\tLoss: 0.265767\n", + "Train Epoch: 4 [59000/60000 (98%)]\tLoss: 0.263115\n", + "Training is finished\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "# Set the Torch Distributed env variables so the training function can be run in the notebook\n", + "# See https://pytorch.org/docs/stable/elastic/run.html#environment-variables\n", + "os.environ[\"RANK\"] = \"0\"\n", + "os.environ[\"LOCAL_RANK\"] = \"0\"\n", + "os.environ[\"WORLD_SIZE\"] = \"1\"\n", + "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", + "os.environ[\"MASTER_PORT\"] = \"1234\"\n", + "\n", + "# Run the training function locally\n", + "train_fashion_mnist()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Scale PyTorch DDP with Kubeflow TrainJob" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can use `TrainingClient()` from the Kubeflow SDK to communicate with Kubeflow APIs and scale your training function across multiple PyTorch training nodes.\n", + "\n", + "Kubeflow Trainer creates a `TrainJob` resource and automatically sets the appropriate environment variables to set up PyTorch in distributed environment." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "from kubeflow.training import Trainer, TrainingClient\n", + "client = TrainingClient()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## List the Training Runtimes" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can get the list of available Training Runtimes to start your TrainJob:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Runtime(name='torch-distributed', phase='pre-training', accelerator='Unknown', accelerator_count='Unknown')\n" + ] + } + ], + "source": [ + "for runtime in client.list_runtimes():\n", + " print(runtime)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each Training Runtime shows whether you can use it for pre-training or post-training.\n", + "Additionally, it shows available accelerator type and number of available resources." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the distributed TrainJob" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "job_name = client.train(\n", + " # Use one the of the training runtimes installed on your Kubernetes cluster\n", + " runtime_ref=\"torch-distributed\",\n", + " trainer=Trainer(\n", + " func=train_fashion_mnist,\n", + " # Set how many worker Pods you want the job to be distributed into\n", + " num_nodes=4,\n", + " # Set the resources for each worker Pod\n", + " resources_per_node={\n", + " \"cpu\": 1,\n", + " \"memory\": \"16Gi\",\n", + " # Uncomment to distribute the TrainJob on nodes with GPUs\n", + " #\"nvidia.com/gpu\": 1,\n", + " },\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check the TrainJob components" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can check the details of the TrainJob that's created:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TrainJob(name='t4a503bc3394', runtime_ref='torch-distributed', creation_timestamp=datetime.datetime(2025, 1, 30, 15, 1, 46, tzinfo=tzutc()), components=[Component(name='trainer-node-0', status='Running', device='gpu', device_count='1', pod_name='t4a503bc3394-trainer-node-0-0-xh4mb'), Component(name='trainer-node-1', status='Running', device='gpu', device_count='1', pod_name='t4a503bc3394-trainer-node-0-1-4vjkq'), Component(name='trainer-node-2', status='Running', device='gpu', device_count='1', pod_name='t4a503bc3394-trainer-node-0-2-f422f'), Component(name='trainer-node-3', status='Running', device='gpu', device_count='1', pod_name='t4a503bc3394-trainer-node-0-3-grcdm')], status='Created')\n" + ] + } + ], + "source": [ + "job = client.get_job(job_name)\n", + "print(job)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since the TrainJob is distributed using 4 nodes, the TrainJob creates 4 components: `trainer-node-0`, ..., `trainer-node-3`, and you can get the individual status for each of these components." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Watch the TrainJob logs" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[trainer-node]: Using Device: cuda, Backend: nccl\n", + "[trainer-node]: Distributed Training for WORLD_SIZE: 4, RANK: 0, LOCAL_RANK: 0\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz\n", + "100%|██████████| 26.4M/26.4M [00:01<00:00, 15.2MB/s]\n", + "[trainer-node]: Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz\n", + "100%|██████████| 29.5k/29.5k [00:00<00:00, 324kB/s]\n", + "[trainer-node]: Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz\n", + "100%|██████████| 4.42M/4.42M [00:00<00:00, 4.85MB/s]\n", + "[trainer-node]: Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n", + "[trainer-node]: Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz\n", + "100%|██████████| 5.15k/5.15k [00:00<00:00, 51.3MB/s]\n", + "[trainer-node]: Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw\n", + "[trainer-node]: Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.306468\n", + "[trainer-node]: Train Epoch: 1 [1000/60000 (7%)]\tLoss: 1.956554\n", + "[trainer-node]: Train Epoch: 1 [2000/60000 (13%)]\tLoss: 2.061135\n", + "[trainer-node]: Train Epoch: 1 [3000/60000 (20%)]\tLoss: 2.437864\n", + "[trainer-node]: Train Epoch: 1 [4000/60000 (27%)]\tLoss: 1.225560\n", + "[trainer-node]: Train Epoch: 1 [5000/60000 (33%)]\tLoss: 0.861721\n", + "[trainer-node]: Train Epoch: 1 [6000/60000 (40%)]\tLoss: 0.782600\n", + "[trainer-node]: Train Epoch: 1 [7000/60000 (47%)]\tLoss: 0.782177\n", + "[trainer-node]: Train Epoch: 1 [8000/60000 (53%)]\tLoss: 0.622933\n", + "[trainer-node]: Train Epoch: 1 [9000/60000 (60%)]\tLoss: 0.644298\n", + "[trainer-node]: Train Epoch: 1 [10000/60000 (67%)]\tLoss: 0.466137\n", + "[trainer-node]: Train Epoch: 1 [11000/60000 (73%)]\tLoss: 0.585689\n", + "[trainer-node]: Train Epoch: 1 [12000/60000 (80%)]\tLoss: 0.505467\n", + "[trainer-node]: Train Epoch: 1 [13000/60000 (87%)]\tLoss: 0.546330\n", + "[trainer-node]: Train Epoch: 1 [14000/60000 (93%)]\tLoss: 0.453699\n", + "[trainer-node]: Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.569546\n", + "[trainer-node]: Train Epoch: 2 [1000/60000 (7%)]\tLoss: 0.419295\n", + "[trainer-node]: Train Epoch: 2 [2000/60000 (13%)]\tLoss: 0.432655\n", + "[trainer-node]: Train Epoch: 2 [3000/60000 (20%)]\tLoss: 0.339359\n", + "[trainer-node]: Train Epoch: 2 [4000/60000 (27%)]\tLoss: 0.546911\n", + "[trainer-node]: Train Epoch: 2 [5000/60000 (33%)]\tLoss: 0.419830\n", + "[trainer-node]: Train Epoch: 2 [6000/60000 (40%)]\tLoss: 0.526430\n", + "[trainer-node]: Train Epoch: 2 [7000/60000 (47%)]\tLoss: 0.486774\n", + "[trainer-node]: Train Epoch: 2 [8000/60000 (53%)]\tLoss: 0.362716\n", + "[trainer-node]: Train Epoch: 2 [9000/60000 (60%)]\tLoss: 0.446857\n", + "[trainer-node]: Train Epoch: 2 [10000/60000 (67%)]\tLoss: 0.322937\n", + "[trainer-node]: Train Epoch: 2 [11000/60000 (73%)]\tLoss: 0.411337\n", + "[trainer-node]: Train Epoch: 2 [12000/60000 (80%)]\tLoss: 0.381369\n", + "[trainer-node]: Train Epoch: 2 [13000/60000 (87%)]\tLoss: 0.442482\n", + "[trainer-node]: Train Epoch: 2 [14000/60000 (93%)]\tLoss: 0.328791\n", + "[trainer-node]: Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.494806\n", + "[trainer-node]: Train Epoch: 3 [1000/60000 (7%)]\tLoss: 0.414654\n", + "[trainer-node]: Train Epoch: 3 [2000/60000 (13%)]\tLoss: 0.361151\n", + "[trainer-node]: Train Epoch: 3 [3000/60000 (20%)]\tLoss: 0.281028\n", + "[trainer-node]: Train Epoch: 3 [4000/60000 (27%)]\tLoss: 0.397668\n", + "[trainer-node]: Train Epoch: 3 [5000/60000 (33%)]\tLoss: 0.277901\n", + "[trainer-node]: Train Epoch: 3 [6000/60000 (40%)]\tLoss: 0.420168\n", + "[trainer-node]: Train Epoch: 3 [7000/60000 (47%)]\tLoss: 0.502875\n", + "[trainer-node]: Train Epoch: 3 [8000/60000 (53%)]\tLoss: 0.292459\n", + "[trainer-node]: Train Epoch: 3 [9000/60000 (60%)]\tLoss: 0.382577\n", + "[trainer-node]: Train Epoch: 3 [10000/60000 (67%)]\tLoss: 0.345345\n", + "[trainer-node]: Train Epoch: 3 [11000/60000 (73%)]\tLoss: 0.335450\n", + "[trainer-node]: Train Epoch: 3 [12000/60000 (80%)]\tLoss: 0.365036\n", + "[trainer-node]: Train Epoch: 3 [13000/60000 (87%)]\tLoss: 0.314009\n", + "[trainer-node]: Train Epoch: 3 [14000/60000 (93%)]\tLoss: 0.309430\n", + "[trainer-node]: Train Epoch: 4 [0/60000 (0%)]\tLoss: 0.399208\n", + "[trainer-node]: Train Epoch: 4 [1000/60000 (7%)]\tLoss: 0.379463\n", + "[trainer-node]: Train Epoch: 4 [2000/60000 (13%)]\tLoss: 0.336221\n", + "[trainer-node]: Train Epoch: 4 [3000/60000 (20%)]\tLoss: 0.266147\n", + "[trainer-node]: Train Epoch: 4 [4000/60000 (27%)]\tLoss: 0.327334\n", + "[trainer-node]: Train Epoch: 4 [5000/60000 (33%)]\tLoss: 0.264875\n", + "[trainer-node]: Train Epoch: 4 [6000/60000 (40%)]\tLoss: 0.416835\n", + "[trainer-node]: Train Epoch: 4 [7000/60000 (47%)]\tLoss: 0.476572\n", + "[trainer-node]: Train Epoch: 4 [8000/60000 (53%)]\tLoss: 0.298670\n", + "[trainer-node]: Train Epoch: 4 [9000/60000 (60%)]\tLoss: 0.325634\n", + "[trainer-node]: Train Epoch: 4 [10000/60000 (67%)]\tLoss: 0.268218\n", + "[trainer-node]: Train Epoch: 4 [11000/60000 (73%)]\tLoss: 0.294167\n", + "[trainer-node]: Train Epoch: 4 [12000/60000 (80%)]\tLoss: 0.302991\n", + "[trainer-node]: Train Epoch: 4 [13000/60000 (87%)]\tLoss: 0.303185\n", + "[trainer-node]: Train Epoch: 4 [14000/60000 (93%)]\tLoss: 0.288636\n", + "[trainer-node]: Training is finished\n" + ] + } + ], + "source": [ + "_ = client.get_job_logs(job_name, follow=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each node processes it's assigned shard of the Fashion-MNIST dataset.\n", + "As the `TrainJob` is distributed on 4 nodes, and the dataset contains a total of 60 000 samples, each node processes 15 000 samples." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Delete the TrainJob" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "client.delete_job(job_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}