diff --git a/examples/collection/transpose.cc b/examples/collection/transpose.cc index f0b79d41ef..82c1932fb2 100644 --- a/examples/collection/transpose.cc +++ b/examples/collection/transpose.cc @@ -43,6 +43,7 @@ */ #include +#include "vt/runtime/mpi_access.h" #include @@ -206,9 +207,13 @@ vt::NodeType my_map(IndexT* idx, IndexT* max_idx, vt::NodeType num_nodes) { // group-targeted handler for the sub-solve /*static*/ void SubSolveInfo::subSolveHandler(SubSolveMsg* msg) { + // Normally MPI calls are NOT ALLOWED inside user code. + VT_ALLOW_MPI_CALLS; + vt::NodeType this_node = vt::theContext()->getNode(); vt::NodeType num_nodes = vt::theContext()->getNumNodes(); + auto const group_id = vt::envelopeGetGroup(msg->env); MPI_Comm sub_comm = vt::theGroup()->getGroupComm(group_id); int sub_size = 0, sub_rank = 0; diff --git a/src/vt/event/event_record.cc b/src/vt/event/event_record.cc index a925c8fff4..b30c7d234c 100644 --- a/src/vt/event/event_record.cc +++ b/src/vt/event/event_record.cc @@ -44,6 +44,7 @@ #include "vt/event/event.h" #include "vt/event/event_record.h" +#include "vt/runtime/mpi_access.h" #include @@ -53,6 +54,7 @@ namespace vt { namespace event { EventRecord::EventRecord(EventRecordType const& type, EventType const& id) : event_id_(id), type_(type) { + VT_ALLOW_MPI_CALLS; switch (type) { case EventRecordType::MPI_EventRecord: @@ -79,6 +81,8 @@ EventRecord::EventRecord(EventRecordType const& type, EventType const& id) } bool EventRecord::testMPIEventReady() { + VT_ALLOW_MPI_CALLS; + int flag = 0; MPI_Request* req = getRequest(); MPI_Status stat; diff --git a/src/vt/group/collective/group_info_collective.cc b/src/vt/group/collective/group_info_collective.cc index 722f2bd91b..2ff1e5c30d 100644 --- a/src/vt/group/collective/group_info_collective.cc +++ b/src/vt/group/collective/group_info_collective.cc @@ -56,6 +56,7 @@ #include "vt/collective/tree/tree.h" #include "vt/collective/collective_alg.h" #include "vt/collective/collective_ops.h" +#include "vt/runtime/mpi_access.h" #include #include @@ -101,6 +102,8 @@ MPI_Comm InfoColl::getComm() const { } void InfoColl::freeComm() { + VT_ALLOW_MPI_CALLS; + if (mpi_group_comm != MPI_COMM_WORLD) { MPI_Comm_free(&mpi_group_comm); mpi_group_comm = MPI_COMM_WORLD; @@ -108,6 +111,8 @@ void InfoColl::freeComm() { } void InfoColl::setupCollective() { + VT_ALLOW_MPI_CALLS; + auto const& this_node = theContext()->getNode(); auto const& num_nodes = theContext()->getNumNodes(); auto const& group_ = getGroupID(); diff --git a/src/vt/messaging/active.cc b/src/vt/messaging/active.cc index 16f01452cc..e360b65545 100644 --- a/src/vt/messaging/active.cc +++ b/src/vt/messaging/active.cc @@ -52,6 +52,7 @@ #include "vt/runnable/general.h" #include "vt/timing/timing.h" #include "vt/scheduler/priority.h" +#include "vt/runtime/mpi_access.h" namespace vt { namespace messaging { @@ -106,6 +107,8 @@ EventType ActiveMessenger::sendMsgBytesWithPut( NodeType const& dest, MsgSharedPtr const& base, MsgSizeType const& msg_size, TagType const& send_tag ) { + VT_ALLOW_MPI_CALLS; + auto msg = base.get(); auto const& is_term = envelopeIsTerm(msg->env); auto const& is_put = envelopeIsPut(msg->env); @@ -179,6 +182,8 @@ EventType ActiveMessenger::sendMsgBytes( NodeType const& dest, MsgSharedPtr const& base, MsgSizeType const& msg_size, TagType const& send_tag ) { + VT_ALLOW_MPI_CALLS; + auto const& msg = base.get(); auto const epoch = envelopeIsEpochType(msg->env) ? @@ -320,6 +325,8 @@ EventType ActiveMessenger::doMessageSend( ActiveMessenger::SendDataRetType ActiveMessenger::sendData( RDMA_GetType const& ptr, NodeType const& dest, TagType const& tag ) { + VT_ALLOW_MPI_CALLS; + auto const& data_ptr = std::get<0>(ptr); auto const& num_bytes = std::get<1>(ptr); auto const send_tag = tag == no_tag ? cur_direct_buffer_tag_++ : tag; @@ -429,6 +436,8 @@ bool ActiveMessenger::recvDataMsgBuffer( RDMA_ContinuationDeleteType next ) { if (not enqueue) { + VT_ALLOW_MPI_CALLS; + CountType num_probe_bytes; MPI_Status stat; int flag; diff --git a/src/vt/messaging/irecv_holder.h b/src/vt/messaging/irecv_holder.h index e4bf5fa773..11d2cff85c 100644 --- a/src/vt/messaging/irecv_holder.h +++ b/src/vt/messaging/irecv_holder.h @@ -48,6 +48,9 @@ #include "vt/config.h" #include "vt/configs/arguments/args.h" +// Unfortunate header leak for VT_ALLOW_MPI_CALLS +#include "vt/runtime/mpi_access.h" + #if backend_check_enabled(trace_enabled) #include "vt/trace/trace_headers.h" #endif @@ -94,6 +97,7 @@ struct IRecvHolder { */ template bool testAll(Callable c) { + VT_ALLOW_MPI_CALLS; # if backend_check_enabled(trace_enabled) std::size_t const holder_size_start = holder_.size(); diff --git a/src/vt/rdma/channel/rdma_channel.cc b/src/vt/rdma/channel/rdma_channel.cc index 7581feefd6..e731aca8b4 100644 --- a/src/vt/rdma/channel/rdma_channel.cc +++ b/src/vt/rdma/channel/rdma_channel.cc @@ -43,6 +43,7 @@ */ #include "vt/rdma/channel/rdma_channel.h" +#include "vt/runtime/mpi_access.h" #define PRINT_RDMA_OP_TYPE(OP) ((OP) == RDMA_TypeType::Get ? "GET" : "PUT") @@ -87,6 +88,8 @@ Channel::Channel( void Channel::initChannelGroup() { + VT_ALLOW_MPI_CALLS; + debug_print( rdma_channel, node, "channel: initChannelGroup: target={}, non_target={}, han={}\n", @@ -236,6 +239,8 @@ void Channel::writeDataToChannel( RDMA_PtrType const& ptr, ByteType const& ptr_num_bytes, ByteType const& offset ) { + VT_ALLOW_MPI_CALLS; + vtAssert(initialized_, "Channel must be initialized"); vtAssert(not is_target_, "The target can not write to this channel"); @@ -278,6 +283,8 @@ Channel::writeDataToChannel( void Channel::freeChannel() { + VT_ALLOW_MPI_CALLS; + if (locked_) { unlockChannelForOp(); } @@ -294,6 +301,8 @@ Channel::freeChannel() { void Channel::initChannelWindow() { + VT_ALLOW_MPI_CALLS; + debug_print( rdma_channel, node, "channel: create window: num_bytes={}\n", num_bytes_ diff --git a/src/vt/runtime/runtime.cc b/src/vt/runtime/runtime.cc index 786530bfb0..dbf630cd22 100644 --- a/src/vt/runtime/runtime.cc +++ b/src/vt/runtime/runtime.cc @@ -69,6 +69,7 @@ #include "vt/configs/error/stack_out.h" #include "vt/configs/error/pretty_print_stack.h" #include "vt/utils/memory/memory_usage.h" +#include "vt/runtime/mpi_access.h" #include #include @@ -1162,6 +1163,7 @@ void Runtime::setup() { // wait for all nodes to start up to initialize the runtime theCollective->barrierThen([this]{ + VT_ALLOW_MPI_CALLS; MPI_Barrier(theContext->getComm()); });