-
Notifications
You must be signed in to change notification settings - Fork 435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support unique_with_counts #2195
Changes from all commits
2a06441
b30801c
476c57b
1e2e4b1
d26326b
8d53e3c
74b4301
f96cc4d
01d61af
863d06c
0144e23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2364,32 +2364,45 @@ def version_10(cls, ctx, node, **kwargs): | |
|
||
|
||
@tf_op("Unique", onnx_op="Unique") | ||
@tf_op("UniqueWithCounts", onnx_op="Unique") | ||
class Unique: | ||
int_cast = [TensorProto.BOOL, TensorProto.INT32, TensorProto.INT16, TensorProto.UINT8, | ||
TensorProto.UINT16, TensorProto.UINT32, TensorProto.UINT64] | ||
dtype_map = {k: TensorProto.INT64 for k in int_cast} | ||
dtype_map[TensorProto.DOUBLE] = TensorProto.FLOAT | ||
|
||
@classmethod | ||
def version_11(cls, ctx, node, **kwargs): | ||
# opset 11 supports explicitly | ||
dtypes = node.output_dtypes | ||
node_name = node.name | ||
node_inputs = node.input | ||
node_outputs = node.output | ||
inp_dtype = ctx.get_dtype(node.input[0]) | ||
|
||
ctx.remove_node(node_name) | ||
if dtypes[0] in [TensorProto.INT32, TensorProto.INT16, TensorProto.UINT8, TensorProto.UINT16]: | ||
inp_cast = ctx.make_node("Cast", [node_inputs[0]], attr={'to': TensorProto.INT64}).output[0] | ||
|
||
# due to ORT missing implementations we need to cast INT inputs to INT64 and FLOAT inputs to FLOAT32 | ||
if inp_dtype in cls.dtype_map: | ||
inp_cast = ctx.make_node("Cast", [node_inputs[0]], attr={'to': cls.dtype_map[inp_dtype]}).output[0] | ||
node_inputs[0] = inp_cast | ||
new_node = ctx.make_node("Unique", node_inputs, name=node_name, output_count=3, attr={'sorted': 0}) | ||
|
||
new_node = ctx.make_node("Unique", node_inputs, name=node_name, attr={'sorted': 0}, | ||
outputs=[utils.make_name("y"), utils.make_name("idx_first"), | ||
utils.make_name("idx"), utils.make_name("counts")]) | ||
ctx.replace_all_inputs(node_outputs[0], new_node.output[0]) | ||
ctx.replace_all_inputs(node_outputs[1], new_node.output[2]) | ||
if ctx.get_dtype(new_node.output[0]) != dtypes[0]: | ||
ctx.insert_new_node_on_output("Cast", new_node.output[0], name=utils.make_name(node.name) + "_cast", | ||
to=dtypes[0]) | ||
if len(node_outputs) > 1: | ||
# cast to int64 if needed | ||
if dtypes[1] != onnx_pb.TensorProto.INT64: | ||
cast_node = ctx.insert_new_node_on_output("Cast", new_node.output[2], | ||
name=utils.make_name(node.name) + "_cast", | ||
to=dtypes[1]) | ||
ctx.set_dtype(cast_node.output[0], dtypes[1]) | ||
ctx.copy_shape(new_node.output[2], cast_node.output[0]) | ||
if len(node_outputs) == 3: # we need counts too (UniqueWithCounts) | ||
ctx.replace_all_inputs(node_outputs[2], new_node.output[3]) | ||
if ctx.get_dtype(new_node.output[0]) != inp_dtype: | ||
ctx.insert_new_node_on_output("Cast", new_node.output[0], to=inp_dtype, | ||
name=utils.make_name(node.name) + "_cast") | ||
|
||
# cast idx and counts if needed | ||
out_dtype = node.get_attr_value('out_idx') | ||
if out_dtype != TensorProto.INT64: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since out_idx is optional, what will happen here if the output of Unique onnx op contains only 1 element? Line 2402 may still be true, but it is not what we expect, right? I agree accessing these attributes by a name is more meaningful, but probably it is more safe to access them by output lenght and index. Thoughts? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to TF API 'out_idx' should be either INT32 or INT64, and has a default value set to INT32, so I would expect that the attribute is always set to one of the two dtypes. And since onnx Unique op always provides "inverse_indices" and "counts" as INT64 they should be casted if 'out_idx' is different from INT64. |
||
for i in range(1, len(node_outputs)): | ||
cast_node = ctx.insert_new_node_on_output("Cast", new_node.output[i+1], to=out_dtype, | ||
name=utils.make_name(node.name) + "_cast") | ||
|
||
|
||
@tf_op(["Bincount", "DenseBincount"]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where do we cast 'counts'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the the for loop (Line 2403). If the node has 2 outputs (TF "Unique" op) we cast onnx output 2 ("inverse_indices"). If the node has 3 outputs (TF "UniqueWithCounts" op) we cast onnx output 2 ("inverse_indices") and output 3 ("counts"). Either case, we skip onnx output 1 ("indices") since it is not present in the TF ops.