Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GKE deployment for MaxText Parquet training benchmark #91

Merged
merged 5 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ WORKDIR /app
# Install additional requirements for running demos.
# Do this before copying the code so that these commands are still cached
# by Docker even if the code changes.
COPY ./demo/lightning/text-based/distributed/requirements.txt requirements-1.txt
COPY ./dataflux_pytorch/benchmark/standalone_dataloader/requirements.txt requirements-1.txt
RUN pip install --no-cache-dir -r requirements-1.txt
COPY ./demo/lightning/image-segmentation/requirements.txt requirements-2.txt
COPY ./demo/lightning/text-based/distributed/requirements.txt requirements-2.txt
RUN pip install --no-cache-dir -r requirements-2.txt
COPY ./demo/lightning/checkpoint/requirements.txt requirements-3.txt
COPY ./demo/lightning/image-segmentation/requirements.txt requirements-3.txt
RUN pip install --no-cache-dir -r requirements-3.txt
COPY ./demo/lightning/checkpoint/requirements.txt requirements-4.txt
RUN pip install --no-cache-dir -r requirements-4.txt

# Copy the code.
COPY ./ ./
Expand Down
16 changes: 16 additions & 0 deletions dataflux_pytorch/benchmark/standalone_dataloader/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,19 @@ These instructions are run relative to this directory.
```sh
JAX_PLATFORMS=cpu python3 standalone_dataloader.py maxtext/MaxText/configs/base.yml ${BENCHMARK_RUN_FLAGS} ${COMMON_RUN_FLAGS}
```

## Running on GKE

### Build the Docker image

In the following commands, update `gcs-tess` to your project ID as needed.

1. `docker build -t dataflux-list-and-download .`
2. `docker tag dataflux-list-and-download gcr.io/gcs-tess/dataflux-maxtext`
3. `docker push gcr.io/gcs-tess/dataflux-maxtext`

### Run the benchmark

1. Update any needed flags/configs in `dataflux_pytorch/benchmark/standalone_dataloader/deployment.yaml`
* Notably the job name, completions/parallelism, image name, and any flags
2. `kubectl apply -f dataflux_pytorch/benchmark/standalone_dataloader/deployment.yaml`
109 changes: 109 additions & 0 deletions dataflux_pytorch/benchmark/standalone_dataloader/deployment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
# Modify this name to distinguish your workload from others.
# Make sure to modify all occurrences of the name in this file.
name: dataflux-maxtext-workload
labels:
kueue.x-k8s.io/queue-name: multislice-queue # Name of the LocalQueue
xpk.google.com/workload: dataflux-maxtext-workload
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool # 1:1 job replica to node pool assignment
spec:
failurePolicy:
maxRestarts: 0
replicatedJobs:
- name: slice-job
replicas: 1
template:
spec:
parallelism: 10 # Equal to the number of VMs per slice
completions: 10 # Same as the above.
backoffLimit: 0 # When any pod fails, the job is failed
template:
metadata:
labels:
xpk.google.com/workload: dataflux-maxtext-workload

spec:
schedulerName: default-scheduler
restartPolicy: Never
affinity:
nodeAffinity:
requiredDuringSchedulingIgnoredDuringExecution:
nodeSelectorTerms:
- matchExpressions:
- key: cloud.google.com/gke-nodepool
operator: NotIn
values:
- default-pool

priorityClassName: medium
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
terminationGracePeriodSeconds: 30

containers:
- name: jax-cpu
image: gcr.io/gcs-tess/dataflux-maxtext

env:
- name: REPLICATED_JOB_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
- name: JOB_INDEX
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/job-index']
- name: JOB_COMPLETION_INDEX
valueFrom:
fieldRef:
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index']
# Modify the following two values too, if you intend to run the workload in smaller scale.
- name: PROCESSES_IN_JOB
value: "10"
- name: JAX_PROCESS_COUNT
value: "10"
- name: JOBSET_NAME
value: "dataflux-maxtext-workload"
- name: JAX_COORDINATOR_ADDRESS
value: "$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)"

ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 1234
securityContext:
privileged: true
command:
- bash
- -c
- |
# Modify the parameters here.
# See the instructions of parameters https://github.com/google/maxtext/blob/gcs-distributed-training-benchmark/MaxText/configs/base.yml#L359.
# Please modify the RUN_NAME to distinguish your run from others and add identifiers
# for the storage team (HdML, PStore, GCS, etc).

