Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scatter max/min #4411

Merged
merged 17 commits into from
Sep 1, 2022
Merged

scatter max/min #4411

merged 17 commits into from
Sep 1, 2022

Conversation

philass
Copy link
Member

@philass philass commented Aug 4, 2022

This PR proposes adding max and min as supported reduction attributes for both ScatterElements and ScatterND.

Closes #4322

ScatterElements and ScatterND provide a mechanism for specifying the reduction function. Currently only add and mul are supported. It would be useful to extend these to support max and min, as these come up in practice and are supported by both Pytorch and TensorFlow. With out this addition ONNX doesn't have a good way to represent these operations.

@philass philass requested a review from a team as a code owner August 4, 2022 20:55
@philass philass marked this pull request as draft August 4, 2022 23:10
@philass philass force-pushed the plassen/scatter-max branch from 9ec2a0f to 06a0007 Compare August 9, 2022 04:20
@philass philass marked this pull request as ready for review August 10, 2022 19:57
@philass philass requested a review from a team as a code owner August 10, 2022 19:58
@gramalingam gramalingam added the topic: operator Issues related to ONNX operators label Aug 11, 2022
@@ -1259,6 +1259,20 @@ When `reduction` is set to "mul", `output` is calculated as follows:
for idx in np.ndindex(update_indices):
output[indices[idx]] *= updates[idx]

When `reduction` is set to "max", `output` is calculated as follows:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: may be we can combine all these 4 descriptions into one, as:

    output = np.copy(data)
    update_indices = indices.shape[:-1]
    for idx in np.ndindex(update_indices):
        output[indices[idx]] = reduction-op (output[indices[idx]], updates[idx])
