Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support LBR #796

Merged
merged 4 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/nncase/codegen/stackvm/op_writer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 9/14/2022 4:24:11 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/2/6 下午2:28:54 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1475,6 +1475,7 @@ struct op_writer<nncase::runtime::stackvm::tensor_gru_op_t>
writer.write(op.input_shape_src);
writer.write(op.w_shape_src);
writer.write(op.direction);
writer.write(op.linear_before_reset);
}
};

Expand Down Expand Up @@ -1678,7 +1679,7 @@ class NNCASE_API op_builder
void tensor_trilu_(datatype_t datatype, uint8_t rshape_src, bool upper, int64_t k);
void tensor_unary_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, unary_op_t unary_op);
void tensor_transpose_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_perm);
void tensor_gru_(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction);
void tensor_gru_(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction, bool linear_before_reset);
void tensor_tflite_detection_postprocess_(uint8_t box_shape_src, uint8_t score_shape_src, uint8_t anchor_shape_src, int32_t max_detections, int32_t max_classes_per_detection, int32_t detections_per_class, bool use_regular_non_max_suppression, float nms_score_threshold, float nms_iou_threshold, int32_t num_classes, float y_scale, float x_scale, float h_scale, float w_scale);
void tensor_layer_normalization_(datatype_t datatype, uint8_t input_shape, int32_t axis, float epsilon);
void tensor_compress_(uint8_t input_shape_src, uint8_t condition_shape_src, float axis);
Expand Down
4 changes: 3 additions & 1 deletion include/nncase/ir/ops/gru.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,17 @@ class NNCASE_API gru : public node

lstm_direction direction() const noexcept { return direction_; }
std::string framework() const noexcept { return framework_; }
bool linear_before_reset() const noexcept { return linear_before_reset_; }

gru(shape_t input_shape, shape_t w_shape, shape_t r_shape, shape_t b_shape, shape_t output_shape,
shape_t output_h_shape, lstm_direction direction, std::string framework);
shape_t output_h_shape, lstm_direction direction, std::string framework, bool linear_before_reset);

protected:
bool properties_equal(node &other) const override;

private:
lstm_direction direction_;
std::string framework_;
bool linear_before_reset_;
};
}
221 changes: 154 additions & 67 deletions include/nncase/kernels/cpu/reference/tensor_compute.h

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion include/nncase/kernels/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ template <typename T>
NNCASE_API result<void> trilu(const T *input, T *output, const runtime_shape_t &in_shape, const bool upper, const int64_t k) noexcept;

template <typename T>
NNCASE_API result<void> gru(const T *input, const T *w, const T *r, const T *b, T *initial_h, T *output, T *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode) noexcept;
NNCASE_API result<void> gru(const T *input, const T *w, const T *r, const T *b, T *initial_h, T *output, T *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode, bool linear_before_reset) noexcept;

template <typename T>
NNCASE_API result<void> tflite_detection_postprocess(const T *boxes, const T *scores, const T *anchors, T *output_locations, T *output_classes, T *output_scores, T *output_num_detections,
Expand Down
3 changes: 2 additions & 1 deletion include/nncase/runtime/stackvm/op_reader.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 9/14/2022 4:24:10 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/2/6 下午2:28:53 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1736,6 +1736,7 @@ struct op_reader<tensor_gru_op_t>
op.input_shape_src = reader.read_unaligned<uint8_t>();
op.w_shape_src = reader.read_unaligned<uint8_t>();
op.direction = reader.read_unaligned<uint8_t>();
op.linear_before_reset = reader.read_unaligned<bool>();
return op;
}
};
Expand Down
7 changes: 4 additions & 3 deletions include/nncase/runtime/stackvm/opcode.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 9/14/2022 4:24:10 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/2/6 下午2:28:53 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1877,10 +1877,11 @@ struct tensor_gru_op_t
uint8_t input_shape_src;
uint8_t w_shape_src;
uint8_t direction;
bool linear_before_reset;

