Skip to content

Commit

Permalink
Support unique_with_counts (#2195)
Browse files Browse the repository at this point in the history
* support UniqueWithCounts
* unique_with_counts test

---------

Signed-off-by: Salvetti, Francesco <[email protected]>
  • Loading branch information
f-salvetti authored Jul 26, 2023
1 parent 6dda2bb commit 6cdb7e3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 15 deletions.
1 change: 1 addition & 0 deletions support_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@
| Transpose | 1 ~ 17 |
| TruncateDiv | 1 ~ 17 |
| Unique | 11 ~ 17 |
| UniqueWithCounts | 11 ~ 18 |
| Unpack | 1 ~ 17 |
| UnsortedSegmentMax | 11 ~ 17 |
| UnsortedSegmentMin | 11 ~ 17 |
Expand Down
33 changes: 33 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5047,6 +5047,39 @@ def func(x):
return y1, y2
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: x_val})

@check_opset_min_version(11, "Unique")
def test_unique_with_counts(self):
x_val = np.array([1, 2, 8, 1, 2, 2, 7, 7, 7, 1], dtype=np.float32)
def func(x):
x1_, x2_, x3_ = tf.unique_with_counts(x)
y1 = tf.identity(x1_, name=_TFOUTPUT)
y2 = tf.identity(x2_, name=_TFOUTPUT1)
y3 = tf.identity(x3_, name=_TFOUTPUT2)
return y1, y2, y3
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val})

@check_opset_min_version(11, "Unique")
def test_unique_with_counts_out_int64(self):
x_val = np.array([2, 3, 3, 6, 4, 1, 1], dtype=np.float32)
def func(x):
x1_, x2_, x3_ = tf.unique_with_counts(x, out_idx=tf.int64)
y1 = tf.identity(x1_, name=_TFOUTPUT)
y2 = tf.identity(x2_, name=_TFOUTPUT1)
y3 = tf.identity(x3_, name=_TFOUTPUT2)
return y1, y2, y3
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val})

@check_opset_min_version(11, "Unique")
def test_unique_with_counts_out_int32(self):
x_val = np.array([2, 3, 3, 6, 4, 1, 1], dtype=np.float32)
def func(x):
x1_, x2_, x3_ = tf.unique_with_counts(x, out_idx=tf.int32)
y1 = tf.identity(x1_, name=_TFOUTPUT)
y2 = tf.identity(x2_, name=_TFOUTPUT1)
y3 = tf.identity(x3_, name=_TFOUTPUT2)
return y1, y2, y3
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2], {_INPUT: x_val})

@check_opset_min_version(11, "Unique")
def test_bincount(self):
x_val = np.array([5, 2, 3, 1, 3, 2, 7, 5, 9, 10], dtype=np.int32)
Expand Down
43 changes: 28 additions & 15 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,32 +2364,45 @@ def version_10(cls, ctx, node, **kwargs):


@tf_op("Unique", onnx_op="Unique")
@tf_op("UniqueWithCounts", onnx_op="Unique")
class Unique:
int_cast = [TensorProto.BOOL, TensorProto.INT32, TensorProto.INT16, TensorProto.UINT8,
TensorProto.UINT16, TensorProto.UINT32, TensorProto.UINT64]
dtype_map = {k: TensorProto.INT64 for k in int_cast}
dtype_map[TensorProto.DOUBLE] = TensorProto.FLOAT

@classmethod
def version_11(cls, ctx, node, **kwargs):
# opset 11 supports explicitly
dtypes = node.output_dtypes
node_name = node.name
node_inputs = node.input
node_outputs = node.output
inp_dtype = ctx.get_dtype(node.input[0])

ctx.remove_node(node_name)
if dtypes[0] in [TensorProto.INT32, TensorProto.INT16, TensorProto.UINT8, TensorProto.UINT16]:
inp_cast = ctx.make_node("Cast", [node_inputs[0]], attr={'to': TensorProto.INT64}).output[0]

# due to ORT missing implementations we need to cast INT inputs to INT64 and FLOAT inputs to FLOAT32
if inp_dtype in cls.dtype_map:
inp_cast = ctx.make_node("Cast", [node_inputs[0]], attr={'to': cls.dtype_map[inp_dtype]}).output[0]
node_inputs[0] = inp_cast
new_node = ctx.make_node("Unique", node_inputs, name=node_name, output_count=3, attr={'sorted': 0})

new_node = ctx.make_node("Unique", node_inputs, name=node_name, attr={'sorted': 0},
outputs=[utils.make_name("y"), utils.make_name("idx_first"),
utils.make_name("idx"), utils.make_name("counts")])
ctx.replace_all_inputs(node_outputs[0], new_node.output[0])
ctx.replace_all_inputs(node_outputs[1], new_node.output[2])
if ctx.get_dtype(new_node.output[0]) != dtypes[0]:
ctx.insert_new_node_on_output("Cast", new_node.output[0], name=utils.make_name(node.name) + "_cast",
to=dtypes[0])
if len(node_outputs) > 1:
# cast to int64 if needed
if dtypes[1] != onnx_pb.TensorProto.INT64:
cast_node = ctx.insert_new_node_on_output("Cast", new_node.output[2],
name=utils.make_name(node.name) + "_cast",
to=dtypes[1])
ctx.set_dtype(cast_node.output[0], dtypes[1])
ctx.copy_shape(new_node.output[2], cast_node.output[0])
if len(node_outputs) == 3: # we need counts too (UniqueWithCounts)
ctx.replace_all_inputs(node_outputs[2], new_node.output[3])
if ctx.get_dtype(new_node.output[0]) != inp_dtype:
ctx.insert_new_node_on_output("Cast", new_node.output[0], to=inp_dtype,
name=utils.make_name(node.name) + "_cast")

# cast idx and counts if needed
out_dtype = node.get_attr_value('out_idx')
if out_dtype != TensorProto.INT64:
for i in range(1, len(node_outputs)):
cast_node = ctx.insert_new_node_on_output("Cast", new_node.output[i+1], to=out_dtype,
name=utils.make_name(node.name) + "_cast")


@tf_op(["Bincount", "DenseBincount"])
Expand Down

0 comments on commit 6cdb7e3

Please sign in to comment.