Skip to content

Commit

Permalink
Make Lightning checkpoint demo work with Bernard's GKE framework and …
Browse files Browse the repository at this point in the history
…with FSDP strategy (#86)

* FSDP demo

* Saving work for GKE

* Working for multi-node

* Update demo to use full strategy

* Add README and put placeholders in yaml file

* Mention gcsfs in the README limitations

* Remove bucket from readme + yaml file
  • Loading branch information
MattIrv authored Aug 8, 2024
1 parent 6fa9631 commit 1a5ebd8
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 47 deletions.
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ COPY ./demo/lightning/text-based/distributed/requirements.txt requirements-1.txt
RUN pip install --no-cache-dir -r requirements-1.txt
COPY ./demo/lightning/image-segmentation/requirements.txt requirements-2.txt
RUN pip install --no-cache-dir -r requirements-2.txt
COPY ./demo/lightning/checkpoint/requirements.txt requirements-3.txt
RUN pip install --no-cache-dir -r requirements-3.txt

# Copy the code.
COPY ./ ./
Expand Down
4 changes: 2 additions & 2 deletions dataflux_pytorch/lightning/dataflux_lightning_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pathlib import Path
from typing import Any, Dict, Optional, Union

import torch
from dataflux_core import user_agent
from google.cloud import storage
from lightning.pytorch.plugins.io import CheckpointIO
from pathlib import Path


class DatafluxLightningCheckpoint(CheckpointIO):
Expand All @@ -28,7 +28,7 @@ def _process_input_path(self, path: Union[str, Path]) -> str:
elif isinstance(path, Path):
# When casting from Path object to string, it considers cloud URLs as Network URLs and gets rid of //
scheme, rest = str(path).split(":/")
return str(scheme)+"://"+str(rest)
return str(scheme) + "://" + str(rest)
else:
raise TypeError(
"path argument must be of type string or pathlib.Path object")
Expand Down
41 changes: 41 additions & 0 deletions demo/lightning/checkpoint/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Checkpoint Demo for PyTorch Lightning

The code in this folder provides a training demo for checkpointing with PyTorch Lightning. This demo is under development.

## Limitations

* The demo currently only runs with [`state_dict_type="full"`](https://lightning.ai/docs/pytorch/stable/common/checkpointing_expert.html#save-a-distributed-checkpoint) when using FSDP.
* `requirements.txt` includes gcsfs because even though it is not used for checkpointing, PyTorch Lightning's default logger also writes to the root directory where checkpoints are saved.

## Running locally

1. Set the environment variables required to run the demo. These include:
* `PROJECT`: The GCP project you are using
* `CKPT_DIR_PATH`: The full path of the directory in which to save checkpoints, in the format `gs://<bucket>/<directory>/`
2. Set the optional environment variables, if desired:
* `NUM_LAYERS`: The number of layers in the model, which affects the size of the model and therefore the size of the checkpoints
* `ACCELERATOR`: Set to `gpu` if running on a GPU, or `cpu` if running on a CPU (default)
* If running on a GPU, you also must set `PJRT_DEVICE` to `CUDA`.
* `TRAIN_STRATEGY`: Set to `fsdp` to use the FSDP strategy. The default is `ddp`. If using FSDP, you must use GPUs
4. Install requirements: `pip install -r demo/lightning/checkpoint/requirements.txt`; `pip install .`
3. Run the binary: `python3 -m demo.lightning.checkpoint.train`

## Running on GKE

These instructions assume you have an existing GKE cluster with Kueue and Jobset installed. These are installed by default if you create the cluster using [xpk](https://github.com/google/xpk).

### Build and push the Docker container

```
docker build -t my-container .
docker tag my-container gcr.io/<PROJECT_NAME>/my-container
docker push gcr.io/<PROJECT_NAME>/my-container
```

Make sure to update the container name in the yaml config file to match the one you're using.

### Run the workload on GKE

1. Connect to your GKE cluster: `gcloud container clusters get-credentials <CLUSTER_NAME> --region=<COMPUTE_REGION>`
2. Make a copy of `demo/lightning/checkpoint/example-deploy.yaml` and update the placeholders and environment variables as needed
3. Run `kubectl -f apply <path-to-your-yaml-file>`
102 changes: 102 additions & 0 deletions demo/lightning/checkpoint/example-deploy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2024 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: my-job-run
labels:
kueue.x-k8s.io/queue-name: multislice-queue # Name of the LocalQueue
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool # 1:1 job replica to node pool assignment
spec:
failurePolicy:
maxRestarts: 0
replicatedJobs:
- name: checkpoint-job
replicas: 1
template:
spec:
parallelism: 2 # Equal to the number of VMs per slice
completions: 2 # Same as the above.
backoffLimit: 0 # When any pod fails, the job is failed
template:
spec:
schedulerName: default-scheduler
restartPolicy: Never

priorityClassName: medium
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
terminationGracePeriodSeconds: 30
containers:
- name: checkpoint-gpu
image: gcr.io/my-project/my-container

resources:
limits:
nvidia.com/gpu: 1

env:
- name: REPLICATED_JOB_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
- name: JOB_INDEX
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/job-index']
- name: JOB_COMPLETION_INDEX
valueFrom:
fieldRef:
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
- name: PROCESSES_IN_JOB
value: "2"
- name: WORLD_SIZE
value: "2"

- name: JOBSET_NAME
value: "my-job-run"
- name: COORDINATOR_ADDRESS
value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"
- name: MASTER_PORT
value: "1234"
- name: PROJECT
value: "my-project"
- name: CKPT_DIR_PATH
value: "gs://my-bucket/checkpoint-path/"
- name: PJRT_DEVICE
value: "CUDA"
- name: NCCL_SOCKET_IFNAME
value: "eth0"
- name: NCCL_DEBUG
value: "WARN"
- name: NUM_LAYERS
value: "100"
- name: TRAIN_STRATEGY
value: "fsdp"
- name: ACCELERATOR
value: "gpu"

ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 1234
securityContext: {}
# privileged: true
command:
- bash
- -c
- |
python3 -u /app/demo/lightning/checkpoint/train.py;
1 change: 1 addition & 0 deletions demo/lightning/checkpoint/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gcsfs
101 changes: 101 additions & 0 deletions demo/lightning/checkpoint/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
import socket
import time

from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos import (LightningTransformer, Transformer,
WikiText2)
from torch.utils.data import DataLoader

from dataflux_pytorch.lightning import DatafluxLightningCheckpoint


def configure_master_addr():
"""Get coordinator IP Address with retries"""
coordinator_address = ""
coordinator_ip_address = ""
if os.environ.get("COORDINATOR_ADDRESS") is not None:
coordinator_address = os.environ.get("COORDINATOR_ADDRESS")
coordinator_found = False
lookup_attempt = 1
max_coordinator_lookups = 50
while not coordinator_found and lookup_attempt <= max_coordinator_lookups:
try:
coordinator_ip_address = socket.gethostbyname(
coordinator_address)
coordinator_found = True
except socket.gaierror:
print(
f"Failed to recognize coordinator address {coordinator_address} on"
f" attempt {lookup_attempt}, retrying...")
lookup_attempt += 1
time.sleep(5)
print(f"Coordinator IP address: {coordinator_ip_address}")
os.environ["MASTER_ADDR"] = str(coordinator_ip_address)


def init_processes():
"""Initializes the distributed environment."""
# Get the necessary environment variables from the GKE environment
world_size = int(os.environ["WORLD_SIZE"])

job_index = int(os.environ.get("JOB_INDEX"))
job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX"))
processes_in_job = int(os.environ.get("PROCESSES_IN_JOB"))
rank = job_index * processes_in_job + job_completion_index
os.environ["NODE_RANK"] = str(rank)

configure_master_addr()


def main(project: str, ckpt_dir_path: str, save_only_latest: bool):
if os.environ.get("COORDINATOR_ADDRESS"):
init_processes()
dataset = WikiText2()
dataloader = DataLoader(dataset, num_workers=1)

model = DemoTransformer(vocab_size=dataset.vocab_size,
nlayers=int(os.environ.get("NUM_LAYERS", 2)))
dataflux_ckpt = DatafluxLightningCheckpoint(project_name=project)
# Save once per step, and if `save_only_latest`, replace the last checkpoint each time.
# Replacing is implemented by saving the new checkpoint, and then deleting the previous one.
# If `save_only_latest` is False, a new checkpoint is created for each step.
checkpoint_callback = ModelCheckpoint(
save_top_k=1 if save_only_latest else -1,
every_n_train_steps=1,
filename="checkpoint-{epoch:02d}-{step:02d}",
enable_version_counter=True,
)
strategy = os.environ.get("TRAIN_STRATEGY", "ddp")
accelerator = os.environ.get("ACCELERATOR", "cpu")
trainer = Trainer(default_root_dir=ckpt_dir_path,
plugins=[dataflux_ckpt],
callbacks=[checkpoint_callback],
min_epochs=4,
max_epochs=5,
max_steps=3,
accelerator=accelerator,
strategy=strategy,
num_nodes=int(os.environ.get("WORLD_SIZE", 1)))
trainer.fit(model, dataloader)


class DemoTransformer(LightningTransformer):

def __init__(
self,
vocab_size: int = 33278,
nlayers: int = 2,
) -> None:
super().__init__()
self.model = Transformer(vocab_size=vocab_size, nlayers=nlayers)


if __name__ == "__main__":

main(
os.getenv("PROJECT"),
os.getenv("CKPT_DIR_PATH"),
os.getenv("SAVE_ONLY_LATEST") == "1",
)
45 changes: 0 additions & 45 deletions demo/lightning/lightning_checkpoint.py

This file was deleted.

0 comments on commit 1a5ebd8

Please sign in to comment.