Skip to content

Commit

Permalink
add decompression unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Oct 17, 2024
1 parent 9bb2b63 commit b318421
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 42 deletions.
6 changes: 5 additions & 1 deletion tensorflow/lite/micro/kernels/Makefile.inc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -180,6 +180,10 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unpack_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/while_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/zeros_like_test.cc

ifeq ($(ENABLE_COMPRESSION), yes)
MICROLITE_KERNEL_SIMPLE_TEST_SRCS += $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_test.cc
endif

# Generate simple kernel test targets in a common way
$(foreach TEST_TARGET,$(MICROLITE_KERNEL_SIMPLE_TEST_SRCS),\
$(eval $(call microlite_test,kernel_$(notdir $(basename $(TEST_TARGET))),$(TEST_TARGET))))
12 changes: 7 additions & 5 deletions tensorflow/lite/micro/kernels/decompress.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ limitations under the License.

#include <cstdint>

#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/compression.h"
#include "tensorflow/lite/micro/micro_profiler.h"

namespace tflite {

Expand All @@ -27,19 +28,20 @@ struct DecompressionState {
DecompressionState(const uint8_t* compressed_indices,
const size_t count_indices,
const CompressionTensorData& comp_data,
const size_t num_channels, MicroContext* micro_context)
const size_t num_channels,
MicroProfiler* profiler = nullptr)
: compressed_indices_(compressed_indices),
count_indices_(count_indices),
comp_data_(comp_data),
num_channels_(num_channels),
micro_context_(micro_context) {}
micro_profiler_(profiler) {}

DecompressionState(const DecompressionState& other)
: compressed_indices_(other.compressed_indices_),
count_indices_(other.count_indices_),
comp_data_(other.comp_data_),
num_channels_(other.num_channels_),
micro_context_(other.micro_context_) {}
micro_profiler_(other.micro_profiler_) {}

template <typename T>
T* DecompressToBuffer(void* buffer);
Expand Down Expand Up @@ -74,7 +76,7 @@ struct DecompressionState {
comp_data_.data.lut_data->use_alternate_axis
? 1
: count_indices_ / num_channels_;
MicroContext* micro_context_;
MicroProfiler* micro_profiler_;
};

#endif // USE_TFLM_COMPRESSION
Expand Down
18 changes: 5 additions & 13 deletions tensorflow/lite/micro/kernels/decompress_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ limitations under the License.
namespace tflite {

void DecompressionState::DecompressToBufferWidth4_16(int8_t* buffer) {
MicroProfiler* profiler =
static_cast<MicroProfiler*>(micro_context_->external_context());
ScopedMicroProfiler scoped_profiler(__func__, profiler);
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);

const size_t stride = comp_data_.data.lut_data->value_table_channel_stride;
const uint8_t* value_table =
Expand Down Expand Up @@ -108,9 +106,7 @@ void DecompressionState::DecompressToBufferWidth4_16(int8_t* buffer) {
}

void DecompressionState::DecompressToBufferWidth2_16(int8_t* buffer) {
MicroProfiler* profiler =
static_cast<MicroProfiler*>(micro_context_->external_context());
ScopedMicroProfiler scoped_profiler(__func__, profiler);
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);

const size_t stride = comp_data_.data.lut_data->value_table_channel_stride;
const uint8_t* value_table =
Expand Down Expand Up @@ -187,9 +183,7 @@ void DecompressionState::DecompressToBufferWidth2_16(int8_t* buffer) {
}

void DecompressionState::DecompressToBufferWidth3_32(int8_t* buffer) {
MicroProfiler* profiler =
static_cast<MicroProfiler*>(micro_context_->external_context());
ScopedMicroProfiler scoped_profiler(__func__, profiler);
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);

const size_t stride = comp_data_.data.lut_data->value_table_channel_stride;
const uint8_t* value_table =
Expand Down Expand Up @@ -325,16 +319,14 @@ void DecompressionState::DecompressToBufferWidth3_32(int8_t* buffer) {
template <typename T>
void DecompressionState::DecompressToBufferWidthAny(T* buffer) {
const char* func_name_p = nullptr;
MicroProfiler* profiler =
static_cast<MicroProfiler*>(micro_context_->external_context());
if (profiler != nullptr) {
if (micro_profiler_ != nullptr) {
static char func_name[35];
MicroSnprintf(func_name, sizeof(func_name), "%s_%u_%s", __func__,
compressed_bit_width_,
TfLiteTypeGetName(typeToTfLiteType<T>()));
func_name_p = func_name;
}
ScopedMicroProfiler scoped_profiler(func_name_p, profiler);
ScopedMicroProfiler scoped_profiler(func_name_p, micro_profiler_);

if (comp_data_.data.lut_data->use_alternate_axis) {
const size_t stride = comp_data_.data.lut_data->value_table_channel_stride;
Expand Down
273 changes: 273 additions & 0 deletions tensorflow/lite/micro/kernels/decompress_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifdef USE_TFLM_COMPRESSION

#include "tensorflow/lite/micro/kernels/decompress.h"

#include <algorithm>
#include <type_traits>

#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro//micro_log.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"

namespace tflite {
namespace testing {
namespace {

template <typename T>
struct TestingInfo {
T* output;
T* goldens;
uint8_t* compressed;
T* value_table;

size_t bit_width;
size_t total_elements;
size_t total_value_table_elements;
size_t channel_count;
bool use_alt_axis;
};

template <typename T>
struct TestingData7_2_256 {
static constexpr size_t kBitWidth = 7;
static constexpr size_t kChannels = 2;
static constexpr size_t kElementsPerChannel = 256;

static constexpr size_t kTotalElements = kElementsPerChannel * kChannels;
static constexpr size_t kCompressedBytes =
((kTotalElements * kBitWidth) + 7) / 8;
static constexpr size_t kValueTableSize = (1 << kBitWidth) * kChannels;

alignas(MicroArenaBufferAlignment()) T output[kTotalElements];
alignas(MicroArenaBufferAlignment()) uint8_t compressed[kCompressedBytes];
alignas(MicroArenaBufferAlignment()) T value_table[kValueTableSize];
T goldens[kTotalElements];
};

TestingData7_2_256<bool> TestingData7_2_256_Bool;
#ifdef notyet
TestingData7_2_256<float> TestingData7_2_256_Float32;
TestingData7_2_256<int8_t> TestingData7_2_256_Int8;
TestingData7_2_256<int16_t> TestingData7_2_256_Int16;
TestingData7_2_256<int32_t> TestingData7_2_256_Int32;
TestingData7_2_256<int64_t> TestingData7_2_256_Int64;
#endif // notyet

template <typename T>
void FillValueTable(const size_t total_elements, T* value_table) {
T fill_value = -1;
for (size_t i = 0; i < total_elements; i++) {
value_table[i] = fill_value;
fill_value -= 1;
}
}

#ifdef notyet
template <>
void FillValueTable(const size_t total_elements, float* value_table) {
float fill_value = -1.1f;
for (size_t i = 0; i < total_elements; i++) {
value_table[i] = fill_value;
fill_value -= 1.0f;
}
}
#endif // notyet

template <>
void FillValueTable(const size_t total_elements, bool* value_table) {
bool fill_value = true;
for (size_t i = 0; i < total_elements; i++) {
value_table[i] = fill_value;
fill_value = !fill_value;
}
}

template <typename T>
void FillGoldens(const size_t total_elements, T* goldens,
const size_t value_table_elements, const T* value_table,
const size_t channels, const bool use_alt_axis) {
if (use_alt_axis) {
const size_t value_table_stride = value_table_elements / channels;
const size_t element_groups = total_elements / channels;
size_t value_table_index = 0; // index within current channel

for (size_t group = 0; group < element_groups; group++) {
for (size_t channel = 0; channel < channels; channel++) {
goldens[(group * channels) + channel] =
value_table[(channel * value_table_stride) + value_table_index];
}
if (++value_table_index == value_table_stride) {
value_table_index = 0;
}
}
} else {
const size_t value_table_stride = value_table_elements / channels;
const size_t elements_per_channel = total_elements / channels;
size_t value_table_index = 0; // index within current channel

for (size_t channel = 0; channel < channels; channel++) {
for (size_t i = 0; i < elements_per_channel; i++) {
goldens[(channel * elements_per_channel) + i] =
value_table[(channel * value_table_stride) + value_table_index++];
if (value_table_index == value_table_stride) {
value_table_index = 0;
}
}
value_table_index = 0;
}
}
}

// returns index within channel
template <typename T>
size_t FindValueTableIndex(const T value, const T* value_table,
const size_t value_table_stride) {
for (size_t i = 0; i < value_table_stride; i++) {
if (value == value_table[i]) {
return i;
}
}
return 0;
}

template <typename T>
void FillCompressed(uint8_t* compressed, const size_t total_golden_elements,
const T* goldens, const size_t value_table_stride,
const T* value_table, const size_t channels,
const bool use_alt_axis, const size_t bit_width) {
uint16_t bits = 0;
size_t bits_accumulated = 0;

if (use_alt_axis) {
size_t golden_element_groups = total_golden_elements / channels;

for (size_t group = 0; group < golden_element_groups; group++) {
for (size_t channel = 0; channel < channels; channel++) {
size_t value_table_index = FindValueTableIndex(
goldens[(group * golden_element_groups) + channel],
&value_table[channel * value_table_stride], value_table_stride);
bits |= value_table_index << (16 - bits_accumulated - bit_width);
bits_accumulated += bit_width;
if (bits_accumulated > 8) {
*compressed++ = static_cast<uint8_t>(bits >> 8);
bits <<= 8;
bits_accumulated -= 8;
}
}
}
} else {
size_t golden_elements_per_channel = total_golden_elements / channels;

for (size_t channel = 0; channel < channels; channel++) {
for (size_t i = 0; i < golden_elements_per_channel; i++) {
size_t value_table_index = FindValueTableIndex(
goldens[(channel * golden_elements_per_channel) + i], value_table,
value_table_stride);
bits |= value_table_index << (16 - bits_accumulated - bit_width);
bits_accumulated += bit_width;
if (bits_accumulated > 8) {
*compressed++ = static_cast<uint8_t>(bits >> 8);
bits <<= 8;
bits_accumulated -= 8;
}
}
value_table += value_table_stride;
}
}
}

template <typename T>
TfLiteStatus TestDecompression(TestingInfo<T>* info) {
CompressionTensorData ctd = {};
LookupTableData lut_data = {};
ctd.scheme = CompressionScheme::kBinQuant;
ctd.data.lut_data = &lut_data;
lut_data.compressed_bit_width = info->bit_width;
lut_data.is_per_channel_quantized = info->channel_count > 1 ? true : false;
lut_data.use_alternate_axis = info->use_alt_axis;
lut_data.value_table = info->value_table;
lut_data.value_table_channel_stride =
info->total_value_table_elements / info->channel_count;

DecompressionState ds(info->compressed, info->total_elements, ctd,
info->channel_count);

std::fill_n(info->output, info->total_elements, ~0ULL);
ds.DecompressToBuffer<T>(info->output);

for (size_t i = 0; i < info->total_elements; i++) {
TF_LITE_MICRO_EXPECT_EQ(info->goldens[i], info->output[i]);
TF_LITE_MICRO_CHECK_FAIL();
}

return kTfLiteOk;
}

template <typename T>
void TestBitWidth(size_t bit_width) {
MicroPrintf(" Testing bit width %d", bit_width);

TestingInfo<T> info = {};

if (std::is_same<T, bool>::value) {
info.output = TestingData7_2_256_Bool.output;
info.goldens = TestingData7_2_256_Bool.goldens;
info.compressed = TestingData7_2_256_Bool.compressed;
info.value_table = TestingData7_2_256_Bool.value_table;
}

info.bit_width = bit_width;
info.channel_count = 1;
info.total_elements = 16;
info.total_value_table_elements = 1 << bit_width;
info.use_alt_axis = false;

FillValueTable(info.total_value_table_elements, info.value_table);
FillGoldens(info.total_elements, info.goldens,
info.total_value_table_elements, info.value_table,
info.channel_count, info.use_alt_axis);
FillCompressed(info.compressed, info.total_elements, info.goldens,
info.total_value_table_elements / info.channel_count,
info.value_table, info.channel_count, info.use_alt_axis,
info.bit_width);

TestDecompression(&info);
}

template <typename T>
void TestAllBitWidths() {
for (size_t bw = 1; bw <= 7; bw++) {
TestBitWidth<T>(bw);
}
}

} // namespace
} // namespace testing
} // namespace tflite

TF_LITE_MICRO_TESTS_BEGIN

TF_LITE_MICRO_TEST(TestBool) { tflite::testing::TestAllBitWidths<bool>(); }

TF_LITE_MICRO_TESTS_END

#endif // USE_TFLM_COMPRESSION
Loading

0 comments on commit b318421

Please sign in to comment.