export RUN_NAME=<YOUR-NAME>-dataflux-maxtext-$(date +"%Y-%m-%d")

export PROJECT="<YOUR-PROJECT>"
export BUCKET="<YOUR-BUCKET>"
export PREFIX="<DATA-PREFIX>"
export EPOCHS=2
export MAX_STEPS=-1
export LOCAL_BATCH_SIZE=32
export PREFETCH_FACTOR=2
export DATA_LOADER_NUM_WORKERS=10
export PER_STEP_INTERVAL=0.1
export GCS_METRICS_BUCKET="<METRICS-BUCKET>"

export COMMON_RUN_FLAGS="enable_checkpointing=False hardware=cpu"
export BENCHMARK_RUN_FLAGS="run_name=${RUN_NAME} dataset_directory=${DATASET_DIRECTORY} epochs=${EPOCHS} max_steps=${MAX_STEPS} local_batch_size=${LOCAL_BATCH_SIZE} prefetch_factor=${PREFETCH_FACTOR} data_loader_num_workers=${DATA_LOADER_NUM_WORKERS} per_step_interval=${PER_STEP_INTERVAL} gcs_metrics_bucket=${GCS_METRICS_BUCKET}"
echo XPK Start: $(date) ; _sigterm() ( kill -SIGTERM $! 2>/dev/null;); trap _sigterm SIGTERM;(JAX_PLATFORMS=cpu python3 dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py dataflux_pytorch/benchmark/standalone_dataloader/maxtext/MaxText/configs/base.yml ${BENCHMARK_RUN_FLAGS} ${COMMON_RUN_FLAGS}) & PID=$!; while kill -0 $PID 2>/dev/null; do sleep 5; done; wait $PID; EXIT_CODE=$? ; echo XPK End: $(date); echo EXIT_CODE=$EXIT_CODE;

resources:
requests:
# Requesting 20 CPU cores as the node machine is n2-32, this is to
# ensure that one pod is scheduled per node.
cpu: 20000m
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
import datetime
import random
import time
from types import MappingProxyType
from typing import Sequence, Type
from typing import Sequence

import jax
import pyarrow as pa
import pyarrow.parquet as pq
from absl import app
from maxtext.MaxText import max_logging, pyconfig, storage_utils, train
from torch.utils import data
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data import DataLoader

from dataflux_pytorch import dataflux_iterable_dataset

Expand All @@ -29,51 +28,6 @@
STEP_BARRIER_MSG = "Synchronize all processes within a step"


def split_list(lst, n):
"""Splits a list into roughly equal sized sublists and pads.

Args:
lst: The list to split.
n: The desired number of sublists.

Returns:
A list of sublists.
"""
# Calculate the size of each sublist.
size = len(lst) // n

# Create the sublists.
sublists = [lst[i * size:(i + 1) * size] for i in range(n)]

last_idx = n * size

if last_idx >= len(lst):
return sublists

remainder = len(lst) - last_idx

for i in range(remainder):
sublists[i].append(lst[last_idx + i])

# Padding to make sure all nodes are loading the same amount of
# files. Needed to make sure no deadlocking when the workload
# is distributed unevenly.
max_length = max([len(each) for each in sublists])
for each in sublists:
while len(each) < max_length:
each.append(random.choice(lst))

return sublists


def list_files_walk(start_path='.'):
dataset_files = []
for root, _, files in os.walk(start_path):
for file in files:
dataset_files.append(os.path.join(root, file))
return sorted(dataset_files)


def parquet_data_loader(config):
batch_size = config.local_batch_size
worker_id = jax.process_index()
Expand Down