tensor_gru_op_t(default_init_t) noexcept { }
explicit tensor_gru_op_t(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::GRU), input_shape_src(input_shape_src), w_shape_src(w_shape_src), direction(direction)
explicit tensor_gru_op_t(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction, bool linear_before_reset) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::GRU), input_shape_src(input_shape_src), w_shape_src(w_shape_src), direction(direction), linear_before_reset(linear_before_reset)
{
}
};
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/stackvm/op_writer.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 9/14/2022 4:24:11 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/2/6 下午2:28:54 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -673,9 +673,9 @@ void op_builder::tensor_transpose_(datatype_t datatype, uint8_t rshape_src, uint
op_writer<tensor_transpose_op_t>()(tensor_transpose_op_t(datatype, rshape_src, rstride_src, rstride_dest, rshape_perm), writer_);
}

void op_builder::tensor_gru_(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction)
void op_builder::tensor_gru_(uint8_t input_shape_src, uint8_t w_shape_src, uint8_t direction, bool linear_before_reset)
{
op_writer<tensor_gru_op_t>()(tensor_gru_op_t(input_shape_src, w_shape_src, direction), writer_);
op_writer<tensor_gru_op_t>()(tensor_gru_op_t(input_shape_src, w_shape_src, direction, linear_before_reset), writer_);
}

void op_builder::tensor_tflite_detection_postprocess_(uint8_t box_shape_src, uint8_t score_shape_src, uint8_t anchor_shape_src, int32_t max_detections, int32_t max_classes_per_detection, int32_t detections_per_class, bool use_regular_non_max_suppression, float nms_score_threshold, float nms_iou_threshold, int32_t num_classes, float y_scale, float x_scale, float h_scale, float w_scale)
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/stackvm/ops/gru.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ void stackvm_module_builder::emit(gru &node, stackvm_op_builder &builder)
builder.stshape(0, input.shape);
builder.stshape(1, w.shape);

builder.tensor_gru_(0, 1, node.direction());
builder.tensor_gru_(0, 1, node.direction(), node.linear_before_reset());
}
2 changes: 1 addition & 1 deletion src/evaluator/ops/neutral/neutral_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ void register_neutral_evaluators()
auto output_h = context.memory_at(rnode.output_h());
kernels::gru(input.buffer().as_span<float>().data(), W.buffer().as_span<float>().data(), R.buffer().as_span<float>().data(),
B.buffer().as_span<float>().data(), initial_h.buffer().as_span<float>().data(), output.buffer().as_span<float>().data(), output_h.buffer().as_span<float>().data(),
input.shape(), W.shape(), rnode.direction())
input.shape(), W.shape(), rnode.direction(), rnode.linear_before_reset())
.unwrap_or_throw(); });

