Skip to content

Commit

Permalink
Merge pull request #2029 from DARMA-tasking/2028-setcontext-doesnt-su…
Browse files Browse the repository at this point in the history
…spend-correct-task-with-threads

2028 setcontext doesnt suspend correct task with threads
  • Loading branch information
PhilMiller authored Dec 13, 2022
2 parents 63376f6 + df9e052 commit 712749c
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 18 deletions.
11 changes: 9 additions & 2 deletions src/vt/context/runnable_context/set_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ void SetContext::suspend() {
print_ptr(this), print_ptr(prev_task_.get()), print_ptr(cur_task_.get())
);

finish();
// Get the innermost task, that actually called for this thread to block
suspended_task_ = theContext()->getTask();

vtAssert(prev_task_ == nullptr, "There should be no prev_task");

ContextAttorney::setTask(prev_task_);
}

void SetContext::resume() {
Expand All @@ -90,7 +95,9 @@ void SetContext::resume() {
print_ptr(this), print_ptr(prev_task_.get()), print_ptr(cur_task_.get())
);

start();
prev_task_ = theContext()->getTask();
ContextAttorney::setTask(suspended_task_);
suspended_task_ = nullptr;
}

}} /* end namespace vt::ctx */
1 change: 1 addition & 0 deletions src/vt/context/runnable_context/set_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ struct SetContext {
util::ObserverPtr<runnable::RunnableNew> prev_task_ = nullptr;
/// The new runnable that is replacing it
util::ObserverPtr<runnable::RunnableNew> cur_task_ = nullptr;
util::ObserverPtr<runnable::RunnableNew> suspended_task_ = nullptr;
NodeType node_ = uninitialized_destination; /**< The from node */
};

