Skip to content

Commit

Permalink
Update DeepSpeed Zero Stage option to a separate option group (#4772)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thiago Crepaldi committed Aug 15, 2020
1 parent b85c770 commit 5191ac4
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 32 deletions.
7 changes: 5 additions & 2 deletions orttraining/orttraining/python/training/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,17 @@ def decorate(func):
return decorate


def import_module_from_file(file_path, module_name):
def import_module_from_file(file_path, module_name=None):
'''Import a Python module from a file into interpreter'''

assert isinstance(file_path, str) and os.path.exists(file_path),\
"'file_path' must be a full path string with the python file to load"
assert isinstance(module_name, str) and module_name,\
assert module_name is None or isinstance(module_name, str) and module_name,\
"'module_name' must be a string with the python module name to load"

if not module_name:
module_name = os.path.basename(file_path).split('.')[0]

spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
Expand Down
40 changes: 20 additions & 20 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def eval_step(self, *args, **kwargs):
raise RuntimeError("Model is uninitialized. Only ONNX and PyTorch models are supported")

# Prepare input/output description
input_desc = self._model_desc_inputs_with_lr
input_desc = self.model_desc.inputs
output_desc = self.model_desc.outputs

# Normalize input
Expand All @@ -236,8 +236,10 @@ def eval_step(self, *args, **kwargs):
input_desc,
output_desc,
run_options)
return session_run_results[output_desc.name][0] if len (session_run_results) == 1\
else [session_run_results[output_desc.name] for output_desc in output_desc]

# Output must be returned in the same order as defined in the model description
results = [session_run_results[output_desc.name] for output_desc in self.model_desc.outputs]
return results[0] if len (results) == 1 else results