register_evaluator(op_tflite_detection_postprocess, [](ir::node &node, function_evaluate_context &context) {
Expand Down
4 changes: 3 additions & 1 deletion src/importer/onnx/ops/gru.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ void onnx_importer::convert_op_GRU(const NodeProto &node)
direction = kBidirectional;
size_t num_directions = direction == kBidirectional ? 2 : 1;

auto linear_before_reset = get_attribute<int64_t>(node, "linear_before_reset").value_or(0);

// input
auto input_size = node.input_size();
assert(input_size >= 3 && input_size <= 8);
Expand Down Expand Up @@ -83,7 +85,7 @@ void onnx_importer::convert_op_GRU(const NodeProto &node)
output_h = node.output()[1];

shape_t output_shape { seq_length, num_directions, batch_size, hidden_size };
auto lstm_node = graph_.emplace<gru>(input_shape, W_shape, R_shape, B_shape, output_shape, initial_shape, direction, "onnx");
auto lstm_node = graph_.emplace<gru>(input_shape, W_shape, R_shape, B_shape, output_shape, initial_shape, direction, "onnx", linear_before_reset == 0 ? false : true);
lstm_node->name(op_name);

input_tensors_.emplace(&lstm_node->input_at(0), input);
Expand Down
6 changes: 3 additions & 3 deletions src/ir/ops/gru.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ using namespace nncase;
using namespace nncase::ir;

gru::gru(shape_t input_shape, shape_t w_shape, shape_t r_shape, shape_t b_shape, shape_t output_shape,
shape_t output_h_shape, lstm_direction direction, std::string framework)
: direction_(direction), framework_(framework)
shape_t output_h_shape, lstm_direction direction, std::string framework, bool linear_before_reset)
: direction_(direction), framework_(framework), linear_before_reset_(linear_before_reset)
{
add_input("input", dt_float32, input_shape);
add_input("w", dt_float32, w_shape);
Expand All @@ -36,5 +36,5 @@ gru::gru(shape_t input_shape, shape_t w_shape, shape_t r_shape, shape_t b_shape,
bool gru::properties_equal(node &other) const
{
auto &r = static_cast<gru &>(other);
return direction() == r.direction() && framework() == r.framework();
return direction() == r.direction() && framework() == r.framework() && linear_before_reset() == r.linear_before_reset();
}
27 changes: 17 additions & 10 deletions src/kernels/cpu/reference/gru.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,22 @@
#include <nncase/kernels/cpu/reference/tensor_compute.h>
#include <nncase/kernels/kernel_utils.h>
#include <nncase/runtime/runtime_op_utility.h>

using namespace nncase;
using namespace nncase::runtime;
using namespace nncase::kernels;
using namespace nncase::kernels::cpu;
using namespace nncase::kernels::cpu::reference;

template result<void> reference::gru<float>(const float *input, const float *w, const float *r, const float *b, float *initial_h, float *output, float *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode) noexcept;
template result<void>
reference::gru<float>(const float *input, const float *w, const float *r, const float *b, float *initial_h,
float *output, float *output_h, const runtime_shape_t &input_shape,
const runtime_shape_t &w_shape, int mode, bool linear_before_reset) noexcept;

template <typename T>
result<void> reference::gru(const T *input, const T *w, const T *r, const T *b, T *initial_h, T *output, T *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode) noexcept
result<void> reference::gru(const T *input, const T *w, const T *r, const T *b, T *initial_h, T *output, T *output_h,
const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode,
bool linear_before_reset) noexcept
{
const int seq_length = input_shape[0];
const int batch_size = input_shape[1];
Expand Down Expand Up @@ -140,19 +146,20 @@ result<void> reference::gru(const T *input, const T *w, const T *r, const T *b,
tmp_a[bs * hidden_size + hs] += x_i[bs * input_size + is] * w_i[2 * hidden_size * input_size + hs * input_size + is];
}
tmp_a[bs * hidden_size + hs] += b_i[2 * hidden_size + hs];

for (int rs = 0; rs < hidden_size; rs++)
{
// if not linear
tmp_b[bs * hidden_size + hs] += gate_r[bs * hidden_size + rs] * h_t[bs * hidden_size + rs] * r_i[2 * hidden_size * hidden_size + hs * hidden_size + rs];
// if linear
// tmp_b[bs * batch_size + hs] += h_t[bs * batch_size + rs] * r_i[hidden_size * hidden_size + hs * hidden_size + rs] + b_i[5 * hidden_size + hs];
if (!linear_before_reset)
tmp_b[bs * hidden_size + hs] += gate_r[bs * hidden_size + rs] * h_t[bs * hidden_size + rs] * r_i[2 * hidden_size * hidden_size + hs * hidden_size + rs];
else
tmp_b[bs * hidden_size + hs] += h_t[bs * hidden_size + rs] * r_i[2 * hidden_size * hidden_size + hs * hidden_size + rs];
}
tmp_b[bs * hidden_size + hs] += b_i[5 * hidden_size + hs];

// if not linear
gate_h[bs * hidden_size + hs] = tmp_a[bs * hidden_size + hs] + tmp_b[bs * hidden_size + hs];
// if linear
// gate_h[bs * batch_size + hs] = tmp_a[bs * batch_size + hs] + gate_r[bs * batch_size + rs] * tmp_b[bs * batch_size + hs];
if (!linear_before_reset)
gate_h[bs * hidden_size + hs] = tmp_a[bs * hidden_size + hs] + tmp_b[bs * hidden_size + hs];
else
gate_h[bs * hidden_size + hs] = tmp_a[bs * hidden_size + hs] + gate_r[bs * hidden_size + hs] * tmp_b[bs * hidden_size + hs];
}
}
// gate_h = tanh(gate_h);
Expand Down
6 changes: 3 additions & 3 deletions src/kernels/tensor_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,12 @@ result<void> kernels::trilu(const T *input, T *output, const runtime_shape_t &in
return cpu::reference::trilu(input, output, in_shape, upper, k);
}

template result<void> kernels::gru<float>(const float *input, const float *w, const float *r, const float *b, float *initial_h, float *output, float *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode) noexcept;
template result<void> kernels::gru<float>(const float *input, const float *w, const float *r, const float *b, float *initial_h, float *output, float *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode, bool linear_before_reset) noexcept;

template <typename T>
result<void> kernels::gru(const T *input, const T *w, const T *r, const T *b, T *initial_h, T *output, T *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode) noexcept
result<void> kernels::gru(const T *input, const T *w, const T *r, const T *b, T *initial_h, T *output, T *output_h, const runtime_shape_t &input_shape, const runtime_shape_t &w_shape, int mode, bool linear_before_reset) noexcept
{
return cpu::reference::gru(input, w, r, b, initial_h, output, output_h, input_shape, w_shape, mode);
return cpu::reference::gru(input, w, r, b, initial_h, output, output_h, input_shape, w_shape, mode, linear_before_reset);
}

template result<void> kernels::tflite_detection_postprocess<float>(const float *boxes, const float *scores, const float *anchors, float *output_locations, float *output_classes, float *output_scores, float *output_num_detections,
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/stackvm/op_reader.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 9/14/2022 4:24:11 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 2023/2/6 下午2:28:53 +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/stackvm/ops/tensor.gru.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ result<void> stackvm_runtime_function::visit(const tensor_gru_op_t &op) noexcept
return kernels::gru(reinterpret_cast<const float *>(input), reinterpret_cast<const float *>(w),
reinterpret_cast<const float *>(r), reinterpret_cast<const float *>(b),
reinterpret_cast<float *>(initial_h), reinterpret_cast<float *>(output),
reinterpret_cast<float *>(output_h), in_shape, w_shape, op.direction);
reinterpret_cast<float *>(output_h), in_shape, w_shape, op.direction, op.linear_before_reset);
}
2 changes: 1 addition & 1 deletion targets/k210/k210_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ void k210_target::register_quantize_annotation_passes(NNCASE_UNUSED const module

{
transform_pass p("annotate_kpu_quantize");
p.emplace<add_quant_checkpoints_transform>(std::in_place, ir::op_fused_unary, ir::k210::op_k210_fake_kpu_conv2d, ir::op_bitcast, ir::op_dequantize, ir::op_binary);
p.emplace<add_quant_checkpoints_transform>(std::in_place, ir::op_fused_unary, ir::k210::op_k210_fake_kpu_conv2d, ir::op_bitcast, ir::op_dequantize, ir::op_binary, ir::op_slice);
pass_mgr.add_pass(std::move(p));
}
}
Expand Down
37 changes: 25 additions & 12 deletions tests/importer/onnx_/basic/test_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from onnx_test_runner import OnnxTestRunner
import numpy as np

def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bias, sequence_lens, initial_h, Y, Y_h):


