From b318421f63f954c2e983b223e90f1c86ea2d9022 Mon Sep 17 00:00:00 2001 From: ddavis-2015 Date: Wed, 16 Oct 2024 17:34:09 -0700 Subject: [PATCH] add decompression unit test --- tensorflow/lite/micro/kernels/Makefile.inc | 6 +- tensorflow/lite/micro/kernels/decompress.h | 12 +- .../lite/micro/kernels/decompress_common.cc | 18 +- .../lite/micro/kernels/decompress_test.cc | 273 ++++++++++++++++++ .../lite/micro/kernels/xtensa/decompress.cc | 34 +-- tensorflow/lite/micro/micro_context.cc | 3 +- 6 files changed, 304 insertions(+), 42 deletions(-) create mode 100644 tensorflow/lite/micro/kernels/decompress_test.cc diff --git a/tensorflow/lite/micro/kernels/Makefile.inc b/tensorflow/lite/micro/kernels/Makefile.inc index 0bd846bc679..f4456242fef 100644 --- a/tensorflow/lite/micro/kernels/Makefile.inc +++ b/tensorflow/lite/micro/kernels/Makefile.inc @@ -1,4 +1,4 @@ -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -180,6 +180,10 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unpack_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/while_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/zeros_like_test.cc +ifeq ($(ENABLE_COMPRESSION), yes) +MICROLITE_KERNEL_SIMPLE_TEST_SRCS += $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_test.cc +endif + # Generate simple kernel test targets in a common way $(foreach TEST_TARGET,$(MICROLITE_KERNEL_SIMPLE_TEST_SRCS),\ $(eval $(call microlite_test,kernel_$(notdir $(basename $(TEST_TARGET))),$(TEST_TARGET)))) diff --git a/tensorflow/lite/micro/kernels/decompress.h b/tensorflow/lite/micro/kernels/decompress.h index 2ce9501cb90..2debed9c5e9 100644 --- a/tensorflow/lite/micro/kernels/decompress.h +++ b/tensorflow/lite/micro/kernels/decompress.h @@ -15,7 +15,8 @@ limitations under the License. #include -#include "tensorflow/lite/micro/micro_context.h" +#include "tensorflow/lite/micro/compression.h" +#include "tensorflow/lite/micro/micro_profiler.h" namespace tflite { @@ -27,19 +28,20 @@ struct DecompressionState { DecompressionState(const uint8_t* compressed_indices, const size_t count_indices, const CompressionTensorData& comp_data, - const size_t num_channels, MicroContext* micro_context) + const size_t num_channels, + MicroProfiler* profiler = nullptr) : compressed_indices_(compressed_indices), count_indices_(count_indices), comp_data_(comp_data), num_channels_(num_channels), - micro_context_(micro_context) {} + micro_profiler_(profiler) {} DecompressionState(const DecompressionState& other) : compressed_indices_(other.compressed_indices_), count_indices_(other.count_indices_), comp_data_(other.comp_data_), num_channels_(other.num_channels_), - micro_context_(other.micro_context_) {} + micro_profiler_(other.micro_profiler_) {} template T* DecompressToBuffer(void* buffer); @@ -74,7 +76,7 @@ struct DecompressionState { comp_data_.data.lut_data->use_alternate_axis ? 1 : count_indices_ / num_channels_; - MicroContext* micro_context_; + MicroProfiler* micro_profiler_; }; #endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/decompress_common.cc b/tensorflow/lite/micro/kernels/decompress_common.cc index ce8deda6e84..45d0c683a4f 100644 --- a/tensorflow/lite/micro/kernels/decompress_common.cc +++ b/tensorflow/lite/micro/kernels/decompress_common.cc @@ -29,9 +29,7 @@ limitations under the License. namespace tflite { void DecompressionState::DecompressToBufferWidth4_16(int8_t* buffer) { - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - ScopedMicroProfiler scoped_profiler(__func__, profiler); + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; const uint8_t* value_table = @@ -108,9 +106,7 @@ void DecompressionState::DecompressToBufferWidth4_16(int8_t* buffer) { } void DecompressionState::DecompressToBufferWidth2_16(int8_t* buffer) { - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - ScopedMicroProfiler scoped_profiler(__func__, profiler); + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; const uint8_t* value_table = @@ -187,9 +183,7 @@ void DecompressionState::DecompressToBufferWidth2_16(int8_t* buffer) { } void DecompressionState::DecompressToBufferWidth3_32(int8_t* buffer) { - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - ScopedMicroProfiler scoped_profiler(__func__, profiler); + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; const uint8_t* value_table = @@ -325,16 +319,14 @@ void DecompressionState::DecompressToBufferWidth3_32(int8_t* buffer) { template void DecompressionState::DecompressToBufferWidthAny(T* buffer) { const char* func_name_p = nullptr; - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - if (profiler != nullptr) { + if (micro_profiler_ != nullptr) { static char func_name[35]; MicroSnprintf(func_name, sizeof(func_name), "%s_%u_%s", __func__, compressed_bit_width_, TfLiteTypeGetName(typeToTfLiteType())); func_name_p = func_name; } - ScopedMicroProfiler scoped_profiler(func_name_p, profiler); + ScopedMicroProfiler scoped_profiler(func_name_p, micro_profiler_); if (comp_data_.data.lut_data->use_alternate_axis) { const size_t stride = comp_data_.data.lut_data->value_table_channel_stride; diff --git a/tensorflow/lite/micro/kernels/decompress_test.cc b/tensorflow/lite/micro/kernels/decompress_test.cc new file mode 100644 index 00000000000..e4b39366dd6 --- /dev/null +++ b/tensorflow/lite/micro/kernels/decompress_test.cc @@ -0,0 +1,273 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef USE_TFLM_COMPRESSION + +#include "tensorflow/lite/micro/kernels/decompress.h" + +#include +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro//micro_log.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/micro_arena_constants.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +template +struct TestingInfo { + T* output; + T* goldens; + uint8_t* compressed; + T* value_table; + + size_t bit_width; + size_t total_elements; + size_t total_value_table_elements; + size_t channel_count; + bool use_alt_axis; +}; + +template +struct TestingData7_2_256 { + static constexpr size_t kBitWidth = 7; + static constexpr size_t kChannels = 2; + static constexpr size_t kElementsPerChannel = 256; + + static constexpr size_t kTotalElements = kElementsPerChannel * kChannels; + static constexpr size_t kCompressedBytes = + ((kTotalElements * kBitWidth) + 7) / 8; + static constexpr size_t kValueTableSize = (1 << kBitWidth) * kChannels; + + alignas(MicroArenaBufferAlignment()) T output[kTotalElements]; + alignas(MicroArenaBufferAlignment()) uint8_t compressed[kCompressedBytes]; + alignas(MicroArenaBufferAlignment()) T value_table[kValueTableSize]; + T goldens[kTotalElements]; +}; + +TestingData7_2_256 TestingData7_2_256_Bool; +#ifdef notyet +TestingData7_2_256 TestingData7_2_256_Float32; +TestingData7_2_256 TestingData7_2_256_Int8; +TestingData7_2_256 TestingData7_2_256_Int16; +TestingData7_2_256 TestingData7_2_256_Int32; +TestingData7_2_256 TestingData7_2_256_Int64; +#endif // notyet + +template +void FillValueTable(const size_t total_elements, T* value_table) { + T fill_value = -1; + for (size_t i = 0; i < total_elements; i++) { + value_table[i] = fill_value; + fill_value -= 1; + } +} + +#ifdef notyet +template <> +void FillValueTable(const size_t total_elements, float* value_table) { + float fill_value = -1.1f; + for (size_t i = 0; i < total_elements; i++) { + value_table[i] = fill_value; + fill_value -= 1.0f; + } +} +#endif // notyet + +template <> +void FillValueTable(const size_t total_elements, bool* value_table) { + bool fill_value = true; + for (size_t i = 0; i < total_elements; i++) { + value_table[i] = fill_value; + fill_value = !fill_value; + } +} + +template +void FillGoldens(const size_t total_elements, T* goldens, + const size_t value_table_elements, const T* value_table, + const size_t channels, const bool use_alt_axis) { + if (use_alt_axis) { + const size_t value_table_stride = value_table_elements / channels; + const size_t element_groups = total_elements / channels; + size_t value_table_index = 0; // index within current channel + + for (size_t group = 0; group < element_groups; group++) { + for (size_t channel = 0; channel < channels; channel++) { + goldens[(group * channels) + channel] = + value_table[(channel * value_table_stride) + value_table_index]; + } + if (++value_table_index == value_table_stride) { + value_table_index = 0; + } + } + } else { + const size_t value_table_stride = value_table_elements / channels; + const size_t elements_per_channel = total_elements / channels; + size_t value_table_index = 0; // index within current channel + + for (size_t channel = 0; channel < channels; channel++) { + for (size_t i = 0; i < elements_per_channel; i++) { + goldens[(channel * elements_per_channel) + i] = + value_table[(channel * value_table_stride) + value_table_index++]; + if (value_table_index == value_table_stride) { + value_table_index = 0; + } + } + value_table_index = 0; + } + } +} + +// returns index within channel +template +size_t FindValueTableIndex(const T value, const T* value_table, + const size_t value_table_stride) { + for (size_t i = 0; i < value_table_stride; i++) { + if (value == value_table[i]) { + return i; + } + } + return 0; +} + +template +void FillCompressed(uint8_t* compressed, const size_t total_golden_elements, + const T* goldens, const size_t value_table_stride, + const T* value_table, const size_t channels, + const bool use_alt_axis, const size_t bit_width) { + uint16_t bits = 0; + size_t bits_accumulated = 0; + + if (use_alt_axis) { + size_t golden_element_groups = total_golden_elements / channels; + + for (size_t group = 0; group < golden_element_groups; group++) { + for (size_t channel = 0; channel < channels; channel++) { + size_t value_table_index = FindValueTableIndex( + goldens[(group * golden_element_groups) + channel], + &value_table[channel * value_table_stride], value_table_stride); + bits |= value_table_index << (16 - bits_accumulated - bit_width); + bits_accumulated += bit_width; + if (bits_accumulated > 8) { + *compressed++ = static_cast(bits >> 8); + bits <<= 8; + bits_accumulated -= 8; + } + } + } + } else { + size_t golden_elements_per_channel = total_golden_elements / channels; + + for (size_t channel = 0; channel < channels; channel++) { + for (size_t i = 0; i < golden_elements_per_channel; i++) { + size_t value_table_index = FindValueTableIndex( + goldens[(channel * golden_elements_per_channel) + i], value_table, + value_table_stride); + bits |= value_table_index << (16 - bits_accumulated - bit_width); + bits_accumulated += bit_width; + if (bits_accumulated > 8) { + *compressed++ = static_cast(bits >> 8); + bits <<= 8; + bits_accumulated -= 8; + } + } + value_table += value_table_stride; + } + } +} + +template +TfLiteStatus TestDecompression(TestingInfo* info) { + CompressionTensorData ctd = {}; + LookupTableData lut_data = {}; + ctd.scheme = CompressionScheme::kBinQuant; + ctd.data.lut_data = &lut_data; + lut_data.compressed_bit_width = info->bit_width; + lut_data.is_per_channel_quantized = info->channel_count > 1 ? true : false; + lut_data.use_alternate_axis = info->use_alt_axis; + lut_data.value_table = info->value_table; + lut_data.value_table_channel_stride = + info->total_value_table_elements / info->channel_count; + + DecompressionState ds(info->compressed, info->total_elements, ctd, + info->channel_count); + + std::fill_n(info->output, info->total_elements, ~0ULL); + ds.DecompressToBuffer(info->output); + + for (size_t i = 0; i < info->total_elements; i++) { + TF_LITE_MICRO_EXPECT_EQ(info->goldens[i], info->output[i]); + TF_LITE_MICRO_CHECK_FAIL(); + } + + return kTfLiteOk; +} + +template +void TestBitWidth(size_t bit_width) { + MicroPrintf(" Testing bit width %d", bit_width); + + TestingInfo info = {}; + + if (std::is_same::value) { + info.output = TestingData7_2_256_Bool.output; + info.goldens = TestingData7_2_256_Bool.goldens; + info.compressed = TestingData7_2_256_Bool.compressed; + info.value_table = TestingData7_2_256_Bool.value_table; + } + + info.bit_width = bit_width; + info.channel_count = 1; + info.total_elements = 16; + info.total_value_table_elements = 1 << bit_width; + info.use_alt_axis = false; + + FillValueTable(info.total_value_table_elements, info.value_table); + FillGoldens(info.total_elements, info.goldens, + info.total_value_table_elements, info.value_table, + info.channel_count, info.use_alt_axis); + FillCompressed(info.compressed, info.total_elements, info.goldens, + info.total_value_table_elements / info.channel_count, + info.value_table, info.channel_count, info.use_alt_axis, + info.bit_width); + + TestDecompression(&info); +} + +template +void TestAllBitWidths() { + for (size_t bw = 1; bw <= 7; bw++) { + TestBitWidth(bw); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(TestBool) { tflite::testing::TestAllBitWidths(); } + +TF_LITE_MICRO_TESTS_END + +#endif // USE_TFLM_COMPRESSION diff --git a/tensorflow/lite/micro/kernels/xtensa/decompress.cc b/tensorflow/lite/micro/kernels/xtensa/decompress.cc index 21c9cf8eb06..e1a15edb399 100644 --- a/tensorflow/lite/micro/kernels/xtensa/decompress.cc +++ b/tensorflow/lite/micro/kernels/xtensa/decompress.cc @@ -54,9 +54,7 @@ struct DecompressionStateXtensa : DecompressionState { // TODO(ddavis-2015): unaligned/stride code has error, method not currently // used. void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) { - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - ScopedMicroProfiler scoped_profiler(__func__, profiler); + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); ae_int8x8 d_shuffle_t = AE_MOVINT8X8_FROMINT64(0xFB73EA62D951C840LL); ae_int8x8 d_shuffle_value_t = AE_MOVINT8X8_FROMINT64(0x08192A3B4C5D6E7FLL); @@ -112,9 +110,7 @@ void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa(int8_t* buffer) { void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa_Old( int8_t* buffer) { - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - ScopedMicroProfiler scoped_profiler(__func__, profiler); + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); char shuffle_pattern_1[8] = {0x08, 0x19, 0x2A, 0x3B, 0x4C, 0x5D, 0x6E, 0x7F}; ae_int8x8 d_shuffle_t = *(ae_int8x8*)&shuffle_pattern_1[0]; @@ -155,15 +151,13 @@ void DecompressionStateXtensa::DecompressToBufferWidth4_Xtensa_Old( void DecompressionStateXtensa::DecompressToBufferWidthAnyInt8_Xtensa( int8_t* buffer) { const char* func_name_p = nullptr; - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - if (profiler != nullptr) { + if (micro_profiler_ != nullptr) { static char func_name[42]; MicroSnprintf(func_name, sizeof(func_name), "%s_%u", __func__, compressed_bit_width_); func_name_p = func_name; } - ScopedMicroProfiler scoped_profiler(func_name_p, profiler); + ScopedMicroProfiler scoped_profiler(func_name_p, micro_profiler_); const int stride = comp_data_.data.lut_data->value_table_channel_stride; const uint8_t* __restrict value_table = @@ -215,15 +209,13 @@ void DecompressionStateXtensa::DecompressToBufferWidthAnyInt8_Xtensa( void DecompressionStateXtensa::DecompressToBufferWidthAnyInt16_Xtensa( int16_t* buffer) { const char* func_name_p = nullptr; - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - if (profiler != nullptr) { + if (micro_profiler_ != nullptr) { static char func_name[43]; MicroSnprintf(func_name, sizeof(func_name), "%s_%u", __func__, compressed_bit_width_); func_name_p = func_name; } - ScopedMicroProfiler scoped_profiler(func_name_p, profiler); + ScopedMicroProfiler scoped_profiler(func_name_p, micro_profiler_); const int stride = comp_data_.data.lut_data->value_table_channel_stride; const uint16_t* __restrict value_table = @@ -275,15 +267,13 @@ void DecompressionStateXtensa::DecompressToBufferWidthAnyInt16_Xtensa( void DecompressionStateXtensa::DecompressToBufferWidthAnyInt32_Xtensa( int32_t* buffer) { const char* func_name_p = nullptr; - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - if (profiler != nullptr) { + if (micro_profiler_ != nullptr) { static char func_name[43]; MicroSnprintf(func_name, sizeof(func_name), "%s_%u", __func__, compressed_bit_width_); func_name_p = func_name; } - ScopedMicroProfiler scoped_profiler(func_name_p, profiler); + ScopedMicroProfiler scoped_profiler(func_name_p, micro_profiler_); const int stride = comp_data_.data.lut_data->value_table_channel_stride; const uint32_t* __restrict value_table = @@ -335,15 +325,13 @@ void DecompressionStateXtensa::DecompressToBufferWidthAnyInt32_Xtensa( void DecompressionStateXtensa::DecompressToBufferWidthAnyInt64_Xtensa( int64_t* buffer) { const char* func_name_p = nullptr; - MicroProfiler* profiler = - static_cast(micro_context_->external_context()); - if (profiler != nullptr) { + if (micro_profiler_ != nullptr) { static char func_name[43]; MicroSnprintf(func_name, sizeof(func_name), "%s_%u", __func__, compressed_bit_width_); func_name_p = func_name; } - ScopedMicroProfiler scoped_profiler(func_name_p, profiler); + ScopedMicroProfiler scoped_profiler(func_name_p, micro_profiler_); const int stride = comp_data_.data.lut_data->value_table_channel_stride; const uint64_t* __restrict value_table = @@ -427,9 +415,11 @@ int8_t* DecompressionState::DecompressToBuffer(void* buffer) { } } else if (comp_data_.data.lut_data->compressed_bit_width == 3 && !comp_data_.data.lut_data->use_alternate_axis) { + // TODO(ddavis-2015): placeholder dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); } else if (comp_data_.data.lut_data->compressed_bit_width == 2 && !comp_data_.data.lut_data->use_alternate_axis) { + // TODO(ddavis-2015): placeholder dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); } else { dsx.DecompressToBufferWidthAnyInt8_Xtensa(static_cast(buffer)); diff --git a/tensorflow/lite/micro/micro_context.cc b/tensorflow/lite/micro/micro_context.cc index eb557e83166..ce8aed8f683 100644 --- a/tensorflow/lite/micro/micro_context.cc +++ b/tensorflow/lite/micro/micro_context.cc @@ -96,7 +96,8 @@ void* MicroContext::DecompressTensorToBuffer( } DecompressionState ds(static_cast(tensor.data.data), count, - compression_data, num_channels, this); + compression_data, num_channels, + static_cast(external_context())); switch (tensor.type) { case kTfLiteBool: {