Skip to content

Commit

Permalink
Fix isort config and files
Browse files Browse the repository at this point in the history
Add isort config based on the fms setup.
Fix imports where required. Include the root folder.

Signed-off-by: Andrea Frittoli <[email protected]>
  • Loading branch information
afrittoli committed Feb 9, 2024
1 parent cd6134b commit 86b43f9
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ jobs:
version: "~= 23.3.0"
- uses: isort/isort-action@master
with:
sort-paths: pretraining
sort-paths: .
requirementsFiles: "requirements.txt" # We don't need extra test requirements for linting
8 changes: 8 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[settings]
ensure_newline_before_comments = True
force_grid_wrap = 0
include_trailing_comma = True
lines_after_imports = 2
multi_line_output = 3
use_parentheses = True
profile = black
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# fms-pretrain
# pretrain
1 change: 1 addition & 0 deletions pretraining/policies/ac_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
checkpoint_wrapper,
)


non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
Expand Down
1 change: 1 addition & 0 deletions pretraining/policies/mixed_precision.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from torch.distributed.fsdp import MixedPrecision


fpSixteen = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
Expand Down
10 changes: 1 addition & 9 deletions pretraining/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,7 @@
import os
import random
import sys
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Type,
Union,
)
from typing import Any, Callable, Dict, List, Optional, Type, Union

import pyarrow as pa
import torch
Expand Down
5 changes: 4 additions & 1 deletion pretraining/utils/train_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os


try:
import packaging.version
except ImportError:
from pkg_resources import packaging # type: ignore

import time

import torch.cuda.nccl as nccl
import torch.distributed as dist

from packaging import version
from torch.distributed.fsdp import ShardingStrategy

from pretraining.policies import *
Expand Down
3 changes: 0 additions & 3 deletions test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,3 @@ mypy-extensions==1.0.0
pyarrow-stubs==10.0.1.7
types-requests==2.31.0.20240125
types-setuptools==69.0.0.20240125

# Install ibm-fms from the main branch for testing purposes
ibm-fms @ git+https://github.com/foundation-model-stack/foundation-model-stack@main

0 comments on commit 86b43f9

Please sign in to comment.