def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bias, sequence_lens, initial_h, Y, Y_h,
LBR):
nodes_inputs = []
nodes_outputs = []
initializers = []
Expand All @@ -33,6 +36,7 @@ def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bia
if direction is not None:
attributes_dict['direction'] = direction
attributes_dict['hidden_size'] = hidden_size
attributes_dict['linear_before_reset'] = LBR

# input
input_shape = [seq_length, batch_size, input_size]
Expand All @@ -45,7 +49,7 @@ def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bia
'W',
TensorProto.FLOAT,
dims=w_shape,
vals=np.random.rand(*w_shape).astype(np.float32).flatten().tolist()
vals=(np.random.rand(*w_shape) * 2 - 1).astype(np.float32).flatten().tolist()
)
nodes_inputs.append('W')
initializers.append(w_tensor)
Expand All @@ -55,7 +59,7 @@ def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bia
'R',
TensorProto.FLOAT,
dims=r_shape,
vals=np.random.rand(*r_shape).astype(np.float32).flatten().tolist()
vals=(np.random.rand(*r_shape) * 2 - 1).astype(np.float32).flatten().tolist()
)
nodes_inputs.append('R')
initializers.append(r_tensor)
Expand All @@ -69,7 +73,7 @@ def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bia
'B',
TensorProto.FLOAT,
dims=bias_shape,
vals=np.random.rand(*bias_shape).astype(np.float32).flatten().tolist()
vals=(np.random.rand(*bias_shape) * 2 - 1).astype(np.float32).flatten().tolist()
)
nodes_inputs.append('B')
initializers.append(bias_tensor)
Expand Down Expand Up @@ -100,7 +104,6 @@ def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bia
nodes_inputs.append('initial_h')
initializers.append(initial_h_tensor)


