Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set adamw_mode default true (follows FusedAdam and < 0.3.11 logic) #844

Merged
merged 4 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self,

self.opt_id = DeepSpeedCPUAdam.optimizer_id
DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1

self.adam_w_mode = adamw_mode
self.ds_opt_adam = CPUAdamBuilder().load()

self.ds_opt_adam.create_adam(self.opt_id,
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
# extra optimizer parameters for adam/adamw
TORCH_ADAM_PARAM = "torch_adam"

# default to adamw logic for adam/adamw optimizers unless user explictly opts out
ADAM_W_MODE = "adam_w_mode"
ADAM_W_MODE_DEFAULT = True


class DeepSpeedConfigError(Exception):
pass
Expand Down
38 changes: 21 additions & 17 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \
TORCH_ADAM_PARAM
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT

from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
Expand Down Expand Up @@ -640,26 +640,30 @@ def _configure_basic_optimizer(self, model_parameters):

if self.optimizer_name() in [ADAM_OPTIMIZER, ADAMW_OPTIMIZER]:
torch_adam = optimizer_parameters.pop(TORCH_ADAM_PARAM, False)
adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER
# zero-offload torch-adam adam_w_mode optimizer
# T|F T T torch.optim.AdamW
# T|F T F torch.optim.Adam
# T F T|F DeepSpeedCPUAdam(adam_w_mode)
# F F T|F FusedAdam(adam_w_mode)
adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT)

# Optimizer name of Adam forces AdamW logic unless adam_w_mode is explictly set
effective_adam_w_mode = self.optimizer_name(
) == ADAMW_OPTIMIZER or adam_w_mode

if torch_adam:
if adam_w_mode:
optimizer = torch.optim.AdamW(model_parameters,
**optimizer_parameters)
else:
if not effective_adam_w_mode:
optimizer = torch.optim.Adam(model_parameters,
**optimizer_parameters)
elif self.zero_cpu_offload():
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=adam_w_mode)
else:
optimizer = torch.optim.AdamW(model_parameters,
**optimizer_parameters)
else:
optimizer_parameters['adam_w_mode'] = adam_w_mode
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
if self.zero_cpu_offload():
from deepspeed.ops.adam import DeepSpeedCPUAdam
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters,
adamw_mode=effective_adam_w_mode)
else:
from deepspeed.ops.adam import FusedAdam
optimizer = FusedAdam(model_parameters,
**optimizer_parameters,
adam_w_mode=effective_adam_w_mode)

elif self.optimizer_name() == LAMB_OPTIMIZER:
from deepspeed.ops.lamb import FusedLamb
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/test_adamw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import deepspeed
import torch
import pytest

from common import distributed_test
from deepspeed.ops.adam import FusedAdam
from deepspeed.ops.adam import DeepSpeedCPUAdam
from simple_model import SimpleModel, args_from_dict

# yapf: disable
#'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer
adam_configs = [["AdamW", False, False, False, (FusedAdam, True)],
["AdamW", False, True, False, (torch.optim.AdamW, None)],
["AdamW", True, False, False, (DeepSpeedCPUAdam, True)],
["AdamW", True, True, False, (torch.optim.AdamW, None)],
["AdamW", False, False, True, (FusedAdam, True)],
["AdamW", False, True, True, (torch.optim.AdamW, None)],
["AdamW", True, False, True, (DeepSpeedCPUAdam, True)],
["AdamW", True, True, True, (torch.optim.AdamW, None)],
["Adam", False, False, False, (FusedAdam, False)],
["Adam", False, True, False, (torch.optim.Adam, None)],
["Adam", True, False, False, (DeepSpeedCPUAdam, False)],
["Adam", True, True, False, (torch.optim.Adam, None)],
["Adam", False, False, True, (FusedAdam, True)],
["Adam", False, True, True, (torch.optim.AdamW, None)],
["Adam", True, False, True, (DeepSpeedCPUAdam, True)],
["Adam", True, True, True, (torch.optim.AdamW, None)]]

@pytest.mark.parametrize(
'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer',
adam_configs)
def test_adam_configs(tmpdir,
optimizer,
zero_offload,
torch_adam,
adam_w_mode,
resulting_optimizer):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": optimizer,
"params": {
"lr": 0.00015,
"torch_adam": torch_adam,
"adam_w_mode": adam_w_mode
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 2,
"cpu_offload": zero_offload
}
}
args = args_from_dict(tmpdir, config_dict)

@distributed_test(world_size=[1])
def helper(args):
model = SimpleModel(10)
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
# get base optimizer under zero
ds_optimizer = model.optimizer.optimizer
opt_class, adam_w_mode = resulting_optimizer
assert isinstance(ds_optimizer, opt_class)
if adam_w_mode in [True, False]:
assert ds_optimizer.adam_w_mode == adam_w_mode

helper(args)