Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow any input in to_onnx and to_torchscript #4378

Merged
merged 41 commits into from
Dec 12, 2020
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d8d0608
branch merge
rohitgr7 Oct 26, 2020
0037f28
sample
rohitgr7 Oct 26, 2020
2da23cd
update with valid input tensors
rohitgr7 Oct 31, 2020
b45792c
pep
rohitgr7 Oct 31, 2020
e98cc33
pathlib
rohitgr7 Oct 31, 2020
3096054
Updated with BoringModel and added more input types
rohitgr7 Nov 2, 2020
62d5b84
try fix
rohitgr7 Nov 2, 2020
d05f145
pep
rohitgr7 Nov 2, 2020
9aab119
skip test with torch < 1.4
rohitgr7 Nov 3, 2020
72e1c17
Merge branch 'master' into bugfix/onnx_batch_transfer
rohitgr7 Nov 3, 2020
aee7bfa
Merge branch 'master' into bugfix/onnx_batch_transfer
tchaton Nov 3, 2020
fb47563
fix test
rohitgr7 Nov 2, 2020
7378f26
Merge branch 'master' into bugfix/onnx_batch_transfer
rohitgr7 Nov 6, 2020
e4e897f
Merge branch 'master' into bugfix/onnx_batch_transfer
Nov 8, 2020
5ab5e4c
Merge branch 'master' into bugfix/onnx_batch_transfer
tchaton Nov 10, 2020
16488dd
Merge branch 'master' into bugfix/onnx_batch_transfer
s-rog Nov 16, 2020
9d799f1
Apply suggestions from code review
Borda Nov 16, 2020
972834d
Merge branch 'master' into bugfix/onnx_batch_transfer
tchaton Nov 16, 2020
5770a75
update tests
rohitgr7 Nov 16, 2020
dd5f190
Merge branch 'master' into bugfix/onnx_batch_transfer
tchaton Nov 17, 2020
d7b814d
Merge branch 'master' into bugfix/onnx_batch_transfer
tchaton Nov 18, 2020
48ed87d
Allow any input in to_onnx and to_torchscript
rohitgr7 Nov 20, 2020
14771fc
Merge branch 'master' into bugfix/onnx_batch_transfer
rohitgr7 Nov 20, 2020
689b630
Update tests/models/test_torchscript.py
rohitgr7 Nov 22, 2020
5ee20fc
Merge branch 'master' into bugfix/onnx_batch_transfer
rohitgr7 Nov 22, 2020
e85dc07
no_grad
rohitgr7 Nov 22, 2020
898add2
Merge branch 'master' into bugfix/onnx_batch_transfer
s-rog Nov 23, 2020
080d688
Merge branch 'master' into bugfix/onnx_batch_transfer
tchaton Nov 23, 2020
77ddc48
Merge branch 'master' into bugfix/onnx_batch_transfer
tchaton Nov 25, 2020
4ec81a9
Merge branch 'master' into bugfix/onnx_batch_transfer
rohitgr7 Nov 30, 2020
7adf810
Merge branch 'master' into bugfix/onnx_batch_transfer
rohitgr7 Dec 1, 2020
3c9ed6d
try fix random failing test
rohitgr7 Dec 1, 2020
e7b1a0a
rm example_input_array
rohitgr7 Dec 1, 2020
dbdc7c3
Merge branch 'master' into bugfix/onnx_batch_transfer
rohitgr7 Dec 1, 2020
341a8cd
rm example_input_array
rohitgr7 Dec 1, 2020
ace662d
Merge branch 'master' into bugfix/onnx_batch_transfer
rohitgr7 Dec 1, 2020
6ac100d
Merge branch 'master' into bugfix/onnx_batch_transfer
rohitgr7 Dec 2, 2020
20af7d5
Merge branch 'master' into bugfix/onnx_batch_transfer
edenlightning Dec 8, 2020
ddcd731
Merge branch 'master' into bugfix/onnx_batch_transfer
tchaton Dec 8, 2020
33553ca
Merge branch 'master' into bugfix/onnx_batch_transfer
edenlightning Dec 8, 2020
ea1be78
Merge branch 'master' into bugfix/onnx_batch_transfer
s-rog Dec 12, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Various hooks to be used in the Lightning code."""

from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import torch
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
Expand Down Expand Up @@ -501,7 +501,7 @@ def val_dataloader(self):
will have an argument ``dataloader_idx`` which matches the order here.
"""

