Skip to content

Commit

Permalink
feat(compression): add decompression library (#2996)
Browse files Browse the repository at this point in the history
Add a decompression library, defining structures for compressed
tensors and decompression logic to be used by kernels. Add a unit
test to validate decompression logic.

BUG=part of #2636
  • Loading branch information
rkuester authored Dec 10, 2024
1 parent 4a8bb6b commit d59136a
Show file tree
Hide file tree
Showing 9 changed files with 1,230 additions and 1 deletion.
10 changes: 10 additions & 0 deletions tensorflow/lite/micro/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ tflm_cc_library(
],
)

tflm_cc_library(
name = "compression",
hdrs = [
"compression.h",
],
deps = [
"//tensorflow/lite/c:common",
],
)

tflm_cc_library(
# TODO(b/187093492): Rename to micro_interpreter.
name = "micro_framework",
Expand Down
68 changes: 68 additions & 0 deletions tensorflow/lite/micro/compression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_
#define TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_

#ifdef USE_TFLM_COMPRESSION

#include "tensorflow/lite/c/common.h"

namespace tflite {

//
// Compressed tensors
//

static constexpr const char* kCompressionMetadataString =
"COMPRESSION_METADATA";

enum class CompressionScheme : uint8_t {
kBinQuant,
};

struct LookupTableData {
static constexpr size_t kMaxBitWidth = 7;
static constexpr size_t kMaxValueTableChannelStride = 128;

const void* value_table; // Pointer into FlatBuffer Values.
uint8_t value_table_channel_stride; // elements per channel
uint8_t compressed_bit_width : 3; // 1 to 7 bits
bool is_per_channel_quantized : 1; // tensor is per-channel quantized
bool use_alternate_axis : 1; // shape default channel:
// 0 = first, 1 = last
uint8_t reserved : 3;
};

union CompressionData {
LookupTableData* lut_data;
};

struct CompressionTensorData {
CompressionScheme scheme;
CompressionData data;
};

struct CompressedTensorList {
// Sparsely populated array with the same number of elements as there are
// tensors in the Subgraph. An alternative would include a tensor index in
// the struct for each and walk the list on look up. This could be slow.
const CompressionTensorData** tensors;
};

} // namespace tflite

#endif // USE_TFLM_COMPRESSION
#endif // TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_
42 changes: 42 additions & 0 deletions tensorflow/lite/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,29 @@ tflm_cc_library(
],
)

tflm_cc_library(
name = "decompress",
srcs = [
"decompress.cc",
"decompress_common.cc",
],
hdrs = [
"decompress.h",
],
visibility = [
":kernel_friends",
":tflite_micro",
],
deps = [
"//tensorflow/lite:type_to_tflitetype",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/micro:compression",
"//tensorflow/lite/micro:micro_common",
"//tensorflow/lite/micro:micro_log",
"//tensorflow/lite/micro:micro_profiler",
],
)

