Skip to content

Commit

Permalink
Input preprocessing library to support vocab-dimension mini-batching.
Browse files Browse the repository at this point in the history
Currently only PMAP is supported for simplicity. JAX support will be added later.

PiperOrigin-RevId: 716024250
  • Loading branch information
Google-ML-Automation committed Jan 18, 2025
1 parent b91e774 commit f7bc071
Show file tree
Hide file tree
Showing 14 changed files with 2,099 additions and 161 deletions.
52 changes: 50 additions & 2 deletions jax_tpu_embedding/sparsecore/lib/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension")
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension", "pybind_library")
load("//third_party/bazel/python:pypi.bzl", "pypi_requirement")
load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test", "pytype_strict_library")

Expand Down Expand Up @@ -72,17 +72,65 @@ cc_test(
],
)

pybind_library(
name = "input_preprocessing_py_util",
srcs = [
"input_preprocessing_py_util.cc",
],
hdrs = [
"input_preprocessing_py_util.h",
],
deps = [
":input_preprocessing_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@tsl//tsl/profiler/lib:traceme",
],
)

pybind_extension(
name = "input_preprocessing_cc",
srcs = ["input_preprocessing.cc"],
srcs = [
"input_preprocessing.cc",
],
deps = [
":input_preprocessing_py_util",
":input_preprocessing_threads",
":input_preprocessing_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@highway//:hwy",
"@highway//hwy/contrib/sort:vqsort",
"@tsl//tsl/profiler/lib:connected_traceme",
"@tsl//tsl/profiler/lib:traceme",
],
)

pybind_extension(
name = "input_preprocessing_with_mini_batching_cc",
srcs = [
"input_preprocessing_with_mini_batching.cc",
"input_preprocessing_with_mini_batching.h",
],
deps = [
":input_preprocessing_py_util",
":input_preprocessing_threads",
":input_preprocessing_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@highway//:hwy",
"@highway//hwy/contrib/sort:vqsort",
"@tsl//tsl/profiler/lib:connected_traceme",
"@tsl//tsl/profiler/lib:traceme",
],
Expand Down
44 changes: 1 addition & 43 deletions jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cmath>
#include <optional>
#include <string>
#include <utility>
Expand All @@ -24,6 +23,7 @@
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/synchronization/blocking_counter.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h"
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
#include "pybind11/cast.h" // from @pybind11
Expand Down Expand Up @@ -148,48 +148,6 @@ int ExtractCooTensors(const py::array& features,
global_device_count, coo_tensors);
}

absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
GetStackedTableMetadata(py::list feature_specs, py::list features) {
tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; });
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
stacked_table_metadata;
for (int i = 0; i < feature_specs.size(); ++i) {
const py::object& feature_spec = feature_specs[i];
const py::array& feature = features[i].cast<py::array>();
const py::object& feature_transformation =
feature_spec.attr("_id_transformation");
const py::object& table_spec = feature_spec.attr("table_spec");
const py::object& stacked_table_spec =
table_spec.attr("stacked_table_spec");
const std::string stacked_table_name = py::cast<std::string>(
table_spec.attr("_setting_in_stack").attr("stack_name"));
int col_shift = 0;
int col_offset = 0;
int row_offset = 0;
const int max_ids_per_partition =
py::cast<int>(stacked_table_spec.attr("max_ids_per_partition"));
const int max_unique_ids_per_partition =
py::cast<int>(stacked_table_spec.attr("max_unique_ids_per_partition"));
if (!feature_transformation.is_none()) {
row_offset = py::cast<int>(feature_transformation.attr("row_offset"));
col_shift = py::cast<int>(feature_transformation.attr("col_shift"));
col_offset = py::cast<int>(feature_transformation.attr("col_offset"));
}
stacked_table_metadata[stacked_table_name].emplace_back(
i, max_ids_per_partition, max_unique_ids_per_partition, row_offset,
col_offset, col_shift,
/*batch_size=*/feature.shape(0));
}
// Sort the stacked tables by row_offset.
for (auto& [_, t] : stacked_table_metadata) {
std::sort(t.begin(), t.end(),
[](const StackedTableMetadata& a, const StackedTableMetadata& b) {
return a.row_offset < b.row_offset;
});
}
return stacked_table_metadata;
}

// Preprocess inputs for a single table. Stacked table here refers to a
// a table that has no parent in the table stacking hierarchy. So in the case
// of table stacking, the stacked table is the top level table and in the case
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2024 The JAX SC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h"

