Skip to content

Commit

Permalink
Merge branch 'dev' into bump/ci
Browse files Browse the repository at this point in the history
  • Loading branch information
j316chuck authored Jan 13, 2024
2 parents 06e16cb + 56fa4bd commit f6eb1c4
Show file tree
Hide file tree
Showing 78 changed files with 761 additions and 395 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ repos:
types: [python]
pass_filenames: false
args: [--warnings]
additional_dependencies: ["[email protected].256"]
additional_dependencies: ["[email protected].310"]
- repo: https://github.com/trufflesecurity/trufflehog.git
rev: v3.40.0
hooks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from composer.utils import MissingConditionalImportError

try:
from composer.algorithms.alibi.attention_surgery_functions import _bert, _gpt2 # pyright: reportUnusedImport=none
from composer.algorithms.alibi.attention_surgery_functions import _bert # pyright: ignore[reportUnusedImport]
from composer.algorithms.alibi.attention_surgery_functions import _gpt2 # pyright: ignore[reportUnusedImport]
from composer.algorithms.alibi.attention_surgery_functions.utils import policy_registry
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e
Expand Down
15 changes: 10 additions & 5 deletions composer/algorithms/colout/colout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
__all__ = ['ColOut', 'ColOutTransform', 'colout_batch']


def colout_batch(sample: Union[ImgT, Tuple[ImgT, ImgT]],
p_row: float = 0.15,
p_col: float = 0.15,
resize_target: Union[bool, str] = 'auto') -> Union[ImgT, Tuple[ImgT, ImgT]]:
def colout_batch(
sample: Union[ImgT, Tuple[ImgT, ImgT]],
p_row: float = 0.15,
p_col: float = 0.15,
resize_target: Union[bool,
str] = 'auto') -> Union[torch.Tensor, ImgT, Tuple[Tensor, Tensor], Tuple[ImgT, ImgT]]:
"""Applies ColOut augmentation to a batch of images and (optionally) targets,
dropping the same random rows and columns from all images and targets in a batch.
Expand Down Expand Up @@ -136,7 +138,10 @@ def __init__(self, p_row: float = 0.15, p_col: float = 0.15, resize_target: Unio
self.p_col = p_col
self.resize_target = resize_target

def __call__(self, sample: Union[ImgT, Tuple[ImgT, ImgT]]) -> Union[ImgT, Tuple[ImgT, ImgT]]:
def __call__(
self, sample: Union[ImgT,
Tuple[ImgT,
ImgT]]) -> Union[torch.Tensor, ImgT, Tuple[Tensor, Tensor], Tuple[ImgT, ImgT]]:
"""Drops random rows and columns from up to two images.
Args:
Expand Down
17 changes: 9 additions & 8 deletions composer/algorithms/factorize/factorize_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def solution_for_rank(self, input: torch.Tensor, rank: int) -> LowRankSolution:

def apply_solution(self, solution: LowRankSolution):
self.latent_size = solution.rank
self.module0.out_channels = solution.rank
self.module1.in_channels = solution.rank
self.module0.out_channels = solution.rank # pyright: ignore[reportGeneralTypeIssues]
self.module1.in_channels = solution.rank # pyright: ignore[reportGeneralTypeIssues]
_apply_solution_to_module_parameters(solution, self.module0, self.module1, transpose=False)

@staticmethod
Expand Down Expand Up @@ -452,8 +452,8 @@ def solution_for_rank(self, input: torch.Tensor, rank: int) -> LowRankSolution:

def apply_solution(self, solution: LowRankSolution) -> None:
self.latent_size = solution.rank
self.module0.out_features = solution.rank
self.module1.in_features = solution.rank
self.module0.out_features = solution.rank # pyright: ignore[reportGeneralTypeIssues]
self.module1.in_features = solution.rank # pyright: ignore[reportGeneralTypeIssues]
_apply_solution_to_module_parameters(solution, self.module0, self.module1, transpose=True)

@staticmethod
Expand All @@ -471,9 +471,10 @@ def max_allowed_latent_channels(in_features: int, out_features: int) -> int:

