Skip to content

Commit

Permalink
Merge pull request #192 from Idein/allow-any-split
Browse files Browse the repository at this point in the history
  • Loading branch information
Guriido authored Oct 2, 2023
2 parents e89e820 + 10b6753 commit 8a670ce
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 52 deletions.
2 changes: 0 additions & 2 deletions nnoir-onnx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ docker run --rm -it -u $UID:$GID -v $(pwd):/work idein/nnoir-tools:20230720 onnx
* [Sin](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sin)
* [Softmax](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax)
* [Split](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Split)
* must be from opset version >= 13
* Second optional parameter `split` is not supported
* [Squeeze](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Squeeze)
* [Sub](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sub)
* 1st input must not be `"constant"`
Expand Down
69 changes: 30 additions & 39 deletions nnoir-onnx/nnoir_onnx/operators/split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

import numpy as np
import onnx
Expand All @@ -10,13 +10,18 @@
ShapeLike = Union[Tuple[int, ...], List[int]]


def create_half_split_matrices(k: int) -> Tuple[NDArray[Any], NDArray[Any]]:
k_2 = k // 2
def create_split_matrices(k: int, sizes: List[int]) -> List[NDArray[Any]]:
split_matrices: List[NDArray[Any]] = []
acc = 0
for sub_k in sizes:
zero_before: NDArray[Any] = np.zeros((acc, sub_k), dtype="float32")
eye: NDArray[Any] = np.eye(sub_k, dtype="float32")
zero_after: NDArray[Any] = np.zeros((k - sub_k - acc, sub_k), dtype="float32")

eye = np.eye(k_2, dtype="float32")
zero = np.zeros((k_2, k_2), dtype="float32")
split_matrices.append(np.concatenate([zero_before, eye, zero_after])) # type: ignore
acc += sub_k

return (np.concatenate([eye, zero]), np.concatenate([zero, eye])) # type: ignore
return split_matrices


def gen_value(env: Dict[str, NDArray[Any]], arr: NDArray[Any]) -> str:
Expand All @@ -39,22 +44,23 @@ def __init__(self, node: onnx.NodeProto, *args: Any):
super().__init__(node, *args)

def to_function(self, env: Dict[str, NDArray[Any]], constants: Dict[str, NDArray[Any]]) -> List[Function]:
if len(self.node.input) > 1:
raise UnsupportedONNXOperation(self.node, "the number of inputs must be 1.")
if len(self.node.output) > 2:
raise UnsupportedONNXOperation(self.node, "the number of outputs must be 2.")

split_axis = 0
for attr in self.node.attribute:
if attr.name == "axis":
split_axis = attr.i

output_sizes_on_axis = []
for output in self.node.output:
output_sizes_on_axis.append(env[output].shape[split_axis])

shape = env[self.node.input[0]].shape
k = shape[split_axis]
if k % 2 != 0:
raise Exception("Cannot reshape on odd size, shape {}, axis {}".format(shape, split_axis))
assert k == sum(
output_sizes_on_axis
), f"outputs are not a split of input on for axis {split_axis} of input shape {shape}: {output_sizes_on_axis}"

matrices = create_split_matrices(k, output_sizes_on_axis)

matrice_up, matrice_down = create_half_split_matrices(k)
transpose_perm_0 = list(range(len(shape)))
transpose_perm_0.append(transpose_perm_0.pop(split_axis))
transpose_perm_1 = list(range(len(shape)))
Expand All @@ -81,30 +87,15 @@ def linear_shape(x_shape: ShapeLike, w_shape: ShapeLike) -> ShapeLike:
trans_shape = transpose_shape(shape, transpose_perm_0)
trans_out = gen_dummy_value(env, trans_shape)

linear_up_shape = linear_shape(trans_shape, matrice_up.shape)
linear_up_out = gen_dummy_value(env, linear_up_shape)

linear_down_shape = linear_shape(trans_shape, matrice_down.shape)
linear_down_out = gen_dummy_value(env, linear_down_shape)

up_const = gen_value(env, matrice_up)
up_const_node = Constant([], [up_const], value=matrice_up) # type: ignore
down_const = gen_value(env, matrice_up)
down_const_node = Constant([], [down_const], value=matrice_down) # type: ignore

