Skip to content

Commit

Permalink
[native] Ensure all ExchangeSource instances are created via make_sha…
Browse files Browse the repository at this point in the history
…red().
  • Loading branch information
spershin committed Nov 8, 2023
1 parent 3512441 commit e4d8fc1
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 30 deletions.
14 changes: 6 additions & 8 deletions presto-native-execution/presto_cpp/main/PrestoExchangeSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@

#include "presto_cpp/main/QueryContextManager.h"
#include "presto_cpp/main/common/Counters.h"
#include "presto_cpp/presto_protocol/presto_protocol.h"
#include "velox/common/base/Exceptions.h"
#include "velox/common/base/StatsReporter.h"
#include "velox/common/testutil/TestValue.h"
#include "velox/exec/Operator.h"

using namespace facebook::velox;

Expand All @@ -46,7 +44,7 @@ std::string extractTaskId(const std::string& path) {

void onFinalFailure(
const std::string& errorMessage,
std::shared_ptr<exec::ExchangeQueue> queue) {
const std::shared_ptr<exec::ExchangeQueue>& queue) {
VLOG(1) << errorMessage;

queue->setError(errorMessage);
Expand All @@ -73,7 +71,7 @@ std::string bodyAsString(
PrestoExchangeSource::PrestoExchangeSource(
const folly::Uri& baseUri,
int destination,
std::shared_ptr<exec::ExchangeQueue> queue,
const std::shared_ptr<exec::ExchangeQueue>& queue,
memory::MemoryPool* pool,
folly::CPUThreadPoolExecutor* driverExecutor,
folly::IOThreadPoolExecutor* httpExecutor,
Expand Down Expand Up @@ -460,15 +458,15 @@ std::shared_ptr<PrestoExchangeSource> PrestoExchangeSource::getSelfPtr() {
}

// static
std::unique_ptr<exec::ExchangeSource> PrestoExchangeSource::create(
std::shared_ptr<exec::ExchangeSource> PrestoExchangeSource::create(
const std::string& url,
int destination,
std::shared_ptr<exec::ExchangeQueue> queue,
const std::shared_ptr<exec::ExchangeQueue>& queue,
memory::MemoryPool* pool,
folly::CPUThreadPoolExecutor* driverExecutor,
folly::IOThreadPoolExecutor* httpExecutor) {
if (strncmp(url.c_str(), "http://", 7) == 0) {
return std::make_unique<PrestoExchangeSource>(
return std::make_shared<PrestoExchangeSource>(
folly::Uri(url),
destination,
queue,
Expand All @@ -480,7 +478,7 @@ std::unique_ptr<exec::ExchangeSource> PrestoExchangeSource::create(
const auto clientCertAndKeyPath =
systemConfig->httpsClientCertAndKeyPath().value_or("");
const auto ciphers = systemConfig->httpsSupportedCiphers();
return std::make_unique<PrestoExchangeSource>(
return std::make_shared<PrestoExchangeSource>(
folly::Uri(url),
destination,
queue,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class PrestoExchangeSource : public velox::exec::ExchangeSource {
PrestoExchangeSource(
const folly::Uri& baseUri,
int destination,
std::shared_ptr<velox::exec::ExchangeQueue> queue,
const std::shared_ptr<velox::exec::ExchangeQueue>& queue,
velox::memory::MemoryPool* pool,
folly::CPUThreadPoolExecutor* driverExecutor,
folly::IOThreadPoolExecutor* httpExecutor,
Expand Down Expand Up @@ -100,10 +100,10 @@ class PrestoExchangeSource : public velox::exec::ExchangeSource {
uint32_t maxBytes,
uint32_t maxWaitSeconds) override;

static std::unique_ptr<ExchangeSource> create(
static std::shared_ptr<ExchangeSource> create(
const std::string& url,
int destination,
std::shared_ptr<velox::exec::ExchangeQueue> queue,
const std::shared_ptr<velox::exec::ExchangeQueue>& queue,
velox::memory::MemoryPool* pool,
folly::CPUThreadPoolExecutor* cpuExecutor,
folly::IOThreadPoolExecutor* ioExecutor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include "presto_cpp/main/common/Configs.h"
#include "presto_cpp/main/operators/BroadcastExchangeSource.h"
#include "presto_cpp/main/operators/BroadcastFactory.h"

namespace facebook::presto::operators {

Expand Down Expand Up @@ -68,11 +67,11 @@ folly::F14FastMap<std::string, int64_t> BroadcastExchangeSource::stats() const {
}

// static
std::unique_ptr<exec::ExchangeSource>
std::shared_ptr<exec::ExchangeSource>
BroadcastExchangeSource::createExchangeSource(
const std::string& url,
int destination,
std::shared_ptr<exec::ExchangeQueue> queue,
const std::shared_ptr<exec::ExchangeQueue>& queue,
memory::MemoryPool* pool) {
if (::strncmp(url.c_str(), "batch://", 8) != 0) {
return nullptr;
Expand All @@ -95,10 +94,10 @@ BroadcastExchangeSource::createExchangeSource(
}

auto fileSystemBroadcast = BroadcastFactory(broadcastFileInfo->filePath_);
return std::make_unique<BroadcastExchangeSource>(
return std::make_shared<BroadcastExchangeSource>(
uri.host(),
destination,
std::move(queue),
queue,
fileSystemBroadcast.createReader(std::move(broadcastFileInfo), pool),
pool);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BroadcastExchangeSource : public velox::exec::ExchangeSource {
BroadcastExchangeSource(
const std::string& taskId,
int destination,
std::shared_ptr<velox::exec::ExchangeQueue> queue,
const std::shared_ptr<velox::exec::ExchangeQueue>& queue,
const std::shared_ptr<BroadcastFileReader>& reader,
velox::memory::MemoryPool* pool)
: ExchangeSource(taskId, destination, queue, pool), reader_(reader) {}
Expand All @@ -51,10 +51,10 @@ class BroadcastExchangeSource : public velox::exec::ExchangeSource {

/// Url format for this exchange source:
/// batch://<taskid>?broadcastInfo={fileInfos:[<fileInfo>]}.
static std::unique_ptr<ExchangeSource> createExchangeSource(
static std::shared_ptr<ExchangeSource> createExchangeSource(
const std::string& url,
int destination,
std::shared_ptr<velox::exec::ExchangeQueue> queue,
const std::shared_ptr<velox::exec::ExchangeQueue>& queue,
velox::memory::MemoryPool* pool);

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ std::optional<std::string> getSerializedShuffleInfo(folly::Uri& uri) {
} // namespace

// static
std::unique_ptr<velox::exec::ExchangeSource>
std::shared_ptr<velox::exec::ExchangeSource>
UnsafeRowExchangeSource::createExchangeSource(
const std::string& url,
int32_t destination,
std::shared_ptr<velox::exec::ExchangeQueue> queue,
const std::shared_ptr<velox::exec::ExchangeQueue>& queue,
velox::memory::MemoryPool* FOLLY_NONNULL pool) {
if (::strncmp(url.c_str(), "batch://", 8) != 0) {
return nullptr;
Expand All @@ -113,10 +113,10 @@ UnsafeRowExchangeSource::createExchangeSource(
"shuffle.name is not provided in config.properties to create a shuffle "
"interface.");
auto shuffleFactory = ShuffleInterfaceFactory::factory(shuffleName);
return std::make_unique<UnsafeRowExchangeSource>(
return std::make_shared<UnsafeRowExchangeSource>(
uri.host(),
destination,
std::move(queue),
queue,
shuffleFactory->createReader(
serializedShuffleInfo.value(), destination, pool),
pool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class UnsafeRowExchangeSource : public velox::exec::ExchangeSource {
UnsafeRowExchangeSource(
const std::string& taskId,
int destination,
std::shared_ptr<velox::exec::ExchangeQueue> queue,
const std::shared_ptr<velox::exec::ExchangeQueue>& queue,
const std::shared_ptr<ShuffleReader>& shuffle,
velox::memory::MemoryPool* FOLLY_NONNULL pool)
: ExchangeSource(taskId, destination, queue, pool), shuffle_(shuffle) {}
Expand All @@ -48,10 +48,10 @@ class UnsafeRowExchangeSource : public velox::exec::ExchangeSource {

/// url needs to follow below format:
/// batch://<taskid>?shuffleInfo=<serialized-shuffle-info>
static std::unique_ptr<velox::exec::ExchangeSource> createExchangeSource(
static std::shared_ptr<velox::exec::ExchangeSource> createExchangeSource(
const std::string& url,
int32_t destination,
std::shared_ptr<velox::exec::ExchangeQueue> queue,
const std::shared_ptr<velox::exec::ExchangeQueue>& queue,
velox::memory::MemoryPool* FOLLY_NONNULL pool);

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,17 @@ void registerExchangeSource(const std::string& shuffleName) {
[shuffleName](
const std::string& taskId,
int destination,
std::shared_ptr<exec::ExchangeQueue> queue,
const std::shared_ptr<exec::ExchangeQueue>& queue,
memory::MemoryPool* FOLLY_NONNULL pool)
-> std::unique_ptr<exec::ExchangeSource> {
-> std::shared_ptr<exec::ExchangeSource> {
if (strncmp(taskId.c_str(), "batch://", 8) == 0) {
auto uri = folly::Uri(taskId);
for (auto& pair : uri.getQueryParams()) {
if (pair.first == "shuffleInfo") {
return std::make_unique<UnsafeRowExchangeSource>(
return std::make_shared<UnsafeRowExchangeSource>(
taskId,
destination,
std::move(queue),
queue,
ShuffleInterfaceFactory::factory(shuffleName)
->createReader(pair.second, destination, pool),
pool);
Expand Down

0 comments on commit e4d8fc1

Please sign in to comment.