Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DeepSpeed Zero Stage option from integer to a separate option group #4772

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions orttraining/orttraining/python/training/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,17 @@ def decorate(func):
return decorate


def import_module_from_file(file_path, module_name):
def import_module_from_file(file_path, module_name=None):
'''Import a Python module from a file into interpreter'''

assert isinstance(file_path, str) and os.path.exists(file_path),\
"'file_path' must be a full path string with the python file to load"
assert isinstance(module_name, str) and module_name,\
assert module_name is None or isinstance(module_name, str) and module_name,\
"'module_name' must be a string with the python module name to load"

if not module_name:
module_name = os.path.basename(file_path).split('.')[0]

spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
Expand Down
40 changes: 20 additions & 20 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def eval_step(self, *args, **kwargs):
raise RuntimeError("Model is uninitialized. Only ONNX and PyTorch models are supported")

# Prepare input/output description
input_desc = self._model_desc_inputs_with_lr
input_desc = self.model_desc.inputs
output_desc = self.model_desc.outputs

# Normalize input
Expand All @@ -236,8 +236,10 @@ def eval_step(self, *args, **kwargs):
input_desc,
output_desc,
run_options)
return session_run_results[output_desc.name][0] if len (session_run_results) == 1\
else [session_run_results[output_desc.name] for output_desc in output_desc]

# Output must be returned in the same order as defined in the model description
results = [session_run_results[output_desc.name] for output_desc in self.model_desc.outputs]
return results[0] if len (results) == 1 else results

