Skip to content

Commit

Permalink
push for test
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurelius84 committed Aug 17, 2021
1 parent c25c494 commit 73831b0
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 20 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ endif()

cc_test(init_test SRCS init_test.cc DEPS device_context)

cc_library(device_event SRCS device_event.cc DEPS place enforce device_context)
cc_library(device_event SRCS device_event.cc DEPS place enforce device_context op_registry)
cc_library(device_event_gpu SRCS device_event_gpu.cc DEPS device_event)


Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/platform/device_event.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

namespace paddle {
namespace platform {
#define PD_CX __attribute__((visibility("default")))

EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes];
EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes];
EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes];
PD_CX EventCreateFunction DeviceEvent::event_creator_[MaxDeviceTypes];
PD_CX EventRecordFunction DeviceEvent::event_recorder_[MaxDeviceTypes];
PD_CX EventQueryFunction DeviceEvent::event_querier_[MaxDeviceTypes];

} // namespace platform
} // namespace paddle
21 changes: 16 additions & 5 deletions paddle/fluid/platform/device_event.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
Expand Down Expand Up @@ -79,7 +80,7 @@ class DeviceEvent {
event_querier_[type_],
platform::errors::Unavailable(
"event_querier_[%d] shall not be nullptr.", type_));
event_querier_[type_](this);
return event_querier_[type_](this);
}

void InitEvent(std::shared_ptr<void> event) { event_ = event; }
Expand Down Expand Up @@ -114,14 +115,24 @@ struct EventCreateFunctionRegisterer {
explicit EventCreateFunctionRegisterer(EventCreateFunction func) {
auto type_idx = DeviceTypeToId(device_type);
DeviceEvent::event_creator_[type_idx] = func;
VLOG(2) << "register creator " << type_idx << " with "
<< DeviceEvent::event_creator_[type_idx];
}
void Touch() {}
};
#define REGISTER_EVENT_CREATE_FUNCTION(device_type, func) \
namespace { \
static EventCreateFunctionRegisterer<device_type> \
g_device_event_create_##type_idx(func); \

#define REGISTER_EVENT_CREATE_FUNCTION(device_type, func) \
static ::paddle::platform::EventCreateFunctionRegisterer<device_type> \
g_device_event_create_1(func); \
int touch_g_device_event_create_1() { \
g_device_event_create_1.Touch(); \
return 0; \
}

#define USE_EVENT(device_type) \
extern int touch_g_device_event_create_1(); \
UNUSED static int use_event_itself_1 = touch_g_device_event_create_1();

template <DeviceType device_type>
struct EventRecordFunctionRegisterer {
explicit EventRecordFunctionRegisterer(EventRecordFunction func) {
Expand Down
12 changes: 7 additions & 5 deletions paddle/fluid/platform/device_event_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/event.h"

#ifdef PADDLE_WITH_CUDA
namespace paddle {
namespace platform {
#ifdef PADDLE_WITH_CUDA
struct CUDADeviceEventWrapper {
explicit CUDADeviceEventWrapper(const DeviceOption& dev_opt) {
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -56,10 +56,12 @@ bool DeviceEventQueryCUDA(const DeviceEvent* event) {
return wrapper->inner_event_.Query();
}

REGISTER_EVENT_CREATE_FUNCTION(DeviceType::CUDA, DeviceEventCreateCUDA)
REGISTER_EVENT_RECORD_FUNCTION(DeviceType::CUDA, DeviceEventRecordCUDA)
REGISTER_EVENT_QUERY_FUNCTION(DeviceType::CUDA, DeviceEventQueryCUDA)
// REGISTER_EVENT_RECORD_FUNCTION(DeviceType::CUDA, DeviceEventRecordCUDA)
// REGISTER_EVENT_QUERY_FUNCTION(DeviceType::CUDA, DeviceEventQueryCUDA)

#endif
} // namespace platform
} // namespace paddle

using ::paddle::platform::DeviceType::CUDA;
REGISTER_EVENT_CREATE_FUNCTION(CUDA, paddle::platform::DeviceEventCreateCUDA)
#endif
19 changes: 13 additions & 6 deletions paddle/fluid/platform/device_event_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,28 @@
#include "glog/logging.h"
#include "gtest/gtest.h"

USE_EVENT(1);

#ifdef PADDLE_WITH_CUDA
Test(DeviceEvent, GPU) {
TEST(DeviceEvent, GPU) {
VLOG(1) << "In Test";
using paddle::platform::CUDAPlace;
using paddle::platform::DeviceOption;
using paddle::platform::DeviceEvent;
using paddle::platform::DeviceContextPool;
using paddle::platform::DeviceType;

auto& pool = DeviceContextPool::Instance();
auto place = CUDAPlace(0);
auto* context = pool.get(place);
DeviceOption dev_opt(place.device);
auto* context = pool.Get(place);
int device_type = static_cast<int>(DeviceType::CUDA);
DeviceOption dev_opt(device_type, place.device);

ASSERT_NE(context, nullptr);
DeviceEvent event(dev_opt);
event.Record(place, context);
bool status = event.Query();
ASSERT_EQ(status, true);
ASSERT_NE(event.GetEvent().get(), nullptr);
// event.Record(place, context);
// bool status = event.Query();
// ASSERT_EQ(status, true);
}
#endif

0 comments on commit 73831b0

Please sign in to comment.