where the reduction-op is +/*/max/min as specified.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a great suggestion

@@ -1362,7 +1378,16 @@ When `reduction` is set to "mul", the update corresponding to the [i][j] entry i
output[indices[i][j]][j] *= updates[i][j] if axis = 0,
output[i][indices[i][j]] *= updates[i][j] if axis = 1,
```

When `reduction` is set to "max", the update corresponding to the [i][j] entry is performed as below:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above: it would help to consolidate all the cases into one.

```
output[indices[i][j]][j] = min(output[indices[i][j]][j], updates[i][j]) if axis = 0,
output[i][indices[i][j]] = min(output[i][indices[i][j]], updates[i][j]) if axis = 1,
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding the following, as the last part of the documentation:


(Opset 18 change): Adds max/min to the set of allowed reduction ops.

@p-wysocki
Copy link
Contributor

LGTM, after @gramalingam suggestions. These are all the reduction functions PyTorch uses, but I also found more in Apple's BNNS. We could try figuring out if any of them could be useful and added together with max and min. Just a thought, I believe it's okay as is.

@philass philass force-pushed the plassen/scatter-max branch 2 times, most recently from 02c56e8 to cf02b4c Compare August 18, 2022 23:46
@philass
Copy link
Member Author

philass commented Aug 19, 2022

@gramalingam, Thanks for the review and the great suggestions. I have updated the PR with your recommendations. Sorry for the delay!

@p-wysocki thats a really good point. There is this delicate balance between trying to not add every possible reduction op and adequately expressing Pytorch and Tensorflow ops in ONNX.

There are more Pytorch Scatter ops than just Add, Mul, Max and Min. We found that the Pytorch scatter package is something that is quite common. It consists of the following Scatters.

The reason that I only added Max and Min is that we have actually seen models that need ScatterMax and ScatterMin. Namely

Furthermore Sub, Std, Mean, Div can likely be represented by a combination of ScatterAdd, and ScatterMul, and other ops to get the required behavior.

I think a true general solution is to use some of the ideas of If and Loop where some subgraph can be given as an attribute. This subgraph would be the reduction operation. That is more involved, and I don't have the knowledge of ONNX to drive that at the moment.

In the short term I think Max and Min are a good start.

@p-wysocki
Copy link
Contributor

@philass Thank you for such a detailed answer, I agree that Max and Min are sufficient for now.

Copy link
Contributor

@gramalingam gramalingam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@gramalingam
Copy link
Contributor

Hi @philass : thanks again for this PR. It seems to require some merge conflict (hopefully just the minor one in operator_sets.h). Once that is addressed, we can merge the PR.

Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
@philass philass force-pushed the plassen/scatter-max branch from f14d514 to 044e0c0 Compare September 1, 2022 08:14
@philass
Copy link
Member Author

philass commented Sep 1, 2022

@gramalingam thanks for pointing that out.

I've resolved the conflict, and I am ready to get this merged :)

@gramalingam gramalingam merged commit c68365e into onnx:main Sep 1, 2022
@philass philass changed the title Plassen/scatter max scatter max/min Sep 8, 2022
@PallottaEnrico
Copy link

PallottaEnrico commented Apr 19, 2023

Hi @philass , i'm using the latest pytorch version, however i'm still facing the error:
torch.onnx.errors.UnsupportedOperatorError: ONNX export failed on an operator with unrecognized namespace torch_scatter::scatter_max. If you are trying to export a custom operator, make sure you registered it with the right domain and version.

Shouldn't it be solved?

broune pushed a commit to broune/onnx that referenced this pull request May 6, 2023
* Add changelog

Signed-off-by: Philip Lassen <[email protected]>

* Add scatterND min and max test cases

Signed-off-by: Philip Lassen <[email protected]>

* Add scatterElements test case

Signed-off-by: Philip Lassen <[email protected]>

* Fix numpy function

Signed-off-by: Philip Lassen <[email protected]>

* improve name

Signed-off-by: Philip Lassen <[email protected]>

* Add tensor defs

Signed-off-by: Philip Lassen <[email protected]>

* Add old

Signed-off-by: Philip Lassen <[email protected]>

* Fix compile issue

Signed-off-by: Philip Lassen <[email protected]>

* Add generated files

Signed-off-by: Philip Lassen <[email protected]>

* add scatter changes to opertar_set.h

Signed-off-by: Philip Lassen <[email protected]>

* update onnx models

Signed-off-by: Philip Lassen <[email protected]>

* Add converter

Signed-off-by: Philip Lassen <[email protected]>

* clean up op wording

Signed-off-by: Philip Lassen <[email protected]>

* update generated docs

Signed-off-by: Philip Lassen <[email protected]>

* Format with black

Signed-off-by: Philip Lassen <[email protected]>

* Generate updated docs

Signed-off-by: Philip Lassen <[email protected]>

* Add test coverage doc changes

Signed-off-by: Philip Lassen <[email protected]>

Signed-off-by: Philip Lassen <[email protected]>
@LebronRemonJames
Copy link

Hi @philass , i'm using the latest pytorch version, however i'm still facing the error: torch.onnx.errors.UnsupportedOperatorError: ONNX export failed on an operator with unrecognized namespace torch_scatter::scatter_max. If you are trying to export a custom operator, make sure you registered it with the right domain and version.

Shouldn't it be solved?

i'm facing same problem too

@philass
Copy link
Member Author

philass commented May 31, 2023

@PallottaEnrico, @LebronRemonJames,

scatter_max isn't an aten torch op. Its defined in its own repository. As a result the exporters in the torch repo don't know that the op exists.

What you can do is augment the export path with a custom exporter for the scatter_max op. Note I haven't tested that the way it exports to ScatterElements is functionally correct. What I mean by this is that this will export to the ONNX ScatterElements, but scatter max is likely semantically different and you will have to adjust your export accordingly.

# POC for export custom torch ops defined outside pytorch
# Example of custom op -> ScatterMax : https://github.com/rusty1s/pytorch_scatter/blob/master/csrc/scatter.cpp#L199
#
# code below follows from instructions here https://pytorch.org/docs/stable/onnx.html#c-operators


import torch
import torch.nn as nn
from torch.onnx import symbolic_helper
from torch_scatter import scatter_max

class ScatterMax(nn.Module):

    def forward(self, src: torch.Tensor, index: torch.Tensor):
        # src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
        # index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
        out, argmax = scatter_max(src, index, dim=-1, out=src)
        return out, argmax


@symbolic_helper.parse_args("v", "v", "i", "v", "i", "i")
def symbolic_scatter_max(g, src, index, dim=-1, out=None, dim_size=None, fill_value=None):
    return (index ,g.op("ScatterElements", out, index, src, axis_i=dim, reduction_s="max"))

torch.onnx.register_custom_op_symbolic("torch_scatter::scatter_max", symbolic_scatter_max, 18)

model = ScatterMax()

src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
input_names = ["input_src", "input_idx"]
output_names = ["result_arr1", "result_arr2"]

torch.onnx.export(model, (src, index), "program.onnx", input_names=input_names, output_names=output_names, opset_version=18)

You can find the code here for reference https://github.com/philass/onnx-experiments/blob/master/export-custom.py.

@LebronRemonJames
Copy link

@PallottaEnrico, @LebronRemonJames,

scatter_max isn't an aten torch op. Its defined in its own repository. As a result the exporters in the torch repo don't know that the op exists.

What you can do is augment the export path with a custom exporter for the scatter_max op. Note I haven't tested that the way it exports to ScatterElements is functionally correct. What I mean by this is that this will export to the ONNX ScatterElements, but scatter max is likely semantically different and you will have to adjust your export accordingly.

# POC for export custom torch ops defined outside pytorch
# Example of custom op -> ScatterMax : https://github.com/rusty1s/pytorch_scatter/blob/master/csrc/scatter.cpp#L199
#
# code below follows from instructions here https://pytorch.org/docs/stable/onnx.html#c-operators


import torch
import torch.nn as nn
from torch.onnx import symbolic_helper
from torch_scatter import scatter_max

class ScatterMax(nn.Module):

    def forward(self, src: torch.Tensor, index: torch.Tensor):
        # src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
        # index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
        out, argmax = scatter_max(src, index, dim=-1, out=src)
        return out, argmax


@symbolic_helper.parse_args("v", "v", "i", "v", "i", "i")
def symbolic_scatter_max(g, src, index, dim=-1, out=None, dim_size=None, fill_value=None):
    return (index ,g.op("ScatterElements", out, index, src, axis_i=dim, reduction_s="max"))

torch.onnx.register_custom_op_symbolic("torch_scatter::scatter_max", symbolic_scatter_max, 18)

model = ScatterMax()

src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
input_names = ["input_src", "input_idx"]
output_names = ["result_arr1", "result_arr2"]

torch.onnx.export(model, (src, index), "program.onnx", input_names=input_names, output_names=output_names, opset_version=18)

You can find the code here for reference https://github.com/philass/onnx-experiments/blob/master/export-custom.py.

thanks

@paladin1410
Copy link

paladin1410 commented Oct 21, 2024

Hi @philass, thank you very much for your solution. I try to use your version for converting my model and it has this problem:
Traceback (most recent call last):
File "/app/LION/tools/huy_test2.py", line 91, in
torch.onnx.export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1612, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1138, in _model_to_graph
graph = _optimize_graph(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 677, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1956, in _run_symbolic_function
return symbolic_fn(graph_context, *inputs, **attrs)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 395, in wrapper
return fn(g, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset11.py", line 554, in cat
return opset9.cat(g, tensor_list, dim)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
return fn(g, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py", line 546, in cat
assert all(
AssertionError

Here is the code:

import torch
import torch.nn as nn
import torch_scatter
import torch.onnx
import numpy as np
from torch.onnx import symbolic_helper

class PFNLayerV2(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 use_norm=True,
                 last_layer=False):
        super().__init__()

        self.last_vfe = last_layer
        self.use_norm = use_norm
        if not self.last_vfe:
            out_channels = out_channels // 2

        if self.use_norm:
            self.linear = nn.Linear(in_channels, out_channels, bias=False)
            self.norm = nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01)
        else:
            self.linear = nn.Linear(in_channels, out_channels, bias=True)

        self.relu = nn.ReLU()

    def forward(self, inputs, unq_inv):

        x = self.linear(inputs)
        x = self.norm(x) if self.use_norm else x
        x = self.relu(x)
        x_max = torch_scatter.scatter_max(x, unq_inv, dim=0)[0]

        if self.last_vfe:
            return x_max
        else:
            x_concatenated = torch.cat([x, x_max[unq_inv, :]], dim=1)
            return x_concatenated

    
num_points = 100
in_features = 16
out_features = 32
inputs = torch.randn(num_points, in_features)
unq_inv = torch.randint(0, 10, (num_points,), dtype=torch.long)  # Assume 10 unique voxels

model = PFNLayerV2(in_features, out_features)



@symbolic_helper.parse_args("v", "v", "i", "v", "i", "i")
def symbolic_scatter_max(g, src, index, dim=-1, out=None, dim_size=None, fill_value=None):
    return (index ,g.op("ScatterElements", out, index, src, axis_i=dim, reduction_s="max"))


torch.onnx.register_custom_op_symbolic("torch_scatter::scatter_max", symbolic_scatter_max, 17)


torch.onnx.export(
    model,
    (inputs, unq_inv),
    "simple_pfn_layer.onnx",
    opset_version=17,  # Use at least opset version 16
    input_names=['inputs', 'unq_inv'],
    output_names=['output'],
)

What could be the problem? I am using pytorch 2.3.1 + cuda11.8, onnxruntime-gpu 1.18.1 and onnx 1.17.0 .Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: operator Issues related to ONNX operators
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Extend ScatterElements op with min and max reduction
6 participants