-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scatter max/min #4411
scatter max/min #4411
Conversation
9ec2a0f
to
06a0007
Compare
onnx/defs/tensor/defs.cc
Outdated
@@ -1259,6 +1259,20 @@ When `reduction` is set to "mul", `output` is calculated as follows: | |||
for idx in np.ndindex(update_indices): | |||
output[indices[idx]] *= updates[idx] | |||
|
|||
When `reduction` is set to "max", `output` is calculated as follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: may be we can combine all these 4 descriptions into one, as:
output = np.copy(data)
update_indices = indices.shape[:-1]
for idx in np.ndindex(update_indices):
output[indices[idx]] = reduction-op (output[indices[idx]], updates[idx])
where the reduction-op is +/*/max/min as specified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thats a great suggestion
onnx/defs/tensor/defs.cc
Outdated
@@ -1362,7 +1378,16 @@ When `reduction` is set to "mul", the update corresponding to the [i][j] entry i | |||
output[indices[i][j]][j] *= updates[i][j] if axis = 0, | |||
output[i][indices[i][j]] *= updates[i][j] if axis = 1, | |||
``` | |||
|
|||
When `reduction` is set to "max", the update corresponding to the [i][j] entry is performed as below: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above: it would help to consolidate all the cases into one.
``` | ||
output[indices[i][j]][j] = min(output[indices[i][j]][j], updates[i][j]) if axis = 0, | ||
output[i][indices[i][j]] = min(output[i][indices[i][j]], updates[i][j]) if axis = 1, | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest adding the following, as the last part of the documentation:
(Opset 18 change): Adds max/min to the set of allowed reduction ops.
LGTM, after @gramalingam suggestions. These are all the reduction functions PyTorch uses, but I also found more in Apple's BNNS. We could try figuring out if any of them could be useful and added together with |
02c56e8
to
cf02b4c
Compare
@gramalingam, Thanks for the review and the great suggestions. I have updated the PR with your recommendations. Sorry for the delay! @p-wysocki thats a really good point. There is this delicate balance between trying to not add every possible There are more Pytorch Scatter ops than just The reason that I only added Furthermore I think a true general solution is to use some of the ideas of In the short term I think |
@philass Thank you for such a detailed answer, I agree that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Hi @philass : thanks again for this PR. It seems to require some merge conflict (hopefully just the minor one in operator_sets.h). Once that is addressed, we can merge the PR. |
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
Signed-off-by: Philip Lassen <[email protected]>
f14d514
to
044e0c0
Compare
@gramalingam thanks for pointing that out. I've resolved the conflict, and I am ready to get this merged :) |
Hi @philass , i'm using the latest pytorch version, however i'm still facing the error: Shouldn't it be solved? |
* Add changelog Signed-off-by: Philip Lassen <[email protected]> * Add scatterND min and max test cases Signed-off-by: Philip Lassen <[email protected]> * Add scatterElements test case Signed-off-by: Philip Lassen <[email protected]> * Fix numpy function Signed-off-by: Philip Lassen <[email protected]> * improve name Signed-off-by: Philip Lassen <[email protected]> * Add tensor defs Signed-off-by: Philip Lassen <[email protected]> * Add old Signed-off-by: Philip Lassen <[email protected]> * Fix compile issue Signed-off-by: Philip Lassen <[email protected]> * Add generated files Signed-off-by: Philip Lassen <[email protected]> * add scatter changes to opertar_set.h Signed-off-by: Philip Lassen <[email protected]> * update onnx models Signed-off-by: Philip Lassen <[email protected]> * Add converter Signed-off-by: Philip Lassen <[email protected]> * clean up op wording Signed-off-by: Philip Lassen <[email protected]> * update generated docs Signed-off-by: Philip Lassen <[email protected]> * Format with black Signed-off-by: Philip Lassen <[email protected]> * Generate updated docs Signed-off-by: Philip Lassen <[email protected]> * Add test coverage doc changes Signed-off-by: Philip Lassen <[email protected]> Signed-off-by: Philip Lassen <[email protected]>
i'm facing same problem too |
@PallottaEnrico, @LebronRemonJames,
What you can do is augment the export path with a custom exporter for the scatter_max op. Note I haven't tested that the way it exports to ScatterElements is functionally correct. What I mean by this is that this will export to the ONNX # POC for export custom torch ops defined outside pytorch
# Example of custom op -> ScatterMax : https://github.com/rusty1s/pytorch_scatter/blob/master/csrc/scatter.cpp#L199
#
# code below follows from instructions here https://pytorch.org/docs/stable/onnx.html#c-operators
import torch
import torch.nn as nn
from torch.onnx import symbolic_helper
from torch_scatter import scatter_max
class ScatterMax(nn.Module):
def forward(self, src: torch.Tensor, index: torch.Tensor):
# src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
# index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index, dim=-1, out=src)
return out, argmax
@symbolic_helper.parse_args("v", "v", "i", "v", "i", "i")
def symbolic_scatter_max(g, src, index, dim=-1, out=None, dim_size=None, fill_value=None):
return (index ,g.op("ScatterElements", out, index, src, axis_i=dim, reduction_s="max"))
torch.onnx.register_custom_op_symbolic("torch_scatter::scatter_max", symbolic_scatter_max, 18)
model = ScatterMax()
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
input_names = ["input_src", "input_idx"]
output_names = ["result_arr1", "result_arr2"]
torch.onnx.export(model, (src, index), "program.onnx", input_names=input_names, output_names=output_names, opset_version=18) You can find the code here for reference https://github.com/philass/onnx-experiments/blob/master/export-custom.py. |
thanks |
Hi @philass, thank you very much for your solution. I try to use your version for converting my model and it has this problem: Here is the code:
What could be the problem? I am using pytorch 2.3.1 + cuda11.8, onnxruntime-gpu 1.18.1 and onnx 1.17.0 .Thank you |
This PR proposes adding max and min as supported reduction attributes for both ScatterElements and ScatterND.
Closes #4322
ScatterElements and ScatterND provide a mechanism for specifying the reduction function. Currently only add and mul are supported. It would be useful to extend these to support max and min, as these come up in practice and are supported by both Pytorch and TensorFlow. With out this addition ONNX doesn't have a good way to represent these operations.