Skip to content

Commit

Permalink
Merge pull request #14 from foundation-model-stack/fix-minors
Browse files Browse the repository at this point in the history
fix some minors
  • Loading branch information
lchu6 authored Feb 16, 2024
2 parents 162128f + 021b736 commit 1cf8b1f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
19 changes: 7 additions & 12 deletions fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@

@dataclass
class train_config:
# seed
seed: int = 2023

# model
model_variant: str = "7b"
ckpt_load_path: str = "/lustre/pretrain/ckpt"
ckpt_save_path: str = "/lustre/pretrain/ckpt"

# data and dataloader
# dataset and dataloader
use_dummy_dataset: bool = False
data_path: str = "/lustre/data"
seq_length: int = 4096
Expand All @@ -20,26 +17,24 @@ class train_config:
weights: str = "7700,500,550,28,17,22,25,8,100,500,175,250,100,25"
logical_shards: int = 768

# compile
use_torch_compile: bool = False

# profiler
use_profiler: bool = False

# fsdp policies
mixed_precision: bool = True
fsdp_activation_checkpointing: bool = False
selective_checkpointing: int = 1
sharding_strategy: str = "hsdp"
sharding_group_size: int = 8
low_cpu_fsdp: bool = False

# training spec
seed: int = 2023
batch_size: int = 2
num_steps: int = 2000000
learning_rate: float = 3e-4
grad_clip_thresh: float = 1.0

# reporting
# profiling and reporting
use_profiler: bool = False
report_interval: int = 200
checkpoint_interval: int = 20000

# compile
use_torch_compile: bool = False
2 changes: 1 addition & 1 deletion fms_fsdp/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,6 @@ def get_model_config(model_variant):
max_expected_seq_len=2048,
)
else:
raise ValueError(f"model variant {cfg.model_variant} not supported.")
raise ValueError(f"model variant {model_variant} not supported.")

return llama_config
2 changes: 1 addition & 1 deletion fms_fsdp/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def train(
start = time.time()
loop_start = time.time()
for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1):
if batch_idx == cfg.num_steps:
if batch_idx > cfg.num_steps:
break
input = input.to(local_rank)
label = label.to(local_rank)
Expand Down

0 comments on commit 1cf8b1f

Please sign in to comment.