From c9b9df22ab42be37aef908bfd445bceb0ed2c4d9 Mon Sep 17 00:00:00 2001 From: Rijin-N Date: Thu, 14 Nov 2024 10:17:47 +0530 Subject: [PATCH] [native] Add Arrow Flight Connector The native Arrow Flight connector can be used to connect to any Arrow Flight enabled Data Source. The metadata layer is handled by the Presto coordinator and does not need to be re-implemented in C++. Any Java connector that inherits from `presto-base-arrow-flight` can use this connector as it's counterpart for the Prestissimo layer. Different Arrow-Flight enabled data sources can differ in authentication styles. A plugin-style interface is provided to handle such cases with custom authentication code by extending `arrow_flight::auth::Authenticator`. RFC: https://github.com/prestodb/rfcs/blob/main/RFC-0004-arrow-flight-connector.md#prestissimo-implementation Co-authored-by: Ashwin Kumar Co-authored-by: Rijin-N Co-authored-by: Nischay Yadav --- presto-native-execution/CMakeLists.txt | 2 + presto-native-execution/Makefile | 3 + presto-native-execution/README.md | 9 + .../presto_cpp/main/CMakeLists.txt | 2 + .../presto_cpp/main/PrestoServer.cpp | 3 + .../presto_cpp/main/SystemConnector.cpp | 3 +- .../presto_cpp/main/SystemConnector.h | 4 +- .../presto_cpp/main/TaskManager.cpp | 5 +- .../presto_cpp/main/TaskManager.h | 1 + .../presto_cpp/main/connectors/CMakeLists.txt | 20 + .../main/connectors/ConnectorRegistration.cpp | 40 ++ .../main/connectors/ConnectorRegistration.h | 20 + .../arrow_flight/ArrowFlightConnector.cpp | 184 +++++++++ .../arrow_flight/ArrowFlightConnector.h | 198 ++++++++++ .../ArrowPrestoToVeloxConnector.cpp | 64 ++++ .../ArrowPrestoToVeloxConnector.h | 49 +++ .../connectors/arrow_flight/CMakeLists.txt | 43 +++ .../connectors/arrow_flight/FlightConfig.cpp | 45 +++ .../connectors/arrow_flight/FlightConfig.h | 56 +++ .../main/connectors/arrow_flight/Macros.h | 50 +++ .../arrow_flight/auth/Authenticator.cpp | 48 +++ .../arrow_flight/auth/Authenticator.h | 85 +++++ .../arrow_flight/auth/CMakeLists.txt | 15 + .../tests/ArrowFlightConnectorAuthTest.cpp | 251 +++++++++++++ .../ArrowFlightConnectorDataTypeTest.cpp | 355 ++++++++++++++++++ .../tests/ArrowFlightConnectorTest.cpp | 197 ++++++++++ .../tests/ArrowFlightConnectorTlsTest.cpp | 138 +++++++ .../arrow_flight/tests/CMakeLists.txt | 44 +++ .../arrow_flight/tests/FlightConfigTest.cpp | 49 +++ .../tests/TestFlightServerTest.cpp | 84 +++++ .../arrow_flight/tests/data/README.md | 6 + .../tests/data/generate_tls_certs.sh | 40 ++ .../arrow_flight/tests/data/tls_certs/ca.crt | 22 ++ .../tests/data/tls_certs/server.crt | 22 ++ .../tests/data/tls_certs/server.key | 28 ++ .../arrow_flight/tests/utils/CMakeLists.txt | 19 + .../tests/utils/FlightConnectorTestBase.cpp | 85 +++++ .../tests/utils/FlightConnectorTestBase.h | 88 +++++ .../tests/utils/FlightPlanBuilder.cpp | 43 +++ .../tests/utils/FlightPlanBuilder.h | 35 ++ .../tests/utils/TestFlightServer.cpp | 34 ++ .../tests/utils/TestFlightServer.h | 48 +++ .../arrow_flight/tests/utils/Utils.cpp | 90 +++++ .../arrow_flight/tests/utils/Utils.h | 54 +++ .../main/types/PrestoToVeloxConnector.cpp | 9 +- .../main/types/PrestoToVeloxConnector.h | 16 +- .../main/types/PrestoToVeloxSplit.cpp | 6 +- .../main/types/PrestoToVeloxSplit.h | 3 +- .../presto_cpp/presto_protocol/Makefile | 9 + .../ArrowFlightConnectorProtocol.h | 29 ++ .../presto_protocol-json-cpp.mustache | 150 ++++++++ .../presto_protocol-json-hpp.mustache | 76 ++++ .../presto_protocol_arrow_flight.cpp | 215 +++++++++++ .../presto_protocol_arrow_flight.h | 82 ++++ .../presto_protocol_arrow_flight.yml | 40 ++ .../special/ArrowTransactionHandle.cpp.inc | 30 ++ .../special/ArrowTransactionHandle.hpp.inc | 28 ++ .../core/presto_protocol_core.cpp | 1 + .../core/presto_protocol_core.h | 2 +- .../core/presto_protocol_core.yml | 4 + .../ConnectorTransactionHandle.cpp.inc | 1 + .../presto_protocol/presto_protocol.cpp | 1 + .../presto_protocol/presto_protocol.h | 1 + .../presto_protocol/presto_protocol.yml | 8 + .../scripts/setup-adapters.sh | 66 ++++ 65 files changed, 3444 insertions(+), 14 deletions(-) create mode 100644 presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/FlightConfigTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestFlightServerTest.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md create mode 100755 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp create mode 100644 presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc create mode 100644 presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc diff --git a/presto-native-execution/CMakeLists.txt b/presto-native-execution/CMakeLists.txt index d5001dde70a74..6ac789b73e3a1 100644 --- a/presto-native-execution/CMakeLists.txt +++ b/presto-native-execution/CMakeLists.txt @@ -63,6 +63,8 @@ option(PRESTO_ENABLE_TESTING "Enable tests" ON) option(PRESTO_ENABLE_JWT "Enable JWT (JSON Web Token) authentication" OFF) +option(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR "Enable Arrow Flight connector" OFF) + # Set all Velox options below add_compile_definitions(FOLLY_HAVE_INT128_T=1) diff --git a/presto-native-execution/Makefile b/presto-native-execution/Makefile index f3fb5f709f4d5..7cd37a714867c 100644 --- a/presto-native-execution/Makefile +++ b/presto-native-execution/Makefile @@ -45,6 +45,9 @@ endif ifneq ($(PRESTO_MEMORY_CHECKER_TYPE),) EXTRA_CMAKE_FLAGS += -DPRESTO_MEMORY_CHECKER_TYPE=$(PRESTO_MEMORY_CHECKER_TYPE) endif +ifneq ($(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR),) + EXTRA_CMAKE_FLAGS += -DPRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR=$(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) +endif CMAKE_FLAGS := -DTREAT_WARNINGS_AS_ERRORS=${TREAT_WARNINGS_AS_ERRORS} CMAKE_FLAGS += -DENABLE_ALL_WARNINGS=${ENABLE_WALL} diff --git a/presto-native-execution/README.md b/presto-native-execution/README.md index cccebfcfb8d03..36a8659f6cbac 100644 --- a/presto-native-execution/README.md +++ b/presto-native-execution/README.md @@ -115,6 +115,15 @@ follow these steps: * For development, use `make debug` to build a non-optimized debug version. * Use `make unittest` to build and run tests. +#### Arrow Flight Connector +To enable Arrow Flight connector support, set the environment variable: +`PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR = "ON"`. + +The Arrow Flight connector requires the Arrow Flight library. You can install this dependency +by running the following script from the `presto/presto-native-execution` directory: + +`./scripts/setup-adapters.sh arrow_flight` + ### Makefile Targets A reminder of the available Makefile targets can be obtained using `make help` ``` diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index b00fee32927b6..91a65c10f8e6b 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(types) add_subdirectory(http) add_subdirectory(common) add_subdirectory(thrift) +add_subdirectory(connectors) add_library( presto_server_lib @@ -48,6 +49,7 @@ target_link_libraries( presto_common presto_exception presto_function_metadata + presto_connector presto_http presto_operators presto_velox_conversion diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index f79f1b5ee1fce..9288df00c6a9d 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -26,6 +26,7 @@ #include "presto_cpp/main/common/ConfigReader.h" #include "presto_cpp/main/common/Counters.h" #include "presto_cpp/main/common/Utils.h" +#include "presto_cpp/main/connectors/ConnectorRegistration.h" #include "presto_cpp/main/http/HttpConstants.h" #include "presto_cpp/main/http/filters/AccessLogFilter.h" #include "presto_cpp/main/http/filters/HttpEndpointLatencyFilter.h" @@ -280,6 +281,8 @@ void PrestoServer::run() { registerPrestoToVeloxConnector( std::make_unique("$system@system")); + presto::connector::registerAllPrestoConnectors(); + velox::exec::OutputBufferManager::initialize({}); initializeVeloxMemory(); initializeThreadPools(); diff --git a/presto-native-execution/presto_cpp/main/SystemConnector.cpp b/presto-native-execution/presto_cpp/main/SystemConnector.cpp index 7622d203e8689..d6d300384c7db 100644 --- a/presto-native-execution/presto_cpp/main/SystemConnector.cpp +++ b/presto-native-execution/presto_cpp/main/SystemConnector.cpp @@ -351,7 +351,8 @@ std::unique_ptr SystemPrestoToVeloxConnector::toVeloxSplit( const protocol::ConnectorId& catalogId, const protocol::ConnectorSplit* const connectorSplit, - const protocol::SplitContext* splitContext) const { + const protocol::SplitContext* splitContext, + const std::map& extraCredentials) const { auto systemSplit = dynamic_cast(connectorSplit); VELOX_CHECK_NOT_NULL( systemSplit, "Unexpected split type {}", connectorSplit->_type); diff --git a/presto-native-execution/presto_cpp/main/SystemConnector.h b/presto-native-execution/presto_cpp/main/SystemConnector.h index 52d9df595f736..c615fa6b8a917 100644 --- a/presto-native-execution/presto_cpp/main/SystemConnector.h +++ b/presto-native-execution/presto_cpp/main/SystemConnector.h @@ -185,7 +185,9 @@ class SystemPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxSplit( const protocol::ConnectorId& catalogId, const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const final; + const protocol::SplitContext* splitContext, + const std::map& extraCredentials = {}) + const final; std::unique_ptr toVeloxColumnHandle( const protocol::ColumnHandle* column, diff --git a/presto-native-execution/presto_cpp/main/TaskManager.cpp b/presto-native-execution/presto_cpp/main/TaskManager.cpp index 237ca4c334e9e..d7a204926d244 100644 --- a/presto-native-execution/presto_cpp/main/TaskManager.cpp +++ b/presto-native-execution/presto_cpp/main/TaskManager.cpp @@ -472,6 +472,7 @@ std::unique_ptr TaskManager::createOrUpdateTask( updateRequest.sources, updateRequest.outputIds, summarize, + updateRequest.extraCredentials, std::move(queryCtx), startProcessCpuTime); } @@ -493,6 +494,7 @@ std::unique_ptr TaskManager::createOrUpdateBatchTask( updateRequest.sources, updateRequest.outputIds, summarize, + updateRequest.extraCredentials, std::move(queryCtx), startProcessCpuTime); } @@ -503,6 +505,7 @@ std::unique_ptr TaskManager::createOrUpdateTaskImpl( const std::vector& sources, const protocol::OutputBuffers& outputBuffers, bool summarize, + const std::map& extraCredentials, std::shared_ptr queryCtx, long startProcessCpuTime) { std::shared_ptr execTask; @@ -606,7 +609,7 @@ std::unique_ptr TaskManager::createOrUpdateTaskImpl( // Keep track of the max sequence for this batch of splits. long maxSplitSequenceId{-1}; for (const auto& protocolSplit : source.splits) { - auto split = toVeloxSplit(protocolSplit); + auto split = toVeloxSplit(protocolSplit, extraCredentials); if (split.hasConnectorSplit()) { maxSplitSequenceId = std::max(maxSplitSequenceId, protocolSplit.sequenceId); diff --git a/presto-native-execution/presto_cpp/main/TaskManager.h b/presto-native-execution/presto_cpp/main/TaskManager.h index d2e931cd357ad..a365f8b70061b 100644 --- a/presto-native-execution/presto_cpp/main/TaskManager.h +++ b/presto-native-execution/presto_cpp/main/TaskManager.h @@ -183,6 +183,7 @@ class TaskManager { const std::vector& sources, const protocol::OutputBuffers& outputBuffers, bool summarize, + const std::map& extraCredentials, std::shared_ptr queryCtx, long startProcessCpuTime); diff --git a/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt new file mode 100644 index 0000000000000..5e4e2b04e5e02 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/CMakeLists.txt @@ -0,0 +1,20 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +if(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) + add_subdirectory(arrow_flight) +endif() + +add_library(presto_connector ConnectorRegistration.cpp) + +if(PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) + target_link_libraries(presto_connector presto_flight_connector) +endif() diff --git a/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.cpp b/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.cpp new file mode 100644 index 0000000000000..ac2cc164c12ae --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.cpp @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/ConnectorRegistration.h" + +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h" +#endif + +namespace facebook::presto::connector { + +void registerAllPrestoConnectors() { +#ifdef PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR + registerPrestoToVeloxConnector( + std::make_unique< + presto::connector::arrow_flight::ArrowPrestoToVeloxConnector>( + "arrow-flight")); + + if (!velox::connector::hasConnectorFactory( + presto::connector::arrow_flight::ArrowFlightConnectorFactory:: + kArrowFlightConnectorName)) { + velox::connector::registerConnectorFactory( + std::make_shared< + presto::connector::arrow_flight::ArrowFlightConnectorFactory>()); + } +#endif +} + +} // namespace facebook::presto::connector diff --git a/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.h b/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.h new file mode 100644 index 0000000000000..187362876247f --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/ConnectorRegistration.h @@ -0,0 +1,20 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace facebook::presto::connector { + +void registerAllPrestoConnectors(); + +} // namespace facebook::presto::connector diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp new file mode 100644 index 0000000000000..06a192e1b0191 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.cpp @@ -0,0 +1,184 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" + +#include +#include "arrow/c/abi.h" +#include "arrow/c/bridge.h" +#include "presto_cpp/main/common/ConfigReader.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "velox/vector/arrow/Bridge.h" + +namespace facebook::presto::connector::arrow_flight { + +using namespace arrow::flight; +using namespace velox; +using namespace velox::connector; + +// Wrapper for CallOptions which does not add any member variables, +// but provides a write-only interface for adding call headers. +class CallOptionsAddHeaders : public FlightCallOptions, public AddCallHeaders { + public: + void AddHeader(const std::string& key, const std::string& value) override { + headers.emplace_back(key, value); + } +}; + +std::optional ArrowFlightConnector::getDefaultLocation( + const std::shared_ptr& config) { + auto defaultHost = config->defaultServerHostname(); + auto defaultPort = config->defaultServerPort(); + if (!defaultHost.has_value() || !defaultPort.has_value()) { + return std::nullopt; + } + + bool defaultSslEnabled = config->defaultServerSslEnabled(); + AFC_RETURN_OR_RAISE( + defaultSslEnabled + ? Location::ForGrpcTls(defaultHost.value(), defaultPort.value()) + : Location::ForGrpcTcp(defaultHost.value(), defaultPort.value())); +} + +std::shared_ptr +ArrowFlightConnector::initClientOpts( + const std::shared_ptr& config) { + auto clientOpts = std::make_shared(); + clientOpts->disable_server_verification = !config->serverVerify(); + + auto certPath = config->serverSslCertificate(); + if (certPath.has_value()) { + std::ifstream file(certPath.value()); + VELOX_CHECK(file.is_open(), "Could not open TLS certificate"); + std::string cert( + (std::istreambuf_iterator(file)), + (std::istreambuf_iterator())); + clientOpts->tls_root_certs = cert; + } + + return clientOpts; +} + +FlightDataSource::FlightDataSource( + const RowTypePtr& outputType, + const std::unordered_map>& + columnHandles, + std::shared_ptr authenticator, + memory::MemoryPool* pool, + const std::shared_ptr& flightConfig, + const std::shared_ptr& clientOpts, + const std::optional& defaultLocation) + : outputType_{outputType}, + authenticator_{std::move(authenticator)}, + pool_{pool}, + flightConfig_{flightConfig}, + clientOpts_{clientOpts}, + defaultLocation_{defaultLocation} { + // columnMapping_ contains the real column names in the expected order. + // This is later used by projectOutputColumns to filter out unnecessary + // columns from the fetched chunk. + columnMapping_.reserve(outputType_->size()); + + for (const auto& columnName : outputType_->names()) { + auto it = columnHandles.find(columnName); + VELOX_CHECK( + it != columnHandles.end(), + "missing columnHandle for column '{}'", + columnName); + + auto handle = std::dynamic_pointer_cast(it->second); + VELOX_CHECK_NOT_NULL( + handle, + "handle for column '{}' is not an FlightColumnHandle", + columnName); + + columnMapping_.push_back(handle->name()); + } +} + +void FlightDataSource::addSplit(std::shared_ptr split) { + auto flightSplit = std::dynamic_pointer_cast(split); + VELOX_CHECK(flightSplit, "FlightDataSource received wrong type of split"); + + FlightEndpoint flightEndpoint; + AFC_ASSIGN_OR_RAISE( + flightEndpoint, + arrow::flight::FlightEndpoint::Deserialize( + flightSplit->flightEndpointBytes)); + + Location loc; + if (!flightEndpoint.locations.empty()) { + loc = flightEndpoint.locations[0]; + } else { + VELOX_CHECK( + defaultLocation_.has_value(), + "Split has empty Location list, but default host or port is missing"); + loc = defaultLocation_.value(); + } + VELOX_CHECK_NOT_NULL(clientOpts_, "FlightClientOptions is not initialized"); + + AFC_ASSIGN_OR_RAISE(auto client, FlightClient::Connect(loc, *clientOpts_)); + + CallOptionsAddHeaders callOptsAddHeaders{}; + authenticator_->authenticateClient( + client, flightSplit->extraCredentials, callOptsAddHeaders); + + auto readerResult = client->DoGet(callOptsAddHeaders, flightEndpoint.ticket); + AFC_ASSIGN_OR_RAISE(currentReader_, readerResult); +} + +std::optional FlightDataSource::next( + uint64_t size, + velox::ContinueFuture& /* unused */) { + VELOX_CHECK_NOT_NULL(currentReader_, "Missing split, call addSplit() first"); + + AFC_ASSIGN_OR_RAISE(auto chunk, currentReader_->Next()); + + // Null values in the chunk indicates that the Flight stream is complete. + if (!chunk.data) { + currentReader_ = nullptr; + return nullptr; + } + + // Extract only required columns from the record batch as a velox RowVector. + auto output = projectOutputColumns(chunk.data); + + completedRows_ += output->size(); + completedBytes_ += output->inMemoryBytes(); + return output; +} + +RowVectorPtr FlightDataSource::projectOutputColumns( + const std::shared_ptr& input) { + std::vector children; + children.reserve(columnMapping_.size()); + + // Extract and convert desired columns in the correct order. + for (const auto& name : columnMapping_) { + auto column = input->GetColumnByName(name); + VELOX_CHECK_NOT_NULL(column, "column with name '{}' not found", name); + ArrowArray array; + ArrowSchema schema; + AFC_RAISE_NOT_OK(arrow::ExportArray(*column, &array, &schema)); + children.push_back(importFromArrowAsOwner(schema, array, pool_)); + } + + return std::make_shared( + pool_, + outputType_, + BufferPtr() /*nulls*/, + input->num_rows(), + std::move(children)); +} + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h new file mode 100644 index 0000000000000..8219cfc9c40a4 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h @@ -0,0 +1,198 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "arrow/flight/api.h" +#include "presto_cpp/main/connectors/arrow_flight/FlightConfig.h" +#include "presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h" +#include "velox/connectors/Connector.h" + +namespace facebook::presto::connector::arrow_flight { + +class FlightTableHandle : public velox::connector::ConnectorTableHandle { + public: + explicit FlightTableHandle(const std::string& connectorId) + : ConnectorTableHandle(connectorId) {} +}; + +struct FlightSplit : public velox::connector::ConnectorSplit { + /// @param connectorId + /// @param flightEndpointBytes Serialized `FlightEndpoint` + /// @param extraCredentials Extra credentials for authentication + FlightSplit( + const std::string& connectorId, + const std::string& flightEndpointBytes, + const std::map& extraCredentials = {}) + : ConnectorSplit(connectorId), + flightEndpointBytes(flightEndpointBytes), + extraCredentials(extraCredentials) {} + + const std::string flightEndpointBytes; + std::map extraCredentials; +}; + +class FlightColumnHandle : public velox::connector::ColumnHandle { + public: + explicit FlightColumnHandle(const std::string& columnName) + : columnName_(columnName) {} + + const std::string& name() { + return columnName_; + } + + private: + std::string columnName_; +}; + +class FlightDataSource : public velox::connector::DataSource { + public: + FlightDataSource( + const velox::RowTypePtr& outputType, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + std::shared_ptr authenticator, + velox::memory::MemoryPool* pool, + const std::shared_ptr& flightConfig, + const std::shared_ptr& clientOpts, + const std::optional& defaultLocation = + std::nullopt); + + void addSplit( + std::shared_ptr split) override; + + std::optional next( + uint64_t size, + velox::ContinueFuture& /* unused */) override; + + void addDynamicFilter( + velox::column_index_t outputChannel, + const std::shared_ptr& filter) override { + VELOX_NYI("This connector doesn't support dynamic filters"); + } + + uint64_t getCompletedBytes() override { + return completedBytes_; + } + + uint64_t getCompletedRows() override { + return completedRows_; + } + + std::unordered_map runtimeStats() + override { + return {}; + } + + private: + /// Convert an Arrow record batch to Velox RowVector. + /// Process only those columns that are present in outputType_. + velox::RowVectorPtr projectOutputColumns( + const std::shared_ptr& input); + + velox::RowTypePtr outputType_; + std::vector columnMapping_; + std::unique_ptr currentReader_; + uint64_t completedRows_ = 0; + uint64_t completedBytes_ = 0; + std::shared_ptr authenticator_; + velox::memory::MemoryPool* const pool_; + const std::shared_ptr flightConfig_; + const std::shared_ptr clientOpts_; + const std::optional defaultLocation_; +}; + +class ArrowFlightConnector : public velox::connector::Connector { + public: + explicit ArrowFlightConnector( + const std::string& id, + std::shared_ptr config, + const char* authenticatorName = nullptr) + : Connector(id), + flightConfig_(std::make_shared(config)), + clientOpts_(initClientOpts(flightConfig_)), + defaultLocation_(getDefaultLocation(flightConfig_)), + authenticator_(auth::getAuthenticatorFactory( + authenticatorName + ? authenticatorName + : flightConfig_->authenticatorName()) + ->newAuthenticator(config)) {} + + std::unique_ptr createDataSource( + const velox::RowTypePtr& outputType, + const std::shared_ptr& + tableHandle, + const std::unordered_map< + std::string, + std::shared_ptr>& columnHandles, + velox::connector::ConnectorQueryCtx* ctx) override { + return std::make_unique( + outputType, + columnHandles, + authenticator_, + ctx->memoryPool(), + flightConfig_, + clientOpts_, + defaultLocation_); + } + + std::unique_ptr createDataSink( + velox::RowTypePtr inputType, + std::shared_ptr + connectorInsertTableHandle, + velox::connector::ConnectorQueryCtx* connectorQueryCtx, + velox::connector::CommitStrategy commitStrategy) override { + VELOX_NYI("Flight connector does not support DataSink"); + } + + private: + // Returns the default location specified in the FlightConfig. + // Returns nullopt if either host or port is missing. + static std::optional getDefaultLocation( + const std::shared_ptr& config); + + static std::shared_ptr initClientOpts( + const std::shared_ptr& config); + + const std::shared_ptr flightConfig_; + const std::shared_ptr clientOpts_; + const std::optional defaultLocation_; + const std::shared_ptr authenticator_; +}; + +class ArrowFlightConnectorFactory : public velox::connector::ConnectorFactory { + public: + static constexpr const char* kArrowFlightConnectorName = "arrow-flight"; + + ArrowFlightConnectorFactory() : ConnectorFactory(kArrowFlightConnectorName) {} + + explicit ArrowFlightConnectorFactory( + const char* name, + const char* authenticatorName = nullptr) + : ConnectorFactory(name), authenticatorName_(authenticatorName) {} + + std::shared_ptr newConnector( + const std::string& id, + std::shared_ptr config, + folly::Executor* ioExecutor = nullptr, + folly::Executor* cpuExecutor = nullptr) override { + return std::make_shared( + id, config, authenticatorName_); + } + + private: + const char* authenticatorName_{nullptr}; +}; + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp new file mode 100644 index 0000000000000..f0af5f5485912 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.cpp @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h" +#include "folly/base64.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h" + +namespace facebook::presto::connector::arrow_flight { + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* const connectorSplit, + const protocol::SplitContext* splitContext, + const std::map& extraCredentials) const { + auto arrowSplit = + dynamic_cast(connectorSplit); + VELOX_CHECK_NOT_NULL( + arrowSplit, "Unexpected split type {}", connectorSplit->_type); + return std::make_unique( + catalogId, arrowSplit->flightEndpointBytes, extraCredentials); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const { + auto arrowColumn = + dynamic_cast(column); + VELOX_CHECK_NOT_NULL( + arrowColumn, "Unexpected column handle type {}", column->_type); + return std::make_unique( + arrowColumn->columnName); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser, + std::unordered_map< + std::string, + std::shared_ptr>& assignments) const { + return std::make_unique( + tableHandle.connectorId); +} + +std::unique_ptr +ArrowPrestoToVeloxConnector::createConnectorProtocol() const { + return std::make_unique(); +} + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h new file mode 100644 index 0000000000000..74284cf74aa19 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/ArrowPrestoToVeloxConnector.h @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/main/types/PrestoToVeloxConnector.h" + +namespace facebook::presto::connector::arrow_flight { + +class ArrowPrestoToVeloxConnector final : public PrestoToVeloxConnector { + public: + explicit ArrowPrestoToVeloxConnector(std::string connectorName) + : PrestoToVeloxConnector(std::move(connectorName)) {} + + std::unique_ptr toVeloxSplit( + const protocol::ConnectorId& catalogId, + const protocol::ConnectorSplit* connectorSplit, + const protocol::SplitContext* splitContext, + const std::map& extraCredentials = {}) + const final; + + std::unique_ptr toVeloxColumnHandle( + const protocol::ColumnHandle* column, + const TypeParser& typeParser) const final; + + std::unique_ptr toVeloxTableHandle( + const protocol::TableHandle& tableHandle, + const VeloxExprConverter& exprConverter, + const TypeParser& typeParser, + std::unordered_map< + std::string, + std::shared_ptr>& assignments) + const final; + + std::unique_ptr createConnectorProtocol() + const final; +}; + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt new file mode 100644 index 0000000000000..d2d92fd036524 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/CMakeLists.txt @@ -0,0 +1,43 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +find_package(Arrow REQUIRED) +find_package(PkgConfig REQUIRED) +pkg_check_modules(ARROW_FLIGHT REQUIRED IMPORTED_TARGET GLOBAL arrow-flight) + +if(NOT ARROW_FLIGHT_FOUND) + message(FATAL_ERROR "Arrow Flight package not found") +endif() + +set(ArrowFlight_FOUND TRUE) +set(ArrowFlight_INCLUDE_DIRS ${ARROW_FLIGHT_INCLUDE_DIRS}) +set(ArrowFlight_LIBRARIES ${ARROW_FLIGHT_LIBRARIES}) +include_directories(${ArrowFlight_INCLUDE_DIRS}) + +add_subdirectory(auth) + +add_library(presto_flight_connector_utils INTERFACE Macros.h) +target_link_libraries(presto_flight_connector_utils INTERFACE velox_exception) + +add_library( + presto_flight_connector OBJECT + ArrowFlightConnector.cpp ArrowPrestoToVeloxConnector.cpp FlightConfig.cpp) + +target_compile_definitions(presto_flight_connector + PUBLIC PRESTO_ENABLE_ARROW_FLIGHT_CONNECTOR) + +target_link_libraries( + presto_flight_connector velox_connector PkgConfig::ARROW_FLIGHT + presto_flight_connector_utils presto_flight_connector_auth presto_types) + +if(PRESTO_ENABLE_TESTING) + add_subdirectory(tests) +endif() diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.cpp new file mode 100644 index 0000000000000..bf52d69e22df8 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.cpp @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/FlightConfig.h" + +namespace facebook::presto::connector::arrow_flight { + +std::string FlightConfig::authenticatorName() { + return config_->get(kAuthenticatorName, "none"); +} + +std::optional FlightConfig::defaultServerHostname() { + return static_cast>( + config_->get(kDefaultServerHost)); +} + +std::optional FlightConfig::defaultServerPort() { + return static_cast>( + config_->get(kDefaultServerPort)); +} + +bool FlightConfig::defaultServerSslEnabled() { + return config_->get(kDefaultServerSslEnabled, false); +} + +bool FlightConfig::serverVerify() { + return config_->get(kServerVerify, true); +} + +std::optional FlightConfig::serverSslCertificate() { + return static_cast>( + config_->get(kServerSslCertificate)); +} + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.h new file mode 100644 index 0000000000000..6969030f24f16 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/FlightConfig.h @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/config/Config.h" + +namespace facebook::presto::connector::arrow_flight { + +class FlightConfig { + public: + explicit FlightConfig(std::shared_ptr config) + : config_{config} {} + + static constexpr const char* kAuthenticatorName = + "arrow-flight.authenticator.name"; + + static constexpr const char* kDefaultServerHost = "arrow-flight.server"; + + static constexpr const char* kDefaultServerPort = "arrow-flight.server.port"; + + static constexpr const char* kDefaultServerSslEnabled = + "arrow-flight.server-ssl-enabled"; + + static constexpr const char* kServerVerify = "arrow-flight.server.verify"; + + static constexpr const char* kServerSslCertificate = + "arrow-flight.server-ssl-certificate"; + + std::string authenticatorName(); + + std::optional defaultServerHostname(); + + std::optional defaultServerPort(); + + bool defaultServerSslEnabled(); + + bool serverVerify(); + + std::optional serverSslCertificate(); + + private: + std::shared_ptr config_; +}; + +} // namespace facebook::presto::connector::arrow_flight diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h new file mode 100644 index 0000000000000..5ab725e582cc6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/Macros.h @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/common/base/Exceptions.h" + +// Macros for dealing with arrow::Status and arrow::Result objects +// and converting them to velox exceptions. + +/// Raise a Velox exception if status is not OK. +/// Counterpart of ARROW_RETURN_NOT_OK. +#define AFC_RAISE_NOT_OK(status) \ + do { \ + ::arrow::Status __s = ::arrow::internal::GenericToStatus(status); \ + VELOX_CHECK(__s.ok(), __s.message()); \ + } while (false) + +#define AFC_ASSIGN_OR_RAISE_IMPL(result_name, lhs, rexpr) \ + auto&& result_name = (rexpr); \ + VELOX_CHECK((result_name).ok(), (result_name).status().message()); \ + lhs = std::move(result_name).ValueUnsafe(); + +/// Raise a Velox exception if expr doesn't return an OK result, +/// else unwrap the value and assign it to `lhs`. +/// `std::move`s its right hand operand. +/// Counterpart of ARROW_ASSIGN_OR_RAISE. +#define AFC_ASSIGN_OR_RAISE(lhs, rexpr) \ + AFC_ASSIGN_OR_RAISE_IMPL( \ + ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), lhs, rexpr); + +/// Raise a Velox exception if rexpr doesn't return an OK result, +/// else unwrap the value and return it. +/// `std::move`s its right hand operand. +#define AFC_RETURN_OR_RAISE(rexpr) \ + do { \ + auto&& __r = (rexpr); \ + VELOX_CHECK(__r.ok(), __r.status().message()); \ + return std::move(__r).ValueUnsafe(); \ + } while (false) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp new file mode 100644 index 0000000000000..3b30123916653 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.cpp @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::presto::connector::arrow_flight::auth { +namespace { +auto& authenticatorFactories() { + static std::unordered_map> + factories; + return factories; +} +} // namespace + +bool registerAuthenticatorFactory( + std::shared_ptr factory) { + bool ok = authenticatorFactories().insert({factory->name(), factory}).second; + VELOX_CHECK( + ok, + "Flight AuthenticatorFactory with name {} is already registered", + factory->name()); + return true; +} + +std::shared_ptr getAuthenticatorFactory( + const std::string& name) { + auto it = authenticatorFactories().find(name); + VELOX_CHECK( + it != authenticatorFactories().end(), + "Flight AuthenticatorFactory with name {} not registered", + name); + return it->second; +} + +AFC_REGISTER_AUTH_FACTORY(std::make_shared()) + +} // namespace facebook::presto::connector::arrow_flight::auth diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h new file mode 100644 index 0000000000000..543535c4ca6b6 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/Authenticator.h @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "arrow/flight/api.h" +#include "velox/common/config/Config.h" + +namespace facebook::presto::connector::arrow_flight::auth { + +class Authenticator { + public: + /// @brief Override this method to define implementation-specific + /// authentication This could be through client->Authenticate, or + /// client->AuthenticateBasicToken or any other custom strategy + /// @param client the Flight client which is to be authenticated + /// @param extraCredentials extra credential data used for authentication + /// @param headerWriter write-only object used to set authentication headers + virtual void authenticateClient( + std::unique_ptr& client, + const std::map& extraCredentials, + arrow::flight::AddCallHeaders& headerWriter) = 0; +}; + +class AuthenticatorFactory { + public: + explicit AuthenticatorFactory(std::string_view name) : name_{name} {} + + const std::string& name() const { + return name_; + } + + virtual std::shared_ptr newAuthenticator( + std::shared_ptr config) = 0; + + private: + std::string name_; +}; + +bool registerAuthenticatorFactory( + std::shared_ptr factory); + +std::shared_ptr getAuthenticatorFactory( + const std::string& name); + +#define AFC_REGISTER_AUTH_FACTORY(factory) \ + namespace { \ + static bool FB_ANONYMOUS_VARIABLE(g_ConnectorFactory) = ::facebook::presto:: \ + connector::arrow_flight::auth::registerAuthenticatorFactory((factory)); \ + } + +class NoOpAuthenticator : public Authenticator { + public: + void authenticateClient( + std::unique_ptr& client, + const std::map& extraCredentials, + arrow::flight::AddCallHeaders& headerWriter) override {} +}; + +class NoOpAuthenticatorFactory : public AuthenticatorFactory { + public: + static constexpr const std::string_view kNoOpAuthenticatorName{"none"}; + + NoOpAuthenticatorFactory() : AuthenticatorFactory{kNoOpAuthenticatorName} {} + + explicit NoOpAuthenticatorFactory(std::string_view name) + : AuthenticatorFactory{name} {} + + std::shared_ptr newAuthenticator( + std::shared_ptr config) override { + return std::make_shared(); + } +}; + +} // namespace facebook::presto::connector::arrow_flight::auth diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt new file mode 100644 index 0000000000000..1e7eba3154a0e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/auth/CMakeLists.txt @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library(presto_flight_connector_auth Authenticator.cpp) + +target_link_libraries(presto_flight_connector_auth + presto_flight_connector_utils velox_exception) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp new file mode 100644 index 0000000000000..b67d05bc810b8 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorAuthTest.cpp @@ -0,0 +1,251 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "arrow/testing/gtest_util.h" +#include "folly/init/Init.h" +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::connector::arrow_flight::test { + +namespace { + +namespace velox = facebook::velox; +namespace core = facebook::velox::core; +namespace exec = facebook::velox::exec; +namespace connector = facebook::velox::connector; +namespace flight = arrow::flight; + +using namespace facebook::presto::connector::arrow_flight; +using namespace facebook::presto::connector::arrow_flight::test; +using exec::test::AssertQueryBuilder; +using exec::test::OperatorTestBase; + +class TestingServerMiddlewareFactory : public flight::ServerMiddlewareFactory { + public: + static constexpr const char* kAuthHeader = "authorization"; + static constexpr const char* kAuthToken = "Bearer 1234"; + static constexpr const char* kAuthTokenUnauthorized = "Bearer 2112"; + + arrow::Status StartCall( + const flight::CallInfo& info, + const flight::ServerCallContext& context, + std::shared_ptr* middleware) override { + auto iter = context.incoming_headers().find(kAuthHeader); + + if (iter == context.incoming_headers().end()) { + return flight::MakeFlightError( + flight::FlightStatusCode::Unauthenticated, + "Authorization token not provided"); + } else { + std::lock_guard l(mutex_); + checkedTokens_.emplace_back(iter->second); + } + + if (kAuthToken != iter->second) { + return flight::MakeFlightError( + flight::FlightStatusCode::Unauthorized, + "Authorization token is invalid"); + } + + return arrow::Status::OK(); + } + + bool isTokenChecked(const std::string& authToken) { + { + std::lock_guard l(mutex_); + return std::find( + checkedTokens_.begin(), checkedTokens_.end(), authToken) != + checkedTokens_.end(); + } + } + + private: + std::string validToken_; + std::vector checkedTokens_; + std::mutex mutex_; +}; + +class TestingAuthenticator : public auth::Authenticator { + public: + explicit TestingAuthenticator(const std::string& authToken) + : authToken_(authToken) {} + + void authenticateClient( + std::unique_ptr& client, + const std::map& extraCredentials, + arrow::flight::AddCallHeaders& headerWriter) override { + if (!authToken_.empty()) { + headerWriter.AddHeader( + TestingServerMiddlewareFactory::kAuthHeader, authToken_); + } + } + + private: + std::string authToken_; +}; + +class TestingAuthenticatorFactory : public auth::AuthenticatorFactory { + public: + TestingAuthenticatorFactory( + const std::string& name, + const std::string& authToken) + : auth::AuthenticatorFactory(name), + testingAuthenticator_{ + std::make_shared(authToken)} {} + + std::shared_ptr newAuthenticator( + std::shared_ptr config) override { + return testingAuthenticator_; + } + + private: + std::shared_ptr testingAuthenticator_; +}; + +constexpr const char* kAuthFactoryName = "testing-auth-valid"; +constexpr const char* kAuthFactoryUnauthorizedName = + "testing-auth-unauthorized"; +constexpr const char* kAuthFactoryNoTokenName = "testing-auth-no-token"; + +bool registerTestAuthFactories() { + static bool once = [] { + auto authFactory = std::make_shared( + kAuthFactoryName, TestingServerMiddlewareFactory::kAuthToken); + auth::registerAuthenticatorFactory(authFactory); + auto authFactoryUnauthorized = + std::make_shared( + kAuthFactoryUnauthorizedName, + TestingServerMiddlewareFactory::kAuthTokenUnauthorized); + auth::registerAuthenticatorFactory(authFactoryUnauthorized); + auto authFactoryNoToken = std::make_shared( + kAuthFactoryNoTokenName, ""); + auth::registerAuthenticatorFactory(authFactoryNoToken); + return true; + }(); + return once; +} + +class FlightConnectorCustomAuthTestBase : public FlightWithServerTestBase { + public: + explicit FlightConnectorCustomAuthTestBase(const std::string& authFactoryName) + : FlightWithServerTestBase(std::make_shared( + std::unordered_map{ + {FlightConfig::kAuthenticatorName, authFactoryName}})), + testingMiddlewareFactory_( + std::make_shared()) {} + + void SetUp() override { + registerTestAuthFactories(); + FlightWithServerTestBase::SetUp(); + } + + void setFlightServerOptions( + flight::FlightServerOptions* serverOptions) override { + serverOptions->middleware.push_back( + {"bearer-auth-server", testingMiddlewareFactory_}); + } + + core::PlanNodePtr addSampleDataAndRunQuery() { + updateTable( + "sample-data", + makeArrowTable( + {"id", "value"}, + {makeNumericArray( + {1, 12, 2, std::numeric_limits::max()}), + makeNumericArray( + {41, 42, 43, std::numeric_limits::min()})})); + + auto idVec = makeFlatVector( + {1, 12, 2, std::numeric_limits::max()}); + auto valueVec = makeFlatVector( + {41, 42, 43, std::numeric_limits::min()}); + + return FlightPlanBuilder() + .flightTableScan( + velox::ROW({"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + } + + protected: + std::shared_ptr testingMiddlewareFactory_; +}; + +class FlightConnectorCustomAuthTest : public FlightConnectorCustomAuthTestBase { + public: + FlightConnectorCustomAuthTest() + : FlightConnectorCustomAuthTestBase(kAuthFactoryName) {} +}; + +TEST_F(FlightConnectorCustomAuthTest, customAuthenticator) { + core::PlanNodePtr plan = addSampleDataAndRunQuery(); + + auto idVec = + makeFlatVector({1, 12, 2, std::numeric_limits::max()}); + auto valueVec = makeFlatVector( + {41, 42, 43, std::numeric_limits::min()}); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + ASSERT_TRUE(testingMiddlewareFactory_->isTokenChecked( + TestingServerMiddlewareFactory::kAuthToken)); +} + +class FlightConnectorCustomAuthUnauthorizedTest + : public FlightConnectorCustomAuthTestBase { + public: + FlightConnectorCustomAuthUnauthorizedTest() + : FlightConnectorCustomAuthTestBase(kAuthFactoryUnauthorizedName) {} +}; + +TEST_F(FlightConnectorCustomAuthUnauthorizedTest, unauthorizedToken) { + core::PlanNodePtr plan = addSampleDataAndRunQuery(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertEmptyResults(), + "Unauthorized"); +} + +class FlightConnectorCustomAuthUnauthenticatedTest + : public FlightConnectorCustomAuthTestBase { + public: + FlightConnectorCustomAuthUnauthenticatedTest() + : FlightConnectorCustomAuthTestBase(kAuthFactoryNoTokenName) {} +}; + +TEST_F(FlightConnectorCustomAuthUnauthenticatedTest, unauthenticatedNoToken) { + core::PlanNodePtr plan = addSampleDataAndRunQuery(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertEmptyResults(), + "Unauthenticated"); +} + +} // namespace + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp new file mode 100644 index 0000000000000..dd12509793023 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorDataTypeTest.cpp @@ -0,0 +1,355 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "arrow/testing/gtest_util.h" +#include "folly/init/Init.h" +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::connector::arrow_flight::test { + +namespace { + +namespace velox = facebook::velox; +namespace core = facebook::velox::core; +namespace exec = facebook::velox::exec; +namespace connector = facebook::velox::connector; +namespace flight = arrow::flight; + +using namespace facebook::presto::connector::arrow_flight; +using namespace facebook::presto::connector::arrow_flight::test; +using exec::test::AssertQueryBuilder; +using exec::test::OperatorTestBase; +using exec::test::PlanBuilder; + +static const std::string kFlightConnectorId = "test-flight"; + +class FlightConnectorDataTypeTest : public FlightWithServerTestBase {}; + +TEST_F(FlightConnectorDataTypeTest, booleanType) { + updateTable( + "sample-data", + makeArrowTable( + {"bool_col"}, {makeBooleanArray({true, false, true, false})})); + + auto boolVec = makeFlatVector({true, false, true, false}); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({"bool_col"}, {velox::BOOLEAN()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({boolVec})); +} + +TEST_F(FlightConnectorDataTypeTest, integerTypes) { + updateTable( + "sample-data", + makeArrowTable( + {"tinyint_col", "smallint_col", "integer_col", "bigint_col"}, + {makeNumericArray( + {-128, 0, 127, std::numeric_limits::max()}), + makeNumericArray( + {-32768, 0, 32767, std::numeric_limits::max()}), + makeNumericArray( + {-2147483648, + 0, + 2147483647, + std::numeric_limits::max()}), + makeNumericArray( + {-3435678987654321234LL, + 0, + 4527897896541234567LL, + std::numeric_limits::max()})})); + + auto tinyintVec = makeFlatVector( + {-128, 0, 127, std::numeric_limits::max()}); + + auto smallintVec = makeFlatVector( + {-32768, 0, 32767, std::numeric_limits::max()}); + + auto integerVec = makeFlatVector( + {-2147483648, 0, 2147483647, std::numeric_limits::max()}); + + auto bigintVec = makeFlatVector( + {-3435678987654321234LL, + 0, + 4527897896541234567LL, + std::numeric_limits::max()}); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"tinyint_col", "smallint_col", "integer_col", "bigint_col"}, + {velox::TINYINT(), + velox::SMALLINT(), + velox::INTEGER(), + velox::BIGINT()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults( + makeRowVector({tinyintVec, smallintVec, integerVec, bigintVec})); +} + +TEST_F(FlightConnectorDataTypeTest, realType) { + updateTable( + "sample-data", + makeArrowTable( + {"real_col", "double_col"}, + {makeNumericArray( + {std::numeric_limits::min(), + 0.0f, + 3.14f, + std::numeric_limits::max()}), + makeNumericArray( + {std::numeric_limits::min(), + 0.0, + 3.14159, + std::numeric_limits::max()})})); + + auto realVec = makeFlatVector( + {std::numeric_limits::min(), + 0.0f, + 3.14f, + std::numeric_limits::max()}); + auto doubleVec = makeFlatVector( + {std::numeric_limits::min(), + 0.0, + 3.14159, + std::numeric_limits::max()}); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"real_col", "double_col"}, {velox::REAL(), velox::DOUBLE()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({realVec, doubleVec})); +} + +TEST_F(FlightConnectorDataTypeTest, varcharType) { + updateTable( + "sample-data", + makeArrowTable( + {"varchar_col"}, {makeStringArray({"Hello", "World", "India"})})); + + auto vec = makeFlatVector( + {facebook::velox::StringView("Hello"), + facebook::velox::StringView("World"), + facebook::velox::StringView("India")}); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({"varchar_col"}, {velox::VARCHAR()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({vec})); +} + +TEST_F(FlightConnectorDataTypeTest, timestampType) { + auto timestampValues = + std::vector{1622538000, 1622541600, 1622545200}; + + updateTable( + "sample-data", + makeArrowTable( + {"timestampsec_col", "timestampmilli_col", "timestampmicro_col"}, + {makeTimestampArray(timestampValues, arrow::TimeUnit::SECOND), + makeTimestampArray(timestampValues, arrow::TimeUnit::MILLI), + makeTimestampArray(timestampValues, arrow::TimeUnit::MICRO)})); + + std::vector veloxTimestampSec; + for (const auto& ts : timestampValues) { + veloxTimestampSec.emplace_back(ts, 0); // Assuming 0 microseconds part + } + + auto timestampSecCol = + makeFlatVector(veloxTimestampSec); + + std::vector veloxTimestampMilli; + for (const auto& ts : timestampValues) { + veloxTimestampMilli.emplace_back( + ts / 1000, (ts % 1000) * 1000000); // Convert to seconds and nanoseconds + } + + auto timestampMilliCol = + makeFlatVector(veloxTimestampMilli); + + std::vector veloxTimestampMicro; + for (const auto& ts : timestampValues) { + veloxTimestampMicro.emplace_back( + ts / 1000000, + (ts % 1000000) * 1000); // Convert to seconds and nanoseconds + } + + auto timestampMicroCol = + makeFlatVector(veloxTimestampMicro); + + core::PlanNodePtr plan; + plan = + FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"timestampsec_col", "timestampmilli_col", "timestampmicro_col"}, + {velox::TIMESTAMP(), velox::TIMESTAMP(), velox::TIMESTAMP()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector( + {timestampSecCol, timestampMilliCol, timestampMicroCol})); +} + +TEST_F(FlightConnectorDataTypeTest, dateDayType) { + std::vector datesDay = {18748, 18749, 18750}; // Days since epoch + std::vector datesMilli = { + 1622538000000, 1622541600000, 1622545200000}; // Milliseconds since epoch + + updateTable( + "sample-data", + makeArrowTable( + {"daydate_col", "daymilli_col"}, + {makeNumericArray(datesDay), + makeNumericArray(datesMilli)})); + + auto dateVec = makeFlatVector(datesDay); + auto milliVec = makeFlatVector(datesMilli); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({"daydate_col"}, {velox::DATE()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({dateVec})); + + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({"daymilli_col"}, {velox::DATE()})) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({milliVec})), + "Unable to convert 'tdm' ArrowSchema format type to Velox"); +} + +TEST_F(FlightConnectorDataTypeTest, decimalType) { + std::vector decimalValuesBigInt = { + 123456789012345678, + -123456789012345678, + std::numeric_limits::max()}; + std::vector> decimalArrayVec; + decimalArrayVec.push_back(makeDecimalArray(decimalValuesBigInt, 18, 2)); + updateTable( + "sample-data", makeArrowTable({"decimal_col_bigint"}, decimalArrayVec)); + auto decimalVecBigInt = makeFlatVector(decimalValuesBigInt); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"decimal_col_bigint"}, + {velox::DECIMAL(18, 2)})) // precision can't be 0 and < scale + .planNode(); + + // Execute the query and assert the results + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({decimalVecBigInt})); +} + +TEST_F(FlightConnectorDataTypeTest, allTypes) { + auto timestampValues = + std::vector{1622550000, 1622553600, 1622557200}; + + auto sampleTable = makeArrowTable( + {"id", + "daydate_col", + "timestamp_col", + "varchar_col", + "real_col", + "int_col", + "bool_col"}, + {makeNumericArray({1, 2, 3}), + makeNumericArray({18748, 18749, 18750}), + makeTimestampArray(timestampValues, arrow::TimeUnit::SECOND), + makeStringArray({"apple", "banana", "cherry"}), + makeNumericArray({3.14, 2.718, 1.618}), + makeNumericArray( + {-32768, 32767, std::numeric_limits::max()}), + makeBooleanArray({true, false, true})}); + + updateTable("gen-data", sampleTable); + + auto dateVec = makeFlatVector({18748, 18749, 18750}); + + std::vector veloxTimestampSec; + for (const auto& ts : timestampValues) { + veloxTimestampSec.emplace_back(ts, 0); // Assuming 0 microseconds part + } + auto timestampSecVec = + makeFlatVector(veloxTimestampSec); + + auto stringVec = makeFlatVector( + {facebook::velox::StringView("apple"), + facebook::velox::StringView("banana"), + facebook::velox::StringView("cherry")}); + auto realVec = makeFlatVector({3.14, 2.718, 1.618}); + auto intVec = makeFlatVector( + {-32768, 32767, std::numeric_limits::max()}); + auto boolVec = makeFlatVector({true, false, true}); + + core::PlanNodePtr plan; + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"daydate_col", + "timestamp_col", + "varchar_col", + "real_col", + "int_col", + "bool_col"}, + {velox::DATE(), + velox::TIMESTAMP(), + velox::VARCHAR(), + velox::DOUBLE(), + velox::INTEGER(), + velox::BOOLEAN()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"gen-data"})) + .assertResults(makeRowVector( + {dateVec, timestampSecVec, stringVec, realVec, intVec, boolVec})); +} + +} // namespace + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp new file mode 100644 index 0000000000000..2a7ff7cf4b4cc --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTest.cpp @@ -0,0 +1,197 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "arrow/testing/gtest_util.h" +#include "folly/init/Init.h" +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::connector::arrow_flight::test { + +namespace { + +namespace velox = facebook::velox; +namespace core = facebook::velox::core; +namespace exec = facebook::velox::exec; +namespace connector = facebook::velox::connector; +namespace flight = arrow::flight; + +using namespace facebook::presto::connector::arrow_flight; +using namespace facebook::presto::connector::arrow_flight::test; +using exec::test::AssertQueryBuilder; +using exec::test::OperatorTestBase; + +static const std::string kFlightConnectorId = "test-flight"; + +class FlightConnectorTest : public FlightWithServerTestBase {}; + +TEST_F(FlightConnectorTest, invalidSplitTest) { + auto plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({{"id", velox::BIGINT()}})) + .planNode(); + + VELOX_ASSERT_THROW( + velox::exec::test::AssertQueryBuilder(plan) + .splits(makeSplits({"unknown"})) + .copyResults(pool()), + "table does not exist"); +} + +TEST_F(FlightConnectorTest, dataSourceCreationTest) { + // missing columnHandle test + auto plan = + FlightPlanBuilder() + .flightTableScan( + velox::ROW({"id", "value"}, {velox::BIGINT(), velox::INTEGER()}), + {{"id", std::make_shared("id")}}, + false /*createDefaultColumnHandles*/) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .copyResults(pool()), + "missing columnHandle for column 'value'"); +} + +TEST_F(FlightConnectorTest, dataSourceTest) { + updateTable( + "sample-data", + makeArrowTable( + {"id", "value", "unsigned"}, + {makeNumericArray( + {1, 12, 2, std::numeric_limits::max()}), + makeNumericArray( + {41, 42, 43, std::numeric_limits::min()}), + // note that velox doesn't support unsigned types + // connector should still be able to query such tables + // as long as this specific column isn't requested. + makeNumericArray( + {101, 102, 12, std::numeric_limits::max()})})); + + auto idColumn = std::make_shared("id"); + auto idVec = + makeFlatVector({1, 12, 2, std::numeric_limits::max()}); + + auto valueColumn = std::make_shared("value"); + auto valueVec = makeFlatVector( + {41, 42, 43, std::numeric_limits::min()}); + + core::PlanNodePtr plan; + + // direct test + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, std::vector{})) + .assertResults(makeRowVector({idVec, valueVec})), + "default host or port is missing"); + + // column alias test + plan = + FlightPlanBuilder() + .flightTableScan( + velox::ROW({"ducks", "id"}, {velox::BIGINT(), velox::BIGINT()}), + {{"ducks", idColumn}}) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, idVec})); + + // invalid columnHandle test + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"ducks", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .copyResults(pool()), + "column with name 'ducks' not found"); +} + +class FlightConnectorTestDefaultServer : public FlightWithServerTestBase { + public: + FlightConnectorTestDefaultServer() + : FlightWithServerTestBase(std::make_shared( + std::unordered_map{ + {FlightConfig::kDefaultServerHost, CONNECT_HOST}, + {FlightConfig::kDefaultServerPort, + std::to_string(LISTEN_PORT)}})) {} +}; + +TEST_F(FlightConnectorTestDefaultServer, dataSourceTest) { + updateTable( + "sample-data", + makeArrowTable( + {"id", "value"}, + {makeNumericArray( + {1, 12, 2, std::numeric_limits::max()}), + makeNumericArray( + {41, 42, 43, std::numeric_limits::min()})})); + + auto idColumn = std::make_shared("id"); + auto idVec = + makeFlatVector({1, 12, 2, std::numeric_limits::max()}); + + auto valueColumn = std::make_shared("value"); + auto valueVec = makeFlatVector( + {41, 42, 43, std::numeric_limits::min()}); + + core::PlanNodePtr plan; + + // direct test + plan = FlightPlanBuilder() + .flightTableScan(velox::ROW( + {"id", "value"}, {velox::BIGINT(), velox::INTEGER()})) + .planNode(); + + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"})) + .assertResults(makeRowVector({idVec, valueVec})); + + AssertQueryBuilder(plan) + .splits(makeSplits( + {"sample-data"}, + std::vector{})) // Using default connector + .assertResults(makeRowVector({idVec, valueVec})); +} + +} // namespace + +} // namespace facebook::presto::connector::arrow_flight::test + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp new file mode 100644 index 0000000000000..b7f7ef5dfc070 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/ArrowFlightConnectorTlsTest.cpp @@ -0,0 +1,138 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "arrow/flight/types.h" +#include "arrow/testing/gtest_util.h" +#include "folly/init/Init.h" +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/config/Config.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::connector::arrow_flight::test { + +namespace { + +namespace velox = facebook::velox; +namespace core = facebook::velox::core; +namespace exec = facebook::velox::exec; +namespace connector = facebook::velox::connector; +namespace flight = arrow::flight; + +using namespace facebook::presto::connector::arrow_flight; +using namespace facebook::presto::connector::arrow_flight::test; +using exec::test::AssertQueryBuilder; +using exec::test::OperatorTestBase; + +class FlightConnectorTlsTestBase : public FlightWithServerTestBase { + protected: + explicit FlightConnectorTlsTestBase( + std::shared_ptr config) + : FlightWithServerTestBase(std::move(config)) {} + + flight::Location getServerLocation() override { + AFC_ASSIGN_OR_RAISE( + auto loc, flight::Location::ForGrpcTls(BIND_HOST, LISTEN_PORT)); + return loc; + } + + void setFlightServerOptions( + flight::FlightServerOptions* serverOptions) override { + flight::CertKeyPair tlsCertificate{ + .pem_cert = readFile("./data/tls_certs/server.crt"), + .pem_key = readFile("./data/tls_certs/server.key")}; + serverOptions->tls_certificates.push_back(tlsCertificate); + } + + void executeTest( + bool isPositiveTest = true, + const std::string& expectedError = "") { + updateTable( + "sample-data", + makeArrowTable( + {"id"}, + {makeNumericArray( + {1, 12, 2, std::numeric_limits::max()})})); + + auto idVec = makeFlatVector( + {1, 12, 2, std::numeric_limits::max()}); + + auto plan = FlightPlanBuilder() + .flightTableScan(velox::ROW({"id"}, {velox::BIGINT()})) + .planNode(); + + ASSERT_OK_AND_ASSIGN( + auto loc, flight::Location::ForGrpcTls(CONNECT_HOST, LISTEN_PORT)); + auto locs = std::vector{loc}; + if (isPositiveTest) { + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, locs)) + .assertResults(makeRowVector({idVec})); + } else { + VELOX_ASSERT_THROW( + AssertQueryBuilder(plan) + .splits(makeSplits({"sample-data"}, locs)) + .assertResults(makeRowVector({idVec})), + expectedError); + } + } +}; + +class FlightConnectorTlsTest : public FlightConnectorTlsTestBase { + protected: + explicit FlightConnectorTlsTest() + : FlightConnectorTlsTestBase(std::make_shared( + std::unordered_map{ + {FlightConfig::kServerVerify, "true"}, + {FlightConfig::kServerSslCertificate, + "./data/tls_certs/ca.crt"}})) {} +}; + +TEST_F(FlightConnectorTlsTest, tlsTest) { + executeTest(); +} + +class FlightConnectorTlsNoCertValidationTest + : public FlightConnectorTlsTestBase { + protected: + explicit FlightConnectorTlsNoCertValidationTest() + : FlightConnectorTlsTestBase(std::make_shared( + std::unordered_map{ + {FlightConfig::kServerVerify, "false"}})) {} +}; + +TEST_F(FlightConnectorTlsNoCertValidationTest, tlsNoCertValidationTest) { + executeTest(); +} + +class FlightConnectorTlsNoCertTest : public FlightConnectorTlsTestBase { + protected: + FlightConnectorTlsNoCertTest() + : FlightConnectorTlsTestBase(std::make_shared( + std::unordered_map{ + {FlightConfig::kServerVerify, "true"}})) {} +}; + +TEST_F(FlightConnectorTlsNoCertTest, tlsNoCertTest) { + executeTest(false, "handshake failed"); +} + +} // namespace + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt new file mode 100644 index 0000000000000..9830ad9c213fb --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/CMakeLists.txt @@ -0,0 +1,44 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_subdirectory(utils) + +add_executable(presto_flight_connector_infra_test TestFlightServerTest.cpp) + +add_test(presto_flight_connector_infra_test presto_flight_connector_infra_test) + +target_link_libraries( + presto_flight_connector_infra_test presto_protocol + presto_flight_connector_test_lib GTest::gtest GTest::gtest_main ${GLOG}) + +add_executable( + presto_flight_connector_test + ArrowFlightConnectorTest.cpp ArrowFlightConnectorAuthTest.cpp + ArrowFlightConnectorTlsTest.cpp ArrowFlightConnectorDataTypeTest.cpp + FlightConfigTest.cpp) + +set(DATA_DIR "${CMAKE_CURRENT_SOURCE_DIR}/data/tls_certs") + +add_custom_target( + copy_flight_test_data ALL + COMMAND ${CMAKE_COMMAND} -E copy_directory ${DATA_DIR} + $/data/tls_certs) + +add_test(presto_flight_connector_test presto_flight_connector_test) + +target_link_libraries( + presto_flight_connector_test + velox_exec_test_lib + presto_flight_connector + gtest + gtest_main + presto_flight_connector_test_lib + presto_protocol) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/FlightConfigTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/FlightConfigTest.cpp new file mode 100644 index 0000000000000..21f8ff2aeb72d --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/FlightConfigTest.cpp @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/FlightConfig.h" +#include "gtest/gtest.h" + +namespace facebook::presto::connector::arrow_flight::test { + +TEST(FlightConfigTest, defaultConfig) { + auto rawConfig = std::make_shared( + std::move(std::unordered_map{})); + auto config = FlightConfig(rawConfig); + ASSERT_EQ(config.authenticatorName(), "none"); + ASSERT_EQ(config.defaultServerHostname(), std::nullopt); + ASSERT_EQ(config.defaultServerPort(), std::nullopt); + ASSERT_EQ(config.defaultServerSslEnabled(), false); + ASSERT_EQ(config.serverVerify(), true); + ASSERT_EQ(config.serverSslCertificate(), std::nullopt); +} + +TEST(FlightConfigTest, overrideConfig) { + std::unordered_map configMap = { + {FlightConfig::kAuthenticatorName, "my-authenticator"}, + {FlightConfig::kDefaultServerHost, "my-server-host"}, + {FlightConfig::kDefaultServerPort, "9000"}, + {FlightConfig::kDefaultServerSslEnabled, "true"}, + {FlightConfig::kServerVerify, "false"}, + {FlightConfig::kServerSslCertificate, "my-cert.crt"}}; + auto config = FlightConfig( + std::make_shared(std::move(configMap))); + ASSERT_EQ(config.authenticatorName(), "my-authenticator"); + ASSERT_EQ(config.defaultServerHostname(), "my-server-host"); + ASSERT_EQ(config.defaultServerPort(), 9000); + ASSERT_EQ(config.defaultServerSslEnabled(), true); + ASSERT_EQ(config.serverVerify(), false); + ASSERT_EQ(config.serverSslCertificate(), "my-cert.crt"); +} + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestFlightServerTest.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestFlightServerTest.cpp new file mode 100644 index 0000000000000..ceb335f89c329 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/TestFlightServerTest.cpp @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h" +#include "arrow/api.h" +#include "arrow/flight/api.h" +#include "arrow/testing/gtest_util.h" +#include "glog/logging.h" +#include "gtest/gtest.h" + +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h" + +namespace { + +using namespace facebook::presto::connector::arrow_flight::test; +using namespace arrow::flight; + +class TestFlightServerTest : public testing::Test { + public: + static void SetUpTestSuite() { + server = std::make_unique(); + ASSERT_OK_AND_ASSIGN(auto loc, Location::ForGrpcTcp("127.0.0.1", 0)); + ASSERT_OK(server->Init(FlightServerOptions(loc))); + } + + static void TearDownTestSuite() { + ASSERT_OK(server->Shutdown()); + } + + static void updateTable( + std::string name, + std::shared_ptr table) { + server->updateTable(std::move(name), std::move(table)); + } + + void SetUp() override { + ASSERT_OK_AND_ASSIGN( + auto loc, Location::ForGrpcTcp("localhost", server->port())); + ASSERT_OK_AND_ASSIGN(client, FlightClient::Connect(loc)); + } + + std::unique_ptr client; + static std::unique_ptr server; +}; + +std::unique_ptr TestFlightServerTest::server; + +TEST_F(TestFlightServerTest, basicTest) { + auto sampleTable = makeArrowTable( + {"id", "value"}, + {makeNumericArray({1, 2}), + makeNumericArray({41, 42})}); + updateTable("sample-data", sampleTable); + + ASSERT_RAISES(KeyError, client->DoGet(Ticket{"empty"})); + + auto emptyTable = makeArrowTable({}, {}); + updateTable("empty", emptyTable); + + ASSERT_RAISES(KeyError, client->DoGet(Ticket{"non-existent-table"})); + + ASSERT_OK_AND_ASSIGN(auto reader, client->DoGet(Ticket{"empty"})); + ASSERT_OK_AND_ASSIGN(auto actual, reader->ToTable()); + EXPECT_TRUE(actual->Equals(*emptyTable)); + + ASSERT_OK_AND_ASSIGN(reader, client->DoGet(Ticket{"sample-data"})); + ASSERT_OK_AND_ASSIGN(actual, reader->ToTable()); + EXPECT_TRUE(actual->Equals(*sampleTable)); + + server->removeTable("sample-data"); + ASSERT_RAISES(KeyError, client->DoGet(Ticket{"sample-data"})); +} + +} // namespace diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md new file mode 100644 index 0000000000000..bac4938fc47d5 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/README.md @@ -0,0 +1,6 @@ +### Placeholder TLS Certificates for Arrow Flight Connector Unit Testing +The `tls_certs` directory contains placeholder TLS certificates generated for unit testing the Arrow Flight Connector with TLS enabled. These certificates are not intended for production use and should only be used in the context of unit tests. + +### Generating TLS Certificates +To create the TLS certificates and keys inside the `tls_certs` folder, run the following command: +`./generate_tls_certs.sh` diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh new file mode 100755 index 0000000000000..718f313c70a75 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/generate_tls_certs.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Set directory for certificates and keys. +CERT_DIR="./tls_certs" +mkdir -p $CERT_DIR + +# Dummy values for the certificates. +COUNTRY="US" +STATE="State" +LOCALITY="City" +ORGANIZATION="MyOrg" +ORG_UNIT="MyUnit" +COMMON_NAME="MyCA" +SERVER_CN="server.mydomain.com" + +# Step 1: Generate CA private key and self-signed certificate. +openssl genpkey -algorithm RSA -out $CERT_DIR/ca.key +openssl req -key $CERT_DIR/ca.key -new -x509 -out $CERT_DIR/ca.crt -days 365000 \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$COMMON_NAME" + +# Step 2: Generate server private key. +openssl genpkey -algorithm RSA -out $CERT_DIR/server.key + +# Step 3: Generate server certificate signing request (CSR). +openssl req -new -key $CERT_DIR/server.key -out $CERT_DIR/server.csr \ + -subj "/C=$COUNTRY/ST=$STATE/L=$LOCALITY/O=$ORGANIZATION/OU=$ORG_UNIT/CN=$SERVER_CN" \ + -addext "subjectAltName=DNS:$COMMON_NAME,DNS:localhost" \ + +# Step 4: Sign server CSR with the CA certificate to generate the server certificate. +openssl x509 -req -in $CERT_DIR/server.csr -CA $CERT_DIR/ca.crt -CAkey $CERT_DIR/ca.key \ + -CAcreateserial -out $CERT_DIR/server.crt -days 365000 \ + -extfile <(printf "subjectAltName=DNS:$COMMON_NAME,DNS:localhost") + +# Step 5: Output the generated files. +echo "Certificate Authority (CA) certificate: $CERT_DIR/ca.crt" +echo "Server certificate: $CERT_DIR/server.crt" +echo "Server private key: $CERT_DIR/server.key" + +# Step 6: Remove unused files. +rm -rf $CERT_DIR/server.csr $CERT_DIR/ca.srl $CERT_DIR/ca.key diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt new file mode 100644 index 0000000000000..6740e89c54e17 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/ca.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDmzCCAoOgAwIBAgIUf+rP48iL39yGlAfFQTIp5bmM4uQwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI0MTIwMzExMDQxMVoYDzMwMjQwNDA1MTEwNDExWjBcMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxDTALBgNVBAMMBE15Q0EwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCliiXIcSmxXAAq2k/XjcZniDgEDCxWKZGiV8JBiJwY +MMBJtqcVzWfiDpO2u6d1dfGb6utlRW+1dnwupzURCMmZff4bqlPx4ZejRXDrWzKz +08WSpDVZwC2H5XOllwK36Cn4gvPRe3YWVcdDGHy7GL+zsJENvawJj0BH952MU4bk +sV52zEkN291bfN9sSYfT1NCJuLPM0Qsf97DeQ+wHXEw+t4XVMF3FQbciQp0y6CnA +wfFFN14WDiWxukP1I3kuDYYA6h/WJCQMp5rU2NCB9nIQrulYRxFaepMYENLxgAyj +gFaoRh2Kt2k7XKv6WOa6CmYm2dZERPlbA+oNAHkaHw6lAgMBAAGjUzBRMB0GA1Ud +DgQWBBSN+3vRlXGjs6c+rN94qgEnkPLl3DAfBgNVHSMEGDAWgBSN+3vRlXGjs6c+ +rN94qgEnkPLl3DAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAb +L40Oe2b/2xdUSyjqXJceVxaeA291fCpnu1C1JliP0hNI3fu9jjZhXHQoYub/4mod +8lriEDIcOCCiUfmi404akpqQHuBmOHaKEOtaaQkezjPsYnUra+O2ssqUo2zto5bK +gR0LGsb+4AO0bDvq+QVI6kEQqAAIf6qC+kpg/jV4iKJ1J6Qw4R3QppYBm6SQcfvI +hfUfDSO6SNfy0f/ZVCavbJIP9zG/BfAD9DEERocw03PiN5bm4IXJ3HH8rxyuBfJ5 +Eg/fPP5TlZ2H7Kqb3VgVBGWJtNXWmJphHyraBJTEuxgXWvl6AaW0P/3dsJi3rfdD +zDIT7AmENLCom8Gl0bgM +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt new file mode 100644 index 0000000000000..92c91f2d613b0 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.crt @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDtTCCAp2gAwIBAgIUUhmhZP94nIowrg2EarzfEBp6W1EwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MQ0wCwYDVQQDDARNeUNB +MCAXDTI0MTIwMzExMDQxMVoYDzMwMjQwNDA1MTEwNDExWjBrMQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFU3RhdGUxDTALBgNVBAcMBENpdHkxDjAMBgNVBAoMBU15T3Jn +MQ8wDQYDVQQLDAZNeVVuaXQxHDAaBgNVBAMME3NlcnZlci5teWRvbWFpbi5jb20w +ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDSxC4zCC4GFZbX+fdFgWbL +sj4PortyOM7mzRjNaQ3M0FTSEy5xET9C2qFlBCJ7AL7DlbSLmSckYY/FkdfMqNN2 ++NZ0Dy2d6bZN+ly5N/QBVnyS/5HVC3MXa6Y2BmFXiBnczWfGBwj+uVHlKOUWUyNi +EyUkhuPwtYXkFmJoqBxJSPC6cxX6NzMujnwCF18dUf0Vra44osu4moaovmg3c9jM +cBtmafFs9F54FoAEuLotjISVEa7VY6th5RxXJHpgas+0R5EBddGYKbTRiUYjht7r +pS+An0ey02oOjEWdqLnQSg/SUGKuRXULyE5l1A0HfNQtvepUQotb9ull1F7OrbfB +AgMBAAGjXjBcMBoGA1UdEQQTMBGCBE15Q0GCCWxvY2FsaG9zdDAdBgNVHQ4EFgQU +vnCLWjre4jqkKzC24psCPh1oIQwwHwYDVR0jBBgwFoAUjft70ZVxo7OnPqzfeKoB +J5Dy5dwwDQYJKoZIhvcNAQELBQADggEBAJCiJgtTw/7b/g3QnJfM4teQkFS440Ii +weqQJMoP6als8Fc3opPKv9eC5w0wqaLlIdwJjzGM5PmCAtGVafo22TbqhZyQdzQu +TUKv1DaVF0JBVAGVxTSDIK9r5Ww4mDAQnQENLC6soS3AvYDEi+8667YLoNNdhRCX +q2D5v76UN45idiShppxOw53whsvpHv+wyqcdse7DhgM9boCbx51Uvv3l/AEToyaj +S1xeIkBwNpSYU0ax2Lr1j2yoKbzAa3MHy8Php+T5CGji02+HwwlvlPDLtw8q5gHw +BLSwlAHgclPxUTWNNoCqjfX8Bi083+QDCLm0rgQ45xljNDbFAF1Y5hA= +-----END CERTIFICATE----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key new file mode 100644 index 0000000000000..2cdf5750a4753 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/data/tls_certs/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDSxC4zCC4GFZbX ++fdFgWbLsj4PortyOM7mzRjNaQ3M0FTSEy5xET9C2qFlBCJ7AL7DlbSLmSckYY/F +kdfMqNN2+NZ0Dy2d6bZN+ly5N/QBVnyS/5HVC3MXa6Y2BmFXiBnczWfGBwj+uVHl +KOUWUyNiEyUkhuPwtYXkFmJoqBxJSPC6cxX6NzMujnwCF18dUf0Vra44osu4moao +vmg3c9jMcBtmafFs9F54FoAEuLotjISVEa7VY6th5RxXJHpgas+0R5EBddGYKbTR +iUYjht7rpS+An0ey02oOjEWdqLnQSg/SUGKuRXULyE5l1A0HfNQtvepUQotb9ull +1F7OrbfBAgMBAAECggEAAxbZuuESGGAMMm9HLGhKHgbHU8gnv2Phdbrka+SYBYg5 +UYzTHLh3FwEsjd4VnaweJ7CN1WDb1NvWmTum/DCebJ1HKqtjKLAZfk8q2TLGmXdL +pzWOdQ8MX1fKP2sIlcl0kFbNCE8vprjneDyBLtqOK36eiAh/fl6BQ12QAMLjyv/L +OwXSY4ESs/RzxRzFgdT98cDZFL7y0FVIjJo/Q5lfW9UwwSfw8tOLNXKTYwPHqIfJ +NjfWD7IqztQlnanyRXv5dScp80i8p9qgH0i8YfVBHZDeOmHGLcltilLRZ0dQ/X0g +Lrr0aIO3iLhmTIkJRzUnGeyvDjxcPINvRSBBwXy04QKBgQDpFJa/EwSsWj8586oh +xgm0Z3q+FiEeCe7aLLPcXAS2EDvix5ibJDT2y1Aa/kXq25S53npa/7Ov6TJs5H4g +eyshDtR1wVhz+rIggREiX/sagkhwnNsssUZFv5t9PdnaFXpVnH49m5Qc8HO3owtN +t8EGSRcAQ4o/fLWLs51qd38cIQKBgQDnfd8YPyDQ03xDC/3+Qrypyc/xhGnCuj7w +ZeA5iEyTnnNxL0a0B6PWcSk2BZReMNQKgYtipnsOQKtwHMttxtXYs/VQpeB4KoWE +zEwW0fV3MMsXN+nVJlEZnVaTbmYXknjeZrh/rNjsY96yxw8NtvAuYSpnqtr3N2nd +iMQ3G/QnoQKBgGMi+bdNvIgeXpQkmrGAzTHpbaCaQv3G1cwAhYPts6dIomAj6znZ +nZl3ApxomI57VPf1s+8uoVvqASOl0Cu6l66Y4y8uzJOQBuGiZApN7rzouy0C2opY +4H3cMKOFgjqrNfxh8qP7n3TrpRxvgehNhxFIVzsqfwvf3EwOWp8lMnBhAoGAZ25E +Ge9K2ENGCCb5i3uCFFLJiF3ja1AQAxVhxBL0NBjd97pp2tJ3D79r7Gk9y4ABndgX +0TIVVV7ruqIC+r+WmMZ/W1NiIg7NrXIipSeWh3TTqUIgRk5iehFkt2biUrHtM2Gu +Gc2+9pAA1tw+C6CrW+2qJrueLksiEAulsAHba0ECgYBIgIiY+Gx+XecEgCwAhWcn +GzNDAAlA4IgBjHpUtIByflzQDqlECKXjPbVBKfyq6eLt40upFmQCLsn+AkiQau8A +3cFAK9wJOAHv9KuWDrbHyhRE9CrJ6BqsY2goC3LiFCTgJy1TrRl6CDaFzHivONwF +LNPflYk5s376UWqxC+HtIA== +-----END PRIVATE KEY----- diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt new file mode 100644 index 0000000000000..a285b9381345c --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/CMakeLists.txt @@ -0,0 +1,19 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +add_library( + presto_flight_connector_test_lib + TestFlightServer.cpp FlightConnectorTestBase.cpp Utils.cpp + FlightPlanBuilder.cpp) + +target_link_libraries( + presto_flight_connector_test_lib arrow presto_flight_connector + velox_exception presto_flight_connector_utils velox_exec_test_lib) diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.cpp new file mode 100644 index 0000000000000..2fff1b28c9f38 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.cpp @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h" +#include "arrow/testing/gtest_util.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" + +namespace facebook::presto::connector::arrow_flight::test { + +namespace connector = velox::connector; +namespace core = velox::core; + +using namespace arrow::flight; +using velox::exec::test::OperatorTestBase; + +void FlightConnectorTestBase::SetUp() { + OperatorTestBase::SetUp(); + + if (!velox::connector::hasConnectorFactory( + presto::connector::arrow_flight::ArrowFlightConnectorFactory:: + kArrowFlightConnectorName)) { + connector::registerConnectorFactory( + std::make_shared< + presto::connector::arrow_flight::ArrowFlightConnectorFactory>()); + } + connector::registerConnector( + connector::getConnectorFactory( + ArrowFlightConnectorFactory::kArrowFlightConnectorName) + ->newConnector(kFlightConnectorId, config_)); +} + +void FlightConnectorTestBase::TearDown() { + connector::unregisterConnector(kFlightConnectorId); + OperatorTestBase::TearDown(); +} + +void FlightWithServerTestBase::SetUp() { + FlightConnectorTestBase::SetUp(); + + FlightServerOptions serverOptions(getServerLocation()); + server_ = std::make_unique(); + setFlightServerOptions(&serverOptions); + ASSERT_OK(server_->Init(serverOptions)); +} + +void FlightWithServerTestBase::TearDown() { + ASSERT_OK(server_->Shutdown()); + FlightConnectorTestBase::TearDown(); +} + +Location FlightWithServerTestBase::getServerLocation() { + AFC_ASSIGN_OR_RAISE(auto loc, Location::ForGrpcTcp(BIND_HOST, LISTEN_PORT)); + return loc; +} + +std::vector> +FlightWithServerTestBase::makeSplits( + const std::initializer_list& tickets, + const std::vector& locations) { + std::vector> splits; + splits.reserve(tickets.size()); + for (auto& ticket : tickets) { + FlightEndpoint flightEndpoint; + flightEndpoint.ticket.ticket = ticket; + flightEndpoint.locations = locations; + AFC_ASSIGN_OR_RAISE( + auto flightEndpointBytes, flightEndpoint.SerializeToString()); + splits.push_back( + std::make_shared(kFlightConnectorId, flightEndpointBytes)); + } + return splits; +} + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h new file mode 100644 index 0000000000000..81c041e3c325a --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightConnectorTestBase.h @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "arrow/flight/api.h" +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h" +#include "velox/common/config/Config.h" +#include "velox/connectors/Connector.h" +#include "velox/exec/tests/utils/OperatorTestBase.h" + +namespace facebook::presto::connector::arrow_flight::test { + +static const std::string kFlightConnectorId = "test-flight"; + +class FlightConnectorTestBase : public velox::exec::test::OperatorTestBase { + public: + void SetUp() override; + + void TearDown() override; + + protected: + explicit FlightConnectorTestBase( + std::shared_ptr config) + : config_{std::move(config)} {} + + FlightConnectorTestBase() + : config_{std::make_shared( + std::move(std::unordered_map{}))} {} + + protected: + std::shared_ptr config_; +}; + +/// Creates and registers an Arrow Flight connector and +/// spawns a Flight server for testing. +/// Initially there is no data in the Flight server, +/// tests should call FlightWithServerTestBase::updateTables to populate it. +class FlightWithServerTestBase : public FlightConnectorTestBase { + public: + static constexpr const char* BIND_HOST = "127.0.0.1"; + static constexpr const char* CONNECT_HOST = "localhost"; + constexpr static int LISTEN_PORT = 5000; + + void SetUp() override; + + void TearDown() override; + + /// Convenience method which creates splits for the test flight server + static std::vector> + makeSplits( + const std::initializer_list& tokens, + const std::vector& locations = + std::vector{ + *arrow::flight::Location::ForGrpcTcp(CONNECT_HOST, LISTEN_PORT)}); + + /// Add (or update) a table in the test flight server + void updateTable(std::string name, std::shared_ptr table) { + server_->updateTable(std::move(name), std::move(table)); + } + + virtual arrow::flight::Location getServerLocation(); + + virtual void setFlightServerOptions( + arrow::flight::FlightServerOptions* serverOptions) {} + + protected: + explicit FlightWithServerTestBase( + std::shared_ptr config) + : FlightConnectorTestBase{std::move(config)} {} + + FlightWithServerTestBase() : FlightConnectorTestBase() {} + + private: + std::unique_ptr server_; +}; + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.cpp new file mode 100644 index 0000000000000..42194df9c890e --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.cpp @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h" +#include "presto_cpp/main/connectors/arrow_flight/ArrowFlightConnector.h" + +namespace facebook::presto::connector::arrow_flight::test { + +static const std::string kFlightConnectorId = "test-flight"; + +velox::exec::test::PlanBuilder& FlightPlanBuilder::flightTableScan( + const velox::RowTypePtr& outputType, + std::unordered_map< + std::string, + std::shared_ptr> assignments, + bool createDefaultColumnHandles) { + if (createDefaultColumnHandles) { + for (const auto& name : outputType->names()) { + // Provide unaliased defaults for unmapped columns. + // `emplace` won't modify the map if the key already exists, + // so existing aliases are kept. + assignments.emplace(name, std::make_shared(name)); + } + } + + return startTableScan() + .tableHandle(std::make_shared(kFlightConnectorId)) + .outputType(outputType) + .assignments(std::move(assignments)) + .endTableScan(); +} + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h new file mode 100644 index 0000000000000..bfc75c3704585 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/FlightPlanBuilder.h @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/exec/tests/utils/PlanBuilder.h" + +namespace facebook::presto::connector::arrow_flight::test { + +class FlightPlanBuilder : public velox::exec::test::PlanBuilder { + public: + /// @brief Add a table scan node to the Plan, using the Flight connector + /// @param outputType The output type of the table scan node + /// @param assignments mapping from the column aliases to real column handles + /// @param createDefaultColumnHandles If true, generate column handles for + /// for the columns which don't have an entry in assignments + velox::exec::test::PlanBuilder& flightTableScan( + const velox::RowTypePtr& outputType, + std::unordered_map< + std::string, + std::shared_ptr> assignments = {}, + bool createDefaultColumnHandles = true); +}; + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.cpp new file mode 100644 index 0000000000000..672ea16730fc8 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.cpp @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h" + +namespace facebook::presto::connector::arrow_flight::test { + +using namespace arrow::flight; + +arrow::Status TestFlightServer::DoGet( + const ServerCallContext& context, + const Ticket& request, + std::unique_ptr* stream) { + auto it = tables_.find(request.ticket); + if (it == tables_.end()) { + return arrow::Status::KeyError("requested table does not exist"); + } + auto& table = it->second; + auto reader = std::make_shared(table); + *stream = std::make_unique(std::move(reader)); + return arrow::Status::OK(); +} + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h new file mode 100644 index 0000000000000..f3f985465a8d4 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/TestFlightServer.h @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "arrow/api.h" +#include "arrow/flight/api.h" + +namespace facebook::presto::connector::arrow_flight::test { + +/// Test Flight server which supports DoGet operations. +/// Maintains a list of named arrow tables, +/// +/// Normally, the tickets would be obtained by calling GetFlightInfo, +/// but since this is done by the coordinator this part is omitted. +/// Instead, the ticket is simply the name of the table to fetch. +class TestFlightServer : public arrow::flight::FlightServerBase { + public: + TestFlightServer() = default; + + void updateTable(std::string name, std::shared_ptr table) { + tables_.emplace(std::move(name), std::move(table)); + } + + void removeTable(const std::string& name) { + tables_.erase(name); + } + + arrow::Status DoGet( + const arrow::flight::ServerCallContext& context, + const arrow::flight::Ticket& request, + std::unique_ptr* stream) override; + + private: + std::unordered_map> tables_; +}; + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp new file mode 100644 index 0000000000000..4b63b8822d45b --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.cpp @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "Utils.h" +#include +#include + +namespace facebook::presto::connector::arrow_flight::test { + +std::shared_ptr makeDecimalArray( + const std::vector& decimalValues, + int precision, + int scale) { + auto decimalType = arrow::decimal(precision, scale); + auto builder = + arrow::Decimal128Builder(decimalType, arrow::default_memory_pool()); + + for (const auto& value : decimalValues) { + arrow::Decimal128 dec(value); + AFC_RAISE_NOT_OK(builder.Append(dec)); + } + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeTimestampArray( + const std::vector& values, + arrow::TimeUnit::type timeUnit, + arrow::MemoryPool* memory_pool) { + arrow::TimestampBuilder builder(arrow::timestamp(timeUnit), memory_pool); + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeStringArray( + const std::vector& values) { + auto builder = arrow::StringBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeBooleanArray( + const std::vector& values) { + auto builder = arrow::BooleanBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +auto makeRecordBatch( + const std::vector& names, + const arrow::ArrayVector& arrays) { + VELOX_CHECK_EQ(names.size(), arrays.size()); + + auto nrows = (!arrays.empty()) ? (arrays[0]->length()) : 0; + arrow::FieldVector fields{}; + for (int i = 0; i < arrays.size(); i++) { + VELOX_CHECK_EQ(arrays[i]->length(), nrows); + fields.push_back( + std::make_shared(names[i], arrays[i]->type())); + } + + auto schema = arrow::schema(fields); + return arrow::RecordBatch::Make(schema, nrows, arrays); +} + +std::shared_ptr makeArrowTable( + const std::vector& names, + const arrow::ArrayVector& arrays) { + AFC_RETURN_OR_RAISE( + arrow::Table::FromRecordBatches({makeRecordBatch(names, arrays)})); +} + +std::string readFile(const std::string& path) { + std::ifstream file(path); + VELOX_CHECK( + file.is_open(), "Could not open file \"{}\": {}", path, strerror(errno)); + return { + std::istreambuf_iterator(file), std::istreambuf_iterator()}; +} + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h new file mode 100644 index 0000000000000..b092d08b02170 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/connectors/arrow_flight/tests/utils/Utils.h @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "arrow/api.h" +#include "presto_cpp/main/connectors/arrow_flight/Macros.h" +#include "velox/common/base/Exceptions.h" + +namespace facebook::presto::connector::arrow_flight::test { + +template +auto makeNumericArray(const std::vector& values) { + auto builder = arrow::NumericBuilder{}; + AFC_RAISE_NOT_OK(builder.AppendValues(values)); + AFC_RETURN_OR_RAISE(builder.Finish()); +} + +std::shared_ptr makeDecimalArray( + const std::vector& decimalValues, + int precision, + int scale); + +std::shared_ptr makeTimestampArray( + const std::vector& values, + arrow::TimeUnit::type timeUnit, + arrow::MemoryPool* memory_pool = arrow::default_memory_pool()); + +std::shared_ptr makeStringArray( + const std::vector& values); + +std::shared_ptr makeBooleanArray(const std::vector& values); + +auto makeRecordBatch( + const std::vector& names, + const arrow::ArrayVector& arrays); + +std::shared_ptr makeArrowTable( + const std::vector& names, + const arrow::ArrayVector& arrays); + +std::string readFile(const std::string& path); + +} // namespace facebook::presto::connector::arrow_flight::test diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp index 9cc3568787e2b..d0492ab957e2a 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.cpp @@ -1099,7 +1099,8 @@ std::unique_ptr HivePrestoToVeloxConnector::toVeloxSplit( const protocol::ConnectorId& catalogId, const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const { + const protocol::SplitContext* splitContext, + const std::map& extraCredentials) const { auto hiveSplit = dynamic_cast(connectorSplit); VELOX_CHECK_NOT_NULL( @@ -1335,7 +1336,8 @@ std::unique_ptr IcebergPrestoToVeloxConnector::toVeloxSplit( const protocol::ConnectorId& catalogId, const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const { + const protocol::SplitContext* splitContext, + const std::map& extraCredentials) const { auto icebergSplit = dynamic_cast(connectorSplit); VELOX_CHECK_NOT_NULL( @@ -1488,7 +1490,8 @@ std::unique_ptr TpchPrestoToVeloxConnector::toVeloxSplit( const protocol::ConnectorId& catalogId, const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const { + const protocol::SplitContext* splitContext, + const std::map& extraCredentials) const { auto tpchSplit = dynamic_cast(connectorSplit); VELOX_CHECK_NOT_NULL( diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h index eb33dfb54ca1d..15370ec5154f8 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxConnector.h @@ -46,7 +46,9 @@ class PrestoToVeloxConnector { toVeloxSplit( const protocol::ConnectorId& catalogId, const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const = 0; + const protocol::SplitContext* splitContext, + const std::map& extraCredentials = {}) + const = 0; [[nodiscard]] virtual std::unique_ptr toVeloxColumnHandle( @@ -117,7 +119,9 @@ class HivePrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxSplit( const protocol::ConnectorId& catalogId, const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const final; + const protocol::SplitContext* splitContext, + const std::map& extraCredentials = {}) + const final; std::unique_ptr toVeloxColumnHandle( const protocol::ColumnHandle* column, @@ -169,7 +173,9 @@ class IcebergPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxSplit( const protocol::ConnectorId& catalogId, const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const final; + const protocol::SplitContext* splitContext, + const std::map& extraCredentials = {}) + const final; std::unique_ptr toVeloxColumnHandle( const protocol::ColumnHandle* column, @@ -196,7 +202,9 @@ class TpchPrestoToVeloxConnector final : public PrestoToVeloxConnector { std::unique_ptr toVeloxSplit( const protocol::ConnectorId& catalogId, const protocol::ConnectorSplit* connectorSplit, - const protocol::SplitContext* splitContext) const final; + const protocol::SplitContext* splitContext, + const std::map& extraCredentials = {}) + const final; std::unique_ptr toVeloxColumnHandle( const protocol::ColumnHandle* column, diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp index 1d11be2e904fc..4bc1c1465e016 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.cpp @@ -20,7 +20,8 @@ using namespace facebook::velox; namespace facebook::presto { velox::exec::Split toVeloxSplit( - const presto::protocol::ScheduledSplit& scheduledSplit) { + const presto::protocol::ScheduledSplit& scheduledSplit, + const std::map& extraCredentials) { const auto& connectorSplit = scheduledSplit.split.connectorSplit; const auto splitGroupId = scheduledSplit.split.lifespan.isgroup ? scheduledSplit.split.lifespan.groupid @@ -42,7 +43,8 @@ velox::exec::Split toVeloxSplit( auto veloxSplit = connector.toVeloxSplit( scheduledSplit.split.connectorId, connectorSplit.get(), - &scheduledSplit.split.splitContext); + &scheduledSplit.split.splitContext, + extraCredentials); return velox::exec::Split(std::move(veloxSplit), splitGroupId); } diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.h index ff8692c2700a0..5b93a767ce029 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxSplit.h @@ -21,6 +21,7 @@ namespace facebook::presto { // Creates and returns exec::Split (with connector::ConnectorSplit inside) based // on the given protocol split. velox::exec::Split toVeloxSplit( - const presto::protocol::ScheduledSplit& scheduledSplit); + const presto::protocol::ScheduledSplit& scheduledSplit, + const std::map& extraCredentials = {}); } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/presto_protocol/Makefile b/presto-native-execution/presto_cpp/presto_protocol/Makefile index 3ee2b4e802b81..09b43df28b4f5 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/Makefile +++ b/presto-native-execution/presto_cpp/presto_protocol/Makefile @@ -45,14 +45,23 @@ presto_protocol-cpp: presto_protocol-json chevron -d connector/tpch/presto_protocol_tpch.json connector/tpch/presto_protocol-json-hpp.mustache >> connector/tpch/presto_protocol_tpch.h clang-format -style=file -i connector/tpch/presto_protocol_tpch.h connector/tpch/presto_protocol_tpch.cpp + # build arrow_flight connector related structs + echo "// DO NOT EDIT : This file is generated by chevron" > connector/arrow_flight/presto_protocol_arrow_flight.cpp + chevron -d connector/arrow_flight/presto_protocol_arrow_flight.json connector/arrow_flight/presto_protocol-json-cpp.mustache >> connector/arrow_flight/presto_protocol_arrow_flight.cpp + echo "// DO NOT EDIT : This file is generated by chevron" > connector/arrow_flight/presto_protocol_arrow_flight.h + chevron -d connector/arrow_flight/presto_protocol_arrow_flight.json connector/arrow_flight/presto_protocol-json-hpp.mustache >> connector/arrow_flight/presto_protocol_arrow_flight.h + clang-format -style=file -i connector/arrow_flight/presto_protocol_arrow_flight.h connector/arrow_flight/presto_protocol_arrow_flight.cpp + presto_protocol-json: ./java-to-struct-json.py --config core/presto_protocol_core.yml core/special/*.java core/special/*.inc -j | jq . > core/presto_protocol_core.json ./java-to-struct-json.py --config connector/hive/presto_protocol_hive.yml connector/hive/special/*.inc -j | jq . > connector/hive/presto_protocol_hive.json ./java-to-struct-json.py --config connector/iceberg/presto_protocol_iceberg.yml connector/iceberg/special/*.inc -j | jq . > connector/iceberg/presto_protocol_iceberg.json ./java-to-struct-json.py --config connector/tpch/presto_protocol_tpch.yml connector/tpch/special/*.inc -j | jq . > connector/tpch/presto_protocol_tpch.json + ./java-to-struct-json.py --config connector/arrow_flight/presto_protocol_arrow_flight.yml connector/arrow_flight/special/*.inc -j | jq . > connector/arrow_flight/presto_protocol_arrow_flight.json presto_protocol.proto: presto_protocol-json pystache presto_protocol-protobuf.mustache core/presto_protocol_core.json > core/presto_protocol_core.proto pystache presto_protocol-protobuf.mustache connector/hive/presto_protocol_hive.json > connector/hive/presto_protocol_hive.proto pystache presto_protocol-protobuf.mustache connector/iceberg/presto_protocol_iceberg.json > connector/iceberg/presto_protocol_iceberg.proto pystache presto_protocol-protobuf.mustache connector/tpch/presto_protocol_tpch.json > connector/tpch/presto_protocol_tpch.proto + pystache presto_protocol-protobuf.mustache connector/arrow_flight/presto_protocol_arrow_flight.json > connector/arrow_flight/presto_protocol_arrow_flight.proto diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h new file mode 100644 index 0000000000000..95cda16115695 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/ArrowFlightConnectorProtocol.h @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +#include "presto_cpp/presto_protocol/core/ConnectorProtocol.h" + +namespace facebook::presto::protocol::arrow_flight { +using ArrowConnectorProtocol = ConnectorProtocolTemplate< + ArrowTableHandle, + ArrowTableLayoutHandle, + ArrowColumnHandle, + NotImplemented, + NotImplemented, + ArrowSplit, + NotImplemented, + ArrowTransactionHandle, + NotImplemented>; +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache new file mode 100644 index 0000000000000..b6ecb68507285 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-cpp.mustache @@ -0,0 +1,150 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// presto_protocol.prolog.cpp +// + +{{#.}} +{{#comment}} +{{comment}} +{{/comment}} +{{/.}} + +#include + +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +using namespace std::string_literals; + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight +{{#.}} +{{#cinc}} +{{&cinc}} +{{/cinc}} +{{^cinc}} +{{#struct}} +namespace facebook::presto::protocol::arrow_flight { + {{#super_class}} + {{&class_name}}::{{&class_name}}() noexcept { + _type = "{{json_key}}"; + } + {{/super_class}} + + void to_json(json& j, const {{&class_name}}& p) { + j = json::object(); + {{#super_class}} + j["@type"] = "{{&json_key}}"; + {{/super_class}} + {{#fields}} + to_json_key(j, "{{&field_name}}", p.{{field_name}}, "{{&class_name}}", "{{&field_text}}", "{{&field_name}}"); + {{/fields}} + } + + void from_json(const json& j, {{&class_name}}& p) { + {{#super_class}} + p._type = j["@type"]; + {{/super_class}} + {{#fields}} + from_json_key(j, "{{&field_name}}", p.{{field_name}}, "{{&class_name}}", "{{&field_text}}", "{{&field_name}}"); + {{/fields}} + } +} +{{/struct}} +{{#enum}} +namespace facebook::presto::protocol::arrow_flight { + //Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() + + // NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays + static const std::pair<{{&class_name}}, json> + {{&class_name}}_enum_table[] = { // NOLINT: cert-err58-cpp + {{#elements}} + { {{&class_name}}::{{&element}}, "{{&element}}" }{{^_last}},{{/_last}} + {{/elements}} + }; + void to_json(json& j, const {{&class_name}}& e) + { + static_assert(std::is_enum<{{&class_name}}>::value, "{{&class_name}} must be an enum!"); + const auto* it = std::find_if(std::begin({{&class_name}}_enum_table), std::end({{&class_name}}_enum_table), + [e](const std::pair<{{&class_name}}, json>& ej_pair) -> bool + { + return ej_pair.first == e; + }); + j = ((it != std::end({{&class_name}}_enum_table)) ? it : std::begin({{&class_name}}_enum_table))->second; + } + void from_json(const json& j, {{&class_name}}& e) + { + static_assert(std::is_enum<{{&class_name}}>::value, "{{&class_name}} must be an enum!"); + const auto* it = std::find_if(std::begin({{&class_name}}_enum_table), std::end({{&class_name}}_enum_table), + [&j](const std::pair<{{&class_name}}, json>& ej_pair) -> bool + { + return ej_pair.second == j; + }); + e = ((it != std::end({{&class_name}}_enum_table)) ? it : std::begin({{&class_name}}_enum_table))->first; + } +} +{{/enum}} +{{#abstract}} +namespace facebook::presto::protocol::arrow_flight { + void to_json(json& j, const std::shared_ptr<{{&class_name}}>& p) { + if ( p == nullptr ) { + return; + } + String type = p->_type; + + {{#subclasses}} + if ( type == "{{&key}}" ) { + j = *std::static_pointer_cast<{{&type}}>(p); + return; + } + {{/subclasses}} + + throw TypeError(type + " no abstract type {{&class_name}} {{&key}}"); + } + + void from_json(const json& j, std::shared_ptr<{{&class_name}}>& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error &e) { + throw ParseError(std::string(e.what()) + " {{&class_name}} {{&key}} {{&class_name}}"); + } + + {{#subclasses}} + if ( type == "{{&key}}" ) { + std::shared_ptr<{{&type}}> k = std::make_shared<{{&type}}>(); + j.get_to(*k); + p = std::static_pointer_cast<{{&class_name}}>(k); + return; + } + {{/subclasses}} + + throw TypeError(type + " no abstract type {{&class_name}} {{&key}}"); + } +} +{{/abstract}} +{{/cinc}} +{{/.}} diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache new file mode 100644 index 0000000000000..be08bd9e491c2 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol-json-hpp.mustache @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +{{#.}} +{{#comment}} +{{comment}} +{{/comment}} +{{/.}} + +#include +#include +#include +#include + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight +{{#.}} +{{#hinc}} +{{&hinc}} +{{/hinc}} +{{^hinc}} +{{#struct}} +namespace facebook::presto::protocol::arrow_flight { + struct {{class_name}} {{#super_class}}: public {{super_class}}{{/super_class}}{ + {{#fields}} + {{#field_local}}{{#optional}}std::shared_ptr<{{/optional}}{{&field_text}}{{#optional}}>{{/optional}} {{&field_name}} = {};{{/field_local}} + {{/fields}} + + {{#super_class}} + {{class_name}}() noexcept; + {{/super_class}} + }; + void to_json(json& j, const {{class_name}}& p); + void from_json(const json& j, {{class_name}}& p); +} +{{/struct}} +{{#enum}} +namespace facebook::presto::protocol::arrow_flight { + enum class {{class_name}} { + {{#elements}} + {{&element}}{{^_last}},{{/_last}} + {{/elements}} + }; + extern void to_json(json& j, const {{class_name}}& e); + extern void from_json(const json& j, {{class_name}}& e); +} +{{/enum}} +{{/hinc}} +{{/.}} diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp new file mode 100644 index 0000000000000..e5b5cf2f9ae3b --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp @@ -0,0 +1,215 @@ +// DO NOT EDIT : This file is generated by chevron +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// presto_protocol.prolog.cpp +// + +// This file is generated DO NOT EDIT @generated + +#include + +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" +using namespace std::string_literals; + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowColumnHandle::ArrowColumnHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowColumnHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, + "columnName", + p.columnName, + "ArrowColumnHandle", + "String", + "columnName"); + to_json_key( + j, "columnType", p.columnType, "ArrowColumnHandle", "Type", "columnType"); +} + +void from_json(const json& j, ArrowColumnHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "columnName", + p.columnName, + "ArrowColumnHandle", + "String", + "columnName"); + from_json_key( + j, "columnType", p.columnType, "ArrowColumnHandle", "Type", "columnType"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowSplit::ArrowSplit() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowSplit& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, "schemaName", p.schemaName, "ArrowSplit", "String", "schemaName"); + to_json_key(j, "tableName", p.tableName, "ArrowSplit", "String", "tableName"); + to_json_key( + j, + "flightEndpointBytes", + p.flightEndpointBytes, + "ArrowSplit", + "String", + "flightEndpointBytes"); +} + +void from_json(const json& j, ArrowSplit& p) { + p._type = j["@type"]; + from_json_key( + j, "schemaName", p.schemaName, "ArrowSplit", "String", "schemaName"); + from_json_key( + j, "tableName", p.tableName, "ArrowSplit", "String", "tableName"); + from_json_key( + j, + "flightEndpointBytes", + p.flightEndpointBytes, + "ArrowSplit", + "String", + "flightEndpointBytes"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowTableHandle::ArrowTableHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowTableHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key(j, "schema", p.schema, "ArrowTableHandle", "String", "schema"); + to_json_key(j, "table", p.table, "ArrowTableHandle", "String", "table"); +} + +void from_json(const json& j, ArrowTableHandle& p) { + p._type = j["@type"]; + from_json_key(j, "schema", p.schema, "ArrowTableHandle", "String", "schema"); + from_json_key(j, "table", p.table, "ArrowTableHandle", "String", "table"); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +void to_json(json& j, const std::shared_ptr& p) { + if (p == nullptr) { + return; + } + String type = p->_type; + + if (type == "arrow-flight") { + j = *std::static_pointer_cast(p); + return; + } + + throw TypeError(type + " no abstract type ColumnHandle "); +} + +void from_json(const json& j, std::shared_ptr& p) { + String type; + try { + type = p->getSubclassKey(j); + } catch (json::parse_error& e) { + throw ParseError(std::string(e.what()) + " ColumnHandle ColumnHandle"); + } + + if (type == "arrow-flight") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } + + throw TypeError(type + " no abstract type ColumnHandle "); +} +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +ArrowTableLayoutHandle::ArrowTableLayoutHandle() noexcept { + _type = "arrow-flight"; +} + +void to_json(json& j, const ArrowTableLayoutHandle& p) { + j = json::object(); + j["@type"] = "arrow-flight"; + to_json_key( + j, + "table", + p.table, + "ArrowTableLayoutHandle", + "ArrowTableHandle", + "table"); + to_json_key( + j, + "columnHandles", + p.columnHandles, + "ArrowTableLayoutHandle", + "List", + "columnHandles"); + to_json_key( + j, + "tupleDomain", + p.tupleDomain, + "ArrowTableLayoutHandle", + "TupleDomain>", + "tupleDomain"); +} + +void from_json(const json& j, ArrowTableLayoutHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "table", + p.table, + "ArrowTableLayoutHandle", + "ArrowTableHandle", + "table"); + from_json_key( + j, + "columnHandles", + p.columnHandles, + "ArrowTableLayoutHandle", + "List", + "columnHandles"); + from_json_key( + j, + "tupleDomain", + p.tupleDomain, + "ArrowTableLayoutHandle", + "TupleDomain>", + "tupleDomain"); +} +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h new file mode 100644 index 0000000000000..2a9cb81d00b47 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h @@ -0,0 +1,82 @@ +// DO NOT EDIT : This file is generated by chevron +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +// This file is generated DO NOT EDIT @generated + +#include +#include +#include +#include + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowColumnHandle : public ColumnHandle { + String columnName = {}; + Type columnType = {}; + + ArrowColumnHandle() noexcept; +}; +void to_json(json& j, const ArrowColumnHandle& p); +void from_json(const json& j, ArrowColumnHandle& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowSplit : public ConnectorSplit { + String schemaName = {}; + String tableName = {}; + String flightEndpointBytes = {}; + + ArrowSplit() noexcept; +}; +void to_json(json& j, const ArrowSplit& p); +void from_json(const json& j, ArrowSplit& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowTableHandle : public ConnectorTableHandle { + String schema = {}; + String table = {}; + + ArrowTableHandle() noexcept; +}; +void to_json(json& j, const ArrowTableHandle& p); +void from_json(const json& j, ArrowTableHandle& p); +} // namespace facebook::presto::protocol::arrow_flight +namespace facebook::presto::protocol::arrow_flight { +struct ArrowTableLayoutHandle : public ConnectorTableLayoutHandle { + ArrowTableHandle table = {}; + List columnHandles = {}; + TupleDomain> tupleDomain = {}; + + ArrowTableLayoutHandle() noexcept; +}; +void to_json(json& j, const ArrowTableLayoutHandle& p); +void from_json(const json& j, ArrowTableLayoutHandle& p); +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml new file mode 100644 index 0000000000000..f34f6068eb777 --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.yml @@ -0,0 +1,40 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +AbstractClasses: + ColumnHandle: + super: JsonEncodedSubclass + comparable: true + subclasses: + - { name: ArrowColumnHandle, key: arrow-flight } + + ConnectorTableHandle: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowTableHandle, key: arrow-flight } + + ConnectorTableLayoutHandle: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowTableLayoutHandle, key: arrow-flight } + + ConnectorSplit: + super: JsonEncodedSubclass + subclasses: + - { name: ArrowSplit, key: arrow-flight } + +JavaClasses: + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc new file mode 100644 index 0000000000000..a93325f5b154a --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.cpp.inc @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +void to_json(json& j, const ArrowTransactionHandle& p) { + j = json::array(); + j.push_back(p._type); + j.push_back(p.instance); +} + +void from_json(const json& j, ArrowTransactionHandle& p) { + j[0].get_to(p._type); + j[1].get_to(p.instance); +} +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc new file mode 100644 index 0000000000000..dc573ca2e68cf --- /dev/null +++ b/presto-native-execution/presto_cpp/presto_protocol/connector/arrow_flight/special/ArrowTransactionHandle.hpp.inc @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// ArrowTransactionHandle is special since +// the corresponding class in Java is an enum. + +namespace facebook::presto::protocol::arrow_flight { + +struct ArrowTransactionHandle : public ConnectorTransactionHandle { + String instance = {}; +}; + +void to_json(json& j, const ArrowTransactionHandle& p); + +void from_json(const json& j, ArrowTransactionHandle& p); + +} // namespace facebook::presto::protocol::arrow_flight diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 3dcff5034cdd5..8c22eacb56f44 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -1069,6 +1069,7 @@ void from_json(const json& j, std::shared_ptr& p) { */ // dependency TpchTransactionHandle +// dependency ArrowTransactionHandle namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index f429324ac8d57..e7235117082a9 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -77,7 +77,7 @@ class TypeError : public Exception { class OutOfRange : public Exception { public: - explicit OutOfRange(const std::string& message) : Exception(message){}; + explicit OutOfRange(const std::string& message) : Exception(message) {}; }; class ParseError : public Exception { public: diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml index ae0222da2e67a..a208a25168ab2 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml @@ -54,6 +54,7 @@ AbstractClasses: - { name: IcebergColumnHandle, key: hive-iceberg } - { name: TpchColumnHandle, key: tpch } - { name: SystemColumnHandle, key: $system@system } + - { name: ArrowColumnHandle, key: arrow-flight } ConnectorPartitioningHandle: super: JsonEncodedSubclass @@ -69,6 +70,7 @@ AbstractClasses: - { name: IcebergTableHandle, key: hive-iceberg } - { name: TpchTableHandle, key: tpch } - { name: SystemTableHandle, key: $system@system } + - { name: ArrowTableHandle, key: arrow-flight } ConnectorOutputTableHandle: super: JsonEncodedSubclass @@ -96,6 +98,7 @@ AbstractClasses: - { name: IcebergTableLayoutHandle, key: hive-iceberg } - { name: TpchTableLayoutHandle, key: tpch } - { name: SystemTableLayoutHandle, key: $system@system } + - { name: ArrowTableLayoutHandle, key: arrow-flight } ConnectorMetadataUpdateHandle: super: JsonEncodedSubclass @@ -111,6 +114,7 @@ AbstractClasses: - { name: RemoteSplit, key: $remote } - { name: EmptySplit, key: $empty } - { name: SystemSplit, key: $system@system } + - { name: ArrowSplit, key: arrow-flight } ConnectorHistogram: super: JsonEncodedSubclass diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc index 8ec2a94e84bd9..1dfb17e4a908f 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc +++ b/presto-native-execution/presto_cpp/presto_protocol/core/special/ConnectorTransactionHandle.cpp.inc @@ -13,6 +13,7 @@ */ // dependency TpchTransactionHandle +// dependency ArrowTransactionHandle namespace facebook::presto::protocol { void to_json(json& j, const std::shared_ptr& p) { diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp index c15084817a434..24f24f27f87a3 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.cpp @@ -15,6 +15,7 @@ // DEPRECATED: This file is deprecated and will be removed in future versions. +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.cpp" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.cpp" #include "presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.cpp" #include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.cpp" diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h index dd94975e3760d..c43ec92629f44 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.h @@ -16,6 +16,7 @@ // DEPRECATED: This file is deprecated and will be removed in future versions. +#include "presto_cpp/presto_protocol/connector/arrow_flight/presto_protocol_arrow_flight.h" #include "presto_cpp/presto_protocol/connector/hive/presto_protocol_hive.h" #include "presto_cpp/presto_protocol/connector/iceberg/presto_protocol_iceberg.h" #include "presto_cpp/presto_protocol/connector/tpch/presto_protocol_tpch.h" diff --git a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml index 18c9afda02ece..83a18b28a72ad 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/presto_protocol.yml @@ -53,6 +53,7 @@ AbstractClasses: - { name: IcebergColumnHandle, key: hive-iceberg } - { name: TpchColumnHandle, key: tpch } - { name: SystemColumnHandle, key: $system@system } + - { name: ArrowColumnHandle, key: arrow-flight } ConnectorPartitioningHandle: super: JsonEncodedSubclass @@ -68,6 +69,7 @@ AbstractClasses: - { name: IcebergTableHandle, key: hive-iceberg } - { name: TpchTableHandle, key: tpch } - { name: SystemTableHandle, key: $system@system } + - { name: ArrowTableHandle, key: arrow-flight } ConnectorOutputTableHandle: super: JsonEncodedSubclass @@ -95,6 +97,7 @@ AbstractClasses: - { name: IcebergTableLayoutHandle, key: hive-iceberg } - { name: TpchTableLayoutHandle, key: tpch } - { name: SystemTableLayoutHandle, key: $system@system } + - { name: ArrowTableLayoutHandle, key: arrow-flight } ConnectorMetadataUpdateHandle: super: JsonEncodedSubclass @@ -110,6 +113,7 @@ AbstractClasses: - { name: RemoteSplit, key: $remote } - { name: EmptySplit, key: $empty } - { name: SystemSplit, key: $system@system } + - { name: ArrowSplit, key: arrow-flight } ConnectorHistogram: super: JsonEncodedSubclass @@ -365,3 +369,7 @@ JavaClasses: - presto-main/src/main/java/com/facebook/presto/connector/system/SystemTransactionHandle.java - presto-spi/src/main/java/com/facebook/presto/spi/function/AggregationFunctionMetadata.java - presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/json/JsonBasedUdfFunctionMetadata.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java + - presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java diff --git a/presto-native-execution/scripts/setup-adapters.sh b/presto-native-execution/scripts/setup-adapters.sh index 6c36424ebf90c..532ec01d9e8f8 100755 --- a/presto-native-execution/scripts/setup-adapters.sh +++ b/presto-native-execution/scripts/setup-adapters.sh @@ -35,15 +35,73 @@ function install_prometheus_cpp { cmake_install -DBUILD_SHARED_LIBS=ON -DENABLE_PUSH=OFF -DENABLE_COMPRESSION=OFF } +function install_abseil { + # abseil-cpp + github_checkout abseil/abseil-cpp 20240116.2 --depth 1 + cmake_install \ + -DABSL_BUILD_TESTING=OFF \ + -DCMAKE_CXX_STANDARD=17 \ + -DABSL_PROPAGATE_CXX_STD=ON \ + -DABSL_ENABLE_INSTALL=ON +} + +function install_grpc { + # grpc + github_checkout grpc/grpc v1.48.1 --depth 1 + cmake_install \ + -DgRPC_BUILD_TESTS=OFF \ + -DgRPC_ABSL_PROVIDER=package \ + -DgRPC_ZLIB_PROVIDER=package \ + -DgRPC_CARES_PROVIDER=package \ + -DgRPC_RE2_PROVIDER=package \ + -DgRPC_SSL_PROVIDER=package \ + -DgRPC_PROTOBUF_PROVIDER=package \ + -DgRPC_INSTALL=ON +} + +function install_arrow_flight { + ARROW_VERSION="15.0.0" + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + export INSTALL_PREFIX=${INSTALL_PREFIX:-"/usr/local"} + LINUX_DISTRIBUTION=$(. /etc/os-release && echo ${ID}) + if [[ "$LINUX_DISTRIBUTION" == "ubuntu" || "$LINUX_DISTRIBUTION" == "debian" ]]; then + SUDO="${SUDO:-"sudo --preserve-env"}" + ${SUDO} apt install -y libc-ares-dev + ${SUDO} ldconfig -v 2>/dev/null | grep "${INSTALL_PREFIX}/lib" || \ + echo "${INSTALL_PREFIX}/lib" | ${SUDO} tee /etc/ld.so.conf.d/local-libraries.conf > /dev/null \ + && ${SUDO} ldconfig + else + dnf -y install c-ares-devel + ldconfig -v 2>/dev/null | grep "${INSTALL_PREFIX}/lib" || \ + echo "${INSTALL_PREFIX}/lib" | tee /etc/ld.so.conf.d/local-libraries.conf > /dev/null \ + && ldconfig + fi + else + # The installation script for the Arrow Flight connector currently works only on Linux distributions. + return 0 + fi + + install_abseil + install_grpc + + wget_and_untar https://github.com/apache/arrow/archive/apache-arrow-${ARROW_VERSION}.tar.gz arrow + cmake_install_dir arrow/cpp \ + -DARROW_FLIGHT=ON \ + -DARROW_BUILD_BENCHMARKS=ON \ + -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} +} + cd "${DEPENDENCY_DIR}" || exit install_jwt=0 install_prometheus_cpp=0 +install_arrow_flight=0 if [ "$#" -eq 0 ]; then # Install all adapters by default install_jwt=1 install_prometheus_cpp=1 + install_arrow_flight=1 fi while [[ $# -gt 0 ]]; do @@ -56,6 +114,10 @@ while [[ $# -gt 0 ]]; do install_prometheus_cpp=1; shift ;; + arrow_flight) + install_arrow_flight=1; + shift + ;; *) echo "ERROR: Unknown option $1! will be ignored!" shift @@ -72,6 +134,10 @@ if [ $install_prometheus_cpp -eq 1 ]; then install_prometheus_cpp fi +if [ $install_arrow_flight -eq 1 ]; then + install_arrow_flight +fi + _ret=$? if [ $_ret -eq 0 ] ; then echo "All deps for Presto adapters installed!"