Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Dynamic Axes feature and add unit test #4795

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,12 +681,16 @@ def _resolve_symbolic_dimensions(self, inputs, inputs_desc, outputs_desc):
for input, i_desc in zip(inputs, inputs_desc):
for i_idx, i_axis in enumerate(i_desc.shape):
if isinstance(i_axis, str):
resolved_dims[i_axis] = input.size()[i_idx]
if i_axis not in resolved_dims:
resolved_dims[i_axis] = input.size()[i_idx]
else:
assert resolved_dims[i_axis] == input.size()[i_idx],\
f"Mismatch in dynamic shape {i_axis}"

for o_desc in outputs:
for idx_o, o_axis in enumerate(o_desc.shape):
if isinstance(o_axis, str):
o_desc.shape_[idx_o] = resolved_dims[o_axis]
o_desc.shape[idx_o] = resolved_dims[o_axis]

unknown_dim = [o_desc.name for dim in o_desc.shape for o_desc in outputs if isinstance(dim, str)]
if unknown_dim:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
###############################################################################


def _load_pytorch_transformer_model(device, legacy_api=False):
def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False):
# Loads external Pytorch TransformerModel into utils
pytorch_transformer_path = os.path.join('..', '..', '..', 'samples', 'python', 'pytorch_transformer')
pt_model_path = os.path.join(pytorch_transformer_path, 'pt_model.py')
Expand All @@ -36,9 +36,16 @@ def _load_pytorch_transformer_model(device, legacy_api=False):
model = pt_model.TransformerModel(28785, 200, 2, 200, 2, 0.2).to(device)
my_loss = ort_utils.my_loss
if legacy_api:
model_desc = ort_utils.legacy_transformer_model_description()
if dynamic_axes:
model_desc = ort_utils.legacy_transformer_model_description_dynamic_axes()
else:
model_desc = ort_utils.legacy_transformer_model_description()
else:
model_desc = ort_utils.transformer_model_description()
if dynamic_axes:
model_desc = ort_utils.transformer_model_description_dynamic_axes()
else:
model_desc = ort_utils.transformer_model_description()


# Preparing data
train_data, val_data, test_data = utils.prepare_data(device, 20, 20)
Expand Down Expand Up @@ -744,6 +751,29 @@ def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps
_test_helpers.assert_model_outputs(expected_loss, actual_loss, rtol=1e-6)


@pytest.mark.parametrize("dynamic_axes", [
(True),
(False),
])
def testORTTrainerDynamicShape(dynamic_axes):
# Common setup
device = 'cuda'

# Setup ORTTrainer
options = orttrainer.ORTTrainerOptions({})
model, model_desc, my_loss, batcher_fn,\
train_data, val_data, _ = _load_pytorch_transformer_model(device, dynamic_axes=dynamic_axes)
optim_config = optim.LambConfig(lr=0.001)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options)

# Training loop
total_steps = 10
for i in range(total_steps):
data, targets = batcher_fn(train_data, i)
_, _ = trainer.train_step(data, targets)

assert trainer._onnx_model is not None

###############################################################################
# Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############
###############################################################################
Expand Down
23 changes: 19 additions & 4 deletions samples/python/pytorch_transformer/ort_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,34 @@ def my_loss(x, target):


def transformer_model_description(bptt=35, batch_size=20, ntokens=28785):
# TODO: Update to Dynamic Axis when backend is fixed
model_desc = {'inputs': [('input1', [bptt, batch_size]),
('label', [bptt, batch_size, ntokens],)],
('label', [bptt * batch_size])],
'outputs': [('loss', [], True),
('predictions', [bptt, batch_size, ntokens])]}
return model_desc


def transformer_model_description_dynamic_axes(ntokens=28785):
model_desc = {'inputs': [('input1', ['bptt', 'batch_size']),
('label', ['bptt_x_batch_size'])],
'outputs': [('loss', [], True),
('predictions', ['bptt', 'batch_size', ntokens])]}
return model_desc


def legacy_transformer_model_description(bptt=35, batch_size=20, ntokens=28785):
# TODO: Update to Dynamic Axis when backend is fixed
input_desc = Legacy_IODescription('input1', [bptt, batch_size])
label_desc = Legacy_IODescription('label', [bptt, batch_size, ntokens])
label_desc = Legacy_IODescription('label', [bptt * batch_size])
loss_desc = Legacy_IODescription('loss', [])
predictions_desc = Legacy_IODescription('predictions', [bptt, batch_size, ntokens])
return Legacy_ModelDescription([input_desc, label_desc],[loss_desc, predictions_desc]),\
Legacy_IODescription('__learning_rate', [1])


def legacy_transformer_model_description_dynamic_axes(ntokens=28785):
input_desc = Legacy_IODescription('input1', ['bptt', 'batch_size'])
label_desc = Legacy_IODescription('label', ['bptt_x_batch_size'])
loss_desc = Legacy_IODescription('loss', [])
predictions_desc = Legacy_IODescription('predictions', ['bptt', 'batch_size', ntokens])
return Legacy_ModelDescription([input_desc, label_desc],[loss_desc, predictions_desc]),\
Legacy_IODescription('__learning_rate', [1])