Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bitcast buffer allocation #632

Merged
merged 1 commit into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/nncase/transforms/neutral/optimize_allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ class NNCASE_API add_copy_to_output_pass : public graph_pass
void run_core(graph &graph, nncase::target &target, const run_pass_options &options) override;
};

class NNCASE_API add_copy_to_bitcast_pass : public graph_pass
{
public:
using graph_pass::graph_pass;

protected:
void run_core(graph &graph, nncase::target &target, const run_pass_options &options) override;
};

class NNCASE_API remove_exclusive_copy_to_output_transform : public transform
{
public:
Expand All @@ -89,6 +98,15 @@ class NNCASE_API remove_exclusive_copy_to_concat_transform : public transform
bool on_try_match(ir::node &node, transform_context &context) override;
};

class NNCASE_API remove_exclusive_copy_to_bitcast_transform : public transform
{
public:
void process(transform_context &context) override;

protected:
bool on_try_match(ir::node &node, transform_context &context) override;
};

class NNCASE_API remove_simple_copy_from_slice_transform : public transform
{
public:
Expand Down
2 changes: 2 additions & 0 deletions src/nncase/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,14 @@ class compiler_impl : public compiler
pmgr.add_pass<add_copy_to_concat_pass>();
pmgr.add_pass<add_copy_to_slice_pass>();
pmgr.add_pass<add_copy_to_output_pass>();
pmgr.add_pass<add_copy_to_bitcast_pass>();

transform_pass pass("optimize_copy");
pass.emplace<remove_exclusive_copy_to_output_transform>();
pass.emplace<remove_simple_copy_from_slice_transform>();
pass.emplace<remove_non_simple_copy_from_slice_transform>();
pass.emplace<remove_exclusive_copy_to_concat_transform>();
pass.emplace<remove_exclusive_copy_to_bitcast_transform>();
pmgr.add_pass(std::move(pass));
});
}
Expand Down
53 changes: 53 additions & 0 deletions src/transforms/neutral/optimize_allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,25 @@ void add_copy_to_output_pass::run_core(graph &graph, [[maybe_unused]] nncase::ta
alias_visitor.visit(graph);
}

void add_copy_to_bitcast_pass::run_core(graph &graph, [[maybe_unused]] nncase::target &target, [[maybe_unused]] const run_pass_options &options)
{
auto alias_visitor = make_relay_ir_visitor([&](node &node) {
if (auto b = node_cast<bitcast>(node))
{
auto &out = *b->input().connection();
if (out.owner().runtime_opcode() != op_copy)
{
auto cp = graph.emplace<copy>(out.type(), out.shape());
cp->module_type(graph.module_type());
cp->name(out.owner().name() + "/copy");
cp->input().connect(out);
b->input().connect(cp->output());
}
}
});
alias_visitor.visit(graph);
}

// x@data x@output
// | |
// copy |
Expand Down Expand Up @@ -222,6 +241,40 @@ void remove_exclusive_copy_to_concat_transform::process(transform_context &conte
in->connect(output);
}

bool remove_exclusive_copy_to_bitcast_transform::on_try_match(node &node, transform_context &context)
{
copy *cp;
bitcast *b;

if ((cp = node_cast<copy>(node))
&& (b = try_get_direct_child<bitcast>(*cp)))
{
auto input = cp->input().connection();

if (input->memory_location() == mem_data
&& ((input->attributes() & cnctr_attr_no_buffer_fusion) == 0))
{
context.inputs.emplace_back(&cp->input());
context.outputs.emplace_back(&cp->output());

context.matched_nodes.emplace_back(cp);
return true;
}
}

return false;
}

void remove_exclusive_copy_to_bitcast_transform::process(transform_context &context)
{
auto &output = *context.inputs[0]->connection();
auto inputs = context.outputs[0]->connections();

output.attributes(output.attributes() | cnctr_attr_no_buffer_fusion);
for (auto &in : dup(inputs))
in->connect(output);
}

// x x
// | |
// slice |
Expand Down
42 changes: 42 additions & 0 deletions tests/schedule/buffer_fusion/test_bitcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import pytest
import tensorflow as tf
import numpy as np
from tflite_test_runner import TfliteTestRunner


def _make_module():
class Module(tf.Module):
def __init__(self):
super(Module).__init__()

@tf.function(input_signature=[tf.TensorSpec([1, 4, 4, 3], tf.float32)])
def __call__(self, x):
return tf.reshape(x, [1, -1, 3])
return Module()


def test_bitcast(request):
module = _make_module()

runner = TfliteTestRunner(request.node.name)
model_file = runner.from_tensorflow(module)
runner.run(model_file)


if __name__ == "__main__":
pytest.main(['-vv', 'test_bitcast.py'])