Skip to content

Commit

Permalink
#510 allow internal VT calls to access MPI
Browse files Browse the repository at this point in the history
- Scoped grants have to be enabled at all usage sites
  and cannot extend into the scheduler itself.
  • Loading branch information
pnstickne committed Apr 28, 2020
1 parent 278527e commit f0dd2e0
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/collection/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
*/

#include <vt/transport.h>
#include "vt/runtime/mpi_access.h"

#include <vector>

Expand Down Expand Up @@ -206,9 +207,13 @@ vt::NodeType my_map(IndexT* idx, IndexT* max_idx, vt::NodeType num_nodes) {

// group-targeted handler for the sub-solve
/*static*/ void SubSolveInfo::subSolveHandler(SubSolveMsg* msg) {
// Normally MPI calls are NOT ALLOWED inside user code.
VT_ALLOW_MPI_CALLS;

vt::NodeType this_node = vt::theContext()->getNode();
vt::NodeType num_nodes = vt::theContext()->getNumNodes();


auto const group_id = vt::envelopeGetGroup(msg->env);
MPI_Comm sub_comm = vt::theGroup()->getGroupComm(group_id);
int sub_size = 0, sub_rank = 0;
Expand Down
4 changes: 4 additions & 0 deletions src/vt/event/event_record.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

#include "vt/event/event.h"
#include "vt/event/event_record.h"
#include "vt/runtime/mpi_access.h"

#include <memory>

Expand All @@ -53,6 +54,7 @@ namespace vt { namespace event {

EventRecord::EventRecord(EventRecordType const& type, EventType const& id)
: event_id_(id), type_(type) {
VT_ALLOW_MPI_CALLS;

switch (type) {
case EventRecordType::MPI_EventRecord:
Expand All @@ -79,6 +81,8 @@ EventRecord::EventRecord(EventRecordType const& type, EventType const& id)
}

bool EventRecord::testMPIEventReady() {
VT_ALLOW_MPI_CALLS;

int flag = 0;
MPI_Request* req = getRequest();
MPI_Status stat;
Expand Down
5 changes: 5 additions & 0 deletions src/vt/group/collective/group_info_collective.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "vt/collective/tree/tree.h"
#include "vt/collective/collective_alg.h"
#include "vt/collective/collective_ops.h"
#include "vt/runtime/mpi_access.h"

#include <memory>
#include <set>
Expand Down Expand Up @@ -101,13 +102,17 @@ MPI_Comm InfoColl::getComm() const {
}

void InfoColl::freeComm() {
VT_ALLOW_MPI_CALLS;

if (mpi_group_comm != MPI_COMM_WORLD) {
MPI_Comm_free(&mpi_group_comm);
mpi_group_comm = MPI_COMM_WORLD;
}
}

void InfoColl::setupCollective() {
VT_ALLOW_MPI_CALLS;

auto const& this_node = theContext()->getNode();
auto const& num_nodes = theContext()->getNumNodes();
auto const& group_ = getGroupID();
Expand Down
9 changes: 9 additions & 0 deletions src/vt/messaging/active.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "vt/runnable/general.h"
#include "vt/timing/timing.h"
#include "vt/scheduler/priority.h"
#include "vt/runtime/mpi_access.h"

namespace vt { namespace messaging {

Expand Down Expand Up @@ -106,6 +107,8 @@ EventType ActiveMessenger::sendMsgBytesWithPut(
NodeType const& dest, MsgSharedPtr<BaseMsgType> const& base,
MsgSizeType const& msg_size, TagType const& send_tag
) {
VT_ALLOW_MPI_CALLS;

auto msg = base.get();
auto const& is_term = envelopeIsTerm(msg->env);
auto const& is_put = envelopeIsPut(msg->env);
Expand Down Expand Up @@ -179,6 +182,8 @@ EventType ActiveMessenger::sendMsgBytes(
NodeType const& dest, MsgSharedPtr<BaseMsgType> const& base,
MsgSizeType const& msg_size, TagType const& send_tag
) {
VT_ALLOW_MPI_CALLS;

auto const& msg = base.get();

auto const epoch = envelopeIsEpochType(msg->env) ?
Expand Down Expand Up @@ -320,6 +325,8 @@ EventType ActiveMessenger::doMessageSend(
ActiveMessenger::SendDataRetType ActiveMessenger::sendData(
RDMA_GetType const& ptr, NodeType const& dest, TagType const& tag
) {
VT_ALLOW_MPI_CALLS;

auto const& data_ptr = std::get<0>(ptr);
auto const& num_bytes = std::get<1>(ptr);
auto const send_tag = tag == no_tag ? cur_direct_buffer_tag_++ : tag;
Expand Down Expand Up @@ -429,6 +436,8 @@ bool ActiveMessenger::recvDataMsgBuffer(
RDMA_ContinuationDeleteType next
) {
if (not enqueue) {
VT_ALLOW_MPI_CALLS;

CountType num_probe_bytes;
MPI_Status stat;
int flag;
Expand Down
4 changes: 4 additions & 0 deletions src/vt/messaging/irecv_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
#include "vt/config.h"
#include "vt/configs/arguments/args.h"

// Unfortunate header leak for VT_ALLOW_MPI_CALLS
#include "vt/runtime/mpi_access.h"

#if backend_check_enabled(trace_enabled)
#include "vt/trace/trace_headers.h"
#endif
Expand Down Expand Up @@ -94,6 +97,7 @@ struct IRecvHolder {
*/
template <typename Callable>
bool testAll(Callable c) {
VT_ALLOW_MPI_CALLS;

# if backend_check_enabled(trace_enabled)
std::size_t const holder_size_start = holder_.size();
Expand Down
9 changes: 9 additions & 0 deletions src/vt/rdma/channel/rdma_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
*/

#include "vt/rdma/channel/rdma_channel.h"
#include "vt/runtime/mpi_access.h"

#define PRINT_RDMA_OP_TYPE(OP) ((OP) == RDMA_TypeType::Get ? "GET" : "PUT")

Expand Down Expand Up @@ -87,6 +88,8 @@ Channel::Channel(

void
Channel::initChannelGroup() {
VT_ALLOW_MPI_CALLS;

debug_print(
rdma_channel, node,
"channel: initChannelGroup: target={}, non_target={}, han={}\n",
Expand Down Expand Up @@ -236,6 +239,8 @@ void
Channel::writeDataToChannel(
RDMA_PtrType const& ptr, ByteType const& ptr_num_bytes, ByteType const& offset
) {
VT_ALLOW_MPI_CALLS;

vtAssert(initialized_, "Channel must be initialized");
vtAssert(not is_target_, "The target can not write to this channel");

Expand Down Expand Up @@ -278,6 +283,8 @@ Channel::writeDataToChannel(

void
Channel::freeChannel() {
VT_ALLOW_MPI_CALLS;

if (locked_) {
unlockChannelForOp();
}
Expand All @@ -294,6 +301,8 @@ Channel::freeChannel() {

void
Channel::initChannelWindow() {
VT_ALLOW_MPI_CALLS;

debug_print(
rdma_channel, node,
"channel: create window: num_bytes={}\n", num_bytes_
Expand Down
2 changes: 2 additions & 0 deletions src/vt/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "vt/configs/error/stack_out.h"
#include "vt/configs/error/pretty_print_stack.h"
#include "vt/utils/memory/memory_usage.h"
#include "vt/runtime/mpi_access.h"

#include <memory>
#include <iostream>
Expand Down Expand Up @@ -1162,6 +1163,7 @@ void Runtime::setup() {

// wait for all nodes to start up to initialize the runtime
theCollective->barrierThen([this]{
VT_ALLOW_MPI_CALLS;
MPI_Barrier(theContext->getComm());
});

Expand Down

0 comments on commit f0dd2e0

Please sign in to comment.