def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
Expand Down Expand Up @@ -549,6 +549,7 @@ def transfer_batch_to_device(self, batch, device)
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
device = device or self.device
return move_data_to_device(batch, device)


Expand Down
87 changes: 51 additions & 36 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tempfile
from abc import ABC
from argparse import Namespace
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
Expand Down Expand Up @@ -1512,12 +1513,19 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
else:
self._hparams = hp

def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs):
"""Saves the model in ONNX format
@torch.no_grad()
def to_onnx(
self,
file_path: Union[str, Path],
input_sample: Optional[Any] = None,
**kwargs,
):
"""
Saves the model in ONNX format

Args:
file_path: The path of the file the model should be saved to.
input_sample: A sample of an input tensor for tracing.
file_path: The path of the file the onnx model should be saved to.
input_sample: An input for tracing. Default: None (Use self.example_input_array)
**kwargs: Will be passed to torch.onnx.export function.

Example:
Expand All @@ -1536,31 +1544,32 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
... os.path.isfile(tmpfile.name)
True
"""
mode = self.training

if isinstance(input_sample, Tensor):
input_data = input_sample
elif self.example_input_array is not None:
input_data = self.example_input_array
else:
if input_sample is not None:
if input_sample is None:
if self.example_input_array is None:
raise ValueError(
f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`"
"Could not export to ONNX since neither `input_sample` nor"
" `model.example_input_array` attribute is set."
)
raise ValueError(
"Could not export to ONNX since neither `input_sample` nor"
" `model.example_input_array` attribute is set."
)
input_data = input_data.to(self.device)
input_sample = self.example_input_array

input_sample = self.transfer_batch_to_device(input_sample)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

if "example_outputs" not in kwargs:
self.eval()
with torch.no_grad():
kwargs["example_outputs"] = self(input_data)
kwargs["example_outputs"] = self(input_sample)

torch.onnx.export(self, input_data, file_path, **kwargs)
torch.onnx.export(self, input_sample, file_path, **kwargs)
self.train(mode)

@torch.no_grad()
def to_torchscript(
self, file_path: Optional[str] = None, method: Optional[str] = 'script',
example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs
self,
file_path: Optional[Union[str, Path]] = None,
method: Optional[str] = 'script',
example_inputs: Optional[Any] = None,
**kwargs,
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
"""
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
Expand All @@ -1572,7 +1581,7 @@ def to_torchscript(
Args:
file_path: Path where to save the torchscript. Default: None (no file saved).
method: Whether to use TorchScript's script or trace method. Default: 'script'
example_inputs: Tensor to be used to do tracing when method is set to 'trace'.
example_inputs: An input to be used to do tracing when method is set to 'trace'.
Default: None (Use self.example_input_array)
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
:func:`torch.jit.trace` function.
Expand Down Expand Up @@ -1606,21 +1615,27 @@ def to_torchscript(
This LightningModule as a torchscript, regardless of whether file_path is
defined or not.
"""

mode = self.training
with torch.no_grad():
if method == 'script':
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == 'trace':
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
example_inputs = self.example_input_array
# automatically send example inputs to the right device and use trace
example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device)
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:"
f"{method}")

if method == 'script':
torchscript_module = torch.jit.script(self.eval(), **kwargs)
elif method == 'trace':
# if no example inputs are provided, try to see if model has example_input_array set
if example_inputs is None:
if self.example_input_array is None:
raise ValueError(
'Choosing method=`trace` requires either `example_inputs`'
' or `model.example_input_array` to be defined'
)
example_inputs = self.example_input_array

rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
# automatically send example inputs to the right device and use trace
example_inputs = self.transfer_batch_to_device(example_inputs)
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
else:
raise ValueError("The 'method' parameter only supports 'script' or 'trace',"
f" but value given was: {method}")

self.train(mode)

if file_path is not None:
Expand Down
49 changes: 22 additions & 27 deletions tests/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,44 +21,44 @@
import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
from tests.base import BoringModel, EvalModelTemplate


def test_model_saves_with_input_sample(tmpdir):
"""Test that ONNX model saves with input sample and size is greater than 3 MB"""
model = EvalModelTemplate()
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer.fit(model)

file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
input_sample = torch.randn((1, 32))
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 3e+06
assert os.path.getsize(file_path) > 4e2


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_model_saves_on_gpu(tmpdir):
"""Test that model saves on gpu"""
model = EvalModelTemplate()
model = BoringModel()
trainer = Trainer(gpus=1, max_epochs=1)
trainer.fit(model)

file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
input_sample = torch.randn((1, 32))
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 3e+06
assert os.path.getsize(file_path) > 4e2


def test_model_saves_with_example_output(tmpdir):
"""Test that ONNX model saves when provided with example output"""
model = EvalModelTemplate()
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer.fit(model)

file_path = os.path.join(tmpdir, "model.onnx")
input_sample = torch.randn((1, 28 * 28))
input_sample = torch.randn((1, 32))
model.eval()
example_outputs = model.forward(input_sample)
model.to_onnx(file_path, input_sample, example_outputs=example_outputs)
Expand All @@ -67,11 +67,13 @@ def test_model_saves_with_example_output(tmpdir):

def test_model_saves_with_example_input_array(tmpdir):
"""Test that ONNX model saves with_example_input_array and size is greater than 3 MB"""
model = EvalModelTemplate()
model = BoringModel()
model.example_input_array = torch.randn(5, 32)

file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path)
assert os.path.exists(file_path) is True
assert os.path.getsize(file_path) > 3e+06
assert os.path.getsize(file_path) > 4e2


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
Expand Down Expand Up @@ -100,38 +102,31 @@ def test_model_saves_on_multi_gpu(tmpdir):

def test_verbose_param(tmpdir, capsys):
"""Test that output is present when verbose parameter is set"""
model = EvalModelTemplate()
model = BoringModel()
model.example_input_array = torch.randn(5, 32)

file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path, verbose=True)
captured = capsys.readouterr()
assert "graph(%" in captured.out


def test_error_if_no_input(tmpdir):
"""Test that an exception is thrown when there is no input tensor"""
model = EvalModelTemplate()
"""Test that an error is thrown when there is no input tensor"""
model = BoringModel()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onnx")
with pytest.raises(ValueError, match=r'Could not export to ONNX since neither `input_sample` nor'
r' `model.example_input_array` attribute is set.'):
model.to_onnx(file_path)


def test_error_if_input_sample_is_not_tensor(tmpdir):
"""Test that an exception is thrown when there is no input tensor"""
model = EvalModelTemplate()
model.example_input_array = None
file_path = os.path.join(tmpdir, "model.onnx")
input_sample = np.random.randn(1, 28 * 28)
with pytest.raises(ValueError, match=f'Received `input_sample` of type {type(input_sample)}. Expected type is '
f'`Tensor`'):
model.to_onnx(file_path, input_sample)


def test_if_inference_output_is_valid(tmpdir):
"""Test that the output inferred from ONNX model is same as from PyTorch"""
model = EvalModelTemplate()
trainer = Trainer(max_epochs=5)
model = BoringModel()
model.example_input_array = torch.randn(5, 32)

trainer = Trainer(max_epochs=2)
trainer.fit(model)

model.eval()
Expand Down
Loading