def save_as_onnx(self, path):
r"""Persists ONNX model into :py:attr:`path`
Expand Down Expand Up @@ -361,7 +363,7 @@ def _combine_torch_model_with_loss_fn(self):
"loss function should take two arguments - predict and label.")

# Basic input names from model
input_names = [input[0] for input in self.model_desc.inputs]
input_names = [input.name for input in self.model_desc.inputs]
sig = signature(self._torch_model.forward)
ordered_input_list = list(sig.parameters.keys())

Expand Down Expand Up @@ -415,7 +417,6 @@ def forward(self, *inputs):
return CombineTorchModelLossFn(self._torch_model, self.loss_fn, input_names)

def _convert_torch_model_loss_fn_to_onnx(self, inputs, device):

# Dynamic axes
dynamic_axes = {}
for input in self.model_desc.inputs:
Expand Down Expand Up @@ -467,8 +468,8 @@ def _convert_torch_model_loss_fn_to_onnx(self, inputs, device):
# Export the model to ONNX
f = io.BytesIO()
torch.onnx._export(model, tuple(sample_inputs), f,
input_names=[input[0] for input in self.model_desc.inputs],
output_names=[output[0] for output in self.model_desc.outputs],
input_names=[input.name for input in self.model_desc.inputs],
output_names=[output.name for output in self.model_desc.outputs],
opset_version=self.options._internal_use.onnx_opset_version,
dynamic_axes=dynamic_axes,
_retain_param_name=True,
Expand Down Expand Up @@ -498,8 +499,6 @@ def _convert_torch_model_loss_fn_to_onnx(self, inputs, device):

return onnx_model

# TODO: Test this througly along with train step, including
# various optimizer parameter groups, frozen weights, loss and lr
def _create_ort_training_session(self):
# Validating frozen_weights names
unused_frozen_weights = [n for n in self.options.utils.frozen_weights\
Expand Down Expand Up @@ -544,9 +543,10 @@ def _create_ort_training_session(self):
ort_parameters.world_size = self.options.distributed.world_size
ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps
ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation
ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_stage
ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage
ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip
ort_parameters.set_gradients_as_graph_outputs = False
ort_parameters.use_invertible_layernorm_grad = self.options.utils.invertible_layer_norm_gradient
ort_parameters.training_optimizer_name = self.optim_config.name
ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name
ort_parameters.weights_to_train = trainable_params
Expand All @@ -562,14 +562,6 @@ def _create_ort_training_session(self):
ort_parameters,
session_options)

# Update model description to update dtype when mixed precision is enabled
# C++ backend modifies model's output dtype from float32 to float16 for mixed precision
# Note that for training we must use float32 and for evaluation we must use float16
for idx, o_desc in enumerate(self.model_desc.outputs):
if (self.options.mixed_precision.enabled and o_desc.dtype == torch.float32 and
not self._training_session.is_output_fp32_node(o_desc.name)):
self.model_desc.add_type_to_output_description(idx, o_desc.dtype, torch.float16)

# I/O bindings
self._train_io_binding = self._training_session.io_binding()
self._eval_io_binding = self._training_session.io_binding()
Expand Down Expand Up @@ -604,6 +596,14 @@ def _init_session(self):
# Create training session used by train_step
self._create_ort_training_session()

# Update model description to update dtype when mixed precision is enabled
# C++ backend modifies model's output dtype from float32 to float16 for mixed precision
# Note that for training we must use float32 and for evaluation we must use float16
for idx, o_desc in enumerate(self.model_desc.outputs):
if (self.options.mixed_precision.enabled and o_desc.dtype == torch.float32 and
not self._training_session.is_output_fp32_node(o_desc.name)):
self.model_desc.add_type_to_output_description(idx, o_desc.dtype, torch.float16)

# Update model description
self._model_desc_inputs_with_lr = [*self.model_desc.inputs, self.model_desc.learning_rate]

Expand Down Expand Up @@ -638,8 +638,8 @@ def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs):

# Append input from 'kwargs'
for input_desc in inputs_desc:
if input_desc[0] in kwargs:
input = input + (kwargs[input_desc[0]],)
if input_desc.name in kwargs:
input = input + (kwargs[input_desc.name],)

# Append learning rate
extra_inputs = 0
Expand Down
27 changes: 22 additions & 5 deletions orttraining/orttraining/python/training/orttrainer_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ class ORTTrainerOptions(object):
'grad_norm_clip' : {
'type' : 'boolean',
'default' : True
},
'invertible_layer_norm_gradient' : {
'type' : 'boolean',
'default' : False
}
}
},
Expand Down Expand Up @@ -202,6 +206,8 @@ class ORTTrainerOptions(object):
list of model parameter names to skip training (weights don't change)
utils.grad_norm_clip (bool, default is True):
enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer'
utils.invertible_layer_norm_gradient (bool, default is False):
enables use of invertible layer norm gradients
debug (dict):
debug options
debug.deterministic_compute (bool, default is False)
Expand Down Expand Up @@ -359,11 +365,18 @@ def _check_is_callable(field, value, error):
'type': 'boolean',
'default': False
},
'deepspeed_zero_stage' : {
'type' : 'integer',
'min' : 0,
'max' : 1,
'default' : 0,
'deepspeed_zero_optimization' : {
'type' : 'dict',
'default_setter': lambda _: {},
'required': False,
'schema': {
'stage': {
'type': 'integer',
'min': 0,
'max': 1,
'default': 0
},
}
},
'enable_adasum': {
'type': 'boolean',
Expand Down Expand Up @@ -405,6 +418,10 @@ def _check_is_callable(field, value, error):
'grad_norm_clip': {
'type': 'boolean',
'default': True
},
'invertible_layer_norm_gradient' : {
'type': 'boolean',
'default': False
}
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def _load_pytorch_transformer_model(device, legacy_api=False):
# Loads external Pytorch TransformerModel into utils
pytorch_transformer_path = os.path.join('..', '..', '..', 'samples', 'python', 'pytorch_transformer')
pt_model_path = os.path.join(pytorch_transformer_path, 'pt_model.py')
pt_model = _utils.import_module_from_file(pt_model_path, 'pt_model')
pt_model = _utils.import_module_from_file(pt_model_path)
ort_utils_path = os.path.join(pytorch_transformer_path, 'ort_utils.py')
ort_utils = _utils.import_module_from_file(ort_utils_path, 'ort_utils')
ort_utils = _utils.import_module_from_file(ort_utils_path)
utils_path = os.path.join(pytorch_transformer_path, 'utils.py')
utils = _utils.import_module_from_file(utils_path, 'utils')
utils = _utils.import_module_from_file(utils_path)

# Modeling
model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
Expand Down Expand Up @@ -75,7 +75,9 @@ def testORTTrainerOptionsDefaultValues(test_input):
'world_size': 1,
'local_rank': 0,
'allreduce_post_accumulation': False,
'deepspeed_zero_stage': 0,
'deepspeed_zero_optimization': {
'stage' : 0,
},
'enable_adasum': False
},
'lr_scheduler': None,
Expand All @@ -85,7 +87,8 @@ def testORTTrainerOptionsDefaultValues(test_input):
},
'utils': {
'frozen_weights': [],
'grad_norm_clip': True
'grad_norm_clip': True,
'invertible_layer_norm_gradient': False,
},
'debug': {
'deterministic_compute': False
Expand Down
1 change: 1 addition & 0 deletions samples/python/pytorch_transformer/ort_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch

from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\
ModelDescription as Legacy_ModelDescription

Expand Down

0 comments on commit 5191ac4

Please sign in to comment.