Expand Down
19 changes: 16 additions & 3 deletions src/vt/scheduler/thread_action.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,18 @@ ThreadAction::~ThreadAction() {

void ThreadAction::run() {
ctx_ = make_fcontext_stack(stack_, runFnImpl);

// `jump_fcontext` transfers control to the fcontext thread indicated by `ctx_`
// that we just set up. Control will 'return' from `jump_fcontext` to here when
// that thread eventually suspends or completes
auto prev_running = cur_running_;
cur_running_ = this;

transfer_in_ = jump_fcontext(ctx_, static_cast<void*>(this));

// reset the current running thread ID after we finish or suspend the current
// thread
cur_running_ = prev_running;
}

void ThreadAction::resume() {
Expand All @@ -85,8 +96,13 @@ void ThreadAction::resume() {
return;
}

// the same logic in run() applies here
auto prev_running = cur_running_;
cur_running_ = this;

transfer_in_ = jump_fcontext(transfer_in_.ctx, nullptr);

cur_running_ = prev_running;
}

void ThreadAction::runUntilDone() {
Expand All @@ -104,10 +120,8 @@ void ThreadAction::runUntilDone() {

auto ta = static_cast<ThreadAction*>(t.data);
if (ta->action_) {
cur_running_ = ta;
ta->transfer_out_ = t;
ta->action_();
cur_running_ = nullptr;
}

vt_debug_print(
Expand All @@ -122,7 +136,6 @@ void ThreadAction::runUntilDone() {
/*static*/ void ThreadAction::suspend() {
if (cur_running_ != nullptr) {
auto x = cur_running_;
cur_running_ = nullptr;
vt_debug_print(
normal, gen,
"suspend\n"
Expand Down
107 changes: 94 additions & 13 deletions tests/unit/active/test_async_op_threads.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,21 @@
#include <vt/messaging/async_op_mpi.h>
#include <vt/messaging/active.h>
#include <vt/objgroup/headers.h>
#include <vt/transport.h>

#include <gtest/gtest.h>


#if vt_check_enabled(fcontext)

namespace vt { namespace tests { namespace unit { namespace threads {

using TestAsyncOpThreads = TestParallelHarness;

using MyMsg = Message;
static std::size_t stack_size_before_running_handler = 0;

struct MyCol : vt::Collection<MyCol, vt::Index1D> {

struct MyObjGroup {
using MyMsg = vt::CollectionMessage<MyCol>;

void handler(MyMsg* msg) {
auto const this_node = theContext()->getNode();
Expand All @@ -69,9 +71,8 @@ struct MyObjGroup {
from_node_ = num_nodes - 1;
}

// get the epoch stack and store the original size
// save a reference to the epoch stack
auto& epoch_stack = theTerm()->getEpochStack();
std::size_t original_epoch_size = epoch_stack.size();

auto comm = theContext()->getComm();
int const tag = 299999;
Expand All @@ -91,11 +92,11 @@ struct MyObjGroup {
}
auto op2 = std::make_unique<messaging::AsyncOpMPI>(
req2,
[this,original_epoch_size]{
[this]{
done_ = true;
// stack should be the size before running this method since we haven't
// resumed the thread yet!
EXPECT_EQ(theTerm()->getEpochStack().size(), original_epoch_size - 2);
EXPECT_EQ(theTerm()->getEpochStack().size(), size_stack_before_running_handler);
}
);

Expand All @@ -104,7 +105,7 @@ struct MyObjGroup {
theMsg()->pushEpoch(cur_ep);
theMsg()->pushEpoch(cur_ep);

EXPECT_EQ(epoch_stack.size(), original_epoch_size + 2);
auto const stack_size_after_push = epoch_stack.size();

// Register these async operations to block the user-level thread until
// completion of the MPI request; since these operations are enclosed in an
Expand All @@ -114,13 +115,13 @@ struct MyObjGroup {
theMsg()->blockOnAsyncOp(std::move(op1));
vt_print(gen, "done with op1\n");

EXPECT_EQ(epoch_stack.size(), original_epoch_size + 2);
EXPECT_EQ(epoch_stack.size(), stack_size_after_push);

vt_print(gen, "call blockOnAsyncOp(op2)\n");
theMsg()->blockOnAsyncOp(std::move(op2));
vt_print(gen, "done with op2\n");

EXPECT_EQ(epoch_stack.size(), original_epoch_size + 2);
EXPECT_EQ(epoch_stack.size(), stack_size_after_push);

check();

Expand All @@ -129,6 +130,57 @@ struct MyObjGroup {
theMsg()->popEpoch(cur_ep);
}

void handlerInvoke(MyMsg* msg) {
auto const this_node = theContext()->getNode();
auto const num_nodes = theContext()->getNumNodes();
auto const to_node = (this_node + 1) % num_nodes;
from_node_ = this_node - 1;
if (from_node_ < 0) {
from_node_ = num_nodes - 1;
}

auto comm = theContext()->getComm();
int const tag = 299999;

MPI_Request req1;
send_val_ = this_node;
{
VT_ALLOW_MPI_CALLS; // MPI_Isend
MPI_Issend(&send_val_, 1, MPI_INT, to_node, tag, comm, &req1);
}
auto op1 = std::make_unique<messaging::AsyncOpMPI>(req1);

MPI_Request req2;
{
VT_ALLOW_MPI_CALLS; // MPI_Irecv
MPI_Irecv(&recv_val_, 1, MPI_INT, from_node_, tag, comm, &req2);
}
auto op2 = std::make_unique<messaging::AsyncOpMPI>(
req2,
[this]{ done_ = true; }
);

auto p = getCollectionProxy();
p[this_node].invoke<decltype(&MyCol::handlerToInvoke), &MyCol::handlerToInvoke>(std::move(op1),std::move(op2));
}

void handlerToInvoke(
std::unique_ptr<messaging::AsyncOpMPI> op1,
std::unique_ptr<messaging::AsyncOpMPI> op2
) {

vt_print(gen, "call blockOnAsyncOp(op1) inside invoke\n");
theMsg()->blockOnAsyncOp(std::move(op1));
vt_print(gen, "done with op1 inside invoke\n");

vt_print(gen, "call blockOnAsyncOp(op2) inside invoke\n");
theMsg()->blockOnAsyncOp(std::move(op2));
vt_print(gen, "done with op2 inside invoke\n");

check();
}


void check() {
vt_print(gen, "running check method\n");
EXPECT_EQ(from_node_, recv_val_);
Expand All @@ -145,18 +197,47 @@ struct MyObjGroup {

TEST_F(TestAsyncOpThreads, test_async_op_threads_1) {
auto const this_node = theContext()->getNode();
auto p = theObjGroup()->makeCollective<MyObjGroup>("test_async_op_threads_1");

stack_size_before_running_handler = theTerm()->size_getEpochStack().size();

vt::Index1D range(static_cast<int>(theContext()->getNumNodes()));
auto p = vt::makeCollection<MyCol>("test_async_op_threads_invoke")
.bounds(range)
.bulkInsert()
.wait();

auto ep = theTerm()->makeEpochRooted(term::UseDS{true});

// When this returns all the MPI requests should be done
runInEpoch(ep, [p, this_node]{
p[this_node].send<typename MyCol::MyMsg, &MyCol::handler>();
});

// Ensure the check method actually ran.
EXPECT_TRUE(p[this_node].tryGetLocalPtr()->check_done_);
}

TEST_F(TestAsyncOpThreads, test_async_op_threads_invoke_2) {
auto const this_node = theContext()->getNode();

vt::Index1D range(static_cast<int>(theContext()->getNumNodes()));
auto p = vt::makeCollection<MyCol>("test_async_op_threads_invoke")
.bounds(range)
.bulkInsert()
.wait();

auto ep = theTerm()->makeEpochRooted(term::UseDS{true});

// When this returns all the MPI requests should be done
runInEpoch(ep, [p, this_node]{
p[this_node].send<MyMsg, &MyObjGroup::handler>();
p[this_node].send<typename MyCol::MyMsg, &MyCol::handlerInvoke>();
});

// Ensure the check method actually ran.
EXPECT_TRUE(p[this_node].get()->check_done_);
EXPECT_TRUE(p[this_node].tryGetLocalPtr()->check_done_);
}


}}}} // end namespace vt::tests::unit::threads

#endif /*vt_check_enabled(fcontext)*/
Loading

0 comments on commit 712749c

Please sign in to comment.