#include <algorithm>
#include <cmath>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "absl/log/check.h" // from @com_google_absl
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
#include "pybind11/cast.h" // from @pybind11
#include "pybind11/gil.h" // from @pybind11
#include "pybind11/numpy.h" // from @pybind11
#include "pybind11/pybind11.h" // from @pybind11
#include "pybind11/pytypes.h" // from @pybind11
#include "tsl/profiler/lib/traceme.h" // from @tsl

namespace jax_sc_embedding {

namespace py = ::pybind11;

absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
GetStackedTableMetadata(const py::list& feature_specs, const int batch_size) {
tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; });
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
stacked_table_metadata;
for (int i = 0; i < feature_specs.size(); ++i) {
const py::object& feature_spec = feature_specs[i];

const py::object& feature_transformation =
feature_spec.attr("_id_transformation");
const py::object& table_spec = feature_spec.attr("table_spec");
const py::object& stacked_table_spec =
table_spec.attr("stacked_table_spec");
const std::string stacked_table_name = py::cast<std::string>(
table_spec.attr("_setting_in_stack").attr("stack_name"));
int col_shift = 0;
int col_offset = 0;
int row_offset = 0;
const int max_ids_per_partition =
py::cast<int>(stacked_table_spec.attr("max_ids_per_partition"));
const int max_unique_ids_per_partition =
py::cast<int>(stacked_table_spec.attr("max_unique_ids_per_partition"));
const int vocab_size =
py::cast<int>(stacked_table_spec.attr("stack_vocab_size"));
if (!feature_transformation.is_none()) {
row_offset = py::cast<int>(feature_transformation.attr("row_offset"));
col_shift = py::cast<int>(feature_transformation.attr("col_shift"));
col_offset = py::cast<int>(feature_transformation.attr("col_offset"));
}
stacked_table_metadata[stacked_table_name].emplace_back(
i, max_ids_per_partition, max_unique_ids_per_partition, row_offset,
col_offset, col_shift,
/*batch_size=*/batch_size, vocab_size);
}
// Sort the stacked tables by row_offset.
for (auto& [_, t] : stacked_table_metadata) {
std::sort(t.begin(), t.end(),
[](const StackedTableMetadata& a, const StackedTableMetadata& b) {
return a.row_offset < b.row_offset;
});
}
return stacked_table_metadata;
}

absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
GetStackedTableMetadata(const py::list& feature_specs,
const py::list& features) {
tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; });
int batch_size = features[0].cast<py::array>().shape(0);
return GetStackedTableMetadata(feature_specs, batch_size);
}

} // namespace jax_sc_embedding
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2024 The JAX SC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_
#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
#include "pybind11/numpy.h" // from @pybind11
#include "pybind11/pytypes.h" // from @pybind11

namespace jax_sc_embedding {

namespace py = ::pybind11;

// Copy information from feature_specs to StackedTableMetadata.
// The features argument is only used to get the batch size.
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
GetStackedTableMetadata(const py::list& feature_specs,
const py::list& features);

// Copy information from feature_specs to StackedTableMetadata.
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
GetStackedTableMetadata(const py::list& feature_specs, int batch_size);

} // namespace jax_sc_embedding

#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_
10 changes: 7 additions & 3 deletions jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct CooFormat {

// Get adjusted col_id based on shift and offset.
int GetColId(int col_id, int col_shift, int col_offset, int num_scs_mod,
int num_scs_mod_inv);
int num_scs_mod_inv);

inline unsigned int CeilOfRatio(unsigned int numerator,
unsigned int denominator) {
Expand All @@ -50,14 +50,16 @@ struct StackedTableMetadata {
StackedTableMetadata() = delete;
StackedTableMetadata(int feature_index, int max_ids_per_partition,
int max_unique_ids_per_partition, int row_offset,
int col_offset, int col_shift, int batch_size)
int col_offset, int col_shift, int batch_size,
int stacked_table_vocab_size = 0)
: feature_index(feature_index),
max_ids_per_partition(max_ids_per_partition),
max_unique_ids_per_partition(max_unique_ids_per_partition),
row_offset(row_offset),
col_offset(col_offset),
col_shift(col_shift),
batch_size(batch_size) {}
batch_size(batch_size),
stacked_table_vocab_size(stacked_table_vocab_size) {}
// The batch is given as a list of features (numpy arrays). `feature_index`
// represents the index of the feature in the list.
int feature_index;
Expand All @@ -70,6 +72,8 @@ struct StackedTableMetadata {

// Process local batch size of the feature.
int batch_size;

int stacked_table_vocab_size;
};

void SortAndGroupCooTensors(
Expand Down
Loading

0 comments on commit f7bc071

Please sign in to comment.