From a04d372ab4bc9b0e5e7f01bcb93ae3e8ef81db02 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 13 Aug 2020 12:51:15 -0700 Subject: [PATCH] Fix Dynamic Axes feature and add unit test --- .../orttraining/python/training/orttrainer.py | 8 +++-- .../orttraining_test_orttrainer_frontend.py | 36 +++++++++++++++++-- .../python/pytorch_transformer/ort_utils.py | 23 +++++++++--- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index d2cf087fb6b24..e2e9d74653ce0 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -681,12 +681,16 @@ def _resolve_symbolic_dimensions(self, inputs, inputs_desc, outputs_desc): for input, i_desc in zip(inputs, inputs_desc): for i_idx, i_axis in enumerate(i_desc.shape): if isinstance(i_axis, str): - resolved_dims[i_axis] = input.size()[i_idx] + if i_axis not in resolved_dims: + resolved_dims[i_axis] = input.size()[i_idx] + else: + assert resolved_dims[i_axis] == input.size()[i_idx],\ + f"Mismatch in dynamic shape {i_axis}" for o_desc in outputs: for idx_o, o_axis in enumerate(o_desc.shape): if isinstance(o_axis, str): - o_desc.shape_[idx_o] = resolved_dims[o_axis] + o_desc.shape[idx_o] = resolved_dims[o_axis] unknown_dim = [o_desc.name for dim in o_desc.shape for o_desc in outputs if isinstance(dim, str)] if unknown_dim: diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index c4c537538538a..ebdf4676d43f9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -22,7 +22,7 @@ ############################################################################### -def _load_pytorch_transformer_model(device, legacy_api=False): +def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False): # Loads external Pytorch TransformerModel into utils pytorch_transformer_path = os.path.join('..', '..', '..', 'samples', 'python', 'pytorch_transformer') pt_model_path = os.path.join(pytorch_transformer_path, 'pt_model.py') @@ -36,9 +36,16 @@ def _load_pytorch_transformer_model(device, legacy_api=False): model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device) my_loss = ort_utils.my_loss if legacy_api: - model_desc = ort_utils.legacy_transformer_model_description() + if dynamic_axes: + model_desc = ort_utils.legacy_transformer_model_description_dynamic_axes() + else: + model_desc = ort_utils.legacy_transformer_model_description() else: - model_desc = ort_utils.transformer_model_description() + if dynamic_axes: + model_desc = ort_utils.transformer_model_description_dynamic_axes() + else: + model_desc = ort_utils.transformer_model_description() + # Preparing data train_data, val_data, test_data = utils.prepare_data(device, 20, 20) @@ -744,6 +751,29 @@ def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps _test_helpers.assert_model_outputs(expected_loss, actual_loss, rtol=1e-6) +@pytest.mark.parametrize("dynamic_axes", [ + (True), + (False), +]) +def testORTTrainerDynamicShape(dynamic_axes): + # Common setup + device = 'cuda' + + # Setup ORTTrainer + options = orttrainer.ORTTrainerOptions({}) + model, model_desc, my_loss, batcher_fn,\ + train_data, val_data, _ = _load_pytorch_transformer_model(device, dynamic_axes=dynamic_axes) + optim_config = optim.LambConfig(lr=0.001) + trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) + + # Training loop + total_steps = 10 + for i in range(total_steps): + data, targets = batcher_fn(train_data, i) + _, _ = trainer.train_step(data, targets) + + assert trainer._onnx_model is not None + ############################################################################### # Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ ############################################################################### diff --git a/samples/python/pytorch_transformer/ort_utils.py b/samples/python/pytorch_transformer/ort_utils.py index 7c51b7f0d3b1e..c97f9b4396b95 100644 --- a/samples/python/pytorch_transformer/ort_utils.py +++ b/samples/python/pytorch_transformer/ort_utils.py @@ -10,19 +10,34 @@ def my_loss(x, target): def transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - # TODO: Update to Dynamic Axis when backend is fixed model_desc = {'inputs': [('input1', [bptt, batch_size]), - ('label', [bptt, batch_size, ntokens],)], + ('label', [bptt * batch_size])], 'outputs': [('loss', [], True), ('predictions', [bptt, batch_size, ntokens])]} return model_desc +def transformer_model_description_dynamic_axes(ntokens=28785): + model_desc = {'inputs': [('input1', ['bptt', 'batch_size']), + ('label', ['bptt_x_batch_size'])], + 'outputs': [('loss', [], True), + ('predictions', ['bptt', 'batch_size', ntokens])]} + return model_desc + + def legacy_transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - # TODO: Update to Dynamic Axis when backend is fixed input_desc = Legacy_IODescription('input1', [bptt, batch_size]) - label_desc = Legacy_IODescription('label', [bptt, batch_size, ntokens]) + label_desc = Legacy_IODescription('label', [bptt * batch_size]) loss_desc = Legacy_IODescription('loss', []) predictions_desc = Legacy_IODescription('predictions', [bptt, batch_size, ntokens]) return Legacy_ModelDescription([input_desc, label_desc],[loss_desc, predictions_desc]),\ Legacy_IODescription('__learning_rate', [1]) + + +def legacy_transformer_model_description_dynamic_axes(ntokens=28785): + input_desc = Legacy_IODescription('input1', ['bptt', 'batch_size']) + label_desc = Legacy_IODescription('label', ['bptt_x_batch_size']) + loss_desc = Legacy_IODescription('loss', []) + predictions_desc = Legacy_IODescription('predictions', ['bptt', 'batch_size', ntokens]) + return Legacy_ModelDescription([input_desc, label_desc],[loss_desc, predictions_desc]),\ + Legacy_IODescription('__learning_rate', [1])