tflm_cc_library(
name = "detection_postprocess_flexbuffers_generated_data",
srcs = [
Expand Down Expand Up @@ -613,6 +636,25 @@ tflm_cc_test(
],
)

tflm_cc_test(
name = "decompress_test",
srcs = [
"decompress_test.cc",
],
target_compatible_with = select({
"//conditions:default": ["@platforms//:incompatible"],
"//:with_compression_enabled": [],
}),
deps = [
":decompress",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:micro_arena_constants",
"//tensorflow/lite/micro:micro_log",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)

tflm_cc_test(
name = "depth_to_space_test",
srcs = [
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/lite/micro/kernels/Makefile.inc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -180,6 +180,10 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unpack_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/while_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/zeros_like_test.cc

ifeq ($(ENABLE_COMPRESSION), yes)
MICROLITE_KERNEL_SIMPLE_TEST_SRCS += $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decompress_test.cc
endif

# Generate simple kernel test targets in a common way
$(foreach TEST_TARGET,$(MICROLITE_KERNEL_SIMPLE_TEST_SRCS),\
$(eval $(call microlite_test,kernel_$(notdir $(basename $(TEST_TARGET))),$(TEST_TARGET))))
61 changes: 61 additions & 0 deletions tensorflow/lite/micro/kernels/decompress.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifdef USE_TFLM_COMPRESSION

#include "tensorflow/lite/micro/kernels/decompress.h"

#include <cstddef>
#include <type_traits>

#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/micro/micro_common.h"

namespace tflite {

template <typename T>
T* DecompressionState::DecompressToBuffer(void* buffer) {
TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth);
TFLITE_DCHECK(compressed_bit_width_ > 0);

if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 4 &&
!comp_data_.data.lut_data->use_alternate_axis) {
DecompressToBufferWidth4_16(static_cast<int8_t*>(buffer));
} else if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 3 &&
!comp_data_.data.lut_data->use_alternate_axis) {
DecompressToBufferWidth3_32(static_cast<int8_t*>(buffer));
} else if (std::is_same<T, int8_t>::value &&
comp_data_.data.lut_data->compressed_bit_width == 2 &&
!comp_data_.data.lut_data->use_alternate_axis) {
DecompressToBufferWidth2_16(static_cast<int8_t*>(buffer));
} else {
DecompressToBufferWidthAny<T>(static_cast<T*>(buffer));
}

return static_cast<T*>(buffer);
}

template bool* DecompressionState::DecompressToBuffer<bool>(void*);
template float* DecompressionState::DecompressToBuffer<float>(void*);
template int8_t* DecompressionState::DecompressToBuffer<int8_t>(void*);
template int16_t* DecompressionState::DecompressToBuffer<int16_t>(void*);
template int32_t* DecompressionState::DecompressToBuffer<int32_t>(void*);
template int64_t* DecompressionState::DecompressToBuffer<int64_t>(void*);

} // namespace tflite

#endif // USE_TFLM_COMPRESSION
89 changes: 89 additions & 0 deletions tensorflow/lite/micro/kernels/decompress.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_
#define TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_

#include <cstdint>

#include "tensorflow/lite/micro/compression.h"
#include "tensorflow/lite/micro/micro_profiler.h"

namespace tflite {

#ifdef USE_TFLM_COMPRESSION

struct DecompressionState {
DecompressionState() = delete;

DecompressionState(const uint8_t* compressed_indices,
const size_t count_indices,
const CompressionTensorData& comp_data,
const size_t num_channels,
MicroProfilerInterface* profiler = nullptr)
: compressed_indices_(compressed_indices),
count_indices_(count_indices),
comp_data_(comp_data),
num_channels_(num_channels),
micro_profiler_(profiler) {}

DecompressionState(const DecompressionState& other)
: compressed_indices_(other.compressed_indices_),
count_indices_(other.count_indices_),
comp_data_(other.comp_data_),
num_channels_(other.num_channels_),
micro_profiler_(other.micro_profiler_) {}

template <typename T>
T* DecompressToBuffer(void* buffer);

protected:
// optimized C++ for INT8, use_alt_axis == false
void DecompressToBufferWidth4_16(int8_t* buffer);
void DecompressToBufferWidth3_32(int8_t* buffer);
void DecompressToBufferWidth2_16(int8_t* buffer);

// generic C++ for any bit width and value table type
template <typename T>
void DecompressToBufferWidthAny(T* buffer);

// Optimized C++ table index fetch
inline size_t GetNextTableIndexWidth7(const size_t current_offset);
inline size_t GetNextTableIndexWidth6(const size_t current_offset);
inline size_t GetNextTableIndexWidth5(const size_t current_offset);
inline size_t GetNextTableIndexWidth4(const size_t current_offset);
inline size_t GetNextTableIndexWidth3(const size_t current_offset);
inline size_t GetNextTableIndexWidth2(const size_t current_offset);
inline size_t GetNextTableIndexWidth1(const size_t current_offset);

protected:
const uint8_t* compressed_indices_;
const size_t count_indices_;
const CompressionTensorData& comp_data_;
const size_t num_channels_;
const size_t compressed_bit_width_ =
comp_data_.data.lut_data->compressed_bit_width;
const size_t elements_per_channel_ =
comp_data_.data.lut_data->use_alternate_axis
? 1
: count_indices_ / num_channels_;
MicroProfilerInterface* micro_profiler_;
};

#endif // USE_TFLM_COMPRESSION

} // namespace tflite

#endif // TENSORFLOW_LITE_MICRO_MICRO_KERNELS_DECOMPRESS_H_
Loading

0 comments on commit d59136a

Please sign in to comment.