-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make Lightning checkpoint demo work with Bernard's GKE framework and …
…with FSDP strategy (#86) * FSDP demo * Saving work for GKE * Working for multi-node * Update demo to use full strategy * Add README and put placeholders in yaml file * Mention gcsfs in the README limitations * Remove bucket from readme + yaml file
- Loading branch information
Showing
7 changed files
with
249 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Checkpoint Demo for PyTorch Lightning | ||
|
||
The code in this folder provides a training demo for checkpointing with PyTorch Lightning. This demo is under development. | ||
|
||
## Limitations | ||
|
||
* The demo currently only runs with [`state_dict_type="full"`](https://lightning.ai/docs/pytorch/stable/common/checkpointing_expert.html#save-a-distributed-checkpoint) when using FSDP. | ||
* `requirements.txt` includes gcsfs because even though it is not used for checkpointing, PyTorch Lightning's default logger also writes to the root directory where checkpoints are saved. | ||
|
||
## Running locally | ||
|
||
1. Set the environment variables required to run the demo. These include: | ||
* `PROJECT`: The GCP project you are using | ||
* `CKPT_DIR_PATH`: The full path of the directory in which to save checkpoints, in the format `gs://<bucket>/<directory>/` | ||
2. Set the optional environment variables, if desired: | ||
* `NUM_LAYERS`: The number of layers in the model, which affects the size of the model and therefore the size of the checkpoints | ||
* `ACCELERATOR`: Set to `gpu` if running on a GPU, or `cpu` if running on a CPU (default) | ||
* If running on a GPU, you also must set `PJRT_DEVICE` to `CUDA`. | ||
* `TRAIN_STRATEGY`: Set to `fsdp` to use the FSDP strategy. The default is `ddp`. If using FSDP, you must use GPUs | ||
4. Install requirements: `pip install -r demo/lightning/checkpoint/requirements.txt`; `pip install .` | ||
3. Run the binary: `python3 -m demo.lightning.checkpoint.train` | ||
|
||
## Running on GKE | ||
|
||
These instructions assume you have an existing GKE cluster with Kueue and Jobset installed. These are installed by default if you create the cluster using [xpk](https://github.com/google/xpk). | ||
|
||
### Build and push the Docker container | ||
|
||
``` | ||
docker build -t my-container . | ||
docker tag my-container gcr.io/<PROJECT_NAME>/my-container | ||
docker push gcr.io/<PROJECT_NAME>/my-container | ||
``` | ||
|
||
Make sure to update the container name in the yaml config file to match the one you're using. | ||
|
||
### Run the workload on GKE | ||
|
||
1. Connect to your GKE cluster: `gcloud container clusters get-credentials <CLUSTER_NAME> --region=<COMPUTE_REGION>` | ||
2. Make a copy of `demo/lightning/checkpoint/example-deploy.yaml` and update the placeholders and environment variables as needed | ||
3. Run `kubectl -f apply <path-to-your-yaml-file>` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright 2024 Google LLC | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
apiVersion: jobset.x-k8s.io/v1alpha2 | ||
kind: JobSet | ||
metadata: | ||
name: my-job-run | ||
labels: | ||
kueue.x-k8s.io/queue-name: multislice-queue # Name of the LocalQueue | ||
annotations: | ||
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool # 1:1 job replica to node pool assignment | ||
spec: | ||
failurePolicy: | ||
maxRestarts: 0 | ||
replicatedJobs: | ||
- name: checkpoint-job | ||
replicas: 1 | ||
template: | ||
spec: | ||
parallelism: 2 # Equal to the number of VMs per slice | ||
completions: 2 # Same as the above. | ||
backoffLimit: 0 # When any pod fails, the job is failed | ||
template: | ||
spec: | ||
schedulerName: default-scheduler | ||
restartPolicy: Never | ||
|
||
priorityClassName: medium | ||
hostNetwork: true | ||
dnsPolicy: ClusterFirstWithHostNet | ||
terminationGracePeriodSeconds: 30 | ||
containers: | ||
- name: checkpoint-gpu | ||
image: gcr.io/my-project/my-container | ||
|
||
resources: | ||
limits: | ||
nvidia.com/gpu: 1 | ||
|
||
env: | ||
- name: REPLICATED_JOB_NAME | ||
valueFrom: | ||
fieldRef: | ||
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name'] | ||
- name: JOB_INDEX | ||
valueFrom: | ||
fieldRef: | ||
fieldPath: metadata.annotations['jobset.sigs.k8s.io/job-index'] | ||
- name: JOB_COMPLETION_INDEX | ||
valueFrom: | ||
fieldRef: | ||
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index'] | ||
- name: PROCESSES_IN_JOB | ||
value: "2" | ||
- name: WORLD_SIZE | ||
value: "2" | ||
|
||
- name: JOBSET_NAME | ||
value: "my-job-run" | ||
- name: COORDINATOR_ADDRESS | ||
value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)" | ||
- name: MASTER_PORT | ||
value: "1234" | ||
- name: PROJECT | ||
value: "my-project" | ||
- name: CKPT_DIR_PATH | ||
value: "gs://my-bucket/checkpoint-path/" | ||
- name: PJRT_DEVICE | ||
value: "CUDA" | ||
- name: NCCL_SOCKET_IFNAME | ||
value: "eth0" | ||
- name: NCCL_DEBUG | ||
value: "WARN" | ||
- name: NUM_LAYERS | ||
value: "100" | ||
- name: TRAIN_STRATEGY | ||
value: "fsdp" | ||
- name: ACCELERATOR | ||
value: "gpu" | ||
|
||
ports: | ||
- containerPort: 8471 | ||
- containerPort: 8080 | ||
- containerPort: 1234 | ||
securityContext: {} | ||
# privileged: true | ||
command: | ||
- bash | ||
- -c | ||
- | | ||
python3 -u /app/demo/lightning/checkpoint/train.py; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
gcsfs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import os | ||
import socket | ||
import time | ||
|
||
from lightning import Trainer | ||
from lightning.pytorch.callbacks import ModelCheckpoint | ||
from lightning.pytorch.demos import (LightningTransformer, Transformer, | ||
WikiText2) | ||
from torch.utils.data import DataLoader | ||
|
||
from dataflux_pytorch.lightning import DatafluxLightningCheckpoint | ||
|
||
|
||
def configure_master_addr(): | ||
"""Get coordinator IP Address with retries""" | ||
coordinator_address = "" | ||
coordinator_ip_address = "" | ||
if os.environ.get("COORDINATOR_ADDRESS") is not None: | ||
coordinator_address = os.environ.get("COORDINATOR_ADDRESS") | ||
coordinator_found = False | ||
lookup_attempt = 1 | ||
max_coordinator_lookups = 50 | ||
while not coordinator_found and lookup_attempt <= max_coordinator_lookups: | ||
try: | ||
coordinator_ip_address = socket.gethostbyname( | ||
coordinator_address) | ||
coordinator_found = True | ||
except socket.gaierror: | ||
print( | ||
f"Failed to recognize coordinator address {coordinator_address} on" | ||
f" attempt {lookup_attempt}, retrying...") | ||
lookup_attempt += 1 | ||
time.sleep(5) | ||
print(f"Coordinator IP address: {coordinator_ip_address}") | ||
os.environ["MASTER_ADDR"] = str(coordinator_ip_address) | ||
|
||
|
||
def init_processes(): | ||
"""Initializes the distributed environment.""" | ||
# Get the necessary environment variables from the GKE environment | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
|
||
job_index = int(os.environ.get("JOB_INDEX")) | ||
job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) | ||
processes_in_job = int(os.environ.get("PROCESSES_IN_JOB")) | ||
rank = job_index * processes_in_job + job_completion_index | ||
os.environ["NODE_RANK"] = str(rank) | ||
|
||
configure_master_addr() | ||
|
||
|
||
def main(project: str, ckpt_dir_path: str, save_only_latest: bool): | ||
if os.environ.get("COORDINATOR_ADDRESS"): | ||
init_processes() | ||
dataset = WikiText2() | ||
dataloader = DataLoader(dataset, num_workers=1) | ||
|
||
model = DemoTransformer(vocab_size=dataset.vocab_size, | ||
nlayers=int(os.environ.get("NUM_LAYERS", 2))) | ||
dataflux_ckpt = DatafluxLightningCheckpoint(project_name=project) | ||
# Save once per step, and if `save_only_latest`, replace the last checkpoint each time. | ||
# Replacing is implemented by saving the new checkpoint, and then deleting the previous one. | ||
# If `save_only_latest` is False, a new checkpoint is created for each step. | ||
checkpoint_callback = ModelCheckpoint( | ||
save_top_k=1 if save_only_latest else -1, | ||
every_n_train_steps=1, | ||
filename="checkpoint-{epoch:02d}-{step:02d}", | ||
enable_version_counter=True, | ||
) | ||
strategy = os.environ.get("TRAIN_STRATEGY", "ddp") | ||
accelerator = os.environ.get("ACCELERATOR", "cpu") | ||
trainer = Trainer(default_root_dir=ckpt_dir_path, | ||
plugins=[dataflux_ckpt], | ||
callbacks=[checkpoint_callback], | ||
min_epochs=4, | ||
max_epochs=5, | ||
max_steps=3, | ||
accelerator=accelerator, | ||
strategy=strategy, | ||
num_nodes=int(os.environ.get("WORLD_SIZE", 1))) | ||
trainer.fit(model, dataloader) | ||
|
||
|
||
class DemoTransformer(LightningTransformer): | ||
|
||
def __init__( | ||
self, | ||
vocab_size: int = 33278, | ||
nlayers: int = 2, | ||
) -> None: | ||
super().__init__() | ||
self.model = Transformer(vocab_size=vocab_size, nlayers=nlayers) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
main( | ||
os.getenv("PROJECT"), | ||
os.getenv("CKPT_DIR_PATH"), | ||
os.getenv("SAVE_ONLY_LATEST") == "1", | ||
) |
This file was deleted.
Oops, something went wrong.