Skip to content

Commit

Permalink
Fix mypy issues
Browse files Browse the repository at this point in the history
Based on the output of mypy, fixed the following:
- Added ibm-fms to requirements
- Ran no_implicit_optional (PEP 484 prohibits implicit Optional)
- Exclude files that depend on untyped packages
- Fix typing issues in dataset_util

Signed-off-by: Andrea Frittoli <[email protected]>
  • Loading branch information
afrittoli committed Feb 7, 2024
1 parent 240a274 commit d81a3ce
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 41 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ jobs:
- name: Test with pytest
run: |
mypy pretraining
# No type stubs available for "fire" and "transformers"
mypy --exclude fms_to_hf.py --exclude main_training.py .
- name: Save Virtualenv
id: cache-venv-save
Expand Down
92 changes: 62 additions & 30 deletions pretraining/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,17 @@
import math
import os
import random
from typing import Any, Callable, List, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Protocol,
runtime_checkable,
Type,
Union,
)

import pyarrow as pa
import torch
Expand Down Expand Up @@ -45,8 +55,8 @@ def __init__(
assert (
worldsize > rank
), f"Worldsize {worldsize} must be greater than rank {rank}"
self.state_params = []
self.reshard_params = []
self.state_params: List[str] = []
self.reshard_params: List[str] = []
self.rank = rank
self.worldsize = worldsize
self.load_worldsize = (
Expand Down Expand Up @@ -261,11 +271,11 @@ def __init__(self, dataset: _Stateful_Dataset, window_size: int):
super().__init__(dataset)
assert (
window_size > 1
), f"Window size {f} must be greater than 1 for shuffling to occur"
), f"Window size {window_size} must be greater than 1 for shuffling to occur"
self.window_size = window_size
self.g_state = None
self.generator = torch.Generator().manual_seed(self.rank)
self.buffer = []
self.buffer: List[str] = []
self.state_params = ["g_state"]
self.reshard_params = ["buffer"]

Expand Down Expand Up @@ -423,8 +433,8 @@ def __init__(
dataset: _Stateful_Dataset,
task_seq_lens: List[int],
pack_hard: bool,
task_tokens: List = None,
task_weights: List[Union[int, float]] = None,
task_tokens: Optional[List] = None,
task_weights: Optional[List[Union[int, float]]] = None,
bos_token=None,
eos_token=None,
pad_token=None,
Expand Down Expand Up @@ -453,7 +463,7 @@ def __init__(
self.choice = list(range(n_tasks))

# Buffer args
self.buffer = []
self.buffer: List[str] = []
self.bos = bos_token
self.eos = eos_token
self.pad = pad_token
Expand Down Expand Up @@ -576,8 +586,8 @@ def __init__(
delimiter_token: Any,
trainsplit: float = 1,
is_val: bool = False,
datasets: List[str] = None,
weights: List[int] = None,
datasets: Optional[List[str]] = None,
weights: Optional[List[int]] = None,
seed: int = 42,
min_length: int = 1,
testrun_data_index: int = -100,
Expand Down Expand Up @@ -667,7 +677,7 @@ def __init__(
# Read shardfrags:
last_shard = ""
ndocs = -1
shardset = []
shardset: List[Any] = []
for i, (shard, frag) in enumerate(shardfrags):
# On new shard, wrap up shardset
if shard != last_shard:
Expand Down Expand Up @@ -728,7 +738,7 @@ def __init__(
self.dataset_tokens_seen = {d: 0 for d in self.datasets}
self.dataset_docs_seen = {d: 0 for d in self.datasets}
self.dataset_percent_seen = {d: 0 for d in self.datasets}
self.docs_seen = {} # (dataset, shard, i) -> # times seen
self.docs_seen: Dict[Any, int] = {} # (dataset, shard, i) -> # times seen

self.state_params = [
"docset_index",
Expand Down Expand Up @@ -821,6 +831,16 @@ def load_state_dict(self, state_dicts, sharded_input=False):
return super().load_state_dict(state_dicts, sharded_input)


@runtime_checkable
class _With_Datasets(Protocol):
datasets: List[Any]


@runtime_checkable
class _With_Docset(Protocol):
docset: List[Any]


class Sampling_Dataset(_Stateful_Dataset):
"""
A _Stateful_Dataset implementing percentage-based sampling: weights can be floats and the number of tokens seen from each subdataset will match those weights as closely as possible.
Expand Down Expand Up @@ -908,7 +928,10 @@ def __init__(
# Build subdataset iterators
self.data = []
for i, d in enumerate(self.datasets):
self.data.append(dataset(**passthrough_args, datasets=[d]))
stateful_dataset = dataset(**passthrough_args)
if isinstance(stateful_dataset, _With_Datasets):
stateful_dataset.datasets = [d]
self.data.append(stateful_dataset)
if verbose:
logging.info(
f"Worker {rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}"
Expand Down Expand Up @@ -1106,7 +1129,7 @@ def __init__(

super().__init__(rank, worldsize)
self.data = []
self.docset = []
self.docset: List[Any] = []
self.n_logicals = n_logical_shards // worldsize
self.total_shards = n_logical_shards
self.delimiter = delimiter_token
Expand Down Expand Up @@ -1147,7 +1170,9 @@ def __init__(
)

# Fetch logical shard sampling stats
self.n_docs_remaining = [len(d.docset) for d in self.data]
self.n_docs_remaining = [
len(d.docset) if isinstance(d, _With_Docset) else 0 for d in self.data
]

# Position "state", used only for maintaining order when n_workers is unchanged
# For scaling up or down, logical position is meaningless, and reset
Expand Down Expand Up @@ -1260,8 +1285,8 @@ def __init__(
delimiter_token: Any,
trainsplit: float = 1,
is_val: bool = False,
datasets: List[str] = None,
weights: List[int] = None,
datasets: Optional[List[str]] = None,
weights: Optional[List[int]] = None,
seed: int = 42,
min_length: int = 1,
testrun_data_index: int = -100,
Expand Down Expand Up @@ -1317,8 +1342,6 @@ def __init__(
for dataset in self.datasets:
if verbose:
logging.info(f"Worker {rank} fetching dataset {dataset}")
docset = []
docset_slim = {}

# Listdir, assemble shardfraglist (ind -> shard, frag)
shards = [
Expand All @@ -1337,8 +1360,9 @@ def __init__(

# Read shardfrags:
last_shard = ""
docset_slim = {}
docset_slim: Dict[Any, Any] = {}
reader = None
shard_file = None
if verbose:
logging.info(
f" Worker {rank} ingesting {len(shardfrags)} shard fragments"
Expand All @@ -1350,18 +1374,26 @@ def __init__(
)
# Grab new reader if new
if shard != last_shard:
# Close previous file
if shard_file is not None:
shard_file.close()
path = os.path.join(datapath, dataset, shard)
reader = pa.ipc.open_file(path)
shard_file = open(path) # Closed below
reader = pa.ipc.open_file(shard_file)
last_shard = shard
shardcount += 1
ndocs = reader.num_record_batches
doc_start = (ndocs * frag) // worldsize
doc_end = (ndocs * frag + ndocs) // worldsize
# Read into temp docset_slim
for i in range(doc_start, doc_end):
doc = reader.get_batch(i)["tokens"] # .to_pylist()
docset_slim[(dataset, shardcount, i)] = doc
del reader
if reader is not None:
ndocs = reader.num_record_batches
doc_start = (ndocs * frag) // worldsize
doc_end = (ndocs * frag + ndocs) // worldsize
# Read into temp docset_slim
for i in range(doc_start, doc_end):
doc = reader.get_batch(i)["tokens"] # .to_pylist()
docset_slim[(dataset, shardcount, i)] = doc
if reader is not None:
del reader
if shard_file is not None:
shard_file.close()

# Shuffle, partition docs into train/val
keylist = list(docset_slim.keys())
Expand Down Expand Up @@ -1399,7 +1431,7 @@ def __init__(
self.dataset_tokens_seen = {d: 0 for d in self.datasets}
self.dataset_docs_seen = {d: 0 for d in self.datasets}
self.dataset_percent_seen = {d: 0 for d in self.datasets}
self.docs_seen = {} # (dataset, shard, i) -> # times seen
self.docs_seen: Dict[Any, int] = {} # (dataset, shard, i) -> # times seen

self.state_params = [
"docset_index",
Expand Down
8 changes: 2 additions & 6 deletions pretraining/utils/train_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import os
from packaging import version
import time

import torch.cuda.nccl as nccl
import torch.distributed as dist

try:
import packaging.version
except ImportError:
from pkg_resources import packaging

from torch.distributed.fsdp import ShardingStrategy

from pretraining.policies import *
Expand Down Expand Up @@ -115,7 +111,7 @@ def get_policies(cfg, rank):
verify_bfloat_support = (
torch.version.cuda
and torch.cuda.is_bf16_supported()
and packaging.version.parse(torch.version.cuda).release >= (11, 0)
and version.parse(torch.version.cuda).release >= (11, 0)
and dist.is_nccl_available()
and nccl.version() >= (2, 10)
)
Expand Down
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch>=2.1.0
fire
pyarrow
transformers
fire==0.5.0
pyarrow==15.0.0
transformers==4.37.2
ibm-fms>=0.3.0
6 changes: 5 additions & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ mypy-extensions==1.0.0

# Types packages
pyarrow-stubs==10.0.1.7
types-requests==2.31.0.20240125
types-requests==2.31.0.20240125
types-setuptools==69.0.0.20240125

# Install ibm-fms from the main branch for testing purposes
ibm-fms @ git+https://github.com/foundation-model-stack/foundation-model-stack@main

0 comments on commit d81a3ce

Please sign in to comment.