Skip to content

Commit

Permalink
Internal change only
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698218778
  • Loading branch information
SiqiaoWu1993 authored and tensorflower-gardener committed Nov 20, 2024
1 parent 43227b2 commit 79dee4c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/tfrt/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ cc_library(
":ifrt_restore_tensor_registry",
"//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime:function",
"//tensorflow/core/framework:attr_value_proto_cc",
"//tensorflow/core/framework:node_def_util",
Expand Down
13 changes: 8 additions & 5 deletions tensorflow/core/tfrt/ifrt/checkpoint_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
Expand Down Expand Up @@ -274,12 +275,14 @@ absl::Status RunShard(RestoreVariableShard shard,
if (!use_async_restore) {
RunShardHelper(runner, async_state.get(), shard);
} else {
tensorflow::Context bg_context(tensorflow::ContextKind::kThread);
// Use dedicated work queue for restore operation.
checkpoint_loader_work_queue->AddTask([runner = std::move(runner),
async_state = std::move(async_state),
shard = std::move(shard)]() {
RunShardHelper(runner, async_state.get(), shard);
});
checkpoint_loader_work_queue->AddTask(
[runner = std::move(runner), async_state = std::move(async_state),
shard = std::move(shard), bg_context = std::move(bg_context)]() {
tensorflow::WithContext wc(bg_context);
RunShardHelper(runner, async_state.get(), shard);
});
}

return absl::OkStatus();
Expand Down

0 comments on commit 79dee4c

Please sign in to comment.