# output
if Y is None:
nodes_outputs.append('')
Expand Down Expand Up @@ -140,6 +143,7 @@ def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bia

return model_def


directions = [
None,
'forward',
Expand All @@ -148,19 +152,19 @@ def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bia
]

hidden_sizes = [
3
32,
]

seq_lengths = [
4
4,
]

batch_sizes = [
5,
16,
]

input_sizes = [
6,
64,
]

biases = [
Expand All @@ -178,15 +182,20 @@ def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bia
]

Ys = [
# None, // At least one output be requested
1
]


Y_hs = [
None,
1
]

LBRs = [
0,
1
]


@pytest.mark.parametrize('direction', directions)
@pytest.mark.parametrize('hidden_size', hidden_sizes)
Expand All @@ -198,12 +207,16 @@ def _make_module(direction, hidden_size, seq_length, batch_size, input_size, bia
@pytest.mark.parametrize('initial_h', initial_hs)
@pytest.mark.parametrize('Y', Ys)
@pytest.mark.parametrize('Y_h', Y_hs)
def test_gru(direction, hidden_size, seq_length, batch_size, input_size, bias, sequence_lens, initial_h, Y, Y_h, request):
model_def = _make_module(direction, hidden_size, seq_length, batch_size, input_size, bias, sequence_lens, initial_h, Y, Y_h)
@pytest.mark.parametrize('LBR', LBRs)
def test_gru(direction, hidden_size, seq_length, batch_size, input_size, bias, sequence_lens, initial_h, Y, Y_h, LBR,
request):
model_def = _make_module(direction, hidden_size, seq_length, batch_size,
input_size, bias, sequence_lens, initial_h, Y, Y_h, LBR)

runner = OnnxTestRunner(request.node.name)
model_file = runner.from_onnx_helper(model_def)
runner.run(model_file)


if __name__ == "__main__":
pytest.main(['-vv', 'test_gru.py'])
4 changes: 4 additions & 0 deletions tools/stackvm_gen/IsaGen/Instructions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2307,6 +2307,10 @@ public class GruInstruction : TensorInstruction
[Description("direction register")]
public byte Direction { get; set; }

[DisplayName("linear_before_reset")]
[Description("LBR register")]
public bool LinearBeforeReset { get; set; }

}
[DisplayName("TENSOR.TFLITE_DETECTION_POSTPROCESS")]
[Category("Tensor Instructions")]
Expand Down