def save_as_onnx(self, path):
r"""Persists ONNX model into :py:attr:`path`
Expand Down Expand Up @@ -361,7 +363,7 @@ def _combine_torch_model_with_loss_fn(self):
"loss function should take two arguments - predict and label.")

# Basic input names from model
input_names = [input[0] for input in self.model_desc.inputs]
input_names = [input.name for input in self.model_desc.inputs]
sig = signature(self._torch_model.forward)
ordered_input_list = list(sig.parameters.keys())

Expand Down Expand Up @@ -415,7 +417,6 @@ def forward(self, *inputs):
return CombineTorchModelLossFn(self._torch_model, self.loss_fn, input_names)

def _convert_torch_model_loss_fn_to_onnx(self, inputs, device):

# Dynamic axes
dynamic_axes = {}
for input in self.model_desc.inputs:
Expand Down Expand Up @@ -467,8 +468,8 @@ def _convert_torch_model_loss_fn_to_onnx(self, inputs, device):
# Export the model to ONNX
f = io.BytesIO()
torch.onnx._export(model, tuple(sample_inputs), f,
input_names=[input[0] for input in self.model_desc.inputs],
output_names=[output[0] for output in self.model_desc.outputs],
input_names=[input.name for input in self.model_desc.inputs],
output_names=[output.name for output in self.model_desc.outputs],
opset_version=self.options._internal_use.onnx_opset_version,
dynamic_axes=dynamic_axes,
_retain_param_name=True,
Expand Down Expand Up @@ -498,8 +499,6 @@ def _convert_torch_model_loss_fn_to_onnx(self, inputs, device):

return onnx_model

# TODO: Test this througly along with train step, including
# various optimizer parameter groups, frozen weights, loss and lr
def _create_ort_training_session(self):
# Validating frozen_weights names
unused_frozen_weights = [n for n in self.options.utils.frozen_weights\
Expand Down Expand Up @@ -544,9 +543,10 @@ def _create_ort_training_session(self):
ort_parameters.world_size = self.options.distributed.world_size
ort_parameters.gradient_accumulation_steps = self.options.batch.gradient_accumulation_steps
ort_parameters.allreduce_post_accumulation = self.options.distributed.allreduce_post_accumulation
ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_stage
ort_parameters.deepspeed_zero_stage = self.options.distributed.deepspeed_zero_optimization.stage
ort_parameters.enable_grad_norm_clip = self.options.utils.grad_norm_clip
ort_parameters.set_gradients_as_graph_outputs = False
ort_parameters.use_invertible_layernorm_grad = self.options.utils.invertible_layer_norm_gradient
ort_parameters.training_optimizer_name = self.optim_config.name
ort_parameters.lr_params_feed_name = self.model_desc.learning_rate.name
ort_parameters.weights_to_train = trainable_params
Expand All @@ -562,14 +562,6 @@ def _create_ort_training_session(self):
ort_parameters,
session_options)

# Update model description to update dtype when mixed precision is enabled
# C++ backend modifies model's output dtype from float32 to float16 for mixed precision
# Note that for training we must use float32 and for evaluation we must use float16
for idx, o_desc in enumerate(self.model_desc.outputs):
if (self.options.mixed_precision.enabled and o_desc.dtype == torch.float32 and
not self._training_session.is_output_fp32_node(o_desc.name)):
self.model_desc.add_type_to_output_description(idx, o_desc.dtype, torch.float16)

# I/O bindings
self._train_io_binding = self._training_session.io_binding()
self._eval_io_binding = self._training_session.io_binding()
Expand Down Expand Up @@ -604,6 +596,14 @@ def _init_session(self):
# Create training session used by train_step
self._create_ort_training_session()

# Update model description to update dtype when mixed precision is enabled
# C++ backend modifies model's output dtype from float32 to float16 for mixed precision
# Note that for training we must use float32 and for evaluation we must use float16
for idx, o_desc in enumerate(self.model_desc.outputs):
if (self.options.mixed_precision.enabled and o_desc.dtype == torch.float32 and
not self._training_session.is_output_fp32_node(o_desc.name)):
self.model_desc.add_type_to_output_description(idx, o_desc.dtype, torch.float16)

# Update model description
self._model_desc_inputs_with_lr = [*self.model_desc.inputs, self.model_desc.learning_rate]

Expand Down Expand Up @@ -638,8 +638,8 @@ def _prepare_model_input(self, inputs_desc, lr, loss_scale, *inputs, **kwargs):

# Append input from 'kwargs'
for input_desc in inputs_desc:
if input_desc[0] in kwargs:
input = input + (kwargs[input_desc[0]],)
if input_desc.name in kwargs:
input = input + (kwargs[input_desc.name],)

# Append learning rate
extra_inputs = 0
Expand Down
27 changes: 22 additions & 5 deletions orttraining/orttraining/python/training/orttrainer_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ class ORTTrainerOptions(object):
'grad_norm_clip' : {
'type' : 'boolean',
'default' : True
},
'invertible_layer_norm_gradient' : {
'type' : 'boolean',
'default' : False
}
}
},
Expand Down Expand Up @@ -202,6 +206,8 @@ class ORTTrainerOptions(object):
list of model parameter names to skip training (weights don't change)
utils.grad_norm_clip (bool, default is True):
enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer'
utils.invertible_layer_norm_gradient (bool, default is False):
enables use of invertible layer norm gradients
debug (dict):
debug options
debug.deterministic_compute (bool, default is False)
Expand Down Expand Up @@ -359,11 +365,18 @@ def _check_is_callable(field, value, error):
'type': 'boolean',
'default': False
},
'deepspeed_zero_stage' : {
'type' : 'integer',
'min' : 0,
'max' : 1,
'default' : 0,
'deepspeed_zero_optimization' : {
'type' : 'dict',
'default_setter': lambda _: {},
'required': False,
'schema': {
'stage': {
'type': 'integer',
'min': 0,
'max': 1,
'default': 0
},
}
},
'enable_adasum': {
'type': 'boolean',
Expand Down Expand Up @@ -405,6 +418,10 @@ def _check_is_callable(field, value, error):
'grad_norm_clip': {
'type': 'boolean',
'default': True
},
'invertible_layer_norm_gradient' : {
'type': 'boolean',
'default': False
}
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def _load_pytorch_transformer_model(device, legacy_api=False):
# Loads external Pytorch TransformerModel into utils
pytorch_transformer_path = os.path.join('..', '..', '..', 'samples', 'python', 'pytorch_transformer')
pt_model_path = os.path.join(pytorch_transformer_path, 'pt_model.py')
pt_model = _utils.import_module_from_file(pt_model_path, 'pt_model')
pt_model = _utils.import_module_from_file(pt_model_path)
ort_utils_path = os.path.join(pytorch_transformer_path, 'ort_utils.py')
ort_utils = _utils.import_module_from_file(ort_utils_path, 'ort_utils')
ort_utils = _utils.import_module_from_file(ort_utils_path)
utils_path = os.path.join(pytorch_transformer_path, 'utils.py')
utils = _utils.import_module_from_file(utils_path, 'utils')
utils = _utils.import_module_from_file(utils_path)

# Modeling
model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
Expand Down Expand Up @@ -75,7 +75,9 @@ def testORTTrainerOptionsDefaultValues(test_input):
'world_size': 1,
'local_rank': 0,
'allreduce_post_accumulation': False,
'deepspeed_zero_stage': 0,
'deepspeed_zero_optimization': {
'stage' : 0,
},
'enable_adasum': False
},
'lr_scheduler': None,
Expand All @@ -85,7 +87,8 @@ def testORTTrainerOptionsDefaultValues(test_input):
},
'utils': {
'frozen_weights': [],
'grad_norm_clip': True
'grad_norm_clip': True,
'invertible_layer_norm_gradient': False,
},
'debug': {
'deterministic_compute': False
Expand Down
1 change: 1 addition & 0 deletions samples/python/pytorch_transformer/ort_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch

from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\
ModelDescription as Legacy_ModelDescription

Expand Down