diff --git a/src/vt/objgroup/proxy/proxy_objgroup.impl.h b/src/vt/objgroup/proxy/proxy_objgroup.impl.h index a0779cf7fd..283338d695 100644 --- a/src/vt/objgroup/proxy/proxy_objgroup.impl.h +++ b/src/vt/objgroup/proxy/proxy_objgroup.impl.h @@ -41,6 +41,8 @@ //@HEADER */ +#include "vt/group/region/group_region.h" +#include "vt/termination/term_common.h" #if !defined INCLUDED_VT_OBJGROUP_PROXY_PROXY_OBJGROUP_IMPL_H #define INCLUDED_VT_OBJGROUP_PROXY_PROXY_OBJGROUP_IMPL_H @@ -139,24 +141,38 @@ Proxy::multicast(GroupType type, Params&&... params) const{ template template typename Proxy::PendingSendType Proxy::multicast( - group::region::Region::RegionUPtrType&& nodes, Params&&... params) const { + group::region::Region::RegionUPtrType&& nodes, Params&&... params +) const { vtAssert( not dynamic_cast(nodes.get()), - "multicast: range of nodes is not supported for ShallowList!" - ); + "multicast: range of nodes is not supported for ShallowList!"); nodes->sort(); auto& range = nodes->makeList(); auto groupID = theGroup()->GetTempGroupForRange(range); if (!groupID.has_value()) { - groupID = theGroup()->newGroup( - std::move(nodes), []([[maybe_unused]] GroupType type) {} - ); - theGroup()->AddNewTempGroup(range, groupID.value()); + return typename Proxy::PendingSendType{ + theTerm()->getCurrentEpoch(), + [nodes_range = range, this, + args = std::make_tuple(std::forward(params)...)] { + std::apply( + [&, this](auto&&... unpackedArgs) { + auto id = + theGroup()->newGroup(std::make_unique(nodes_range), [&](GroupType type) { + multicast( + type, std::forward(unpackedArgs)...); + }); + theGroup()->AddNewTempGroup(nodes_range, id); + }, + std::move(args)); + }}; + } else { + return multicast(groupID.value(), std::forward(params)...); } - return multicast(groupID.value(), std::forward(params)...); + // Silence nvcc warning (no longer needed for CUDA 11.7 and up) + return typename Proxy::PendingSendType{std::nullptr_t{}}; } template