From 81ecf2e878d07a165f16b9dff53d6fcaac1c35e8 Mon Sep 17 00:00:00 2001 From: ddavis-2015 Date: Thu, 17 Oct 2024 18:17:06 -0700 Subject: [PATCH] decompression unit test improvements --- .../lite/micro/kernels/decompress_test.cc | 184 ++++++++++++++---- 1 file changed, 146 insertions(+), 38 deletions(-) diff --git a/tensorflow/lite/micro/kernels/decompress_test.cc b/tensorflow/lite/micro/kernels/decompress_test.cc index e4b39366dd6..e34906c5834 100644 --- a/tensorflow/lite/micro/kernels/decompress_test.cc +++ b/tensorflow/lite/micro/kernels/decompress_test.cc @@ -47,7 +47,7 @@ struct TestingInfo { }; template -struct TestingData7_2_256 { +struct TestingData { static constexpr size_t kBitWidth = 7; static constexpr size_t kChannels = 2; static constexpr size_t kElementsPerChannel = 256; @@ -63,14 +63,45 @@ struct TestingData7_2_256 { T goldens[kTotalElements]; }; -TestingData7_2_256 TestingData7_2_256_Bool; -#ifdef notyet -TestingData7_2_256 TestingData7_2_256_Float32; -TestingData7_2_256 TestingData7_2_256_Int8; -TestingData7_2_256 TestingData7_2_256_Int16; -TestingData7_2_256 TestingData7_2_256_Int32; -TestingData7_2_256 TestingData7_2_256_Int64; -#endif // notyet +TestingData TestingData_Bool; +TestingData TestingData_Float32; +TestingData TestingData_Int8; +TestingData TestingData_Int16; +TestingData TestingData_Int32; +TestingData TestingData_Int64; + +template +TestingData* GetTestingData(); + +template <> +TestingData* GetTestingData() { + return &TestingData_Bool; +} + +template <> +TestingData* GetTestingData() { + return &TestingData_Float32; +} + +template <> +TestingData* GetTestingData() { + return &TestingData_Int8; +} + +template <> +TestingData* GetTestingData() { + return &TestingData_Int16; +} + +template <> +TestingData* GetTestingData() { + return &TestingData_Int32; +} + +template <> +TestingData* GetTestingData() { + return &TestingData_Int64; +} template void FillValueTable(const size_t total_elements, T* value_table) { @@ -81,7 +112,6 @@ void FillValueTable(const size_t total_elements, T* value_table) { } } -#ifdef notyet template <> void FillValueTable(const size_t total_elements, float* value_table) { float fill_value = -1.1f; @@ -90,7 +120,6 @@ void FillValueTable(const size_t total_elements, float* value_table) { fill_value -= 1.0f; } } -#endif // notyet template <> void FillValueTable(const size_t total_elements, bool* value_table) { @@ -127,8 +156,8 @@ void FillGoldens(const size_t total_elements, T* goldens, for (size_t channel = 0; channel < channels; channel++) { for (size_t i = 0; i < elements_per_channel; i++) { goldens[(channel * elements_per_channel) + i] = - value_table[(channel * value_table_stride) + value_table_index++]; - if (value_table_index == value_table_stride) { + value_table[(channel * value_table_stride) + value_table_index]; + if (++value_table_index == value_table_stride) { value_table_index = 0; } } @@ -163,7 +192,7 @@ void FillCompressed(uint8_t* compressed, const size_t total_golden_elements, for (size_t group = 0; group < golden_element_groups; group++) { for (size_t channel = 0; channel < channels; channel++) { size_t value_table_index = FindValueTableIndex( - goldens[(group * golden_element_groups) + channel], + goldens[(group * channels) + channel], &value_table[channel * value_table_stride], value_table_stride); bits |= value_table_index << (16 - bits_accumulated - bit_width); bits_accumulated += bit_width; @@ -193,10 +222,14 @@ void FillCompressed(uint8_t* compressed, const size_t total_golden_elements, value_table += value_table_stride; } } + + if (bits_accumulated > 0) { + *compressed = static_cast(bits >> 8); + } } template -TfLiteStatus TestDecompression(TestingInfo* info) { +void TestDecompression(TestingInfo* info) { CompressionTensorData ctd = {}; LookupTableData lut_data = {}; ctd.scheme = CompressionScheme::kBinQuant; @@ -211,36 +244,22 @@ TfLiteStatus TestDecompression(TestingInfo* info) { DecompressionState ds(info->compressed, info->total_elements, ctd, info->channel_count); - std::fill_n(info->output, info->total_elements, ~0ULL); + std::fill_n(info->output, info->total_elements, static_cast(~0ULL)); ds.DecompressToBuffer(info->output); + bool saved_fail_state = micro_test::did_test_fail; + micro_test::did_test_fail = false; for (size_t i = 0; i < info->total_elements; i++) { TF_LITE_MICRO_EXPECT_EQ(info->goldens[i], info->output[i]); - TF_LITE_MICRO_CHECK_FAIL(); + if (micro_test::did_test_fail) { + return; + } } - - return kTfLiteOk; + micro_test::did_test_fail = saved_fail_state; } template -void TestBitWidth(size_t bit_width) { - MicroPrintf(" Testing bit width %d", bit_width); - - TestingInfo info = {}; - - if (std::is_same::value) { - info.output = TestingData7_2_256_Bool.output; - info.goldens = TestingData7_2_256_Bool.goldens; - info.compressed = TestingData7_2_256_Bool.compressed; - info.value_table = TestingData7_2_256_Bool.value_table; - } - - info.bit_width = bit_width; - info.channel_count = 1; - info.total_elements = 16; - info.total_value_table_elements = 1 << bit_width; - info.use_alt_axis = false; - +void GenerateData(TestingInfo& info) { FillValueTable(info.total_value_table_elements, info.value_table); FillGoldens(info.total_elements, info.goldens, info.total_value_table_elements, info.value_table, @@ -249,14 +268,98 @@ void TestBitWidth(size_t bit_width) { info.total_value_table_elements / info.channel_count, info.value_table, info.channel_count, info.use_alt_axis, info.bit_width); +} +template +void TestDataSetup(TestingInfo* info, TestingData* data) { + info->output = data->output; + info->goldens = data->goldens; + info->compressed = data->compressed; + info->value_table = data->value_table; +} + +template +void TestValueTable2n(TestingInfo& info) { + info.total_elements = 16; + if (std::is_same::value) { + info.total_value_table_elements = 2 * info.channel_count; + } else { + info.total_value_table_elements = + (1 << info.bit_width) * info.channel_count; + info.total_value_table_elements = + std::min(info.total_value_table_elements, info.total_elements); + } + + MicroPrintf(" Testing value table 2^n: %d", + info.total_value_table_elements); + GenerateData(info); TestDecompression(&info); } +template +void TestValueTable2nMinus1(TestingInfo& info) { + info.total_elements = 16; + if (std::is_same::value) { + info.total_value_table_elements = 1 * info.channel_count; + } else { + info.total_value_table_elements = + ((1 << info.bit_width) - 1) * info.channel_count; + info.total_value_table_elements = + std::min(info.total_value_table_elements, info.total_elements); + } + + MicroPrintf(" Testing value table 2^n-1: %d", + info.total_value_table_elements); + GenerateData(info); + TestDecompression(&info); +} + +template +void TestSingleChannel(TestingInfo& info) { + info.channel_count = 1; + + MicroPrintf(" Testing single channel"); + TestValueTable2n(info); + TestValueTable2nMinus1(info); +} + +template +void TestMultiChannel(TestingInfo& info) { + info.channel_count = 2; + + MicroPrintf(" Testing multiple channels: %d", info.channel_count); + TestValueTable2n(info); + TestValueTable2nMinus1(info); +} + +template +void TestBitWidth(TestingInfo& info) { + info.use_alt_axis = false; + + MicroPrintf(" Testing bit width %d", info.bit_width); + TestSingleChannel(info); + TestMultiChannel(info); +} + +template +void TestBitWidthAltAxis(TestingInfo& info) { + info.use_alt_axis = true; + + MicroPrintf(" Testing alt-axis bit width %d", info.bit_width); + TestSingleChannel(info); + TestMultiChannel(info); +} + template void TestAllBitWidths() { + TestingInfo info = {}; + TestDataSetup(&info, GetTestingData()); + for (size_t bw = 1; bw <= 7; bw++) { - TestBitWidth(bw); + info.bit_width = bw; + + TestBitWidth(info); + TestBitWidthAltAxis(info); } } @@ -267,6 +370,11 @@ void TestAllBitWidths() { TF_LITE_MICRO_TESTS_BEGIN TF_LITE_MICRO_TEST(TestBool) { tflite::testing::TestAllBitWidths(); } +TF_LITE_MICRO_TEST(TestFloat) { tflite::testing::TestAllBitWidths(); } +TF_LITE_MICRO_TEST(TestInt8) { tflite::testing::TestAllBitWidths(); } +TF_LITE_MICRO_TEST(TestInt16) { tflite::testing::TestAllBitWidths(); } +TF_LITE_MICRO_TEST(TestInt32) { tflite::testing::TestAllBitWidths(); } +TF_LITE_MICRO_TEST(TestInt64) { tflite::testing::TestAllBitWidths(); } TF_LITE_MICRO_TESTS_END