@staticmethod
def from_linear(module: torch.nn.Linear, module_ix: int = -1, **kwargs) -> FactorizedLinear:
ret = FactorizedLinear(in_features=module.in_features,
out_features=module.out_features,
bias=((module.bias is not None) and (module.bias is not False)),
**kwargs)
ret = FactorizedLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=(module.bias is not None and module.bias is not False), # pyright: ignore[reportUnnecessaryComparison]
**kwargs)
ret.reset_parameters()
return ret
17 changes: 11 additions & 6 deletions composer/algorithms/gated_linear_units/gated_linear_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def from_BertOutput(layer: torch.nn.Module,
non_gated_layer_bias: bool = False) -> BERTGatedFFOutput:
"""Defines a replacement policy from a :class:`transformers.models.bert.modeling_bert.BertOutput` to a :class:`composer.algorithms.gated_linear_units.gated_linear_unit_layers.BERTGatedFFOutput`"""
assert isinstance(
layer, BertOutput
layer,
BertOutput # pyright: ignore[reportUnboundVariable]
), 'The replacement policy requires an instance of transformers.models.bert.modeling_bert.BertOutput for the necessary fields to be defined.'
return BERTGatedFFOutput(
d_embed=layer.dense.out_features, #type: ignore dense.out_features member of BertOutput
Expand Down Expand Up @@ -85,16 +86,20 @@ def apply_gated_linear_units(model: torch.nn.Module,
unwrapped_model = model.model if isinstance(model, HuggingFaceModel) else model

# ensure that the model is an instance of a Hugging Face BertPreTrainedModel class, since our replacement policy is only defined for BERTs
if not isinstance(unwrapped_model, BertPreTrainedModel):
if not isinstance(unwrapped_model, BertPreTrainedModel): # pyright: ignore[reportUnboundVariable]
raise TypeError(
'Gated Linear Units only has a surgery policy defined for subclasses of transformers.BertPreTrainedModel')

# Early exit if nothing to replace
if module_surgery.count_module_instances(module=model, module_class=BertIntermediate) == 0:
if module_surgery.count_module_instances(
module=model, module_class=BertIntermediate) == 0: # pyright: ignore[reportUnboundVariable]
return

if act_fn is None:
intermediate_modules = {module for module in model.modules() if isinstance(module, BertIntermediate)}
intermediate_modules = {
module for module in model.modules()
if isinstance(module, BertIntermediate) # pyright: ignore[reportUnboundVariable]
} # pyright: ignore[reportUnboundVariable]
if len(intermediate_modules) == 0:
warnings.warn(
NoEffectWarning('No instances of BertIntermediate were found so Gated Linear Units will be skipped '
Expand Down Expand Up @@ -130,8 +135,8 @@ def from_bound_BertOutput(layer: torch.nn.Module, module_index: int) -> BERTGate

# prepare the replacement policy and perform replacement
policy: Dict[Type[torch.nn.Module], module_surgery.ReplacementFunction] = {
BertIntermediate: from_BertIntermediate,
BertOutput: from_bound_BertOutput
BertIntermediate: from_BertIntermediate, # pyright: ignore[reportUnboundVariable]
BertOutput: from_bound_BertOutput # pyright: ignore[reportUnboundVariable]
}
replaced_instances = module_surgery.replace_module_classes(module=model, optimizers=optimizers, policies=policy)
if len(replaced_instances) == 0:
Expand Down
4 changes: 2 additions & 2 deletions composer/algorithms/ghost_batchnorm/ghost_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def __init__(self, base_batchnorm: _TORCH_BATCHNORM_BASE_CLASS, ghost_batch_size
super().__init__()
self.ghost_batch_size = ghost_batch_size
self.batchnorm = base_batchnorm
self.batchnorm._already_ghost_batchnormed = True # Mark to avoid rewrapping on duplicate calls
self.batchnorm._already_ghost_batchnormed = True # Mark to avoid rewrapping on duplicate calls # pyright: ignore[reportGeneralTypeIssues]

def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
batch_size = input.shape[0]
Expand All @@ -161,7 +161,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
raise ValueError(f'Worker batch size {batch_size} < ghost_batch_size {self.ghost_batch_size}')

nchunks: int = int(math.ceil(batch_size / self.ghost_batch_size))
has_momentum = self.batchnorm.momentum is not None
has_momentum: bool = hasattr(self.batchnorm, 'momentum')
original_momentum: float = self.batchnorm.momentum

if self.training and has_momentum:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, device=None
def forward(self, x):
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
downcast_weight = _cast_if_autocast_enabled(
self.weight) if self.weight is not None else self.weight # pyright: ignore[reportUnnecessaryComparison]
downcast_bias = _cast_if_autocast_enabled(
self.bias) if self.bias is not None else self.bias # pyright: ignore[reportUnnecessaryComparison]
with torch.autocast(enabled=False, device_type=module_device.type):
return F.group_norm(downcast_x, self.num_groups, downcast_weight, downcast_bias, self.eps)

Expand All @@ -111,11 +113,11 @@ def _to_LPGroupNorm(layer: torch.nn.Module, module_index: int) -> LPGroupNorm:
lp_groupnorm = LPGroupNorm(layer.num_groups, layer.num_channels, layer.eps, layer.affine)

with torch.no_grad():
if layer.weight is None:
if layer.weight is None: # pyright: ignore[reportUnnecessaryComparison]
lp_groupnorm.register_parameter('weight', None)
else:
lp_groupnorm.weight.copy_(layer.weight) # type: ignore
if layer.bias is None:
if layer.bias is None: # pyright: ignore[reportUnnecessaryComparison]
lp_groupnorm.register_parameter('bias', None)
else:
lp_groupnorm.bias.copy_(layer.bias) # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=
def forward(self, x):
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
downcast_weight = _cast_if_autocast_enabled(
self.weight) if self.weight is not None else self.weight # pyright: ignore[reportUnnecessaryComparison]
downcast_bias = _cast_if_autocast_enabled(
self.bias) if self.bias is not None else self.bias # pyright: ignore[reportUnnecessaryComparison]
with torch.autocast(enabled=False, device_type=module_device.type):
return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)

Expand Down Expand Up @@ -141,11 +143,11 @@ def _to_LPLayerNorm(layer: torch.nn.Module, module_index: int) -> LPLayerNorm:
lp_layernorm = LPLayerNorm(layer.normalized_shape, layer.eps, layer.elementwise_affine)

with torch.no_grad():
if layer.weight is None:
if hasattr(layer, 'weight'):
lp_layernorm.register_parameter('weight', None)
else:
lp_layernorm.weight.copy_(layer.weight) # type: ignore
if layer.bias is None:
if layer.bias is None: # pyright: ignore[reportUnnecessaryComparison]
lp_layernorm.register_parameter('bias', None)
else:
lp_layernorm.bias.copy_(layer.bias) # type: ignore
Expand All @@ -160,12 +162,12 @@ def _to_FusedLayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLa
fused_layernorm = APEXFusedLayerNorm(normalized_shape=layer.normalized_shape, eps=layer.eps)

with torch.no_grad():
if layer.weight is None:
fused_layernorm.weight = None
if layer.weight is None: # pyright: ignore[reportUnnecessaryComparison]
fused_layernorm.weight = None # pyright: ignore[reportGeneralTypeIssues]
else:
fused_layernorm.weight.copy_(layer.weight)
if layer.bias is None:
fused_layernorm.bias = None
if layer.bias is None: # pyright: ignore[reportUnnecessaryComparison]
fused_layernorm.bias = None # pyright: ignore[reportGeneralTypeIssues]
else:
fused_layernorm.bias.copy_(layer.bias)

Expand Down
8 changes: 4 additions & 4 deletions composer/algorithms/sam/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self,
defaults = {'rho': rho, 'epsilon': epsilon, **kwargs}
super(SAMOptimizer, self).__init__(self.base_optimizer.param_groups, defaults)

@torch.no_grad()
@torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator]
def sub_e_w(self):
for group in self.param_groups:
for p in group['params']:
Expand All @@ -59,7 +59,7 @@ def sub_e_w(self):
e_w = self.state[p]['e_w'] # retrieve stale e(w)
p.sub_(e_w) # get back to "w" from "w + e(w)"

@torch.no_grad()
@torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator]
def first_step(self):
grad_norm = self._grad_norm()
for group in self.param_groups:
Expand All @@ -71,7 +71,7 @@ def first_step(self):
p.add_(e_w) # climb to the local maximum "w + e(w)"
self.state[p]['e_w'] = e_w

@torch.no_grad()
@torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator]
def second_step(self):
for group in self.param_groups:
for p in group['params']:
Expand All @@ -80,7 +80,7 @@ def second_step(self):
p.sub_(self.state[p]['e_w']) # get back to "w" from "w + e(w)"
self.base_optimizer.step() # do the actual "sharpness-aware" update

@torch.no_grad()
@torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator]
def step(self, closure=None):
assert closure is not None, 'Sharpness Aware Minimization requires closure, but it was not provided'
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/squeeze_excite/squeeze_excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class SqueezeExciteConv2d(torch.nn.Module):
def __init__(self, *args, latent_channels: float = 0.125, conv: Optional[torch.nn.Conv2d] = None, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(*args, **kwargs) if conv is None else conv
self.conv._already_squeeze_excited = True # Mark to avoid rewrapping on duplicate calls
self.conv._already_squeeze_excited = True # Mark to avoid rewrapping on duplicate calls # pyright: ignore[reportGeneralTypeIssues]
self.se = SqueezeExcite2d(num_features=self.conv.out_channels, latent_channels=latent_channels)

def forward(self, input: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions composer/callbacks/activation_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def register_forward_hook(self, model: torch.nn.Module, logger: Logger, step: Op
def _register_forward_hook(self, logger: Logger, step: Optional[int], module: torch.nn.Module):
self.handles.append(module.register_forward_hook(partial(self.forward_hook, logger, step)))

def forward_hook(self, logger: Logger, step: Optional[int], module: torch.nn.Module, input: Sequence,
output: Sequence):
def forward_hook(self, logger: Logger, step: Optional[int], module: torch.nn.Module, input: Optional[Sequence],
output: Optional[Sequence]):
module_name = self.module_names[module]

if self.ignore_module_types is not None:
Expand Down
36 changes: 30 additions & 6 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from typing import Any, Callable, Dict, List, Optional, Union

from composer.core import Callback, Event, State, Time, Timestamp
from composer.loggers import Logger
from composer.loggers import Logger, MLFlowLogger
from composer.utils import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath,
checkpoint, create_interval_scheduler, create_symlink_file, dist,
ensure_folder_has_no_conflicting_files, format_name_with_dist,
format_name_with_dist_and_time, is_model_deepspeed, using_torch_2)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
format_name_with_dist_and_time, is_model_deepspeed, partial_format, using_torch_2)
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -271,6 +271,30 @@ def __init__(
self.start_batch = None

def init(self, state: State, logger: Logger) -> None:
# If MLFlowLogger is being used, format MLFlow-specific placeholders in the save folder and paths.
# Assumes that MLFlowLogger comes before CheckpointSaver in the list of loggers.
for destination in logger.destinations:
if isinstance(destination, MLFlowLogger):
mlflow_format_kwargs = {
MLFLOW_EXPERIMENT_ID_FORMAT_KEY: destination._experiment_id,
MLFLOW_RUN_ID_FORMAT_KEY: destination._run_id
}
self.folder = partial_format(self.folder, **mlflow_format_kwargs)

self.filename.folder = self.folder
if self.latest_filename is not None:
self.latest_filename.folder = self.folder

# The remote paths have the placeholders in their filename rather than folder
if self.remote_file_name is not None:
self.remote_file_name.filename = partial_format(self.remote_file_name.filename,
**mlflow_format_kwargs)
if self.latest_remote_file_name is not None:
self.latest_remote_file_name.filename = partial_format(self.latest_remote_file_name.filename,
**mlflow_format_kwargs)

break

folder = format_name_with_dist(self.folder, state.run_name)
os.makedirs(folder, exist_ok=True)

Expand Down Expand Up @@ -335,9 +359,9 @@ def _save_checkpoint(self, state: State, logger: Logger):
# Store before saving so state_dict in checkpoint has reference to latest checkpoint (itself)
self.all_saved_checkpoints_to_timestamp[save_filename] = state.timestamp

saved_path = checkpoint._save_checkpoint(
saved_path = checkpoint.save_checkpoint(
state=state,
save_filename=save_filename,
filename=filename_with_placeholders,
weights_only=self.weights_only,
)
log.debug(f'Checkpoint locally saved to {saved_path}')
Expand Down Expand Up @@ -377,7 +401,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
).lstrip('/')
assert state.sharded_ckpt_prefix_dir is not None
remote_prefix = state.sharded_ckpt_prefix_dir
ckpt_filename = _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME if using_torch_2() else pathlib.Path(
ckpt_filename = checkpoint._TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME if using_torch_2() else pathlib.Path(
remote_file_name).name
remote_file_name = os.path.join(pathlib.Path(remote_file_name).parent, remote_prefix, ckpt_filename)
remote_file_name = format_name_with_dist_and_time(remote_file_name, state.run_name, state.timestamp)
Expand Down
2 changes: 1 addition & 1 deletion composer/callbacks/early_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
raise ValueError('If `patience` is an instance of Time, it must have units of EPOCH or BATCH.')

def _get_monitored_metric(self, state: State):
if self.dataloader_label == 'train':
if self.dataloader_label == 'train' and state.train_metrics is not None:
if self.monitor in state.train_metrics:
return state.train_metrics[self.monitor].compute()
else:
Expand Down
Loading

0 comments on commit f6eb1c4

Please sign in to comment.