Skip to content

Commit

Permalink
major cleanup of improved MPI implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremiah Wilke committed Mar 3, 2016
1 parent 8e3097e commit a03bb5e
Show file tree
Hide file tree
Showing 24 changed files with 277 additions and 312 deletions.
2 changes: 1 addition & 1 deletion sstmac/hardware/nic/nic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ nic::recv_message(const sst_message::ptr& msg)
switch (netmsg->type()) {
case network_message::rdma_get_request: {
netmsg->nic_reverse(network_message::rdma_get_payload);
netmsg->put_on_wire();
internode_send(netmsg);
finish_recv_req(msg);
break;
Expand Down Expand Up @@ -181,7 +182,6 @@ void
nic::ack_send(const network_message::ptr& payload)
{
if (payload->needs_ack()){
payload->put_on_wire();
network_message::ptr ack = payload->clone_injection_ack();
nic_debug("acking payload %p:%s with ack %p",
payload.get(), payload->to_string().c_str(), ack.get());
Expand Down
3 changes: 2 additions & 1 deletion sstmac/libraries/mpi/mpi_comm/mpi_comm_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
#include <sstmac/software/process/task_id.h>


namespace sumi {
namespace sstmac {
namespace sw {

class mpi_comm_data {
public:
Expand Down
2 changes: 1 addition & 1 deletion sumi
1 change: 0 additions & 1 deletion sumi-mpi/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ nobase_library_include_HEADERS = \
mpi_queue/mpi_queue_send_request.h \
mpi_queue/mpi_queue.h \
mpi_queue/mpi_queue_fwd.h \
mpi_queue/user_thread_mpi_queue.h \
mpi_protocol/mpi_protocol.h \
mpi_protocol/mpi_protocol_fwd.h \
mpi_types/mpi_type.h \
Expand Down
27 changes: 25 additions & 2 deletions sumi-mpi/mpi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <sstmac/common/messages/sleep_message.h>

#include <sumi-mpi/mpi_queue/mpi_queue.h>
#include <sumi-mpi/mpi_queue/user_thread_mpi_queue.h>

#include <sumi-mpi/mpi_api.h>
#include <sumi-mpi/mpi_api_persistent.h>
Expand Down Expand Up @@ -104,7 +103,7 @@ mpi_api::init_factory_params(sprockit::sim_parameters* params)
blocks the application [user] thread.;
}
*/
queue_ = new user_thread_mpi_queue;
queue_ = new mpi_queue;
queue_->init_sid(id_);
queue_->init_factory_params(queue_params);
queue_->set_api(this);
Expand Down Expand Up @@ -249,6 +248,30 @@ mpi_api::wtime()
return os_->now().sec();
}

const char*
mpi_api::op_str(MPI_Op op)
{
#define op_case(x) case x: return #x;
switch(op)
{
op_case(MPI_MAX);
op_case(MPI_MIN);
op_case(MPI_SUM);
op_case(MPI_PROD);
op_case(MPI_LAND);
op_case(MPI_BAND);
op_case(MPI_LOR);
op_case(MPI_BOR);
op_case(MPI_LXOR);
op_case(MPI_BXOR);
op_case(MPI_MAXLOC);
op_case(MPI_MINLOC);
op_case(MPI_REPLACE);
default:
return "CUSTOM";
}
}

std::string
mpi_api::type_str(MPI_Datatype mid)
{
Expand Down
7 changes: 7 additions & 0 deletions sumi-mpi/mpi_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <sstmac/software/process/software_id.h>
#include <sstmac/software/process/key.h>
#include <sstmac/software/process/pmi.h>
#include <sstmac/software/process/backtrace.h>

#include <sumi-mpi/mpi_queue/mpi_queue_fwd.h>
#include <sstmac/software/process/operating_system_fwd.h>
Expand Down Expand Up @@ -464,6 +465,9 @@ class mpi_api :
std::string
type_str(MPI_Datatype mid);

const char*
op_str(MPI_Op op);

mpi_comm*
get_comm(MPI_Comm comm);

Expand Down Expand Up @@ -510,6 +514,9 @@ class mpi_api :
get_keyval(int key);

private:
int
do_wait(MPI_Request *request, MPI_Status *status);

void
validate_mpi_collective(const char* name, MPI_Datatype sendtype, MPI_Datatype recvtype);

Expand Down
32 changes: 31 additions & 1 deletion sumi-mpi/mpi_api_collectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,19 @@ mpi_api::start_allgather(const void *sendbuf, void *recvbuf, int count, MPI_Data
int
mpi_api::allgather(int sendcount, MPI_Datatype sendtype, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
{
validate_mpi_collective("allgather", sendtype, recvtype);
return allgather(NULL, sendcount, sendtype, NULL, recvcount, recvtype, comm);
}

int
mpi_api::allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
{
SSTMACBacktrace("MPI_Allgather");
validate_mpi_collective("allgather", sendtype, recvtype);
mpi_api_debug(sprockit::dbg::mpi | sprockit::dbg::mpi_collective,
"MPI_Allgather(%d,%s,%d,%s,%s)",
sendcount, type_str(sendtype).c_str(),
recvcount, type_str(recvtype).c_str(),
comm_str(comm).c_str());
int tag = start_allgather(sendbuf, recvbuf, sendcount, sendtype, comm);
collective_progress_loop(collective::allgather, tag);
return MPI_SUCCESS;
Expand All @@ -99,7 +105,13 @@ mpi_api::start_alltoall(const void *sendbuf, void *recvbuf, int count, MPI_Datat
int
mpi_api::alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
{
SSTMACBacktrace("MPI_Alltoall");
validate_mpi_collective("alltoall", sendtype, recvtype);
mpi_api_debug(sprockit::dbg::mpi | sprockit::dbg::mpi_collective,
"MPI_Alltoall(%d,%s,%d,%s,%s)",
sendcount, type_str(sendtype).c_str(),
recvcount, type_str(recvtype).c_str(),
comm_str(comm).c_str());
int tag = start_alltoall(sendbuf, recvbuf, sendcount, sendtype, comm);
collective_progress_loop(collective::alltoall, tag);
return MPI_SUCCESS;
Expand All @@ -124,6 +136,10 @@ mpi_api::start_allreduce(const void *src, void *dst, int count, MPI_Datatype typ
int
mpi_api::allreduce(const void *src, void *dst, int count, MPI_Datatype type, MPI_Op op, MPI_Comm comm)
{
SSTMACBacktrace("MPI_Allreduce");
mpi_api_debug(sprockit::dbg::mpi | sprockit::dbg::mpi_collective,
"MPI_Allreduce(%d,%s,%s,%s)",
count, type_str(type).c_str(), op_str(op), comm_str(comm).c_str());
int tag = start_allreduce(src, dst, count, type, op, comm);
collective_progress_loop(collective::allreduce, tag);
return MPI_SUCCESS;
Expand All @@ -150,6 +166,8 @@ mpi_api::barrier(MPI_Comm comm)
{
SSTMACBacktrace("MPI_Barrier");
int tag = start_barrier(comm);
mpi_api_debug(sprockit::dbg::mpi, "MPI_Barrier(%s) on tag %d",
comm_str(comm).c_str(), int(tag));
collective_progress_loop(collective::barrier, tag);
return MPI_SUCCESS;
}
Expand Down Expand Up @@ -241,6 +259,10 @@ mpi_api::start_reduce(const void *src, void *dst, int count, MPI_Datatype type,
int
mpi_api::reduce(const void *src, void *dst, int count, MPI_Datatype type, MPI_Op op, int root, MPI_Comm comm)
{
SSTMACBacktrace("MPI_Reduce");
mpi_api_debug(sprockit::dbg::mpi | sprockit::dbg::mpi_collective,
"MPI_Reduce(%d,%s,%s,%d,%s)", count, type_str(type).c_str(),
op_str(op), int(root), comm_str(comm).c_str());
int tag = start_reduce(src, dst, count, type, op, root, comm);
collective_progress_loop(collective::reduce, tag);
return MPI_SUCCESS;
Expand Down Expand Up @@ -271,6 +293,10 @@ mpi_api::start_reduce_scatter(const void *src, void *dst, int *recvcnts, MPI_Dat
int
mpi_api::reduce_scatter(const void *src, void *dst, int *recvcnts, MPI_Datatype type, MPI_Op op, MPI_Comm comm)
{
SSTMACBacktrace("MPI_Reducescatter");
mpi_api_debug(sprockit::dbg::mpi | sprockit::dbg::mpi_collective,
"MPI_Reduce_scatter(<...>,%s,%s,%s)",
type_str(type).c_str(), op_str(op), comm_str(comm).c_str());
int tag = start_reduce_scatter(src, dst, recvcnts, type, op, comm);
collective_progress_loop(collective::reduce_scatter, tag);
return MPI_SUCCESS;
Expand Down Expand Up @@ -301,6 +327,10 @@ mpi_api::start_scan(const void *src, void *dst, int count, MPI_Datatype type, MP
int
mpi_api::scan(const void *src, void *dst, int count, MPI_Datatype type, MPI_Op op, MPI_Comm comm)
{
SSTMACBacktrace("MPI_Scan");
mpi_api_debug(sprockit::dbg::mpi | sprockit::dbg::mpi_collective,
"MPI_Scan(%d,%s,%s,%s)",
count, type_str(type).c_str(), op_str(op), comm_str(comm).c_str());
int tag = start_scan(src, dst, count, type, op, comm);
collective_progress_loop(collective::scan, tag);
return MPI_SUCCESS;
Expand Down
12 changes: 12 additions & 0 deletions sumi-mpi/mpi_api_comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ int
mpi_api::comm_dup(MPI_Comm input, MPI_Comm *output)
{
SSTMACBacktrace("MPI_Comm_dup");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Comm_dup(...)");
mpi_comm* inputPtr = get_comm(input);
mpi_comm* outputPtr = comm_factory_->comm_dup(inputPtr);
*output = add_comm_ptr(outputPtr);
Expand All @@ -19,6 +20,7 @@ int
mpi_api::comm_create(MPI_Comm input, MPI_Group group, MPI_Comm *output)
{
SSTMACBacktrace("MPI_Comm_create");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Comm_create(...)");
mpi_comm* inputPtr = get_comm(input);
mpi_group* groupPtr = get_group(group);
*output = add_comm_ptr(comm_factory_->comm_create(inputPtr, groupPtr));
Expand All @@ -30,6 +32,7 @@ mpi_api::cart_create(MPI_Comm comm_old, int ndims, const int dims[],
const int periods[], int reorder, MPI_Comm *comm_cart)
{
SSTMACBacktrace("MPI_Cart_create");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Cart_create(...)");
mpi_comm* incommPtr = get_comm(comm_old);
mpi_comm* outcommPtr = comm_factory_->create_cart(incommPtr, ndims, dims, periods, reorder);
*comm_cart = add_comm_ptr(outcommPtr);
Expand All @@ -41,6 +44,7 @@ mpi_api::cart_get(MPI_Comm comm, int maxdims, int dims[], int periods[],
int coords[])
{
SSTMACBacktrace("MPI_Cart_get");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Cart_get(...)");

mpi_comm* incommPtr = get_comm(comm);
mpi_comm_cart* c = safe_cast(mpi_comm_cart, incommPtr,
Expand All @@ -60,6 +64,7 @@ int
mpi_api::cartdim_get(MPI_Comm comm, int *ndims)
{
SSTMACBacktrace("MPI_Cartdim_get");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Cartdim_get(...)");
mpi_comm* incommPtr = get_comm(comm);
mpi_comm_cart* c = safe_cast(mpi_comm_cart, incommPtr,
"mpi_api::cartdim_get: mpi comm did not cast to mpi_comm_cart");
Expand All @@ -71,6 +76,7 @@ int
mpi_api::cart_rank(MPI_Comm comm, const int coords[], int *rank)
{
SSTMACBacktrace("MPI_Cart_rank");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Cart_rank(...)");
mpi_comm* incommPtr = get_comm(comm);
mpi_comm_cart* c = safe_cast(mpi_comm_cart, incommPtr,
"mpi_api::cart_rank: mpi comm did not cast to mpi_comm_cart");
Expand All @@ -83,6 +89,7 @@ mpi_api::cart_shift(MPI_Comm comm, int direction, int disp, int *rank_source,
int *rank_dest)
{
SSTMACBacktrace("MPI_Cart_shift");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Cart_shift(...)");
mpi_comm* incommPtr = get_comm(comm);
mpi_comm_cart* c = safe_cast(mpi_comm_cart, incommPtr,
"mpi_api::cart_shift: mpi comm did not cast to mpi_comm_cart");
Expand All @@ -95,6 +102,7 @@ int
mpi_api::cart_coords(MPI_Comm comm, int rank, int maxdims, int coords[])
{
SSTMACBacktrace("MPI_Cart_coords");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Cart_coords(...)");
mpi_comm* incommPtr = get_comm(comm);
mpi_comm_cart* c = safe_cast(mpi_comm_cart, incommPtr,
"mpi_api::cart_coords: mpi comm did not cast to mpi_comm_cart");
Expand All @@ -107,6 +115,9 @@ int
mpi_api::comm_split(MPI_Comm incomm, int color, int key, MPI_Comm *outcomm)
{
SSTMACBacktrace("MPI_Comm_split");
mpi_api_debug(sprockit::dbg::mpi,
"MPI_Comm_split(%s,%d,%d) enter",
comm_str(incomm).c_str(), color, key);
mpi_comm* incommPtr = get_comm(incomm);
mpi_comm* outcommPtr = comm_factory_->comm_split(incommPtr, color, key);
*outcomm = add_comm_ptr(outcommPtr);
Expand All @@ -117,6 +128,7 @@ int
mpi_api::comm_free(MPI_Comm* input)
{
SSTMACBacktrace("MPI_Comm_free");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Comm_free(...)");
mpi_comm* inputPtr = get_comm(*input);
delete inputPtr;
*input = MPI_COMM_NULL;
Expand Down
14 changes: 12 additions & 2 deletions sumi-mpi/mpi_api_send_recv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@ namespace sumi {
int
mpi_api::send(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm)
{
mpi_request* req = mpi_request::construct(default_key_category);
SSTMACBacktrace("MPI_Send");
mpi_comm* commPtr = get_comm(comm);
mpi_api_debug(sprockit::dbg::mpi | sprockit::dbg::mpi_pt2pt,
"MPI_Send(%d,%s,%d:%d,%s,%s)",
count, type_str(datatype).c_str(), int(dest), int(commPtr->peer_task(dest)),
tag_str(tag).c_str(), comm_str(comm).c_str());
mpi_request* req = mpi_request::construct(default_key_category);
queue_->send(req, count, datatype, dest, tag, commPtr, const_cast<void*>(buf));
queue_->progress_loop(req);
delete req;
Expand All @@ -19,8 +24,13 @@ mpi_api::send(const void *buf, int count, MPI_Datatype datatype, int dest, int t
int
mpi_api::isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)
{
mpi_request* req = mpi_request::construct(default_key_category);
SSTMACBacktrace("MPI_Isend");
mpi_comm* commPtr = get_comm(comm);
mpi_api_debug(sprockit::dbg::mpi | sprockit::dbg::mpi_request | sprockit::dbg::mpi_pt2pt,
"MPI_Isend(%d,%s,%d:%d,%s,%s)",
count, type_str(datatype).c_str(), int(dest), int(commPtr->peer_task(dest)),
tag_str(tag).c_str(), comm_str(comm).c_str());
mpi_request* req = mpi_request::construct(default_key_category);
queue_->send(req, count, datatype, dest, tag, commPtr, const_cast<void*>(buf));
*request = add_request_ptr(req);
return MPI_SUCCESS;
Expand Down
3 changes: 3 additions & 0 deletions sumi-mpi/mpi_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ namespace sumi {
bool
mpi_api::test(MPI_Request *request, MPI_Status *status)
{
SSTMACBacktrace("MPI_Test");
mpi_api_debug(sprockit::dbg::mpi | sprockit::dbg::mpi_request, "MPI_Test(...)");

if (*request == MPI_REQUEST_NULL){
return true;
}
Expand Down
12 changes: 6 additions & 6 deletions sumi-mpi/mpi_api_vcollectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ mpi_api::start_gatherv(const void *sendbuf, void *recvbuf, int sendcount, const
int
mpi_api::gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int *recvcounts, const int *displs, MPI_Datatype recvtype, int root, MPI_Comm comm)
{
SSTMACBacktrace("MPI_Scatterv");
SSTMACBacktrace("MPI_Gatherv");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Gatherv(%d,%s,<...>,%s,%d,%s)",
sendcount, type_str(sendtype).c_str(),
type_str(recvtype).c_str(),
Expand Down Expand Up @@ -115,6 +115,11 @@ mpi_api::start_scatterv(const void *sendbuf, void *recvbuf, const int* sendcount
int
mpi_api::scatterv(const void *sendbuf, const int *sendcounts, const int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
{
SSTMACBacktrace("MPI_Scatterv");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Scatterv(<...>,%s,%d,%s,%d,%s)",
type_str(sendtype).c_str(),
recvcount, type_str(recvtype).c_str(),
int(root), comm_str(comm).c_str());
validate_mpi_collective("alltoallv", sendtype, recvtype);
int tag = start_scatterv(sendbuf, recvbuf, sendcounts, displs, recvcount, sendtype, root, comm);
collective_progress_loop(collective::scatterv, tag);
Expand All @@ -124,11 +129,6 @@ mpi_api::scatterv(const void *sendbuf, const int *sendcounts, const int *displs,
int
mpi_api::scatterv(const int *sendcounts, MPI_Datatype sendtype, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
{
SSTMACBacktrace("MPI_Scatterv");
mpi_api_debug(sprockit::dbg::mpi, "MPI_Scatterv(<...>,%s,%d,%s,%d,%s)",
type_str(sendtype).c_str(),
recvcount, type_str(recvtype).c_str(),
int(root), comm_str(comm).c_str());
return scatterv(NULL, sendcounts, NULL, sendtype, NULL, recvcount, recvtype, root, comm);
}

Expand Down
Loading

0 comments on commit a03bb5e

Please sign in to comment.