diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 206cc6505c4b..ae9759b6685f 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -27,13 +27,18 @@ repository = { workspace = true } license = { workspace = true } [dependencies] +arrow-arith = { workspace = true, optional = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } # Cast is needed to work around https://github.com/apache/arrow-rs/issues/3389 arrow-cast = { workspace = true } -arrow-data = { workspace = true } +arrow-data = { workspace = true, optional = true } arrow-ipc = { workspace = true } +arrow-ord = { workspace = true, optional = true } +arrow-row = { workspace = true, optional = true } +arrow-select = { workspace = true, optional = true } arrow-schema = { workspace = true } +arrow-string = { workspace = true, optional = true } base64 = { version = "0.21", default-features = false, features = ["std"] } bytes = { version = "1", default-features = false } futures = { version = "0.3", default-features = false, features = ["alloc"] } @@ -53,7 +58,7 @@ all-features = true [features] default = [] -flight-sql-experimental = ["once_cell"] +flight-sql-experimental = ["arrow-arith", "arrow-data", "arrow-ord", "arrow-row", "arrow-select", "arrow-string", "once_cell"] tls = ["tonic/tls"] # Enable CLI tools diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 783e0bf5bdf6..6b92621a564d 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -20,6 +20,7 @@ use base64::Engine; use futures::{stream, Stream, TryStreamExt}; use once_cell::sync::Lazy; use prost::Message; +use std::collections::HashSet; use std::pin::Pin; use std::sync::Arc; use tonic::transport::Server; @@ -29,6 +30,9 @@ use tonic::{Request, Response, Status, Streaming}; use arrow_array::builder::StringBuilder; use arrow_array::{ArrayRef, RecordBatch}; use arrow_flight::encode::FlightDataEncoderBuilder; +use arrow_flight::sql::catalogs::{ + get_catalogs_schema, get_db_schemas_schema, get_tables_schema, +}; use arrow_flight::sql::sql_info::SqlInfoList; use arrow_flight::sql::{ server::FlightSqlService, ActionBeginSavepointRequest, ActionBeginSavepointResult, @@ -72,6 +76,8 @@ static INSTANCE_SQL_INFO: Lazy = Lazy::new(|| { .with_sql_info(SqlInfo::FlightSqlServerArrowVersion, "1.3") }); +static TABLES: Lazy> = Lazy::new(|| vec!["flight_sql.example.table"]); + #[derive(Clone)] pub struct FlightSqlServiceImpl {} @@ -236,32 +242,62 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn get_flight_info_catalogs( &self, - _query: CommandGetCatalogs, - _request: Request, + query: CommandGetCatalogs, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_catalogs not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(get_catalogs_schema()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_schemas( &self, - _query: CommandGetDbSchemas, - _request: Request, + query: CommandGetDbSchemas, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_schemas not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(get_db_schemas_schema().as_ref()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_tables( &self, - _query: CommandGetTables, - _request: Request, + query: CommandGetTables, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented( - "get_flight_info_tables not implemented", - )) + let flight_descriptor = request.into_inner(); + let ticket = Ticket { + ticket: query.encode_to_vec().into(), + }; + let endpoint = FlightEndpoint::new().with_ticket(ticket); + + let flight_info = FlightInfo::new() + .try_with_schema(get_tables_schema(query.include_schema).as_ref()) + .map_err(|e| status!("Unable to encode schema", e))? + .with_endpoint(endpoint) + .with_descriptor(flight_descriptor); + + Ok(tonic::Response::new(flight_info)) } async fn get_flight_info_table_types( @@ -363,26 +399,88 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_get_catalogs( &self, - _query: CommandGetCatalogs, + query: CommandGetCatalogs, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_catalogs not implemented")) + let catalog_names = TABLES + .iter() + .map(|full_name| full_name.split('.').collect::>()[0].to_string()) + .collect::>(); + let mut builder = query.into_builder(); + for catalog_name in catalog_names { + builder.append(catalog_name); + } + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(Arc::new(get_catalogs_schema().clone())) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_schemas( &self, - _query: CommandGetDbSchemas, + query: CommandGetDbSchemas, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_schemas not implemented")) + let schemas = TABLES + .iter() + .map(|full_name| { + let parts = full_name.split('.').collect::>(); + (parts[0].to_string(), parts[1].to_string()) + }) + .collect::>(); + + let mut builder = query.into_builder(); + for (catalog_name, schema_name) in schemas { + builder.append(catalog_name, schema_name); + } + + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(get_db_schemas_schema()) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_tables( &self, - _query: CommandGetTables, + query: CommandGetTables, _request: Request, ) -> Result::DoGetStream>, Status> { - Err(Status::unimplemented("do_get_tables not implemented")) + let tables = TABLES + .iter() + .map(|full_name| { + let parts = full_name.split('.').collect::>(); + ( + parts[0].to_string(), + parts[1].to_string(), + parts[2].to_string(), + ) + }) + .collect::>(); + + let dummy_schema = Schema::empty(); + let mut builder = query.into_builder(); + for (catalog_name, schema_name, table_name) in tables { + builder + .append( + catalog_name, + schema_name, + table_name, + "TABLE", + &dummy_schema, + ) + .map_err(Status::from)?; + } + + let batch = builder.build(); + let stream = FlightDataEncoderBuilder::new() + .with_schema(get_db_schemas_schema()) + .build(futures::stream::once(async { batch })) + .map_err(Status::from); + Ok(Response::new(Box::pin(stream))) } async fn do_get_table_types( diff --git a/arrow-flight/src/sql/catalogs/db_schemas.rs b/arrow-flight/src/sql/catalogs/db_schemas.rs new file mode 100644 index 000000000000..76c5499c89cc --- /dev/null +++ b/arrow-flight/src/sql/catalogs/db_schemas.rs @@ -0,0 +1,284 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GetSchemasBuilder`] for building responses to [`CommandGetDbSchemas`] queries. +//! +//! [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas + +use std::sync::Arc; + +use arrow_arith::boolean::and; +use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch}; +use arrow_ord::comparison::eq_utf8_scalar; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::{filter::filter_record_batch, take::take}; +use arrow_string::like::like_utf8_scalar; +use once_cell::sync::Lazy; + +use super::lexsort_to_indices; +use crate::error::*; +use crate::sql::CommandGetDbSchemas; + +/// Return the schema of the RecordBatch that will be returned from [`CommandGetDbSchemas`] +/// +/// [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas +pub fn get_db_schemas_schema() -> SchemaRef { + Arc::clone(&GET_DB_SCHEMAS_SCHEMA) +} + +/// The schema for GetDbSchemas +static GET_DB_SCHEMAS_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, false), + Field::new("db_schema_name", DataType::Utf8, false), + ])) +}); + +/// Builds rows like this: +/// +/// * catalog_name: utf8, +/// * db_schema_name: utf8, +pub struct GetSchemasBuilder { + // Specifies the Catalog to search for the tables. + // - An empty string retrieves those without a catalog. + // - If omitted the catalog name is not used to narrow the search. + catalog_filter: Option, + // Optional filters to apply + db_schema_filter_pattern: Option, + // array builder for catalog names + catalog_name: StringBuilder, + // array builder for schema names + db_schema_name: StringBuilder, +} + +impl CommandGetDbSchemas { + pub fn into_builder(self) -> GetSchemasBuilder { + self.into() + } +} + +impl From for GetSchemasBuilder { + fn from(value: CommandGetDbSchemas) -> Self { + Self::new(value.catalog, value.db_schema_filter_pattern) + } +} + +impl GetSchemasBuilder { + /// Create a new instance of [`GetSchemasBuilder`] + /// + /// # Parameters + /// + /// - `catalog`: Specifies the Catalog to search for the tables. + /// - An empty string retrieves those without a catalog. + /// - If omitted the catalog name is not used to narrow the search. + /// - `db_schema_filter_pattern`: Specifies a filter pattern for schemas to search for. + /// When no pattern is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + /// + /// [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas + pub fn new( + catalog: Option>, + db_schema_filter_pattern: Option>, + ) -> Self { + Self { + catalog_filter: catalog.map(|v| v.into()), + db_schema_filter_pattern: db_schema_filter_pattern.map(|v| v.into()), + catalog_name: StringBuilder::new(), + db_schema_name: StringBuilder::new(), + } + } + + /// Append a row + /// + /// In case the catalog should be considered as empty, pass in an empty string '""'. + pub fn append( + &mut self, + catalog_name: impl AsRef, + schema_name: impl AsRef, + ) { + self.catalog_name.append_value(catalog_name); + self.db_schema_name.append_value(schema_name); + } + + /// builds a `RecordBatch` with the correct schema for a `CommandGetDbSchemas` response + pub fn build(self) -> Result { + let Self { + catalog_filter, + db_schema_filter_pattern, + mut catalog_name, + mut db_schema_name, + } = self; + + // Make the arrays + let catalog_name = catalog_name.finish(); + let db_schema_name = db_schema_name.finish(); + + let mut filters = vec![]; + + if let Some(db_schema_filter_pattern) = db_schema_filter_pattern { + // use like kernel to get wildcard matching + filters.push(like_utf8_scalar( + &db_schema_name, + &db_schema_filter_pattern, + )?) + } + + if let Some(catalog_filter_name) = catalog_filter { + filters.push(eq_utf8_scalar(&catalog_name, &catalog_filter_name)?); + } + + // `AND` any filters together + let mut total_filter = None; + while let Some(filter) = filters.pop() { + let new_filter = match total_filter { + Some(total_filter) => and(&total_filter, &filter)?, + None => filter, + }; + total_filter = Some(new_filter); + } + + let batch = RecordBatch::try_new( + get_db_schemas_schema(), + vec![ + Arc::new(catalog_name) as ArrayRef, + Arc::new(db_schema_name) as ArrayRef, + ], + )?; + + // Apply the filters if needed + let filtered_batch = if let Some(filter) = total_filter { + filter_record_batch(&batch, &filter)? + } else { + batch + }; + + // Order filtered results by catalog_name, then db_schema_name + let indices = lexsort_to_indices(filtered_batch.columns()); + let columns = filtered_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(get_db_schemas_schema(), columns)?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{StringArray, UInt32Array}; + + fn get_ref_batch() -> RecordBatch { + RecordBatch::try_new( + get_db_schemas_schema(), + vec![ + Arc::new(StringArray::from(vec![ + "a_catalog", + "a_catalog", + "b_catalog", + "b_catalog", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "a_schema", "b_schema", "a_schema", "b_schema", + ])) as ArrayRef, + ], + ) + .unwrap() + } + + #[test] + fn test_schemas_are_filtered() { + let ref_batch = get_ref_batch(); + + let mut builder = GetSchemasBuilder::new(None::, None::); + builder.append("a_catalog", "a_schema"); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + assert_eq!(schema_batch, ref_batch); + + let mut builder = GetSchemasBuilder::new(None::, Some("a%")); + builder.append("a_catalog", "a_schema"); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + let indices = UInt32Array::from(vec![0, 2]); + let ref_filtered = RecordBatch::try_new( + get_db_schemas_schema(), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + + assert_eq!(schema_batch, ref_filtered); + } + + #[test] + fn test_schemas_are_sorted() { + let ref_batch = get_ref_batch(); + + let mut builder = GetSchemasBuilder::new(None::, None::); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("a_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + assert_eq!(schema_batch, ref_batch) + } + + #[test] + fn test_builder_from_query() { + let ref_batch = get_ref_batch(); + let query = CommandGetDbSchemas { + catalog: Some("a_catalog".into()), + db_schema_filter_pattern: Some("b%".into()), + }; + + let mut builder = query.into_builder(); + builder.append("a_catalog", "a_schema"); + builder.append("a_catalog", "b_schema"); + builder.append("b_catalog", "a_schema"); + builder.append("b_catalog", "b_schema"); + let schema_batch = builder.build().unwrap(); + + let indices = UInt32Array::from(vec![1]); + let ref_filtered = RecordBatch::try_new( + get_db_schemas_schema(), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + + assert_eq!(schema_batch, ref_filtered); + } +} diff --git a/arrow-flight/src/sql/catalogs/mod.rs b/arrow-flight/src/sql/catalogs/mod.rs new file mode 100644 index 000000000000..e4cbb6fedc45 --- /dev/null +++ b/arrow-flight/src/sql/catalogs/mod.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Builders and function for building responses to information schema requests +//! +//! - [`get_catalogs_batch`] and [`get_catalogs_schema`] for building responses to [`CommandGetCatalogs`] queries. +//! - [`GetSchemasBuilder`] and [`get_db_schemas_schema`] for building responses to [`CommandGetDbSchemas`] queries. +//! - [`GetTablesBuilder`] and [`get_tables_schema`] for building responses to [`CommandGetTables`] queries. +//! +//! [`CommandGetCatalogs`]: crate::sql::CommandGetCatalogs +//! [`CommandGetDbSchemas`]: crate::sql::CommandGetDbSchemas +//! [`CommandGetTables`]: crate::sql::CommandGetTables + +use std::sync::Arc; + +use arrow_array::{ArrayRef, RecordBatch, StringArray, UInt32Array}; +use arrow_row::{RowConverter, SortField}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use once_cell::sync::Lazy; + +use crate::error::Result; +use crate::sql::CommandGetCatalogs; + +pub use db_schemas::{get_db_schemas_schema, GetSchemasBuilder}; +pub use tables::{get_tables_schema, GetTablesBuilder}; + +mod db_schemas; +mod tables; + +pub struct GetCatalogsBuilder { + catalogs: Vec, +} + +impl CommandGetCatalogs { + pub fn into_builder(self) -> GetCatalogsBuilder { + self.into() + } +} + +impl From for GetCatalogsBuilder { + fn from(_: CommandGetCatalogs) -> Self { + Self::new() + } +} + +impl Default for GetCatalogsBuilder { + fn default() -> Self { + Self::new() + } +} + +impl GetCatalogsBuilder { + /// Create a new instance of [`GetCatalogsBuilder`] + pub fn new() -> Self { + Self { + catalogs: Vec::new(), + } + } + + /// Append a row + pub fn append(&mut self, catalog_name: impl Into) { + self.catalogs.push(catalog_name.into()); + } + + /// builds a `RecordBatch` with the correct schema for a `CommandGetCatalogs` response + pub fn build(self) -> Result { + get_catalogs_batch(self.catalogs) + } +} + +/// Returns the RecordBatch for `CommandGetCatalogs` +pub fn get_catalogs_batch(mut catalog_names: Vec) -> Result { + catalog_names.sort_unstable(); + + let batch = RecordBatch::try_new( + Arc::clone(&GET_CATALOG_SCHEMA), + vec![Arc::new(StringArray::from_iter_values(catalog_names)) as _], + )?; + + Ok(batch) +} + +/// Returns the schema that will result from [`CommandGetCatalogs`] +/// +/// [`CommandGetCatalogs`]: crate::sql::CommandGetCatalogs +pub fn get_catalogs_schema() -> &'static Schema { + &GET_CATALOG_SCHEMA +} + +/// The schema for GetCatalogs +static GET_CATALOG_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![Field::new( + "catalog_name", + DataType::Utf8, + false, + )])) +}); + +fn lexsort_to_indices(arrays: &[ArrayRef]) -> UInt32Array { + let fields = arrays + .iter() + .map(|a| SortField::new(a.data_type().clone())) + .collect(); + let mut converter = RowConverter::new(fields).unwrap(); + let rows = converter.convert_columns(arrays).unwrap(); + let mut sort: Vec<_> = rows.iter().enumerate().collect(); + sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); + UInt32Array::from_iter_values(sort.iter().map(|(i, _)| *i as u32)) +} diff --git a/arrow-flight/src/sql/catalogs/tables.rs b/arrow-flight/src/sql/catalogs/tables.rs new file mode 100644 index 000000000000..fcdc0dbb7447 --- /dev/null +++ b/arrow-flight/src/sql/catalogs/tables.rs @@ -0,0 +1,466 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GetTablesBuilder`] for building responses to [`CommandGetTables`] queries. +//! +//! [`CommandGetTables`]: crate::sql::CommandGetTables + +use std::sync::Arc; + +use arrow_arith::boolean::{and, or}; +use arrow_array::builder::{BinaryBuilder, StringBuilder}; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_ord::comparison::eq_utf8_scalar; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_select::{filter::filter_record_batch, take::take}; +use arrow_string::like::like_utf8_scalar; +use once_cell::sync::Lazy; + +use super::lexsort_to_indices; +use crate::error::*; +use crate::sql::CommandGetTables; +use crate::{IpcMessage, IpcWriteOptions, SchemaAsIpc}; + +/// Return the schema of the RecordBatch that will be returned from [`CommandGetTables`] +/// +/// Note the schema differs based on the values of `include_schema +/// +/// [`CommandGetTables`]: crate::sql::CommandGetTables +pub fn get_tables_schema(include_schema: bool) -> SchemaRef { + if include_schema { + Arc::clone(&GET_TABLES_SCHEMA_WITH_TABLE_SCHEMA) + } else { + Arc::clone(&GET_TABLES_SCHEMA_WITHOUT_TABLE_SCHEMA) + } +} + +/// Builds rows like this: +/// +/// * catalog_name: utf8, +/// * db_schema_name: utf8, +/// * table_name: utf8 not null, +/// * table_type: utf8 not null, +/// * (optional) table_schema: bytes not null (schema of the table as described +/// in Schema.fbs::Schema it is serialized as an IPC message.) +pub struct GetTablesBuilder { + catalog_filter: Option, + table_types_filter: Vec, + // Optional filters to apply to schemas + db_schema_filter_pattern: Option, + // Optional filters to apply to tables + table_name_filter_pattern: Option, + // array builder for catalog names + catalog_name: StringBuilder, + // array builder for db schema names + db_schema_name: StringBuilder, + // array builder for tables names + table_name: StringBuilder, + // array builder for table types + table_type: StringBuilder, + // array builder for table schemas + table_schema: Option, +} + +impl CommandGetTables { + pub fn into_builder(self) -> GetTablesBuilder { + self.into() + } +} + +impl From for GetTablesBuilder { + fn from(value: CommandGetTables) -> Self { + Self::new( + value.catalog, + value.db_schema_filter_pattern, + value.table_name_filter_pattern, + value.table_types, + value.include_schema, + ) + } +} + +impl GetTablesBuilder { + /// Create a new instance of [`GetTablesBuilder`] + /// + /// # Paramneters + /// + /// - `catalog`: Specifies the Catalog to search for the tables. + /// - An empty string retrieves those without a catalog. + /// - If omitted the catalog name is not used to narrow the search. + /// - `db_schema_filter_pattern`: Specifies a filter pattern for schemas to search for. + /// When no pattern is provided, the pattern will not be used to narrow the search. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + /// - `table_name_filter_pattern`: Specifies a filter pattern for tables to search for. + /// When no pattern is provided, all tables matching other filters are searched. + /// In the pattern string, two special characters can be used to denote matching rules: + /// - "%" means to match any substring with 0 or more characters. + /// - "_" means to match any one character. + /// - `table_types`: Specifies a filter of table types which must match. + /// An empy Vec matches all table types. + /// - `include_schema`: Specifies if the Arrow schema should be returned for found tables. + /// + /// [`CommandGetTables`]: crate::sql::CommandGetTables + pub fn new( + catalog: Option>, + db_schema_filter_pattern: Option>, + table_name_filter_pattern: Option>, + table_types: impl IntoIterator>, + include_schema: bool, + ) -> Self { + let table_schema = if include_schema { + Some(BinaryBuilder::new()) + } else { + None + }; + Self { + catalog_filter: catalog.map(|s| s.into()), + table_types_filter: table_types.into_iter().map(|tt| tt.into()).collect(), + db_schema_filter_pattern: db_schema_filter_pattern.map(|s| s.into()), + table_name_filter_pattern: table_name_filter_pattern.map(|t| t.into()), + catalog_name: StringBuilder::new(), + db_schema_name: StringBuilder::new(), + table_name: StringBuilder::new(), + table_type: StringBuilder::new(), + table_schema, + } + } + + /// Append a row + pub fn append( + &mut self, + catalog_name: impl AsRef, + schema_name: impl AsRef, + table_name: impl AsRef, + table_type: impl AsRef, + table_schema: &Schema, + ) -> Result<()> { + self.catalog_name.append_value(catalog_name); + self.db_schema_name.append_value(schema_name); + self.table_name.append_value(table_name); + self.table_type.append_value(table_type); + if let Some(self_table_schema) = self.table_schema.as_mut() { + let options = IpcWriteOptions::default(); + // encode the schema into the correct form + let message: std::result::Result = + SchemaAsIpc::new(table_schema, &options).try_into(); + let IpcMessage(schema) = message?; + self_table_schema.append_value(schema); + } + + Ok(()) + } + + /// builds a `RecordBatch` for `CommandGetTables` + pub fn build(self) -> Result { + let Self { + catalog_filter, + table_types_filter, + db_schema_filter_pattern, + table_name_filter_pattern, + + mut catalog_name, + mut db_schema_name, + mut table_name, + mut table_type, + table_schema, + } = self; + + // Make the arrays + let catalog_name = catalog_name.finish(); + let db_schema_name = db_schema_name.finish(); + let table_name = table_name.finish(); + let table_type = table_type.finish(); + let table_schema = table_schema.map(|mut table_schema| table_schema.finish()); + + // apply any filters, getting a BooleanArray that represents + // the rows that passed the filter + let mut filters = vec![]; + + if let Some(catalog_filter_name) = catalog_filter { + filters.push(eq_utf8_scalar(&catalog_name, &catalog_filter_name)?); + } + + let tt_filter = table_types_filter + .into_iter() + .map(|tt| eq_utf8_scalar(&table_type, &tt)) + .collect::, _>>()? + .into_iter() + // We know the arrays are of same length as they are produced fromn the same root array + .reduce(|filter, arr| or(&filter, &arr).unwrap()); + if let Some(filter) = tt_filter { + filters.push(filter); + } + + if let Some(db_schema_filter_pattern) = db_schema_filter_pattern { + // use like kernel to get wildcard matching + filters.push(like_utf8_scalar( + &db_schema_name, + &db_schema_filter_pattern, + )?) + } + + if let Some(table_name_filter_pattern) = table_name_filter_pattern { + // use like kernel to get wildcard matching + filters.push(like_utf8_scalar(&table_name, &table_name_filter_pattern)?) + } + + let include_schema = table_schema.is_some(); + let batch = if let Some(table_schema) = table_schema { + RecordBatch::try_new( + get_tables_schema(include_schema), + vec![ + Arc::new(catalog_name) as ArrayRef, + Arc::new(db_schema_name) as ArrayRef, + Arc::new(table_name) as ArrayRef, + Arc::new(table_type) as ArrayRef, + Arc::new(table_schema) as ArrayRef, + ], + ) + } else { + RecordBatch::try_new( + get_tables_schema(include_schema), + vec![ + Arc::new(catalog_name) as ArrayRef, + Arc::new(db_schema_name) as ArrayRef, + Arc::new(table_name) as ArrayRef, + Arc::new(table_type) as ArrayRef, + ], + ) + }?; + + // `AND` any filters together + let mut total_filter = None; + while let Some(filter) = filters.pop() { + let new_filter = match total_filter { + Some(total_filter) => and(&total_filter, &filter)?, + None => filter, + }; + total_filter = Some(new_filter); + } + + // Apply the filters if needed + let filtered_batch = if let Some(total_filter) = total_filter { + filter_record_batch(&batch, &total_filter)? + } else { + batch + }; + + // Order filtered results by catalog_name, then db_schema_name, then table_name, then table_type + // https://github.com/apache/arrow/blob/130f9e981aa98c25de5f5bfe55185db270cec313/format/FlightSql.proto#LL1202C1-L1202C1 + let sort_cols = filtered_batch.project(&[0, 1, 2, 3])?; + let indices = lexsort_to_indices(sort_cols.columns()); + let columns = filtered_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>()?; + + Ok(RecordBatch::try_new( + get_tables_schema(include_schema), + columns, + )?) + } +} + +/// The schema for GetTables without `table_schema` column +static GET_TABLES_SCHEMA_WITHOUT_TABLE_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, false), + Field::new("db_schema_name", DataType::Utf8, false), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + ])) +}); + +/// The schema for GetTables with `table_schema` column +static GET_TABLES_SCHEMA_WITH_TABLE_SCHEMA: Lazy = Lazy::new(|| { + Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, false), + Field::new("db_schema_name", DataType::Utf8, false), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + Field::new("table_schema", DataType::Binary, false), + ])) +}); + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{StringArray, UInt32Array}; + + fn get_ref_batch() -> RecordBatch { + RecordBatch::try_new( + get_tables_schema(false), + vec![ + Arc::new(StringArray::from(vec![ + "a_catalog", + "a_catalog", + "a_catalog", + "a_catalog", + "b_catalog", + "b_catalog", + "b_catalog", + "b_catalog", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "a_schema", "a_schema", "b_schema", "b_schema", "a_schema", + "a_schema", "b_schema", "b_schema", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "a_table", "b_table", "a_table", "b_table", "a_table", "a_table", + "b_table", "b_table", + ])) as ArrayRef, + Arc::new(StringArray::from(vec![ + "TABLE", "TABLE", "TABLE", "TABLE", "TABLE", "VIEW", "TABLE", "VIEW", + ])) as ArrayRef, + ], + ) + .unwrap() + } + + fn get_ref_builder( + catalog: Option<&str>, + db_schema_filter_pattern: Option<&str>, + table_name_filter_pattern: Option<&str>, + table_types: Vec<&str>, + include_schema: bool, + ) -> GetTablesBuilder { + let dummy_schema = Schema::empty(); + let tables = [ + ("a_catalog", "a_schema", "a_table", "TABLE"), + ("a_catalog", "a_schema", "b_table", "TABLE"), + ("a_catalog", "b_schema", "a_table", "TABLE"), + ("a_catalog", "b_schema", "b_table", "TABLE"), + ("b_catalog", "a_schema", "a_table", "TABLE"), + ("b_catalog", "a_schema", "a_table", "VIEW"), + ("b_catalog", "b_schema", "b_table", "TABLE"), + ("b_catalog", "b_schema", "b_table", "VIEW"), + ]; + let mut builder = GetTablesBuilder::new( + catalog, + db_schema_filter_pattern, + table_name_filter_pattern, + table_types, + include_schema, + ); + for (catalog_name, schema_name, table_name, table_type) in tables { + builder + .append( + catalog_name, + schema_name, + table_name, + table_type, + &dummy_schema, + ) + .unwrap(); + } + builder + } + + #[test] + fn test_tables_are_filtered() { + let ref_batch = get_ref_batch(); + + let builder = get_ref_builder(None, None, None, Vec::new(), false); + let table_batch = builder.build().unwrap(); + assert_eq!(table_batch, ref_batch); + + let builder = get_ref_builder(None, Some("a%"), Some("a%"), Vec::new(), false); + let table_batch = builder.build().unwrap(); + let indices = UInt32Array::from(vec![0, 4, 5]); + let ref_filtered = RecordBatch::try_new( + get_tables_schema(false), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + assert_eq!(table_batch, ref_filtered); + + let builder = get_ref_builder(Some("a_catalog"), None, None, Vec::new(), false); + let table_batch = builder.build().unwrap(); + let indices = UInt32Array::from(vec![0, 1, 2, 3]); + let ref_filtered = RecordBatch::try_new( + get_tables_schema(false), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + assert_eq!(table_batch, ref_filtered); + + let builder = get_ref_builder(None, None, None, vec!["VIEW"], false); + let table_batch = builder.build().unwrap(); + let indices = UInt32Array::from(vec![5, 7]); + let ref_filtered = RecordBatch::try_new( + get_tables_schema(false), + ref_batch + .columns() + .iter() + .map(|c| take(c, &indices, None)) + .collect::, _>>() + .unwrap(), + ) + .unwrap(); + assert_eq!(table_batch, ref_filtered); + } + + #[test] + fn test_tables_are_sorted() { + let ref_batch = get_ref_batch(); + let dummy_schema = Schema::empty(); + + let tables = [ + ("b_catalog", "a_schema", "a_table", "TABLE"), + ("b_catalog", "b_schema", "b_table", "TABLE"), + ("b_catalog", "b_schema", "b_table", "VIEW"), + ("b_catalog", "a_schema", "a_table", "VIEW"), + ("a_catalog", "a_schema", "a_table", "TABLE"), + ("a_catalog", "b_schema", "a_table", "TABLE"), + ("a_catalog", "b_schema", "b_table", "TABLE"), + ("a_catalog", "a_schema", "b_table", "TABLE"), + ]; + let mut builder = GetTablesBuilder::new( + None::, + None::, + None::, + None::, + false, + ); + for (catalog_name, schema_name, table_name, table_type) in tables { + builder + .append( + catalog_name, + schema_name, + table_name, + table_type, + &dummy_schema, + ) + .unwrap(); + } + let table_batch = builder.build().unwrap(); + assert_eq!(table_batch, ref_batch); + } +} diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 74d3176c67b1..c864a1ba4f7e 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -96,6 +96,7 @@ pub use gen::UpdateDeleteRules; pub use sql_info::SqlInfoList; +pub mod catalogs; pub mod client; pub mod server; pub mod sql_info; diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 9010c8d9a2a9..5b9a1bb88078 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -110,7 +110,7 @@ //! .map(|a| SortField::new(a.data_type().clone())) //! .collect(); //! let mut converter = RowConverter::new(fields).unwrap(); -//! let rows = converter.convert_columns(&arrays).unwrap(); +//! let rows = converter.convert_columns(arrays).unwrap(); //! let mut sort: Vec<_> = rows.iter().enumerate().collect(); //! sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); //! UInt32Array::from_iter_values(sort.iter().map(|(i, _)| *i as u32))