From f7bc071d51479a60fc47aec70611172699355b6f Mon Sep 17 00:00:00 2001 From: The JAX SC Authors Date: Wed, 15 Jan 2025 18:13:24 -0800 Subject: [PATCH] Input preprocessing library to support vocab-dimension mini-batching. Currently only PMAP is supported for simplicity. JAX support will be added later. PiperOrigin-RevId: 716024250 --- jax_tpu_embedding/sparsecore/lib/core/BUILD | 52 +- .../lib/core/input_preprocessing.cc | 44 +- .../lib/core/input_preprocessing_py_util.cc | 88 ++ .../lib/core/input_preprocessing_py_util.h | 40 + .../lib/core/input_preprocessing_util.h | 10 +- .../input_preprocessing_with_mini_batching.cc | 1165 +++++++++++++++++ .../input_preprocessing_with_mini_batching.h | 47 + .../sparsecore/lib/core/primitives/BUILD | 1 + ...rse_dense_matmul_csr_with_mini_batching.py | 3 +- jax_tpu_embedding/sparsecore/lib/nn/BUILD | 38 + .../sparsecore/lib/nn/embedding.py | 119 +- .../sparsecore/lib/nn/embedding_utils.py | 138 ++ .../lib/nn/embedding_with_mini_batching.py | 485 +++++++ .../sparsecore/lib/nn/tests/test_utils.py | 30 + 14 files changed, 2099 insertions(+), 161 deletions(-) create mode 100644 jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.cc create mode 100644 jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h create mode 100644 jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_with_mini_batching.cc create mode 100644 jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_with_mini_batching.h create mode 100644 jax_tpu_embedding/sparsecore/lib/nn/embedding_utils.py create mode 100644 jax_tpu_embedding/sparsecore/lib/nn/embedding_with_mini_batching.py diff --git a/jax_tpu_embedding/sparsecore/lib/core/BUILD b/jax_tpu_embedding/sparsecore/lib/core/BUILD index 16285a6..1d8a1df 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/core/BUILD @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -load("//third_party/bazel/python:pybind11.bzl", "pybind_extension") +load("//third_party/bazel/python:pybind11.bzl", "pybind_extension", "pybind_library") load("//third_party/bazel/python:pypi.bzl", "pypi_requirement") load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test", "pytype_strict_library") @@ -72,17 +72,65 @@ cc_test( ], ) +pybind_library( + name = "input_preprocessing_py_util", + srcs = [ + "input_preprocessing_py_util.cc", + ], + hdrs = [ + "input_preprocessing_py_util.h", + ], + deps = [ + ":input_preprocessing_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@tsl//tsl/profiler/lib:traceme", + ], +) + pybind_extension( name = "input_preprocessing_cc", - srcs = ["input_preprocessing.cc"], + srcs = [ + "input_preprocessing.cc", + ], deps = [ + ":input_preprocessing_py_util", ":input_preprocessing_threads", ":input_preprocessing_util", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@highway//:hwy", + "@highway//hwy/contrib/sort:vqsort", + "@tsl//tsl/profiler/lib:connected_traceme", + "@tsl//tsl/profiler/lib:traceme", + ], +) + +pybind_extension( + name = "input_preprocessing_with_mini_batching_cc", + srcs = [ + "input_preprocessing_with_mini_batching.cc", + "input_preprocessing_with_mini_batching.h", + ], + deps = [ + ":input_preprocessing_py_util", + ":input_preprocessing_threads", + ":input_preprocessing_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@highway//:hwy", + "@highway//hwy/contrib/sort:vqsort", "@tsl//tsl/profiler/lib:connected_traceme", "@tsl//tsl/profiler/lib:traceme", ], diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc index 0b5d9ef..ac7cf2e 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. #include -#include #include #include #include @@ -24,6 +23,7 @@ #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/synchronization/blocking_counter.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h" #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h" #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" #include "pybind11/cast.h" // from @pybind11 @@ -148,48 +148,6 @@ int ExtractCooTensors(const py::array& features, global_device_count, coo_tensors); } -absl::flat_hash_map> -GetStackedTableMetadata(py::list feature_specs, py::list features) { - tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; }); - absl::flat_hash_map> - stacked_table_metadata; - for (int i = 0; i < feature_specs.size(); ++i) { - const py::object& feature_spec = feature_specs[i]; - const py::array& feature = features[i].cast(); - const py::object& feature_transformation = - feature_spec.attr("_id_transformation"); - const py::object& table_spec = feature_spec.attr("table_spec"); - const py::object& stacked_table_spec = - table_spec.attr("stacked_table_spec"); - const std::string stacked_table_name = py::cast( - table_spec.attr("_setting_in_stack").attr("stack_name")); - int col_shift = 0; - int col_offset = 0; - int row_offset = 0; - const int max_ids_per_partition = - py::cast(stacked_table_spec.attr("max_ids_per_partition")); - const int max_unique_ids_per_partition = - py::cast(stacked_table_spec.attr("max_unique_ids_per_partition")); - if (!feature_transformation.is_none()) { - row_offset = py::cast(feature_transformation.attr("row_offset")); - col_shift = py::cast(feature_transformation.attr("col_shift")); - col_offset = py::cast(feature_transformation.attr("col_offset")); - } - stacked_table_metadata[stacked_table_name].emplace_back( - i, max_ids_per_partition, max_unique_ids_per_partition, row_offset, - col_offset, col_shift, - /*batch_size=*/feature.shape(0)); - } - // Sort the stacked tables by row_offset. - for (auto& [_, t] : stacked_table_metadata) { - std::sort(t.begin(), t.end(), - [](const StackedTableMetadata& a, const StackedTableMetadata& b) { - return a.row_offset < b.row_offset; - }); - } - return stacked_table_metadata; -} - // Preprocess inputs for a single table. Stacked table here refers to a // a table that has no parent in the table stacking hierarchy. So in the case // of table stacking, the stacked table is the top level table and in the case diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.cc new file mode 100644 index 0000000..db1747a --- /dev/null +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.cc @@ -0,0 +1,88 @@ +// Copyright 2024 The JAX SC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/log/check.h" // from @com_google_absl +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" +#include "pybind11/cast.h" // from @pybind11 +#include "pybind11/gil.h" // from @pybind11 +#include "pybind11/numpy.h" // from @pybind11 +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "tsl/profiler/lib/traceme.h" // from @tsl + +namespace jax_sc_embedding { + +namespace py = ::pybind11; + +absl::flat_hash_map> +GetStackedTableMetadata(const py::list& feature_specs, const int batch_size) { + tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; }); + absl::flat_hash_map> + stacked_table_metadata; + for (int i = 0; i < feature_specs.size(); ++i) { + const py::object& feature_spec = feature_specs[i]; + + const py::object& feature_transformation = + feature_spec.attr("_id_transformation"); + const py::object& table_spec = feature_spec.attr("table_spec"); + const py::object& stacked_table_spec = + table_spec.attr("stacked_table_spec"); + const std::string stacked_table_name = py::cast( + table_spec.attr("_setting_in_stack").attr("stack_name")); + int col_shift = 0; + int col_offset = 0; + int row_offset = 0; + const int max_ids_per_partition = + py::cast(stacked_table_spec.attr("max_ids_per_partition")); + const int max_unique_ids_per_partition = + py::cast(stacked_table_spec.attr("max_unique_ids_per_partition")); + const int vocab_size = + py::cast(stacked_table_spec.attr("stack_vocab_size")); + if (!feature_transformation.is_none()) { + row_offset = py::cast(feature_transformation.attr("row_offset")); + col_shift = py::cast(feature_transformation.attr("col_shift")); + col_offset = py::cast(feature_transformation.attr("col_offset")); + } + stacked_table_metadata[stacked_table_name].emplace_back( + i, max_ids_per_partition, max_unique_ids_per_partition, row_offset, + col_offset, col_shift, + /*batch_size=*/batch_size, vocab_size); + } + // Sort the stacked tables by row_offset. + for (auto& [_, t] : stacked_table_metadata) { + std::sort(t.begin(), t.end(), + [](const StackedTableMetadata& a, const StackedTableMetadata& b) { + return a.row_offset < b.row_offset; + }); + } + return stacked_table_metadata; +} + +absl::flat_hash_map> +GetStackedTableMetadata(const py::list& feature_specs, + const py::list& features) { + tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; }); + int batch_size = features[0].cast().shape(0); + return GetStackedTableMetadata(feature_specs, batch_size); +} + +} // namespace jax_sc_embedding diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h new file mode 100644 index 0000000..290b0ec --- /dev/null +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h @@ -0,0 +1,40 @@ +// Copyright 2024 The JAX SC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_ +#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_ +#include +#include + +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" +#include "pybind11/numpy.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 + +namespace jax_sc_embedding { + +namespace py = ::pybind11; + +// Copy information from feature_specs to StackedTableMetadata. +// The features argument is only used to get the batch size. +absl::flat_hash_map> +GetStackedTableMetadata(const py::list& feature_specs, + const py::list& features); + +// Copy information from feature_specs to StackedTableMetadata. +absl::flat_hash_map> +GetStackedTableMetadata(const py::list& feature_specs, int batch_size); + +} // namespace jax_sc_embedding + +#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_ diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h index 1359271..dc9f6ab 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h @@ -35,7 +35,7 @@ struct CooFormat { // Get adjusted col_id based on shift and offset. int GetColId(int col_id, int col_shift, int col_offset, int num_scs_mod, - int num_scs_mod_inv); + int num_scs_mod_inv); inline unsigned int CeilOfRatio(unsigned int numerator, unsigned int denominator) { @@ -50,14 +50,16 @@ struct StackedTableMetadata { StackedTableMetadata() = delete; StackedTableMetadata(int feature_index, int max_ids_per_partition, int max_unique_ids_per_partition, int row_offset, - int col_offset, int col_shift, int batch_size) + int col_offset, int col_shift, int batch_size, + int stacked_table_vocab_size = 0) : feature_index(feature_index), max_ids_per_partition(max_ids_per_partition), max_unique_ids_per_partition(max_unique_ids_per_partition), row_offset(row_offset), col_offset(col_offset), col_shift(col_shift), - batch_size(batch_size) {} + batch_size(batch_size), + stacked_table_vocab_size(stacked_table_vocab_size) {} // The batch is given as a list of features (numpy arrays). `feature_index` // represents the index of the feature in the list. int feature_index; @@ -70,6 +72,8 @@ struct StackedTableMetadata { // Process local batch size of the feature. int batch_size; + + int stacked_table_vocab_size; }; void SortAndGroupCooTensors( diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_with_mini_batching.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_with_mini_batching.cc new file mode 100644 index 0000000..07249fe --- /dev/null +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_with_mini_batching.cc @@ -0,0 +1,1165 @@ +// Copyright 2024 The JAX SC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_with_mini_batching.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" // from @com_google_absl +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/log/check.h" // from @com_google_absl +#include "absl/log/log.h" // from @com_google_absl +#include "absl/strings/str_cat.h" // from @com_google_absl +#include "absl/synchronization/blocking_counter.h" // from @com_google_absl +#include "absl/synchronization/mutex.h" // from @com_google_absl +#include "absl/types/span.h" // from @com_google_absl +#include "hwy/base.h" // from @highway +#include "hwy/contrib/sort/order.h" // from @highway +#include "hwy/contrib/sort/vqsort.h" // from @highway +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h" +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h" +#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" +#include "pybind11/gil.h" // from @pybind11 +#include "pybind11/numpy.h" // from @pybind11 +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "tsl/profiler/lib/connected_traceme.h" // from @tsl +#include "tsl/profiler/lib/traceme.h" // from @tsl + +namespace jax_sc_embedding { + +namespace { + +namespace py = ::pybind11; + +using MiniBatchingSplit = int64_t; + +using BufferSizes = absl::flat_hash_map; + +// This device specific list is a 1D list. +// +// The 1st level of vector is for the local devices. +// The value is the batch size for each sparsecore on this device. +using DeviceBatchSizeList = std::vector; +using DeviceBatchSizeLists = + absl::flat_hash_map; + +// This device specific list is a 2D list. +// +// The 1st level of vector is for the local devices. +// The 2nd level of vector is for the device-local list of data. +template +using DeviceDataList = std::vector>; +template +// The map from (stacked) table name to associated DeviceDataList. +using DeviceDataLists = absl::flat_hash_map>; + +// ID counter per device, aggregated across all local sparsecores. +// The index is for the global sparsecores. +using AggregatedIdCounterPerDevice = std::vector; + +// ID counter for all devices. The index is for the local devices. +using AggregatedIdCounter = std::vector; + +// The map from (stacked) table name to associated AggregatedIdCounter. +using AggregatedIdCounters = + absl::flat_hash_map; + +// This sparsecore specific list is a 2D list. +// +// The 1st level of vector is for the device-local sparsecores. +// The 2nd level of vector is for the sparsecore-local list of data. +template +using SparsecoreDataList = std::vector>; + +// This sparsecore specific list is a 3D list. +// +// The 1st level of vector is for the process-local devices. +// The 2nd level of vector is for the device-local sparsecores. +// The 3rd level of vector is for the sparsecore-local list of data. +template +using DeviceSparsecoreDataList = std::vector>; + +// The map from (stacked) table name to associated DeviceSparsecoreDataList. +template +using DeviceSparsecoreDataLists = + absl::flat_hash_map>; + +// This sparsecore specific list is a 4D list. +// +// The 1st level of vector is for the process-local devices. +// The 2nd level of vector is for the device-local sparsecores. +// The 3rd level of vector is for the mini-batching split. +// The 4th level of vector is for the sparsecore-local list of data. +template +using DeviceSparsecoreMiniBatchingDataList = + std::vector>; + +// The map from (stacked) table name to associated +// DeviceSparsecoreMiniBatchingDataList. +template +using DeviceSparsecoreMiniBatchingDataLists = + absl::flat_hash_map>; + +template +void Convert2dToPyDictLocked(py::dict& py_dict, const DeviceDataLists& map) { + for (const auto& [key, value] : map) { + if (value[0].size() > 0) { + py::array_t py_value({value.size(), value[0].size()}); + for (int local_device = 0; local_device < value.size(); ++local_device) { + T* const py_data_ptr = py_value.mutable_data(local_device); + std::copy(value[local_device].begin(), value[local_device].end(), + py_data_ptr); + } + py_dict[key.c_str()] = py_value; + } else { + py_dict[key.c_str()] = py::none(); + } + } +} + +template +void Reshape2dToPyDictLocked(py::dict& py_dict, const DeviceDataLists& map) { + for (const auto& [key, value] : map) { + if (value[0].size() > 0) { + py::array_t py_value( + {static_cast(value.size() * value[0].size())}); + T* py_data_ptr = py_value.mutable_data(); + for (int local_device = 0; local_device < value.size(); ++local_device) { + std::copy(value[local_device].begin(), value[local_device].end(), + py_data_ptr); + py_data_ptr += value[local_device].size(); + } + py_dict[key.c_str()] = py_value; + } else { + py_dict[key.c_str()] = py::none(); + } + } +} + +template +void Extend2dToPyDictLocked(py::dict& py_dict, const DeviceDataLists& map, + const BufferSizes& intended_buffer_sizes) { + for (const auto& [key, value] : map) { + auto buffer_size_it = intended_buffer_sizes.find(key); + CHECK(buffer_size_it != intended_buffer_sizes.end()); + const int buffer_size = buffer_size_it->second; + if (buffer_size < value[0].size()) { + throw std::runtime_error("The intended buffer size is too small."); + } + + if (value[0].size() > 0) { + // Note that the elements in the python array are only valid to the point + // of the actual size of the C++ vector. Beyond that, the elements in the + // python array are uninitialized, and accessing them could lead to + // undefined behavior. This is to save time, but we still need the shape + // to be consistent across all steps to avoid costly re-compilations. + py::array_t py_value({static_cast(value.size()), buffer_size}); + for (int local_device = 0; local_device < value.size(); ++local_device) { + T* const py_data_ptr = py_value.mutable_data(local_device); + std::copy(value[local_device].begin(), value[local_device].end(), + py_data_ptr); + } + py_dict[key.c_str()] = py_value; + } else { + py_dict[key.c_str()] = py::none(); + } + } +} + +std::tuple, AggregatedIdCounterPerDevice, + AggregatedIdCounterPerDevice, AggregatedIdCounterPerDevice> +SortAndGroupCooTensorsWithIdDrop( + const std::vector& coo_tensors_for_device, bool drop_ids, + int num_scs, int num_sc_per_device, int batch_size_per_sc, + int max_ids_per_partition, int max_unique_ids_per_partition, + int initial_num_coo_tensors_per_sc, std::vector& max_ids_per_sc, + std::vector& max_unique_ids_per_sc, + std::vector& id_drop_counter_per_sc, std::vector& keys) { + tsl::profiler::TraceMe t("SortAndGroupCooTensorsWithIdDrop"); + + bool mini_batching_needed = false; + + SparsecoreDataList coo_tensors_by_sc; + + // Initialize the counters to be 0 for all SCs. + // Index is for global SCs as sources of the embedding data. + AggregatedIdCounterPerDevice max_id_counter_by_sc(num_scs, 0); + AggregatedIdCounterPerDevice max_unique_id_counter_by_sc(num_scs, 0); + AggregatedIdCounterPerDevice id_drop_counter_by_sc(num_scs, 0); + + coo_tensors_by_sc.resize(num_sc_per_device); + for (auto& coo_tensors_by_client_sc : coo_tensors_by_sc) { + coo_tensors_by_client_sc.reserve(initial_num_coo_tensors_per_sc); + } + + uint32_t index = 0; + const int32_t num_scs_bit = std::log2(num_scs); + const int total_coo_tensors = coo_tensors_for_device.size(); + for (int32_t i = 0; i < num_sc_per_device; ++i) { + // Reset the counters as the vectors are reused. + max_ids_per_sc.clear(); + max_ids_per_sc.resize(num_scs, 0); + max_unique_ids_per_sc.clear(); + max_unique_ids_per_sc.resize(num_scs, 0); + id_drop_counter_per_sc.clear(); + id_drop_counter_per_sc.resize(num_scs, 0); + keys.clear(); + // We take the advantage of the fact that the row_ids are already sorted + // within each batch. + while (index < total_coo_tensors && + (unsigned)(coo_tensors_for_device[index].row_id - + i * batch_size_per_sc) < batch_size_per_sc) { + // The key here is [col_ids % num_scs, col_ids / num_scs, index]. + // Note that this assumes `num_scs` is a power of 2. + keys.push_back( + (static_cast(absl::rotr( + static_cast(coo_tensors_for_device[index].col_id), + num_scs_bit)) + << 32) + + index); + ++index; + } + hwy::VQSort(keys.data(), keys.size(), hwy::SortAscending()); + + uint32_t prev_col_id = std::numeric_limits::max(); + for (const auto key : keys) { + uint32_t sc_id = static_cast(key >> (64 - num_scs_bit)); + if (static_cast(key >> 32) != prev_col_id) { + max_unique_ids_per_sc[sc_id] += 1; + } + max_ids_per_sc[sc_id] += 1; + + // If either max_unique_ids_per_partition or max_ids_per_partition is + // exceeded, we drop the id. + if (max_unique_ids_per_sc[sc_id] > max_unique_ids_per_partition || + max_ids_per_sc[sc_id] > max_ids_per_partition) { + if (drop_ids) { + prev_col_id = static_cast(key >> 32); + // Record that the id is dropped. + id_drop_counter_per_sc[sc_id] += 1; + continue; + } else { + // Don't drop the id. Just record that mini-batching is needed. + mini_batching_needed = true; + } + } + + coo_tensors_by_sc[i].push_back( + coo_tensors_for_device[static_cast(key)]); + prev_col_id = static_cast(key >> 32); + } + + for (int s = 0; s < num_scs; ++s) { + // Taking the max of the id counters for each SC. + // Note here the max is taken across all local sparsecores. + max_id_counter_by_sc[s] = + std::max(max_id_counter_by_sc[s], max_ids_per_sc[s]); + max_unique_id_counter_by_sc[s] = + std::max(max_unique_id_counter_by_sc[s], max_unique_ids_per_sc[s]); + + // Accumulate ID drop counters across all local sparsecores. + id_drop_counter_by_sc[s] += id_drop_counter_per_sc[s]; + } + } + return std::make_tuple(mini_batching_needed, std::move(coo_tensors_by_sc), + std::move(max_id_counter_by_sc), + std::move(max_unique_id_counter_by_sc), + std::move(id_drop_counter_by_sc)); +} + +DeviceSparsecoreMiniBatchingDataLists SplitCooTensorsByVocabularyDiv( + const absl::flat_hash_map>& + stacked_tables, + int split_count, const DeviceSparsecoreDataLists& coo_tensors) { + tsl::profiler::TraceMe t("SplitCooTensorsByVocabularyDiv"); + DeviceSparsecoreMiniBatchingDataLists split_coo_tensors; + for (const auto& [stacked_table_name, stacked_table_metadata] : + stacked_tables) { + // Vocabulary size is the same for all stacked tables. + const int vocabulary_size = + stacked_table_metadata[0].stacked_table_vocab_size; + const int split_size = vocabulary_size / split_count; + + DeviceSparsecoreMiniBatchingDataList + split_coo_tensors_current_table; + + const auto& coo_tensors_current_table = coo_tensors.at(stacked_table_name); + + // local device count + split_coo_tensors_current_table.resize(coo_tensors_current_table.size()); + for (int local_device = 0; local_device < coo_tensors_current_table.size(); + ++local_device) { + // number of sparsecores per device + split_coo_tensors_current_table[local_device].resize( + coo_tensors_current_table[local_device].size()); + for (int sc_index = 0; + sc_index < coo_tensors_current_table[local_device].size(); + ++sc_index) { + // number of mini batches per sparsecore + split_coo_tensors_current_table[local_device][sc_index].resize( + split_count); + + // Reserve space for each mini batch. + for (int mini_batch_index = 0; mini_batch_index < split_count; + ++mini_batch_index) { + split_coo_tensors_current_table + [local_device][sc_index][mini_batch_index] + .reserve( + coo_tensors_current_table[local_device][sc_index].size() / + split_count); + } + + for (const auto& coo_tensor : + coo_tensors_current_table[local_device][sc_index]) { + const int mini_batch_index = coo_tensor.col_id / split_size; + split_coo_tensors_current_table[local_device][sc_index] + [mini_batch_index] + .push_back(coo_tensor); + } + } + } + + split_coo_tensors[stacked_table_name] = + std::move(split_coo_tensors_current_table); + } + return split_coo_tensors; +} + +DeviceSparsecoreMiniBatchingDataLists SplitCooTensorsByVocabularyMod( + const absl::flat_hash_map>& + stacked_tables, + int modulus, const DeviceSparsecoreDataLists& coo_tensors) { + tsl::profiler::TraceMe t("SplitCooTensorsByVocabularyMod"); + DeviceSparsecoreMiniBatchingDataLists split_coo_tensors; + for (const auto& [stacked_table_name, stacked_table_metadata] : + stacked_tables) { + DeviceSparsecoreMiniBatchingDataList + split_coo_tensors_current_table; + + const auto& coo_tensors_current_table = coo_tensors.at(stacked_table_name); + + // local device count + split_coo_tensors_current_table.resize(coo_tensors_current_table.size()); + for (int local_device = 0; local_device < coo_tensors_current_table.size(); + ++local_device) { + // number of sparsecores per device + split_coo_tensors_current_table[local_device].resize( + coo_tensors_current_table[local_device].size()); + for (int sc_index = 0; + sc_index < coo_tensors_current_table[local_device].size(); + ++sc_index) { + // number of mini batches per sparsecore + split_coo_tensors_current_table[local_device][sc_index].resize(modulus); + + // Reserve space for each mini batch. + for (int mini_batch_index = 0; mini_batch_index < modulus; + ++mini_batch_index) { + split_coo_tensors_current_table + [local_device][sc_index][mini_batch_index] + .reserve( + coo_tensors_current_table[local_device][sc_index].size() / + modulus); + } + + for (const auto& coo_tensor : + coo_tensors_current_table[local_device][sc_index]) { + const int mini_batch_index = coo_tensor.col_id % modulus; + split_coo_tensors_current_table[local_device][sc_index] + [mini_batch_index] + .push_back(coo_tensor); + } + } + } + + split_coo_tensors[stacked_table_name] = + std::move(split_coo_tensors_current_table); + } + return split_coo_tensors; +} + +void PadDataTensorsToEndOfRegisterWidth(DeviceDataList* embedding_ids, + DeviceDataList* sample_ids, + DeviceDataList* gains, + const int local_device_id, + const int sparsecore_register_width) { + CHECK_NE(embedding_ids, nullptr); + CHECK_NE(sample_ids, nullptr); + CHECK_NE(gains, nullptr); + + // All data tensors should have the same size. + CHECK_EQ((*sample_ids)[local_device_id].size(), + (*embedding_ids)[local_device_id].size()); + CHECK_EQ((*gains)[local_device_id].size(), + (*sample_ids)[local_device_id].size()); + + const int remainder = + (*embedding_ids)[local_device_id].size() % sparsecore_register_width; + const int padding_size = + (sparsecore_register_width - remainder) % sparsecore_register_width; + for (int i = 0; i < padding_size; ++i) { + (*embedding_ids)[local_device_id].push_back(INT_MAX); + (*sample_ids)[local_device_id].push_back(INT_MAX); + (*gains)[local_device_id].push_back(NAN); + } +} + +std::tuple, DeviceDataLists, + DeviceDataLists, DeviceDataLists, BufferSizes> +EncodeMiniBatchingDataUnlocked( + const DeviceSparsecoreMiniBatchingDataLists& split_coo_tensors, + const absl::flat_hash_map>& + stacked_tables, + const DeviceBatchSizeLists& batch_sizes, const int local_device_count, + const int global_device_count, const int num_sc_per_device, + const int sparsecore_register_width, const bool has_leading_dimension, + const int static_buffer_size_multiplier) { + // Global number of sparsecores for embedding table + const int num_scs = num_sc_per_device * global_device_count; + + // All tables have to have the same mini-batch size. + int mini_batch_size = -1; + + if (!split_coo_tensors.empty()) { + // Note that if the COO tensors are empty, we will not have any data. + // So we can use the first table to get the mini batch size. + mini_batch_size = split_coo_tensors.begin()->second[0][0].size(); + } + + struct { + absl::Mutex mutex; + DeviceDataLists row_pointers ABSL_GUARDED_BY(mutex); + DeviceDataLists embedding_ids ABSL_GUARDED_BY(mutex); + DeviceDataLists sample_ids ABSL_GUARDED_BY(mutex); + DeviceDataLists gains ABSL_GUARDED_BY(mutex); + BufferSizes buffer_sizes ABSL_GUARDED_BY(mutex); + } results; + + absl::BlockingCounter counter(stacked_tables.size()); + tsl::profiler::TraceMeProducer producer("EncodingMainThread"); + { + for (const auto& [stacked_table_name, split_coo_tensors_current_table] : + split_coo_tensors) { + PreprocessingThreadPool()->Schedule([&, context_id = + producer.GetContextId()] { + tsl::profiler::TraceMeConsumer consumer( + [&] { return absl::StrCat("Encoding-", stacked_table_name); }, + context_id); + + // Initialize the resulting data lists. + // Note that through padding, the resulting data lists are always + // the same size, no matter how much is required after encoding. + DeviceDataList row_pointers_per_table(local_device_count); + DeviceDataList embedding_ids_per_table(local_device_count); + DeviceDataList sample_ids_per_table(local_device_count); + DeviceDataList gains_per_table(local_device_count); + + // Reserve space for lists. Note that we're not initializing the lists + // yet, so the size() member for each list is still 0. + const int expected_row_pointers_size_per_device = + num_sc_per_device * mini_batch_size * + std::max(num_scs, sparsecore_register_width); + + auto& batch_size_per_sc_for_current_table = + batch_sizes.at(stacked_table_name); + + // Allocate the static buffers. + auto stacked_table_metadata = stacked_tables.find(stacked_table_name); + CHECK(stacked_table_metadata != stacked_tables.end()); + const int pad_per_device_array_to_size = ComputeCooBufferSize( + num_scs, num_sc_per_device, stacked_table_metadata->second, + static_buffer_size_multiplier); + + for (int local_device = 0; local_device < local_device_count; + ++local_device) { + row_pointers_per_table[local_device].reserve( + expected_row_pointers_size_per_device); + embedding_ids_per_table[local_device].reserve( + pad_per_device_array_to_size); + sample_ids_per_table[local_device].reserve( + pad_per_device_array_to_size); + gains_per_table[local_device].reserve(pad_per_device_array_to_size); + } + + for (int local_device = 0; local_device < local_device_count; + ++local_device) { + const int batch_size_per_sc = + batch_size_per_sc_for_current_table[local_device]; + + // A client sparsecore handles input samples, grouped by mini-batch. + for (int client_sc_id = 0; client_sc_id < num_sc_per_device; + ++client_sc_id) { + for (int mini_batch_id = 0; mini_batch_id < mini_batch_size; + ++mini_batch_id) { + auto& coo_tensors_within_mini_batch = + split_coo_tensors_current_table[local_device][client_sc_id] + [mini_batch_id]; + + auto next_coo_tensor = coo_tensors_within_mini_batch.begin(); + auto coo_tensors_within_mini_batch_end = + coo_tensors_within_mini_batch.end(); + for (int expected_server_sc_id = 0; + expected_server_sc_id < num_scs;) { + // This is a "partition", as defined by the combination of + // client sc, server sc, and mini batch. + int id_counter = 0; + int previous_row_id_within_server_sc = -1; + int unique_id_counter = 0; + int server_sc_id = -1; + while (next_coo_tensor != coo_tensors_within_mini_batch_end) { + // Consume the next COO tensor. + const auto& coo_tensor = *next_coo_tensor; + + // Which sc should provide this embedding data. + server_sc_id = coo_tensor.col_id % num_scs; + + if (server_sc_id != expected_server_sc_id) { + // Break from the COO tensor loop, so id counters are reset. + break; + } + + // Within this server sc, which (embedding table) row has the + // desired data. + const int row_id_within_server_sc = + coo_tensor.col_id / num_scs; + + // Within this client sc, which (input sample) row should + // receive (accumulate) the embedding data. + const int row_id_within_client_sc = + coo_tensor.row_id % batch_size_per_sc; + + id_counter++; + if (unique_id_counter == 0) { + // Set unique id counter to 1 if it's the first id. + unique_id_counter = 1; + } else if (row_id_within_server_sc != + previous_row_id_within_server_sc) { + // If this is not the first id, and it's not the same row id + // as the previous id, then it's a new unique id. + unique_id_counter++; + previous_row_id_within_server_sc = row_id_within_server_sc; + } + + // Record the data for this COO tensor. + embedding_ids_per_table[local_device].push_back( + row_id_within_server_sc); + sample_ids_per_table[local_device].push_back( + row_id_within_client_sc); + gains_per_table[local_device].push_back(coo_tensor.gain); + + ++next_coo_tensor; + } // End of COO tensor loop. + + if (next_coo_tensor == coo_tensors_within_mini_batch_end) { + // we've consumed all COO tensors for this mini batch. + // Set the server_sc_id to end of this mini batch, so that we + // can pad the row pointers properly. + server_sc_id = std::max(num_scs, sparsecore_register_width); + } + + int padding_size = 1; + if (server_sc_id >= 0) { + // server_sc_id == -1 if there is no data for this partition. + // Since the COO tensors are sorted, as long as there is data, + // server_sc_id must be larger than expected_server_sc_id. + CHECK_GT(server_sc_id, expected_server_sc_id); + padding_size = server_sc_id - expected_server_sc_id; + } + // padding_size is at least 1. + CHECK_GE(padding_size, 1); + // Push one new row pointer for this particular server sc. + row_pointers_per_table[local_device].push_back( + embedding_ids_per_table[local_device].size()); + + // Pad all three tensors up to sparsecore register width as + // we're ending this partition. Note that if there is no new + // data for this partition, we do not pad more entries. + PadDataTensorsToEndOfRegisterWidth( + &embedding_ids_per_table, &sample_ids_per_table, + &gains_per_table, local_device, sparsecore_register_width); + + for (int i = 1; i < padding_size; ++i) { + // Push a new row pointer for each server sc. + row_pointers_per_table[local_device].push_back( + embedding_ids_per_table[local_device].size()); + } + expected_server_sc_id += padding_size; + } // End of "partition" loop. + } // Mini batch loop. + } // Client SC loop. + } // Local device loop. + + { + // Move the data lists to the main thread. + absl::MutexLock lock(&results.mutex); + results.row_pointers[stacked_table_name] = + std::move(row_pointers_per_table); + results.embedding_ids[stacked_table_name] = + std::move(embedding_ids_per_table); + results.sample_ids[stacked_table_name] = + std::move(sample_ids_per_table); + results.gains[stacked_table_name] = std::move(gains_per_table); + results.buffer_sizes[stacked_table_name] = + pad_per_device_array_to_size; + } + + // Signal the main thread that this task is done. + counter.DecrementCount(); + } // End of lambda for threaded task. + ); // End of Schedule. + } + counter.Wait(); + } // End of EncodingMainThread context. + + { + absl::MutexLock lock(&results.mutex); + return std::make_tuple( + std::move(results.row_pointers), std::move(results.embedding_ids), + std::move(results.sample_ids), std::move(results.gains), + std::move(results.buffer_sizes)); + } +} + +inline int GetColIdInline(const int col_id, const int col_shift, + const int col_offset, const int num_scs_mod, + const int num_scs_mod_inv) { + // This is equivalent to: + // (col_ids + col_shift) % num_sc_shards + + // (col_ids // num_sc_shards * num_sc_shards) + col_offset + return ((col_id + col_shift) & num_scs_mod) + (col_id & num_scs_mod_inv) + + col_offset; +} + +class FeatureWeightRepresentation { + public: + using index_ref_type = py::detail::unchecked_reference; + using value_ref_type = py::detail::unchecked_reference; + using weights_ref_type = py::detail::unchecked_reference; + + FeatureWeightRepresentation(const index_ref_type& indices, + const value_ref_type& values, + const weights_ref_type& weights) + : indices_(indices), values_(values), weights_(weights) { + index_stride = &indices(1, 0) - &indices(0, 0); + value_stride = &values(1) - &values(0); + weight_stride = &weights(1) - &weights(0); + } + + void ExtractCooTensors(const int start_index, const int end_index, + const int row_offset, const int col_offset, + const int col_shift, const int num_scs, + const int global_device_count, + std::vector& coo_tensors) const { + tsl::profiler::TraceMe t([] { return "ExtractCooTensors"; }); + + const int num_scs_bit = std::log2(num_scs); + const int num_scs_mod = (1 << num_scs_bit) - 1; + const int num_scs_mod_inv = ~num_scs_mod; + + const int row_offset_per_device = row_offset / global_device_count; + + // Get the range of elements in the indices array for samples between + // start_index and end_index. + auto [begin_cursor, end_cursor] = GetElemRange(start_index, end_index); + + // Expand the size of the vector to accommodate the new COO tensors. + coo_tensors.reserve(coo_tensors.size() + end_cursor - begin_cursor); + + const bool has_weights = weights_.size() > 0; + + // Iterate through all elements in the current slice of theindices array. + // These pointers are created to avoid repeated calculations around shape + // and strides. + const int64_t* indices_ptr = &indices_(begin_cursor, 0); + const int32_t* values_ptr = &values_(begin_cursor); + const float* weights_ptr = has_weights ? &weights_(begin_cursor) : nullptr; + for (int cursor = begin_cursor; cursor < end_cursor; ++cursor) { + const int sample_id = *indices_ptr; + const int adjusted_sample_id = + sample_id - start_index + row_offset_per_device; + + coo_tensors.emplace_back( + adjusted_sample_id, + GetColIdInline(*values_ptr, col_shift, col_offset, num_scs_mod, + num_scs_mod_inv), + has_weights ? *weights_ptr : 1.0f); + + indices_ptr += index_stride; + values_ptr += value_stride; + weights_ptr += weight_stride; + } + } + + private: + // Returns a tuple of the range of elements in the indices array for samples + // between start_index and end_index. + std::tuple GetElemRange(int start_index, int end_index) const { + int begin_cursor = -1; + int end_cursor = -1; + for (int i = 0; i < indices_.shape(0); ++i) { + const auto row = indices_(i, 0); + if (row >= start_index && row < end_index) { + if (begin_cursor == -1) { + begin_cursor = i; + } + end_cursor = i; + } + } + CHECK_GE(begin_cursor, 0); + CHECK_GT(end_cursor, 0); + return std::make_tuple(begin_cursor, end_cursor + 1); + } + + bool HasWeights() const { return weights_.size() > 0; } + + int index_stride; + int value_stride; + int weight_stride; + py::detail::unchecked_reference indices_; + py::detail::unchecked_reference values_; + py::detail::unchecked_reference weights_; +}; + +// This function handles one local device and one stacked table, which feeds to +// multiple features (due to feature stacking) and potentially multiple tables +// (due to table stacking). Returns a tuple of data for the stacked tables on +// the current device: +// 1. All COO tensors. +// 2. Batch size. +std::tuple, int> +GetCooTensorsForStackedTablesOnDeviceUnlocked( + const int local_batch_size, + const std::vector& features, + const std::vector& stacked_table_metadata, + const int local_device_id, const int local_device_count, + const int global_device_count, const int num_sc_per_device) { + tsl::profiler::TraceMe t("GetCooTensorsForAllTablesOnDeviceUnlocked"); + const int num_scs = num_sc_per_device * global_device_count; + std::vector coo_tensors; + int batch_size_for_device = 0; + + // Iterate through all features that have been stacked into the same table. + for (const auto& metadata : stacked_table_metadata) { + const int feature_index = metadata.feature_index; + const int row_offset = metadata.row_offset; + const int col_offset = metadata.col_offset; + const int col_shift = metadata.col_shift; + + const auto& curr_feature = features[feature_index]; + + // Split the feature and feature weights into per-device spans. + const int num_samples = local_batch_size; + const int num_samples_per_split = num_samples / local_device_count; + const int start_index = local_device_id * num_samples_per_split; + int end_index = (local_device_id + 1) * num_samples_per_split; + if (local_device_id == local_device_count - 1) { + // Just in case the last split is not a full batch. + end_index = num_samples; + } + + batch_size_for_device += (end_index - start_index); + + // In the case of feature stacking, we need to group all the COO + // tensors at this stage (i.e., before the sorting later on). + curr_feature.ExtractCooTensors(start_index, end_index, row_offset, + col_offset, col_shift, num_scs, + global_device_count, coo_tensors); + } + + return std::make_tuple(std::move(coo_tensors), batch_size_for_device); +} + +/* +Returns a tuple of data for all stacked tables on all local devices, and all +local sparsecores: +1. Whether mini-batching is needed. +2. All COO tensors. +3. Id counter. +4. Unique id counter. +5. Id drop counter. +*/ +std::tuple, + AggregatedIdCounters, AggregatedIdCounters, AggregatedIdCounters> +SortDeviceListOfCooTensorsWithIdDropUnlocked( + const int local_batch_size, + const std::vector& features, + const absl::flat_hash_map>& + stacked_tables, + const bool drop_ids, const int local_device_count, + const int global_device_count, const int num_sc_per_device) { + tsl::profiler::TraceMe t("SortDeviceListOfCooTensorsWithIdDropUnlocked"); + + const int num_scs = num_sc_per_device * global_device_count; + + absl::BlockingCounter counter(stacked_tables.size()); + + struct { + absl::Mutex mutex; + bool mini_batching_needed = false; + DeviceBatchSizeLists batch_sizes ABSL_GUARDED_BY(mutex); + + // Stacked table names to a list of COO tensors + DeviceSparsecoreDataLists coo_tensors ABSL_GUARDED_BY(mutex); + + // Stacked table names to a list of id counters + AggregatedIdCounters max_id_counters ABSL_GUARDED_BY(mutex); + AggregatedIdCounters max_unique_id_counters ABSL_GUARDED_BY(mutex); + AggregatedIdCounters id_drop_counters ABSL_GUARDED_BY(mutex); + } results; + + tsl::profiler::TraceMeProducer producer("SortListOfCooTensorsMainThread"); + { + for (const auto& [stacked_table_name, stacked_table_metadata] : + stacked_tables) { + PreprocessingThreadPool()->Schedule([&, context_id = + producer.GetContextId()] { + // Each thread handles one (stacked) table for all local devices. + tsl::profiler::TraceMeConsumer consumer( + [&] { + return absl::StrCat("InputPreprocessingTable-", + stacked_table_name); + }, + context_id); + + // The following lists contains data for this + // stacked table to be processed by all local devices. + // 1st dimension is for local devices. + // 2nd dimension is for SCs per device. + // 3rd dimension is for the sparsecore-local list of data. + DeviceSparsecoreDataList coo_tensors_for_current_table( + local_device_count); + AggregatedIdCounter max_id_counter_for_current_table( + local_device_count); + AggregatedIdCounter max_unique_id_counter_for_current_table( + local_device_count); + AggregatedIdCounter id_drop_counter_for_current_table( + local_device_count); + bool mini_batching_needed_for_current_table = false; + DeviceBatchSizeList batch_size_per_sc_for_current_table( + local_device_count); + + // Temporary storage for the per-sparsecore data. + // Avoid reallocations by pre-allocating the vectors. The keys could + // grow pretty large. + std::vector max_ids_per_sc_temp_storage(num_scs, 0); + std::vector max_unique_ids_per_sc_temp_storage(num_scs, 0); + std::vector id_drop_counter_per_sc_temp_storage(num_scs, 0); + std::vector keys_temp_storage; + + for (int local_device = 0; local_device < local_device_count; + ++local_device) { + // + // Per-device Step 1: Extract the COO tensors for each table. + // + auto [coo_tensors_for_device, batch_size_for_device] = + GetCooTensorsForStackedTablesOnDeviceUnlocked( + local_batch_size, features, stacked_table_metadata, + local_device, local_device_count, global_device_count, + num_sc_per_device); + + // + // Per-device Step 2: Sort the COO tensors and group them by SC. + // + const int batch_size_per_sc = + CeilOfRatio(batch_size_for_device, num_sc_per_device); + const int approximate_num_coo_tensors_per_sc = + coo_tensors_for_device.size() / num_sc_per_device + 1; + + // Make sure the keys are large enough to hold at least these many + // elements. + keys_temp_storage.reserve(batch_size_per_sc); + + auto [mini_batching_needed_for_current_device, coo_tensors_by_sc, + max_id_counter_by_sc, max_unique_id_counter_by_sc, + id_drop_counter_by_sc] = + SortAndGroupCooTensorsWithIdDrop( + coo_tensors_for_device, drop_ids, num_scs, num_sc_per_device, + batch_size_per_sc, + stacked_table_metadata[0].max_ids_per_partition, + stacked_table_metadata[0].max_unique_ids_per_partition, + approximate_num_coo_tensors_per_sc, + max_ids_per_sc_temp_storage, + max_unique_ids_per_sc_temp_storage, + id_drop_counter_per_sc_temp_storage, keys_temp_storage); + + mini_batching_needed_for_current_table |= + mini_batching_needed_for_current_device; + + batch_size_per_sc_for_current_table[local_device] = + batch_size_for_device / num_sc_per_device; + coo_tensors_for_current_table[local_device] = + std::move(coo_tensors_by_sc); + max_id_counter_for_current_table[local_device] = + std::move(max_id_counter_by_sc); + max_unique_id_counter_for_current_table[local_device] = + std::move(max_unique_id_counter_by_sc); + id_drop_counter_for_current_table[local_device] = + std::move(id_drop_counter_by_sc); + } + + // Save the COO tensors for this table for all local devices. + { + absl::MutexLock lock(&results.mutex); + results.mini_batching_needed |= + mini_batching_needed_for_current_table; + results.batch_sizes[stacked_table_name] = + std::move(batch_size_per_sc_for_current_table); + results.coo_tensors[stacked_table_name.c_str()] = + std::move(coo_tensors_for_current_table); + results.max_id_counters[stacked_table_name.c_str()] = + std::move(max_id_counter_for_current_table); + results.max_unique_id_counters[stacked_table_name.c_str()] = + std::move(max_unique_id_counter_for_current_table); + results.id_drop_counters[stacked_table_name.c_str()] = + std::move(id_drop_counter_for_current_table); + } + counter.DecrementCount(); + } // End of lambda for threaded task. + ); // End of Schedule. + } + counter.Wait(); + } + + absl::MutexLock lock(&results.mutex); + + return std::make_tuple( + std::move(results.mini_batching_needed), std::move(results.batch_sizes), + std::move(results.coo_tensors), std::move(results.max_id_counters), + std::move(results.max_unique_id_counters), + std::move(results.id_drop_counters)); +} + +int GetMiniBatchSize(const py::dict& mini_batching_config) { + return mini_batching_config["MINI_BATCH_SIZE"].cast(); +} + +MiniBatchingMode GetMiniBatchingMode(const py::dict& mini_batching_config) { + int mode = mini_batching_config["MINI_BATCH_MODE"].cast(); + switch (mode) { + case static_cast(MiniBatchingMode::kNone): + return MiniBatchingMode::kNone; + case static_cast(MiniBatchingMode::kVocabularyDimension): + return MiniBatchingMode::kVocabularyDimension; + case static_cast(MiniBatchingMode::kSampleDimension): + return MiniBatchingMode::kSampleDimension; + case static_cast(MiniBatchingMode::kExperimentalForceVocabularyDiv): + return MiniBatchingMode::kExperimentalForceVocabularyDiv; + case static_cast(MiniBatchingMode::kExperimentalForceVocabularyMod): + return MiniBatchingMode::kExperimentalForceVocabularyMod; + default: + throw std::invalid_argument("Not supported mini-batching mode."); + } +} + +py::tuple _PreprocessSparseDenseMatmulInput( + const int local_batch_size, + const std::vector& features, + const absl::flat_hash_map>& + stacked_tables, + const py::dict& mini_batching_config, const int local_device_count, + const int global_device_count, const int static_buffer_size_multiplier, + const int num_sc_per_device, const int sparsecore_register_width, + const int sharding_strategy, const bool has_leading_dimension) { + tsl::profiler::TraceMe t("_PreprocessSparseDenseMatmulInput"); + + if (has_leading_dimension != true) { + throw std::invalid_argument( + "Currently, only leading dimension is supported for mini-batching."); + } + + // GIL is held when we enter this function. + py::dict lhs_row_pointers; + py::dict lhs_embedding_ids; + py::dict lhs_sample_ids; + py::dict lhs_gains; + py::dict id_counter_per_table; + py::dict unique_id_counter_per_table; + py::dict id_drop_counter_per_table; + int mini_batch_size = GetMiniBatchSize(mini_batching_config); + MiniBatchingMode mini_batching_mode = + GetMiniBatchingMode(mini_batching_config); + { + // Release GIL here as we don't need python objects after this point. + py::gil_scoped_release main_release; + + // Sort COO tensors and group them by SC. + // Note this function would release and reacquire the GIL. + // If mini-batching mode is set to NONE, embedding ids beyond limitations + // are directly dropped. + auto [mini_batching_needed, batch_sizes, coo_tensors, id_counters, + unique_id_counters, id_drop_counters] = + SortDeviceListOfCooTensorsWithIdDropUnlocked( + local_batch_size, features, stacked_tables, + mini_batching_mode == MiniBatchingMode::kNone, local_device_count, + global_device_count, num_sc_per_device); + + { + py::gil_scoped_acquire acq; + Reshape2dToPyDictLocked(id_counter_per_table, id_counters); + Reshape2dToPyDictLocked(unique_id_counter_per_table, unique_id_counters); + Reshape2dToPyDictLocked(id_drop_counter_per_table, id_drop_counters); + } + + // Communicate with other tasks to see if mini-batching is needed. + // Here we assume the mini-batching size is the same across all tables. + + // If mini-batching is needed in this task, determine the split points. + // Here we assume a simple mod-N split. + mini_batching_needed = true; + + // Communicate with other tasks to reach consensus on mini-batching split + // points. + + // In this prototype, we always use mini-batching if allowed. + if (mini_batching_mode == MiniBatchingMode::kNone) { + // No mini-batching. + mini_batching_mode = MiniBatchingMode::kExperimentalForceVocabularyDiv; + // force the mini-batch size to be 1. + mini_batch_size = 1; + } + + if (mini_batching_mode == + MiniBatchingMode::kExperimentalForceVocabularyDiv) { + DeviceSparsecoreMiniBatchingDataLists split_coo_tensors = + SplitCooTensorsByVocabularyDiv(stacked_tables, mini_batch_size, + coo_tensors); + + auto [row_pointers, embedding_ids, sample_ids, gains, buffer_sizes] = + EncodeMiniBatchingDataUnlocked( + split_coo_tensors, stacked_tables, batch_sizes, + local_device_count, global_device_count, num_sc_per_device, + sparsecore_register_width, has_leading_dimension, + static_buffer_size_multiplier); + { + py::gil_scoped_acquire acq; + Convert2dToPyDictLocked(lhs_row_pointers, row_pointers); + Extend2dToPyDictLocked(lhs_embedding_ids, embedding_ids, buffer_sizes); + Extend2dToPyDictLocked(lhs_sample_ids, sample_ids, buffer_sizes); + Extend2dToPyDictLocked(lhs_gains, gains, buffer_sizes); + } + } else if (mini_batching_mode == + MiniBatchingMode::kExperimentalForceVocabularyMod) { + // Note modulus is the mini-batch size here. + DeviceSparsecoreMiniBatchingDataLists split_coo_tensors = + SplitCooTensorsByVocabularyMod(stacked_tables, mini_batch_size, + coo_tensors); + + auto [row_pointers, embedding_ids, sample_ids, gains, buffer_sizes] = + EncodeMiniBatchingDataUnlocked( + split_coo_tensors, stacked_tables, batch_sizes, + local_device_count, global_device_count, num_sc_per_device, + sparsecore_register_width, has_leading_dimension, + static_buffer_size_multiplier); + { + py::gil_scoped_acquire acq; + Convert2dToPyDictLocked(lhs_row_pointers, row_pointers); + Extend2dToPyDictLocked(lhs_embedding_ids, embedding_ids, buffer_sizes); + Extend2dToPyDictLocked(lhs_sample_ids, sample_ids, buffer_sizes); + Extend2dToPyDictLocked(lhs_gains, gains, buffer_sizes); + } + } else { + throw std::invalid_argument("Not supported mini-batching mode."); + } + } + + py::dict stats; + stats["max_ids"] = std::move(id_counter_per_table); + stats["max_unique_ids"] = std::move(unique_id_counter_per_table); + stats["id_drop_counters"] = std::move(id_drop_counter_per_table); + return py::make_tuple(lhs_row_pointers, lhs_embedding_ids, lhs_sample_ids, + lhs_gains, mini_batch_size, stats); +} + +py::tuple PreprocessSparseDenseMatmulInputWithBCOO( + const int local_batch_size, const py::list& indices, const py::list& values, + const py::list& weights, const py::list& feature_specs, + const py::dict& mini_batching_config, const int local_device_count, + const int global_device_count, const int static_buffer_size_multiplier, + const int num_sc_per_device, const int sparsecore_register_width, + const int sharding_strategy, const bool has_leading_dimension) { + tsl::profiler::TraceMe t("PreprocessSparseDenseMatmulInputWithBCOO"); + + const auto num_features = indices.size(); + CHECK_EQ(num_features, indices.size()); + CHECK_EQ(num_features, values.size()); + CHECK_EQ(num_features, weights.size()); + CHECK_EQ(num_features, feature_specs.size()); + + py::array_t dummy_weights(0); + auto dummy_weights_ref = dummy_weights.unchecked<1>(); + + std::vector sparse_features; + sparse_features.reserve(num_features); + + // Fill the sparse features and weights from the BCOO tensors. + for (int feature_index = 0; feature_index < num_features; ++feature_index) { + const auto& current_values = + values[feature_index].cast>(); + + const auto& current_index = + indices[feature_index].cast>(); + + if (!weights[feature_index].is_none()) { + const auto& current_weights = + weights[feature_index].cast>(); + sparse_features.emplace_back(current_index.unchecked<2>(), + current_values.unchecked<1>(), + current_weights.unchecked<1>()); + } else { + sparse_features.emplace_back(current_index.unchecked<2>(), + current_values.unchecked<1>(), + dummy_weights_ref); + } + } + + // Get the stacked table metadata for each top level table. + // The keys are stacked table names (or the table itself if not stacked) and + // the values are a vector of StackedTableMetadata for each feature that is + // mapped to the table. + const absl::flat_hash_map> + stacked_tables = GetStackedTableMetadata(feature_specs, local_batch_size); + + return _PreprocessSparseDenseMatmulInput( + local_batch_size, sparse_features, stacked_tables, mini_batching_config, + local_device_count, global_device_count, static_buffer_size_multiplier, + num_sc_per_device, sparsecore_register_width, sharding_strategy, + has_leading_dimension); +} + +} // namespace + +PYBIND11_MODULE(input_preprocessing_with_mini_batching_cc, m) { + m.def("PreprocessSparseDenseMatmulInputWithBCOO", + &PreprocessSparseDenseMatmulInputWithBCOO, + pybind11::arg("local_batch_size"), pybind11::arg("indices"), + pybind11::arg("values"), pybind11::arg("weights"), + pybind11::arg("feature_specs"), pybind11::arg("mini_batching_config"), + pybind11::arg("local_device_count"), + pybind11::arg("global_device_count"), + pybind11::arg("static_buffer_size_multiplier"), + pybind11::arg("num_sc_per_device"), + pybind11::arg("sparsecore_register_width"), + pybind11::arg("sharding_strategy"), + pybind11::arg("has_leading_dimension")); +} + +} // namespace jax_sc_embedding diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_with_mini_batching.h b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_with_mini_batching.h new file mode 100644 index 0000000..f8a4c73 --- /dev/null +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_with_mini_batching.h @@ -0,0 +1,47 @@ +// Copyright 2024 The JAX SC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_WITH_MINI_BATCHING_H_ +#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_WITH_MINI_BATCHING_H_ + +namespace jax_sc_embedding { + +// The mode of mini-batching operation. +enum class MiniBatchingMode { + // No mini-batching, essentially the same as + // MINI_BATCHING_EXPERIMENTAL_FORCE_VOCABULARY_DIV with mini_batch_size = 1. + kNone = 1, + + // If there is no need to mini-batch in any of the tasks for this input batch, + // this is essentially + // the same as MINI_BATCHING_NONE. + // First hash the embedding IDs into 2^64 domain, and then modulo into + // 2^max_division_level buckets. Finally optimize the number of buckets + // necessary to a minimum through merging neighboring buckets and + // communication among all tasks. + kVocabularyDimension = 2, + + kSampleDimension = 3, + + // Linearly divide the vocabulary dimension into specified mini_batch_size + // slices. + kExperimentalForceVocabularyDiv = 200, + + // Split the vocabulary dimension into specified mini_batch_size slices + // through simple modulo operations. + kExperimentalForceVocabularyMod = 201, +}; + +} // namespace jax_sc_embedding + +#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_WITH_MINI_BATCHING_H_ diff --git a/jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD b/jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD index e8146f8..e1b6c52 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD @@ -42,6 +42,7 @@ pytype_strict_library( "//jax_tpu_embedding/sparsecore/lib/core:constants", pypi_requirement("jax"), pypi_requirement("jax/_src/lib"), + pypi_requirement("jax/extend"), pypi_requirement("numpy"), ], ) diff --git a/jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_csr_with_mini_batching.py b/jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_csr_with_mini_batching.py index 99a92f9..678381d 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_csr_with_mini_batching.py +++ b/jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_csr_with_mini_batching.py @@ -20,13 +20,14 @@ from jax._src import dispatch from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo +import jax.extend as jex from jax.interpreters import mlir import jax.numpy as jnp from jax_tpu_embedding.sparsecore.lib.core import constants import numpy as np # Define the sparse dense matmul primitive. -tpu_sparse_dense_matmul_csr_with_mini_batching_primitive = core.Primitive( +tpu_sparse_dense_matmul_csr_with_mini_batching_primitive = jex.core.Primitive( "sparse_dense_matmul_csr_with_mini_batching" ) diff --git a/jax_tpu_embedding/sparsecore/lib/nn/BUILD b/jax_tpu_embedding/sparsecore/lib/nn/BUILD index aa63e87..db0b124 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/nn/BUILD @@ -35,14 +35,31 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "embedding_utils", + srcs = ["embedding_utils.py"], + visibility = ["//jax_tpu_embedding/sparsecore/lib/nn:__subpackages__"], + deps = [ + ":embedding_spec", + ":table_stacking", + pypi_requirement("absl/logging"), + pypi_requirement("jax"), + pypi_requirement("jax:experimental"), + pypi_requirement("numpy"), + pypi_requirement("tree"), + ], +) + pytype_strict_library( name = "embedding", srcs = ["embedding.py"], deps = [ ":embedding_spec", + ":embedding_utils", ":table_stacking", "//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing_cc", "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_csr", + "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_csr_with_mini_batching", "//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2", pypi_requirement("absl/logging"), pypi_requirement("jax"), @@ -52,6 +69,27 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "embedding_with_mini_batching", + srcs = ["embedding_with_mini_batching.py"], + deps = [ + ":embedding", + ":embedding_spec", + ":embedding_utils", + ":table_stacking", + "//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing_with_mini_batching_cc", + "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_csr_with_mini_batching", + "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_adagrad_with_mini_batching", + "//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_sgd_with_mini_batching", + pypi_requirement("absl/logging"), + pypi_requirement("jax"), + pypi_requirement("jax:experimental"), + pypi_requirement("jax/extend"), + pypi_requirement("numpy"), + pypi_requirement("tree"), + ], +) + pytype_strict_library( name = "table_stacking", srcs = ["table_stacking.py"], diff --git a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py index 6acc533..75b7781 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py @@ -26,11 +26,13 @@ from jax_tpu_embedding.sparsecore.lib.core import input_preprocessing_cc from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_csr from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec +from jax_tpu_embedding.sparsecore.lib.nn import embedding_utils from jax_tpu_embedding.sparsecore.lib.nn import table_stacking from jax_tpu_embedding.sparsecore.lib.proto import embedding_spec_pb2 import numpy as np import tree + ArrayLike = jnp.ndarray | np.ndarray T: TypeAlias = TypeVar("T") @@ -272,16 +274,6 @@ def auto_stack_tables( ) -def sharding_strategy_to_int(sharding_strategy: str) -> int: - if sharding_strategy == "MOD": - return 1 - else: - raise ValueError( - f"Unsupported sharding strategy: {sharding_strategy}. Only MOD is" - " supported." - ) - - def preprocess_sparse_dense_matmul_input( features: Nested[ArrayLike], features_weights: Nested[ArrayLike], @@ -348,60 +340,13 @@ def preprocess_sparse_dense_matmul_input( local_device_count, global_device_count, num_sc_per_device, - sharding_strategy_to_int(sharding_strategy), + embedding_utils.sharding_strategy_to_int(sharding_strategy), has_leading_dimension, static_buffer_size_multiplier, allow_id_dropping=allow_id_dropping, ) -def _get_activation_for_feature( - feature: embedding_spec.FeatureSpec, - activations: dict[str, jax.Array], - global_device_count: int, -) -> jax.Array: - """Gets the activation slice for a given feature.""" - assert feature.table_spec.stacked_table_spec is not None - if feature.id_transformation is None: - raise ValueError( - "FeatureIdTransformation cannot be None. It is None for" - f" {feature.name}", - ) - per_device_offset = ( - feature.id_transformation.row_offset // global_device_count - ) - if feature.output_shape[-1] > feature.table_spec.embedding_dim: - raise ValueError( - f"Feature {feature.name} has output shape {feature.output_shape} and" - f" embedding dim {feature.table_spec.embedding_dim}. The output shape" - " must be at least same as the (original, unpadded)embedding dim." - ) - return jax.lax.slice( - activations[feature.table_spec.stacked_table_spec.stack_name], - (per_device_offset, 0), - ( - per_device_offset + feature.output_shape[0] // global_device_count, - feature.output_shape[-1], - ), - ) - - -def _unstack_embedding_activations( - activations: dict[str, jax.Array], - feature_specs: Nested[embedding_spec.FeatureSpec], - global_device_count: int, -) -> Nested[jax.Array]: - """Unstacks the activations to match the feature specs.""" - - get_activation_for = functools.partial( - _get_activation_for_feature, - activations=activations, - global_device_count=global_device_count, - ) - - return jax.tree_util.tree_map(get_activation_for, feature_specs) - - @jax.named_call def tpu_sparse_dense_matmul( lhs_row_pointers: Mapping[str, jax.Array], @@ -482,7 +427,7 @@ def tpu_sparse_dense_matmul( stacked_table_specs = get_stacked_table_specs(feature_specs) assert lhs_row_pointers.keys() == stacked_table_specs.keys() - sharding_strategy = _sharding_strategy_to_enum(sharding_strategy) + sharding_strategy = embedding_utils.sharding_strategy_to_enum(sharding_strategy) activations = {} for stacked_table_name in stacked_table_specs: @@ -507,61 +452,11 @@ def tpu_sparse_dense_matmul( ) ) - return _unstack_embedding_activations( + return embedding_utils.unstack_embedding_activations( activations, feature_specs, global_device_count ) -def _sharding_strategy_to_enum(sharding_strategy: str) -> int: - """Converts the sharding strategy string to the enum.""" - if sharding_strategy.upper() == "MOD": - return 1 - else: - raise ValueError( - f"Unsupported sharding strategy: {sharding_strategy}. Only MOD is" - " supported." - ) - - -def _stack_embedding_gradients( - activation_gradients: Nested[jax.Array], - feature_specs: Nested[embedding_spec.FeatureSpec], -) -> Mapping[str, jax.Array]: - """Stacks the gradients for update to embedding variables.""" - stacked_table_to_features = collections.defaultdict(list) - for gradient, feature in zip( - tree.flatten(activation_gradients), tree.flatten(feature_specs) - ): - assert feature.table_spec.stacked_table_spec is not None - if feature.id_transformation is None: - raise ValueError( - "FeatureIdTransformation cannot be None here. It is None for" - f" {feature.name}" - ) - stacked_table_to_features[ - feature.table_spec.stacked_table_spec.stack_name - ].append((feature, gradient)) - stacked_table_to_gradients = collections.defaultdict(list) - for stacked_table_name, stacked_features in stacked_table_to_features.items(): - stacked_features.sort(key=lambda x: x[0].id_transformation.row_offset) - for f, g in stacked_features: - # feature.table_spec.embedding_dim is the original table dim, before - # padding - gradient = g.reshape([-1, f.table_spec.embedding_dim]) - # Add padding for extra cols - extra_cols = ( - f.table_spec.setting_in_stack.padded_embedding_dim - - f.table_spec.embedding_dim - ) - if extra_cols != 0: - gradient = jax.lax.pad(gradient, 0.0, [(0, 0, 0), (0, extra_cols, 0)]) - stacked_table_to_gradients[stacked_table_name].append(gradient) - return { - t: jax.lax.concatenate(grads, dimension=0) - for t, grads in stacked_table_to_gradients.items() - } - - @jax.named_call def tpu_sparse_dense_matmul_grad( activation_gradients: Nested[jax.Array], @@ -643,10 +538,10 @@ def tpu_sparse_dense_matmul_grad( stacked_table_specs = get_stacked_table_specs(feature_specs) assert lhs_row_pointers.keys() == stacked_table_specs.keys() - gradients = _stack_embedding_gradients(activation_gradients, feature_specs) + gradients = embedding_utils.stack_embedding_gradients(activation_gradients, feature_specs) assert lhs_row_pointers.keys() == gradients.keys() - sharding_strategy = _sharding_strategy_to_enum(sharding_strategy) + sharding_strategy = embedding_utils.sharding_strategy_to_enum(sharding_strategy) updated_embedding_variables = {} for stacked_table_name in stacked_table_specs: diff --git a/jax_tpu_embedding/sparsecore/lib/nn/embedding_utils.py b/jax_tpu_embedding/sparsecore/lib/nn/embedding_utils.py new file mode 100644 index 0000000..8463024 --- /dev/null +++ b/jax_tpu_embedding/sparsecore/lib/nn/embedding_utils.py @@ -0,0 +1,138 @@ +# Copyright 2024 The JAX SC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal utilities for embedding lookup and update.""" + +import collections +import functools +from typing import Mapping, Sequence, TypeAlias, TypeVar, Union + +import jax +import jax.numpy as jnp +from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec +import numpy as np +import tree + +ArrayLike = jnp.ndarray | np.ndarray + +T: TypeAlias = TypeVar("T") +Nested: TypeAlias = Union[T, Sequence[T], Mapping[str, T]] + + +def sharding_strategy_to_int(sharding_strategy: str) -> int: + if sharding_strategy == "MOD": + return 1 + else: + raise ValueError( + f"Unsupported sharding strategy: {sharding_strategy}. Only MOD is" + " supported." + ) + + +def _get_activation_for_feature( + feature: embedding_spec.FeatureSpec, + activations: dict[str, jax.Array], + global_device_count: int, +) -> jax.Array: + """Gets the activation slice for a given feature.""" + assert feature.table_spec.stacked_table_spec is not None + if feature.id_transformation is None: + raise ValueError( + "FeatureIdTransformation cannot be None. It is None for" + f" {feature.name}", + ) + per_device_offset = ( + feature.id_transformation.row_offset // global_device_count + ) + if feature.output_shape[-1] > feature.table_spec.embedding_dim: + raise ValueError( + f"Feature {feature.name} has output shape {feature.output_shape} and" + f" embedding dim {feature.table_spec.embedding_dim}. The output shape" + " must be at least same as the (original, unpadded)embedding dim." + ) + return jax.lax.slice( + activations[feature.table_spec.stacked_table_spec.stack_name], + (per_device_offset, 0), + ( + per_device_offset + feature.output_shape[0] // global_device_count, + feature.output_shape[-1], + ), + ) + + +def unstack_embedding_activations( + activations: dict[str, jax.Array], + feature_specs: Nested[embedding_spec.FeatureSpec], + global_device_count: int, +) -> Nested[jax.Array]: + """Unstacks the activations to match the feature specs.""" + + get_activation_for = functools.partial( + _get_activation_for_feature, + activations=activations, + global_device_count=global_device_count, + ) + + return jax.tree_util.tree_map(get_activation_for, feature_specs) + + +def sharding_strategy_to_enum(sharding_strategy: str) -> int: + """Converts the sharding strategy string to the enum.""" + if sharding_strategy.upper() == "MOD": + return 1 + else: + raise ValueError( + f"Unsupported sharding strategy: {sharding_strategy}. Only MOD is" + " supported." + ) + + +def stack_embedding_gradients( + activation_gradients: Nested[jax.Array], + feature_specs: Nested[embedding_spec.FeatureSpec], +) -> Mapping[str, jax.Array]: + """Stacks the gradients for update to embedding variables.""" + stacked_table_to_features = collections.defaultdict(list) + for gradient, feature in zip( + tree.flatten(activation_gradients), tree.flatten(feature_specs) + ): + assert feature.table_spec.stacked_table_spec is not None + if feature.id_transformation is None: + raise ValueError( + "FeatureIdTransformation cannot be None here. It is None for" + f" {feature.name}" + ) + stacked_table_to_features[ + feature.table_spec.stacked_table_spec.stack_name + ].append((feature, gradient)) + stacked_table_to_gradients = collections.defaultdict(list) + for stacked_table_name, stacked_features in stacked_table_to_features.items(): + stacked_features.sort(key=lambda x: x[0].id_transformation.row_offset) + for f, g in stacked_features: + # feature.table_spec.embedding_dim is the original table dim, before + # padding + gradient = g.reshape([-1, f.table_spec.embedding_dim]) + # Add padding for extra cols + extra_cols = ( + f.table_spec.setting_in_stack.padded_embedding_dim + - f.table_spec.embedding_dim + ) + if extra_cols != 0: + gradient = jax.lax.pad(gradient, 0.0, [(0, 0, 0), (0, extra_cols, 0)]) + stacked_table_to_gradients[stacked_table_name].append(gradient) + return { + t: jax.lax.concatenate(grads, dimension=0) + for t, grads in stacked_table_to_gradients.items() + } + + diff --git a/jax_tpu_embedding/sparsecore/lib/nn/embedding_with_mini_batching.py b/jax_tpu_embedding/sparsecore/lib/nn/embedding_with_mini_batching.py new file mode 100644 index 0000000..671c48d --- /dev/null +++ b/jax_tpu_embedding/sparsecore/lib/nn/embedding_with_mini_batching.py @@ -0,0 +1,485 @@ +# Copyright 2024 The JAX SC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""List of functions for embedding lookup.""" + +from typing import Any, List, Mapping, Sequence, TypeAlias, TypeVar, Union + +from absl import logging +import jax +import jax.extend as jex +import jax.numpy as jnp +from jax_tpu_embedding.sparsecore.lib.core import input_preprocessing_with_mini_batching_cc +from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_csr_with_mini_batching +from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_grad_with_adagrad_with_mini_batching +from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_grad_with_sgd_with_mini_batching +from jax_tpu_embedding.sparsecore.lib.nn import embedding +from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec +from jax_tpu_embedding.sparsecore.lib.nn import embedding_utils +import numpy as np +import tree + + +ArrayLike = jnp.ndarray | np.ndarray + +T: TypeAlias = TypeVar("T") +Nested: TypeAlias = Union[T, Sequence[T], Mapping[str, T]] + +# Necessary for all configurations and all operations. +CONFIG_MODE = "MINI_BATCH_MODE" + +# Necessary for all configurations and all operations. Set to 1 in MODE_NONE. +CONFIG_SIZE = "MINI_BATCH_SIZE" + +# Supported modes of preprocessing. +# The definitions must be aligned with the ones in +# input_preprocessing_with_mini_batching.h. +MODE_NONE = 1 +MODE_VOCABULARY_DIMENSION = 2 +MODE_SAMPLE_DIMENSION = 3 +MODE_EXPERIMENTAL_FORCE_VOCABULARY_DIV = 200 +MODE_EXPERIMENTAL_FORCE_VOCABULARY_MOD = 201 + + +# SGD optimizer supporting mini-batching. +class SGDOptimizerSpec(embedding_spec.SGDOptimizerSpec): + """Spec for the Stochastic Gradient Descent (SGD) optimizer. + + An iterative optimization method that updates the weights of the embedding + variables by taking a step in the direction of the gradient. The step size is + controlled by the learning rate. + SGD is a usually a default choice in training setup. + + Attributes: + learning_rate: The learning rate for the training variables or embeddings. + """ + + def get_optimizer_primitive(self) -> jex.core.Primitive: + """Returns the optimizer primitive for the SGD optimizer.""" + return ( + sparse_dense_matmul_grad_with_sgd_with_mini_batching. + tpu_sparse_dense_matmul_grad_with_sgd_with_mini_batching_primitive + ) + + +# Adagrad optimizer supporting mini-batching. +class AdagradOptimizerSpec(embedding_spec.AdagradOptimizerSpec): + """Spec for the Adagrad optimizer. + + An Adagrad optimizer is an adaptive optimizer that adjusts the learning rate + for each embedding variable based on its past gradients. This helps in + reducing the number of steps needed for convergence, especially for sparse + data. + Attributes: + learning_rate: The learning rate for the training variables or embeddings. + initial_accumulator_value: The initial value for the accumulator slot + variable. This constant is used to initialize the accumulator slot + variable. + """ + + def get_optimizer_primitive(self) -> jex.core.Primitive: + return ( + sparse_dense_matmul_grad_with_adagrad_with_mini_batching. + tpu_sparse_dense_matmul_grad_with_adagrad_with_mini_batching_primitive + ) + + +def preprocess_sparse_dense_matmul_input( + local_batch_size: int, + indices: Sequence[Sequence[Sequence[int]]], + values: Sequence[Sequence[int]], + weights: Sequence[Sequence[float]] | Sequence[None], + feature_specs: List[embedding_spec.FeatureSpec], + mini_batching_config: Mapping[str, Any], + local_device_count: int, + global_device_count: int, + static_buffer_size_multiplier: int = 0, + num_sc_per_device: int = 4, + sparsecore_register_width: int = 8, + sharding_strategy: str = "MOD", + has_leading_dimension: bool = False, +) -> tuple[ + Mapping[str, np.ndarray], + Mapping[str, np.ndarray], + Mapping[str, np.ndarray], + Mapping[str, np.ndarray], + int, + Mapping[str, np.ndarray], +]: + """Preprocesses the input for sparse dense matmul. + + Args: + local_batch_size: The number of samples in this batch. This is called a + 'local' batch because it is the combined batch size for all local devices. + indices: The indices to values and weights. The first dimension is over the + features. The second dimension is over the samples. All elements are + expected to be 64bit integers. + values: The values to process. The outer list is over the features. All + elements are expected to be 32bit integers. + weights: The weights associated with the values. The outer list is over + the features. All elements are expected to be 32bit floats. If the weights + are None for some or all features, the computation would assume the + weights are 1.0. + feature_specs: The feature specs. The order of this list must be aligned + with the order of the indices, values, and weights lists. + mini_batching_config: The mini-batching config. This is a dictionary + containing the mini-batching mode and the mini-batch size. More + configuration items will be added in the future. + local_device_count: The number of local devices (chips). Typically + `mesh.local_mesh.size`. + global_device_count: The number of global devices (chips). Typically + `mesh.size`. + static_buffer_size_multiplier: If larger than 0, this is the multiplier that + is used to determine the size of the static buffers (lhs_embedding_ids, + lhs_sample_ids and lhs_gains). The size of the buffer returned is + static_buffer_size_multiplier x batch_size. If less than or equal to 0, + the size of the buffer is determined based off of the + max_ids_per_partition limits. + num_sc_per_device: The number of sparse cores per device. + sparsecore_register_width: The width of the sparsecore registers. This is + hardware dependent and would change based on the underlying TPU + generations. + sharding_strategy: The sharding strategy (e.g., MOD) + has_leading_dimension: If set to True, then the first dimension of the + output will be the number of local devices. This is useful when using the + output in jax.pmap. If set to False, then the first dimension of the + output will be the number of local devices * the static buffer size. This + is useful when using the output in jax.jit. In conclusion, Set it to True + if using jax.pmap and set it to False if using jax.jit. Currently, + only jax.pmap is supported. + + Returns: + A tuple of four dictionaries mapping the stacked table names to the + preprocessed inputs for the corresponding table. The four dictionaries are + lhs_row_pointers, lhs_embedding_ids, lhs_sample_ids, and lhs_gains. The + tuple also contains the resulting mini-batch size and a dictionary + containing the statistics. The mini-batch size could be different from the + the input mini-batch config, and should be used as the actual mini-batch + size in later lookup and update calls. The statistics dictionary contains + the max ids per table, the max unique ids per table, and the id drop + counters per table. + """ + + return ( + input_preprocessing_with_mini_batching_cc.PreprocessSparseDenseMatmulInputWithBCOO( + local_batch_size, + indices, + values, + weights, + feature_specs, + mini_batching_config, + local_device_count, + global_device_count, + static_buffer_size_multiplier, + num_sc_per_device, + sparsecore_register_width, + embedding_utils.sharding_strategy_to_int(sharding_strategy), + has_leading_dimension, + ) + ) + + +def flatten_features_and_weights( + features: Mapping[str, ArrayLike], + weights: Mapping[str, ArrayLike], + flatten_feature_specs: Sequence[embedding_spec.FeatureSpec], +) -> tuple[ + int, + Sequence[Sequence[Sequence[int]]], + Sequence[Sequence[int]], + Sequence[Sequence[float]], +]: + """Transforms features and weights from numpy arrays to sparse BCOO format. + + This function is used to transform the features and weights from numpy arrays + to sparse BCOO format. The sparse BCOO format is used to store the features + and weights in a way that is efficient for sparsecore API. The returned tuple + is suitable for the sparse dense matmul API. Note that this function is not + performant and should only be used for testing purposes. The expectation is + that the input features and weights should already be in the sparse BCOO + format. + + Args: + features: The features to process. The keys are the feature names. + weights: The weights associated with the values. The keys are the feature + names. Weights can be None for some or all features. In this case, the + resulting flattened weights would also be None. + flatten_feature_specs: The feature specs. The resulting flattened indices + and values would be aligned with the feature order of this list. + + Returns: + A tuple containing the local batch size, the flattened indices, the + flattened values, and the flattened weights, to be fed to + preprocess_sparse_dense_matmul_input. + """ + local_batch_size = 0 + flatten_indices = [] + flatten_values = [] + flatten_weights = [] + + assert flatten_feature_specs, "Feature specs must not be empty." + assert features, "Features must not be empty." + assert weights, "Weights must not be empty." + + assert len(features) == len(weights), ( + "Features and weights must have the same length." + ) + assert len(features) == len(flatten_feature_specs), ( + "Features and feature specs must have the same length." + ) + + for feature_spec in flatten_feature_specs: + feature_name = feature_spec.name + current_feature = features[feature_name] + current_weights = weights[feature_name] + if local_batch_size == 0: + local_batch_size = current_feature.shape[0] + assert local_batch_size > 0, "Batch size must be greater than 0." + else: + assert ( + local_batch_size == current_feature.shape[0] + ), "Batch size must be the same for all features." + + # Create the indices array to point to all values and weights. + index_size = 0 + for row in current_feature: + index_size += len(row) + indices = np.empty((index_size, 2), dtype=np.int64) + + # Create the values array to store all the values (embedding ids). + concatenated_values = np.empty(index_size, dtype=np.int32) + + # Optionally create the weights array to store all the weights. + concatenated_weights = None + if current_weights is not None: + concatenated_weights = np.empty(index_size, dtype=np.float32) + + # Populate the indices, values, and optionally the weights arrays. + index_cursor = 0 + for sample_index in range(local_batch_size): + sample_length = len(current_feature[sample_index]) + for elem_index in range(sample_length): + indices[index_cursor] = [sample_index, elem_index] + concatenated_values[index_cursor] = current_feature[sample_index][ + elem_index + ] + if concatenated_weights is not None: + concatenated_weights[index_cursor] = current_weights[sample_index][ + elem_index + ] + index_cursor += 1 + + flatten_indices.append(indices) + flatten_values.append(concatenated_values) + flatten_weights.append(concatenated_weights) + + return (local_batch_size, flatten_indices, flatten_values, flatten_weights) + + +@jax.named_call +def tpu_sparse_dense_matmul( + lhs_row_pointers: Mapping[str, jax.Array], + lhs_embedding_ids: Mapping[str, jax.Array], + lhs_sample_ids: Mapping[str, jax.Array], + lhs_gains: Mapping[str, jax.Array], + embedding_variables: Mapping[str, embedding.EmbeddingVariables], + feature_specs: Nested[embedding_spec.FeatureSpec], + mini_batching_config: Mapping[str, Any], + global_device_count: int, + sharding_strategy: str = "MOD", +) -> Nested[jax.Array]: + """Computes the sparse dense matmul, or embedding lookup. + + Check the docstring of `tpu_sparse_dense_matmul` in embedding.py for + more details. + + Args: + lhs_row_pointers: The row pointers to process. The keys are the stacked + table names. + lhs_embedding_ids: The embedding ids to process. The keys are the stacked + table names. Must have same structure as `lhs_row_pointers`. + lhs_sample_ids: The sample ids to process. The keys are the stacked table + names. Must have same structure as `lhs_row_pointers`. + lhs_gains: The gains to process. The keys are the stacked table names. Must + have same structure as `lhs_row_pointers`. + embedding_variables: A tuple of embedding tables and slot variables. The + first one is always the embedding table, the following ones are slot + variables. The tree structure must be identical to the lhs_row_pointers. + feature_specs: The input features for the current process. + mini_batching_config: The mini-batching config. Note that the mini-batch + size in the config should come from the result of each preprocess call. + global_device_count: The number of global devices (chips). Typically + `mesh.size`. + sharding_strategy: The sharding strategy (e.g., MOD) + + Returns: + The activations structure with the same structure as feature_specs. + + Raises: + ValueError: The input arrays and tuples are not of the expected structure or + the sharding strategy is not supported. + """ + assert lhs_row_pointers.keys() == lhs_embedding_ids.keys() + assert lhs_row_pointers.keys() == lhs_gains.keys() + assert lhs_row_pointers.keys() == lhs_sample_ids.keys() + assert lhs_row_pointers.keys() == embedding_variables.keys() + + stacked_table_specs = embedding.get_stacked_table_specs(feature_specs) + assert lhs_row_pointers.keys() == stacked_table_specs.keys() + + sharding_strategy = embedding_utils.sharding_strategy_to_enum( + sharding_strategy + ) + + activations = {} + for stacked_table_name in stacked_table_specs: + row_pointer = lhs_row_pointers[stacked_table_name] + embedding_id = lhs_embedding_ids[stacked_table_name] + sample_id = lhs_sample_ids[stacked_table_name] + gain = lhs_gains[stacked_table_name] + embedding_variable = embedding_variables[stacked_table_name] + stacked_table = stacked_table_specs[stacked_table_name] + + if mini_batching_config[CONFIG_MODE] in ( + MODE_VOCABULARY_DIMENSION, + MODE_EXPERIMENTAL_FORCE_VOCABULARY_DIV, + MODE_EXPERIMENTAL_FORCE_VOCABULARY_MOD, + ): + activations[stacked_table.stack_name] = ( + sparse_dense_matmul_csr_with_mini_batching.tpu_sparse_dense_matmul_csr_with_mini_batching_primitive.bind( + row_pointer, + embedding_id, + sample_id, + gain, + mini_batching_config[CONFIG_SIZE], + embedding_variable[0], # [0] is the embedding table + device_batch_size=(stacked_table.total_sample_count + // global_device_count), + max_ids_per_partition=stacked_table.max_ids_per_partition, + max_unique_ids_per_partition=stacked_table.max_unique_ids_per_partition, + sharding_strategy=sharding_strategy, + ) + ) + else: + raise ValueError( + f"Unsupported mini-batching mode: {mini_batching_config[CONFIG_MODE]}" + ) + + return embedding_utils.unstack_embedding_activations( + activations, feature_specs, global_device_count + ) + + +@jax.named_call +def tpu_sparse_dense_matmul_grad( + activation_gradients: Nested[jax.Array], + lhs_row_pointers: Mapping[str, jax.Array], + lhs_embedding_ids: Mapping[str, jax.Array], + lhs_sample_ids: Mapping[str, jax.Array], + lhs_gains: Mapping[str, jax.Array], + embedding_variables: Mapping[str, embedding.EmbeddingVariables], + feature_specs: Nested[embedding_spec.FeatureSpec], + mini_batching_config: Mapping[str, Any], + sharding_strategy: str = "MOD", + label: str = "", +) -> Mapping[str, embedding.EmbeddingVariables]: + """Computes the updated embedding variables based on the activation gradients. + + Check the docstring of `tpu_sparse_dense_matmul_grad` in embedding.py for + more details. + + Args: + activation_gradients: The activation gradients. + lhs_row_pointers: The row pointers to process. The keys are the stacked + table names. + lhs_embedding_ids: The embedding ids to process. The keys are the stacked + table names. Must have same structure as `lhs_row_pointers`. + lhs_sample_ids: The sample ids to process. The keys are the stacked table + names. Must have same structure as `lhs_row_pointers`. + lhs_gains: The gains to process. The keys are the stacked table names. Must + have same structure as `lhs_row_pointers`. + embedding_variables: A tuple of embedding tables and slot variables. The + first one is always the embedding table, the following ones are slot + variables. The tree structure must be identical to the lhs_row_pointers. + feature_specs: The input features for the current process. + mini_batching_config: The mini-batching config. Note that the mini-batch + size in the config should come from the result of each preprocess call. + sharding_strategy: The sharding strategy (e.g., MOD) + label: The label for the optimizer computation. + + Returns: + The updated activation embedding variables. + """ + + # Verify the input structures and lengths. + assert lhs_row_pointers.keys() == lhs_embedding_ids.keys() + assert lhs_row_pointers.keys() == lhs_gains.keys() + assert lhs_row_pointers.keys() == lhs_sample_ids.keys() + assert lhs_row_pointers.keys() == embedding_variables.keys() + # Activations match the feature specs structure + tree.assert_same_structure(feature_specs, activation_gradients) + + stacked_table_specs = embedding.get_stacked_table_specs(feature_specs) + assert lhs_row_pointers.keys() == stacked_table_specs.keys() + + gradients = embedding_utils.stack_embedding_gradients( + activation_gradients, feature_specs + ) + assert lhs_row_pointers.keys() == gradients.keys() + + sharding_strategy = embedding_utils.sharding_strategy_to_enum( + sharding_strategy + ) + + updated_embedding_variables = {} + for stacked_table_name in stacked_table_specs: + row_pointer = lhs_row_pointers[stacked_table_name] + embedding_id = lhs_embedding_ids[stacked_table_name] + sample_id = lhs_sample_ids[stacked_table_name] + gain = lhs_gains[stacked_table_name] + embedding_variable = embedding_variables[stacked_table_name] + activation_gradient = gradients[stacked_table_name] + stack_table_spec = stacked_table_specs[stacked_table_name] + learning_rate = stack_table_spec.optimizer.get_learning_rate() + hyper_params = [learning_rate] + # The MLIR computation symbol names need to be different. We attach the + # table name to the symbol name to ensure that. + symbol_name = "{}-{}{}".format( + stack_table_spec.optimizer.short_name(), + stack_table_spec.stack_name, + label, + ) + optimizer_primitive = stack_table_spec.optimizer.get_optimizer_primitive() + + flatten_variables, treedef = jax.tree.flatten(embedding_variable) + updated_variables = optimizer_primitive.bind( + row_pointer, + embedding_id, + sample_id, + gain, + mini_batching_config[CONFIG_SIZE], + *flatten_variables, + activation_gradient, + *hyper_params, + max_ids_per_partition=stack_table_spec.max_ids_per_partition, + max_unique_ids_per_partition=stack_table_spec.max_unique_ids_per_partition, + computation_name=symbol_name, + sharding_strategy=sharding_strategy, + ) + + updated_embedding_variables[stacked_table_name] = jax.tree.unflatten( + treedef, + jax.tree.leaves(updated_variables), + ) + + return updated_embedding_variables diff --git a/jax_tpu_embedding/sparsecore/lib/nn/tests/test_utils.py b/jax_tpu_embedding/sparsecore/lib/nn/tests/test_utils.py index ccd41e2..61a18d4 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/tests/test_utils.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/tests/test_utils.py @@ -90,6 +90,36 @@ def init(key, shape) -> jax.Array: return init +def row_id_with_offset_initializer_value( + offset_value: int, row: int +) -> jax.numpy.float32: + """Returns the value for row_col_id_initializer.""" + return offset_value + row + + +def row_id_with_offset_initializer( + offset_value: int = 0, +) -> jax.nn.initializers.Initializer: + """Initializes a table with offset value + row id.""" + + def create_array(offset_value, shape): + rows, cols = shape + result = jax.numpy.zeros(shape, dtype=jax.numpy.float32) + for i in range(rows): + for j in range(cols): + result = result.at[i, j].set( + row_id_with_offset_initializer_value(offset_value, i) + ) + + return result + + def init(key, shape) -> jax.Array: + del key + return create_array(offset_value, shape) + + return init + + def rotate_sharded_table( embedding_table: jax.Array, rotation: int ) -> jax.Array: