Skip to content

Commit

Permalink
Add ONNX LinearRegressor operator support (#248)
Browse files Browse the repository at this point in the history
* Add ONNX LinearRegressor operator support

* Update branch name to release-1.11.1 in build/test scripts
  • Loading branch information
Xingyu Zhou authored Mar 24, 2022
1 parent a6ddc8e commit 7ccef62
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ $(foreach CMAKE_TARGET,$(CMAKE_TARGETS),$(eval $(GEN_CMAKE_RULE)))
# scripts that are executed in the CI should be in tests/lint. This
# allows docker/lint.sh to behave similarly to the CI.
format:
./tests/lint/git-clang-format.sh -i origin/release-1.11.0
./tests/lint/git-clang-format.sh -i origin/release-1.11.1
black .
cd rust && which cargo && cargo fmt --all

Expand Down
4 changes: 2 additions & 2 deletions docker/lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function run_lint_step() {
# NOTE: need to run git status to update some docker-side cache. Otherwise,
# git-clang-format will fail with "The following files would be modified but have
# unstaged changes:"
cmd=( bash -c 'git status &>/dev/null && tests/lint/git-clang-format.sh -i origin/release-1.11.0' )
cmd=( bash -c 'git status &>/dev/null && tests/lint/git-clang-format.sh -i origin/release-1.11.1' )
fi
;;
cpplint)
Expand All @@ -61,7 +61,7 @@ function run_lint_step() {
if [ $inplace_fix -eq 0 ]; then
cmd=( tests/lint/python_format.sh )
else
cmd=( tests/lint/git-black.sh -i origin/release-1.11.0 )
cmd=( tests/lint/git-black.sh -i origin/release-1.11.1 )
fi
;;
jnilint)
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3444,6 +3444,35 @@ def body_fn(*loop_inputs):
return outputs


class LinearRegressor(OnnxOpConverter):
"""Operator converter for LinearRegressor."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
coefficients = attr.get("coefficients", 0)
data_shape = infer_shape(data)
targets = attr.get("targets", 1)
coefficients = _expr.const(list(coefficients), dtype="float32")
coefficients_shape = infer_shape(coefficients)

coefficients = _op.reshape(coefficients, (targets, coefficients_shape[0] // targets))
if coefficients_shape[0] // targets < data_shape[-1]:
data = _op.split(data, [coefficients_shape[0] // targets], -1)[0]

mm_out = _op.nn.dense(data, coefficients)

if "intercepts" in attr:
intercepts = attr.get("intercepts", 0)
intercepts = _expr.const(list(intercepts), dtype="float32")

if targets == 1:
return _op.nn.bias_add(mm_out, intercepts, axis=-1)
return get_relay_op("add")(mm_out, intercepts)

return mm_out


class NonMaxSuppression(OnnxOpConverter):
"""Operator converter for NonMaxSuppression."""

Expand Down Expand Up @@ -4762,6 +4791,8 @@ def _get_convert_map(opset):
"Adam": Adam.get_converter(opset),
"Momentum": Momentum.get_converter(opset),
"Scan": Scan.get_converter(opset),
# ML
"LinearRegressor": LinearRegressor.get_converter(opset),
}


Expand Down
2 changes: 1 addition & 1 deletion tests/lint/clang_format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@

# check lastest change, for squash merge into main
./tests/lint/git-clang-format.sh HEAD~1
./tests/lint/git-clang-format.sh origin/release-1.11.0
./tests/lint/git-clang-format.sh origin/release-1.11.1
2 changes: 1 addition & 1 deletion tests/lint/python_format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@


./tests/lint/git-black.sh HEAD~1
./tests/lint/git-black.sh origin/release-1.11.0
./tests/lint/git-black.sh origin/release-1.11.1
44 changes: 44 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -6137,6 +6137,49 @@ def verify_scan(
verify_scan(input_shapes, output_shapes, 3, [-4, -1, -2], [1] * 3, [-3, -2], [1] * 2, 9)


@tvm.testing.parametrize_targets
def test_LinearRegressor(target, dev):
def verify_LinearRegressor(a_shape, c_shape, i_shape, targets=1, batch=1):
a_array = np.random.uniform(size=a_shape).astype("float32")
out_shape = (batch, targets)

coefficients = np.random.uniform(size=c_shape).astype("float32")
intercepts = np.random.uniform(size=i_shape).astype("float32")

mul_node = helper.make_node(
"LinearRegressor",
["a"],
["out"],
coefficients=coefficients,
intercepts=intercepts,
targets=targets,
domain="ai.onnx.ml",
)

graph = helper.make_graph(
[mul_node],
"LinearRegressor_test",
inputs=[
helper.make_tensor_value_info("a", TensorProto.FLOAT, list(a_shape)),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, out_shape)],
)
model = helper.make_model(
graph,
producer_name="LinearRegressor_test",
opset_imports=[
onnx.helper.make_opsetid("ai.onnx.ml", 1),
],
)
verify_with_ort_with_inputs(model, [a_array], target=target, dev=dev)

verify_LinearRegressor((1, 3), (3), (1))
verify_LinearRegressor((2, 10), (10), (1), batch=2)
verify_LinearRegressor((1, 3), (30), (10), targets=10)
verify_LinearRegressor((10, 3), (30), (10), targets=10, batch=10)
verify_LinearRegressor((1, 4), (3), (1))


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -6232,3 +6275,4 @@ def verify_scan(
test_random_uniform_like()
test_random_normal()
test_random_normal_like()
test_LinearRegressor()
2 changes: 1 addition & 1 deletion tests/scripts/git_change_docs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ DOCS_DIR=0
OTHER_DIR=0
DOC_DIR="docs/"

changed_files=`git diff --no-commit-id --name-only -r origin/release-1.11.0`
changed_files=`git diff --no-commit-id --name-only -r origin/release-1.11.1`

for file in $changed_files; do
FOUND_ONE_FILE=1
Expand Down

0 comments on commit 7ccef62

Please sign in to comment.