transpose_node = Transpose(list(self.node.input), [trans_out], axes=transpose_perm_0) # type: ignore
linear_up_node = MatMul([trans_out, up_const], [linear_up_out]) # type: ignore
linear_down_node = MatMul([trans_out, down_const], [linear_down_out]) # type: ignore
transpose_up_node = Transpose([linear_up_out], [self.node.output[0]], axes=transpose_perm_1) # type: ignore
transpose_down_node = Transpose([linear_down_out], [self.node.output[1]], axes=transpose_perm_1) # type: ignore
nodes = [
up_const_node,
down_const_node,
transpose_node,
linear_up_node,
linear_down_node,
transpose_up_node,
transpose_down_node,
]
nodes: List[Function] = [Transpose([self.node.input[0]], [trans_out], axes=transpose_perm_0)] # type: ignore
for i, mat in enumerate(matrices):
_linear_shape = linear_shape(trans_shape, mat.shape)
linear_out = gen_dummy_value(env, _linear_shape)

_const = gen_value(env, mat)
_const_node = Constant([], [_const], value=mat) # type: ignore
linear_node = MatMul([trans_out, _const], [linear_out]) # type: ignore
transpose_node = Transpose([linear_out], [self.node.output[i]], axes=transpose_perm_1) # type: ignore
nodes.extend([_const_node, linear_node, transpose_node])

return nodes
53 changes: 42 additions & 11 deletions nnoir-onnx/test/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from numpy.typing import NDArray
from onnx import TensorProto
from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info
from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor, make_tensor_value_info
from onnx.numpy_helper import from_array
from util import Base

Expand Down Expand Up @@ -86,22 +86,19 @@ def create_onnx(self) -> onnx.ModelProto:
SplitTester({"v0": v0}, outputs).run()


@pytest.mark.xfail()
def test_split_specify_split() -> None:
"""
Specify second input (optional parameter).
Due to lack of implementation, the second input is not supported.
Specify split attribute (opset 11).
"""

class SplitTester(Base):
def __init__(self, inputs: Dict[str, NDArray[Any]], outputs: List[str]):
super().__init__(inputs, outputs)

def create_onnx(self) -> onnx.ModelProto:
node = make_node("Split", inputs=["v0", "p0"], outputs=["v1", "v2", "v3"], axis=3)
node = make_node("Split", inputs=["v0"], outputs=["v1", "v2", "v3"], axis=3, split=[2, 3, 5])
inputs = [
info("v0", TensorProto.FLOAT, (1, 3, 4, 10)),
info("p0", TensorProto.INT64, (3,)),
]
outputs = [
info("v1", TensorProto.FLOAT, (1, 3, 4, 2)),
Expand All @@ -110,11 +107,45 @@ def create_onnx(self) -> onnx.ModelProto:
]

graph = make_graph([node], "add_graph", inputs, outputs)
model = make_model(graph)
return model
return make_model(graph, opset_imports=[make_opsetid("", 11)])

v0 = np.random.rand(1, 3, 4, 10).astype(np.float32)
p0 = np.array([2, 3, 5]).astype(np.int64)
outputs = ["v1", "v2", "v3"]
SplitTester({"v0": v0}, outputs).run()

outputs = ["v1", "v2"]
SplitTester({"v0": v0, "p0": p0}, outputs).run()

def test_split_specify_split_13() -> None:
"""
Specify split input (opset 13).
"""

class SplitTester(Base):
def __init__(self, inputs: Dict[str, NDArray[Any]], outputs: List[str]):
super().__init__(inputs, outputs)

def create_onnx(self) -> onnx.ModelProto:
node = make_node("Split", inputs=["v0", "p0"], outputs=["v1", "v2", "v3"], axis=3)
inputs = [
info("v0", TensorProto.FLOAT, (1, 3, 4, 10)),
]
node_p0 = make_node(
"Constant",
value=make_tensor(
name="p0_constant", data_type=TensorProto.INT64, dims=(3,), vals=np.array([2, 3, 5]).astype(np.int64)
),
inputs=[],
outputs=["p0"],
)
outputs = [
info("v1", TensorProto.FLOAT, (1, 3, 4, 2)),
info("v2", TensorProto.FLOAT, (1, 3, 4, 3)),
info("v3", TensorProto.FLOAT, (1, 3, 4, 5)),
]

graph = make_graph([node_p0, node], "add_graph", inputs, outputs)
return make_model(graph, opset_imports=[make_opsetid("", 13)])

v0 = np.random.rand(1, 3, 4, 10).astype(np.float32)

outputs = ["v1", "v2", "v3"]
SplitTester({"v0": v0}, outputs).run()

0 comments on commit 8a670ce

Please sign in to comment.