From ebc2f9e08088e2285577330f64474948e7bc438b Mon Sep 17 00:00:00 2001 From: Even Rouault Date: Sat, 27 May 2023 20:55:23 +0200 Subject: [PATCH] Parquet/Arrow: implement faster spatial filtering with ArrowArray interface We no longer fallback to the slow & generic implementation that goes through GetNextFeature(), but directly post filter the ArrowArray to remove features not intersecting the spatial filter. The performance gain is mostly when a big number of features is selected (the fallback GetNextFeature() has already an efficient spatial filtering, so when selecting a small number of features, this optimization doesn't bring anything) Can be up to 10x faster in that situation. Now: ``` $ time bench_ogr_batch nz-building-outlines.parquet -spat 1167513 4794680 2089113 6190596 real 0m1,275s user 0m1,565s sys 0m0,322s ``` Before: ``` $ time bench_ogr_batch nz-building-outlines.parquet -spat 1167513 4794680 2089113 6190596 real 0m13,507s user 0m13,728s sys 0m0,712s ``` --- autotest/ogr/ogr_parquet.py | 83 +++ ogr/ogrsf_frmts/arrow_common/ogr_arrow.h | 9 + .../arrow_common/ograrrowlayer.hpp | 163 ++++-- ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp | 524 ++++++++++++++++++ ogr/ogrsf_frmts/ogrsf_frmts.h | 4 + perftests/bench_ogr_to_geopandas.py | 48 +- 6 files changed, 778 insertions(+), 53 deletions(-) diff --git a/autotest/ogr/ogr_parquet.py b/autotest/ogr/ogr_parquet.py index f4ea5822e48b..d1d9267b9f18 100755 --- a/autotest/ogr/ogr_parquet.py +++ b/autotest/ogr/ogr_parquet.py @@ -1638,6 +1638,89 @@ def test_ogr_parquet_arrow_stream_numpy(): ) +############################################################################### + + +def test_ogr_parquet_arrow_stream_numpy_fast_spatial_filter(): + pytest.importorskip("osgeo.gdal_array") + numpy = pytest.importorskip("numpy") + import datetime + + ds = ogr.Open("data/parquet/test.parquet") + lyr = ds.GetLayer(0) + ignored_fields = ["decimal128", "decimal256", "time64_ns"] + lyr_defn = lyr.GetLayerDefn() + for i in range(lyr_defn.GetFieldCount()): + fld_defn = lyr_defn.GetFieldDefn(i) + if ( + fld_defn.GetName().startswith("map_") + or fld_defn.GetName().startswith("struct_") + or fld_defn.GetName().startswith("fixed_size_") + or fld_defn.GetType() + not in ( + ogr.OFTInteger, + ogr.OFTInteger64, + ogr.OFTReal, + ogr.OFTString, + ogr.OFTBinary, + ogr.OFTTime, + ogr.OFTDate, + ogr.OFTDateTime, + ) + ): + ignored_fields.append(fld_defn.GetNameRef()) + lyr.SetIgnoredFields(ignored_fields) + lyr.SetSpatialFilterRect(-10, -10, 10, 10) + assert lyr.TestCapability(ogr.OLCFastGetArrowStream) == 1 + + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + fc = 0 + for batch in stream: + fc += len(batch["uint8"]) + assert fc == 4 + + lyr.SetSpatialFilterRect(3, 2, 3, 2) + assert lyr.TestCapability(ogr.OLCFastGetArrowStream) == 1 + + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + batches = [batch for batch in stream] + assert len(batches) == 1 + batch = batches[0] + assert len(batch["geometry"]) == 1 + assert batch["boolean"][0] == False + assert batch["uint8"][0] == 4 + assert batch["int8"][0] == 1 + assert batch["uint16"][0] == 30001 + assert batch["int16"][0] == 10000 + assert batch["uint32"][0] == 3000000001 + assert batch["int32"][0] == 1000000000 + assert batch["uint64"][0] == 300000000001 + assert batch["int64"][0] == 100000000000 + assert batch["int64"][0] == 100000000000 + assert batch["float32"][0] == 4.5 + assert batch["float64"][0] == 4.5 + assert batch["string"][0] == b"c" + assert batch["large_string"][0] == b"c" + assert batch["timestamp_ms_gmt"][0] == numpy.datetime64("2019-01-01T14:00:00.000") + assert batch["time32_s"][0] == datetime.time(0, 0, 4) + assert batch["time32_ms"][0] == datetime.time(0, 0, 0, 4000) + assert batch["time64_us"][0] == datetime.time(0, 0, 0, 4) + assert batch["date32"][0] == numpy.datetime64("1970-01-05") + assert batch["date64"][0] == numpy.datetime64("1970-01-01") + assert bytes(batch["binary"][0]) == b"\00\01" + assert bytes(batch["large_binary"][0]) == b"\00\01" + assert ( + ogr.CreateGeometryFromWkb(batch["geometry"][0]).ExportToWkt() == "POINT (3 2)" + ) + + lyr.SetSpatialFilterRect(1, 1, 1, 1) + assert lyr.TestCapability(ogr.OLCFastGetArrowStream) == 1 + + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + batches = [batch for batch in stream] + assert len(batches) == 0 + + ############################################################################### # Test bbox diff --git a/ogr/ogrsf_frmts/arrow_common/ogr_arrow.h b/ogr/ogrsf_frmts/arrow_common/ogr_arrow.h index 58b5bf3f5f1e..1b74c8edc955 100644 --- a/ogr/ogrsf_frmts/arrow_common/ogr_arrow.h +++ b/ogr/ogrsf_frmts/arrow_common/ogr_arrow.h @@ -85,6 +85,10 @@ class OGRArrowLayer CPL_NON_FINAL std::vector m_asAttributeFilterConstraints{}; int m_nUseOptimizedAttributeFilter = -1; bool m_bSpatialFilterIntersectsLayerExtent = true; + bool m_bUseRecordBatchBaseImplementation = false; + + // Modified by UseRecordBatchBaseImplementation() + mutable struct ArrowSchema m_sCachedSchema = {}; bool SkipToNextFeatureDueToAttributeFilter() const; void ExploreExprNode(const swq_expr_node *poNode); @@ -93,6 +97,8 @@ class OGRArrowLayer CPL_NON_FINAL static struct ArrowArray * CreateWKTArrayFromWKBArray(const struct ArrowArray *sourceArray); + int GetArrowSchemaInternal(struct ArrowSchema *out) const; + protected: OGRArrowDataset *m_poArrowDS = nullptr; arrow::MemoryPool *m_poMemoryPool = nullptr; @@ -214,6 +220,9 @@ class OGRArrowLayer CPL_NON_FINAL int TestCapability(const char *pszCap) override; + bool GetArrowStream(struct ArrowArrayStream *out_stream, + CSLConstList papszOptions = nullptr) override; + virtual std::unique_ptr BuildDomain(const std::string &osDomainName, int iFieldIndex) const = 0; diff --git a/ogr/ogrsf_frmts/arrow_common/ograrrowlayer.hpp b/ogr/ogrsf_frmts/arrow_common/ograrrowlayer.hpp index 4e18e15fc294..aa0bbccf0b24 100644 --- a/ogr/ogrsf_frmts/arrow_common/ograrrowlayer.hpp +++ b/ogr/ogrsf_frmts/arrow_common/ograrrowlayer.hpp @@ -60,6 +60,9 @@ inline OGRArrowLayer::OGRArrowLayer(OGRArrowDataset *poDS, inline OGRArrowLayer::~OGRArrowLayer() { + if (m_sCachedSchema.release) + m_sCachedSchema.release(&m_sCachedSchema); + CPLDebug("ARROW", "Memory pool: bytes_allocated = %" PRId64, m_poMemoryPool->bytes_allocated()); CPLDebug("ARROW", "Memory pool: max_memory = %" PRId64, @@ -3237,7 +3240,7 @@ static void OverrideArrowRelease(OGRArrowDataset *poDS, T *obj) inline bool OGRArrowLayer::UseRecordBatchBaseImplementation() const { - if (m_poAttrQuery != nullptr || m_poFilterGeom != nullptr || + if (m_poAttrQuery != nullptr || CPLTestBool(CPLGetConfigOption("OGR_ARROW_STREAM_BASE_IMPL", "NO"))) { return true; @@ -3254,6 +3257,8 @@ inline bool OGRArrowLayer::UseRecordBatchBaseImplementation() const m_aeGeomEncoding[i] != OGRArrowGeomEncoding::WKB && m_aeGeomEncoding[i] != OGRArrowGeomEncoding::WKT) { + CPLDebug("ARROW", "Geometry encoding not compatible of fast " + "Arrow implementation"); return true; } } @@ -3279,15 +3284,44 @@ inline bool OGRArrowLayer::UseRecordBatchBaseImplementation() const // struct fields will point to the same arrow column if (ignoredState[nArrowCol] != static_cast(bIsIgnored)) { + CPLDebug("ARROW", + "Inconsistent ignore state for Arrow Columns"); return true; } } } } + if (m_poFilterGeom) + { + struct ArrowSchema *psSchema = &m_sCachedSchema; + if (psSchema->release) + psSchema->release(psSchema); + memset(psSchema, 0, sizeof(*psSchema)); + + const bool bCanPostFilter = GetArrowSchemaInternal(psSchema) == 0 && + CanPostFilterArrowArray(psSchema); + if (!bCanPostFilter) + return true; + } + return false; } +/************************************************************************/ +/* GetArrowStream() */ +/************************************************************************/ + +inline bool OGRArrowLayer::GetArrowStream(struct ArrowArrayStream *out_stream, + CSLConstList papszOptions) +{ + if (!OGRLayer::GetArrowStream(out_stream, papszOptions)) + return false; + + m_bUseRecordBatchBaseImplementation = UseRecordBatchBaseImplementation(); + return true; +} + /************************************************************************/ /* GetArrowSchema() */ /************************************************************************/ @@ -3295,9 +3329,19 @@ inline bool OGRArrowLayer::UseRecordBatchBaseImplementation() const inline int OGRArrowLayer::GetArrowSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out_schema) { - if (UseRecordBatchBaseImplementation()) + if (m_bUseRecordBatchBaseImplementation) return OGRLayer::GetArrowSchema(stream, out_schema); + return GetArrowSchemaInternal(out_schema); +} + +/************************************************************************/ +/* GetArrowSchemaInternal() */ +/************************************************************************/ + +inline int +OGRArrowLayer::GetArrowSchemaInternal(struct ArrowSchema *out_schema) const +{ auto status = arrow::ExportSchema(*m_poSchema, out_schema); if (!status.ok()) { @@ -3420,76 +3464,93 @@ inline int OGRArrowLayer::GetArrowSchema(struct ArrowArrayStream *stream, inline int OGRArrowLayer::GetNextArrowArray(struct ArrowArrayStream *stream, struct ArrowArray *out_array) { - if (UseRecordBatchBaseImplementation()) + if (m_bUseRecordBatchBaseImplementation) return OGRLayer::GetNextArrowArray(stream, out_array); - if (m_bEOF) - { - memset(out_array, 0, sizeof(*out_array)); - return 0; - } - - if (m_poBatch == nullptr || m_nIdxInBatch == m_poBatch->num_rows()) + while (true) { - m_bEOF = !ReadNextBatch(); if (m_bEOF) { memset(out_array, 0, sizeof(*out_array)); return 0; } - } - auto status = arrow::ExportRecordBatch(*m_poBatch, out_array, nullptr); - m_nIdxInBatch = m_poBatch->num_rows(); - if (!status.ok()) - { - CPLError(CE_Failure, CPLE_AppDefined, - "ExportRecordBatch() failed with %s", - status.message().c_str()); - return EIO; - } + if (m_poBatch == nullptr || m_nIdxInBatch == m_poBatch->num_rows()) + { + m_bEOF = !ReadNextBatch(); + if (m_bEOF) + { + memset(out_array, 0, sizeof(*out_array)); + return 0; + } + } - if (EQUAL(m_aosArrowArrayStreamOptions.FetchNameValueDef( - "GEOMETRY_ENCODING", ""), - "WKB")) - { - const int nGeomFieldCount = m_poFeatureDefn->GetGeomFieldCount(); - for (int i = 0; i < nGeomFieldCount; i++) + auto status = arrow::ExportRecordBatch(*m_poBatch, out_array, nullptr); + m_nIdxInBatch = m_poBatch->num_rows(); + if (!status.ok()) { - const auto poGeomFieldDefn = m_poFeatureDefn->GetGeomFieldDefn(i); - if (!poGeomFieldDefn->IsIgnored()) + CPLError(CE_Failure, CPLE_AppDefined, + "ExportRecordBatch() failed with %s", + status.message().c_str()); + return EIO; + } + + if (EQUAL(m_aosArrowArrayStreamOptions.FetchNameValueDef( + "GEOMETRY_ENCODING", ""), + "WKB")) + { + const int nGeomFieldCount = m_poFeatureDefn->GetGeomFieldCount(); + for (int i = 0; i < nGeomFieldCount; i++) { - if (m_aeGeomEncoding[i] == OGRArrowGeomEncoding::WKT) + const auto poGeomFieldDefn = + m_poFeatureDefn->GetGeomFieldDefn(i); + if (!poGeomFieldDefn->IsIgnored()) { - const int nArrayIdx = - m_bIgnoredFields - ? m_anMapGeomFieldIndexToArrayIndex[i] - : m_anMapGeomFieldIndexToArrowColumn[i]; - auto sourceArray = out_array->children[nArrayIdx]; - auto targetArray = CreateWKTArrayFromWKBArray(sourceArray); - if (targetArray) + if (m_aeGeomEncoding[i] == OGRArrowGeomEncoding::WKT) { - sourceArray->release(sourceArray); - out_array->children[nArrayIdx] = targetArray; + const int nArrayIdx = + m_bIgnoredFields + ? m_anMapGeomFieldIndexToArrayIndex[i] + : m_anMapGeomFieldIndexToArrowColumn[i]; + auto sourceArray = out_array->children[nArrayIdx]; + auto targetArray = + CreateWKTArrayFromWKBArray(sourceArray); + if (targetArray) + { + sourceArray->release(sourceArray); + out_array->children[nArrayIdx] = targetArray; + } + else + { + out_array->release(out_array); + memset(out_array, 0, sizeof(*out_array)); + return ENOMEM; + } } - else + else if (m_aeGeomEncoding[i] != OGRArrowGeomEncoding::WKB) { - out_array->release(out_array); - memset(out_array, 0, sizeof(*out_array)); - return ENOMEM; + // Shouldn't happen if UseRecordBatchBaseImplementation() + // is up to date + CPLAssert(false); } } - else if (m_aeGeomEncoding[i] != OGRArrowGeomEncoding::WKB) - { - // Shouldn't happen if UseRecordBatchBaseImplementation() - // is up to date - CPLAssert(false); - } } } - } - OverrideArrowRelease(m_poArrowDS, out_array); + OverrideArrowRelease(m_poArrowDS, out_array); + + if (m_poFilterGeom) + { + PostFilterArrowArray(&m_sCachedSchema, out_array); + if (out_array->length == 0) + { + // If there are no records after filtering, start again + // with a new batch + continue; + } + } + break; + } return 0; } diff --git a/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp b/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp index 1ce74ca52b6e..6ad005735568 100644 --- a/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp +++ b/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp @@ -30,6 +30,7 @@ #include "ogr_api.h" #include "ogr_recordbatch.h" #include "ograrrowarrayhelper.h" +#include "ogr_wkb.h" #include "cpl_time.h" #include @@ -1986,3 +1987,526 @@ bool OGR_L_GetArrowStream(OGRLayerH hLayer, struct ArrowArrayStream *out_stream, return OGRLayer::FromHandle(hLayer)->GetArrowStream(out_stream, papszOptions); } + +/************************************************************************/ +/* ParseArrowMetadata() */ +/************************************************************************/ + +static std::map +ParseArrowMetadata(const char *pabyMetadata) +{ + std::map oMetadata; + int32_t nKVP; + memcpy(&nKVP, pabyMetadata, sizeof(int32_t)); + pabyMetadata += sizeof(int32_t); + for (int i = 0; i < nKVP; ++i) + { + int32_t nSizeKey; + memcpy(&nSizeKey, pabyMetadata, sizeof(int32_t)); + pabyMetadata += sizeof(int32_t); + std::string osKey; + osKey.assign(pabyMetadata, nSizeKey); + pabyMetadata += nSizeKey; + + int32_t nSizeValue; + memcpy(&nSizeValue, pabyMetadata, sizeof(int32_t)); + pabyMetadata += sizeof(int32_t); + std::string osValue; + osValue.assign(pabyMetadata, nSizeValue); + pabyMetadata += nSizeValue; + + oMetadata[osKey] = osValue; + } + + return oMetadata; +} + +/************************************************************************/ +/* OGRLayer::CanPostFilterArrowArray() */ +/************************************************************************/ + +/** Whether the PostFilterArrowArray() can work on the schema to remove + * rows that aren't selected by the spatial or attribute filter. + * + * Note: only spatial filter implemented for now. + */ +bool OGRLayer::CanPostFilterArrowArray(const struct ArrowSchema *schema) const +{ + if (m_poAttrQuery) + { + CPLDebug("OGR", + "Cannot post filter ArrowArray with attribute filter set"); + return false; + } + + if (strcmp(schema->format, "+s") != 0) + { + CPLDebug("OGR", "Unexpected top level schema->format = %s", + schema->format); + return false; + } + + const char *const apszHandledFormats[] = { + "b", // boolean + "c", // int8 + "C", // uint8 + "s", // int16 + "S", // uint16 + "i", // int32 + "I", // uint32 + "l", // int64 + "L", // uint64 + "e", // float16 + "f", // float32 + "g", // float64, + "z", // binary + "Z", // large binary + "u", // UTF-8 string + "U", // large UTF-8 string + // "d:xxxxx" // decimal128, decimal256 + // "w:xxxxx" // fixed width binary + "tdD", // date32[days] + "tdm", // date64[milliseconds] + "tts", //time32 [seconds] + "ttm", //time32 [milliseconds] + "ttu", //time64 [microseconds] + "ttn", //time64 [nanoseconds] + }; + + const char *const apszHandledFormatsPrefix[] = { + "tss:", // timestamp [seconds] with timezone + "tsm:", // timestamp [milliseconds] with timezone + "tsu:", // timestamp [microseconds] with timezone + "tsn:", // timestamp [nanoseconds] with timezone + }; + + for (int64_t i = 0; i < schema->n_children; ++i) + { + const auto fieldSchema = schema->children[i]; + bool bFound = false; + for (const char *pszHandledFormat : apszHandledFormats) + { + if (strcmp(fieldSchema->format, pszHandledFormat) == 0) + { + bFound = true; + break; + } + } + if (!bFound) + { + for (const char *pszHandledFormat : apszHandledFormatsPrefix) + { + if (strncmp(fieldSchema->format, pszHandledFormat, + strlen(pszHandledFormat)) == 0) + { + bFound = true; + break; + } + } + } + if (!bFound) + { + CPLDebug("OGR", "Field %s has unhandled format '%s'", + fieldSchema->name, fieldSchema->format); + return false; + } + } + + if (m_poFilterGeom) + { + bool bFound = false; + const char *pszGeomFieldName = + const_cast(this) + ->GetLayerDefn() + ->GetGeomFieldDefn(m_iGeomFieldFilter) + ->GetNameRef(); + for (int64_t i = 0; i < schema->n_children; ++i) + { + const auto fieldSchema = schema->children[i]; + if (strcmp(fieldSchema->name, pszGeomFieldName) == 0) + { + if (strcmp(fieldSchema->format, "z") != 0 && + strcmp(fieldSchema->format, "Z") != 0) + { + CPLDebug("OGR", "Geometry field %s has handled format '%s'", + fieldSchema->name, fieldSchema->format); + return false; + } + + // Check if ARROW:extension:name = ogc.wkb + const char *pabyMetadata = fieldSchema->metadata; + if (!pabyMetadata) + { + CPLDebug( + "OGR", + "Geometry field %s lacks metadata in its schema field", + fieldSchema->name); + return false; + } + + const auto oMetadata = ParseArrowMetadata(pabyMetadata); + auto oIter = oMetadata.find(ARROW_EXTENSION_NAME_KEY); + if (oIter == oMetadata.end()) + { + CPLDebug("OGR", + "Geometry field %s lacks " + "%s metadata " + "in its schema field", + fieldSchema->name, ARROW_EXTENSION_NAME_KEY); + return false; + } + if (oIter->second != EXTENSION_NAME) + { + CPLDebug("OGR", + "Geometry field %s has unexpected " + "%s = '%s' metadata " + "in its schema field", + fieldSchema->name, ARROW_EXTENSION_NAME_KEY, + oIter->second.c_str()); + return false; + } + + bFound = true; + break; + } + } + if (!bFound) + { + CPLDebug("OGR", "Cannot find geometry field %s in schema", + pszGeomFieldName); + return false; + } + } + + return true; +} + +/************************************************************************/ +/* TestBit() */ +/************************************************************************/ + +inline bool TestBit(const uint8_t *pabyData, size_t nIdx) +{ + return (pabyData[nIdx / 8] & (1 << (nIdx % 8))) != 0; +} + +/************************************************************************/ +/* SetBit() */ +/************************************************************************/ + +inline void SetBit(uint8_t *pabyData, size_t nIdx) +{ + pabyData[nIdx / 8] |= (1 << (nIdx % 8)); +} + +/************************************************************************/ +/* UnsetBit() */ +/************************************************************************/ + +inline void UnsetBit(uint8_t *pabyData, size_t nIdx) +{ + pabyData[nIdx / 8] &= uint8_t(~(1 << (nIdx % 8))); +} + +/************************************************************************/ +/* CompactValidityBuffer() */ +/************************************************************************/ + +static void +CompactValidityBuffer(struct ArrowArray *array, + const std::vector &abyValidityFromFilters) +{ + if (array->null_count == 0) + return; + uint8_t *pabyValidity = + static_cast(const_cast(array->buffers[0])); + const size_t nLength = static_cast(array->length); + const size_t nOffset = static_cast(array->offset); + for (size_t i = 0, j = 0; i < nLength; ++i) + { + if (abyValidityFromFilters[i]) + { + if (TestBit(pabyValidity, i + nOffset)) + SetBit(pabyValidity, j + nOffset); + else + UnsetBit(pabyValidity, j + nOffset); + + ++j; + } + } +} + +/************************************************************************/ +/* CompactBoolArray() */ +/************************************************************************/ + +static void CompactBoolArray(struct ArrowArray *array, + const std::vector &abyValidityFromFilters) +{ + CPLAssert(array->n_children == 0); + CPLAssert(array->n_buffers == 2); + CPLAssert(static_cast(array->length) == + abyValidityFromFilters.size()); + + const size_t nLength = static_cast(array->length); + const size_t nOffset = static_cast(array->offset); + uint8_t *pabyData = + static_cast(const_cast(array->buffers[1])); + size_t j = 0; + for (size_t i = 0; i < nLength; ++i) + { + if (abyValidityFromFilters[i]) + { + if (TestBit(pabyData, i + nOffset)) + SetBit(pabyData, j + nOffset); + else + UnsetBit(pabyData, j + nOffset); + + ++j; + } + } + + CompactValidityBuffer(array, abyValidityFromFilters); + array->length = j; +} + +/************************************************************************/ +/* CompactPrimitiveArray() */ +/************************************************************************/ + +template +static void +CompactPrimitiveArray(struct ArrowArray *array, + const std::vector &abyValidityFromFilters) +{ + CPLAssert(array->n_children == 0); + CPLAssert(array->n_buffers == 2); + CPLAssert(static_cast(array->length) == + abyValidityFromFilters.size()); + + const size_t nLength = static_cast(array->length); + const size_t nOffset = static_cast(array->offset); + T *paData = + static_cast(const_cast(array->buffers[1])) + nOffset; + size_t j = 0; + for (size_t i = 0; i < nLength; ++i) + { + if (abyValidityFromFilters[i]) + { + paData[j] = paData[i]; + ++j; + } + } + + CompactValidityBuffer(array, abyValidityFromFilters); + array->length = j; +} + +/************************************************************************/ +/* CompactStringOrBinaryArray() */ +/************************************************************************/ + +template +static void +CompactStringOrBinaryArray(struct ArrowArray *array, + const std::vector &abyValidityFromFilters) +{ + CPLAssert(array->n_children == 0); + CPLAssert(array->n_buffers == 3); + CPLAssert(static_cast(array->length) == + abyValidityFromFilters.size()); + + const size_t nLength = static_cast(array->length); + const size_t nOffset = static_cast(array->offset); + OffsetType *panOffsets = + static_cast(const_cast(array->buffers[1])) + + nOffset; + GByte *pabyData = + static_cast(const_cast(array->buffers[2])); + size_t j = 0; + OffsetType nCurOffset = panOffsets[0]; + for (size_t i = 0; i < nLength; ++i) + { + if (abyValidityFromFilters[i]) + { + const auto nStartOffset = panOffsets[i]; + const auto nEndOffset = panOffsets[i + 1]; + panOffsets[j] = nCurOffset; + const auto nSize = static_cast(nEndOffset - nStartOffset); + if (nSize) + { + if (nCurOffset < nStartOffset) + { + memmove(pabyData + nCurOffset, pabyData + nStartOffset, + nSize); + } + nCurOffset += static_cast(nSize); + } + ++j; + } + } + panOffsets[j] = nCurOffset; + + CompactValidityBuffer(array, abyValidityFromFilters); + array->length = j; +} + +/************************************************************************/ +/* FillValidityArrayFromWKBArray() */ +/************************************************************************/ + +template +static size_t +FillValidityArrayFromWKBArray(struct ArrowArray *array, + const OGREnvelope &sFilterEnvelope, + std::vector &abyValidityFromFilters) +{ + const size_t nLength = static_cast(array->length); + const uint8_t *pabyValidity = + array->null_count == 0 + ? nullptr + : static_cast(array->buffers[0]); + const size_t nOffset = static_cast(array->offset); + const OffsetType *panOffsets = + static_cast(array->buffers[1]) + nOffset; + const GByte *pabyData = static_cast(array->buffers[2]); + OGREnvelope sEnvelope; + abyValidityFromFilters.resize(nLength); + size_t nCountIntersecting = 0; + for (size_t i = 0; i < nLength; ++i) + { + if (!pabyValidity || TestBit(pabyValidity, i + nOffset)) + { + if (OGRWKBGetBoundingBox( + pabyData + panOffsets[i], + static_cast(panOffsets[i + 1] - panOffsets[i]), + sEnvelope) && + sFilterEnvelope.Intersects(sEnvelope)) + { + abyValidityFromFilters[i] = true; + nCountIntersecting++; + } + } + } + return nCountIntersecting; +} + +/************************************************************************/ +/* OGRLayer::PostFilterArrowArray() */ +/************************************************************************/ + +/** Remove rows that aren't selected by the spatial or attribute filter. + * + * Assumes that CanPostFilterArrowArray() has been called and returned true. + * + * Note: only spatial filter implemented for now. + */ +void OGRLayer::PostFilterArrowArray(const struct ArrowSchema *schema, + struct ArrowArray *array) const +{ + if (!m_poFilterGeom) + return; + + CPLAssert(schema->n_children == array->n_children); + + const char *pszGeomFieldName = const_cast(this) + ->GetLayerDefn() + ->GetGeomFieldDefn(m_iGeomFieldFilter) + ->GetNameRef(); + int64_t iGeomField = -1; + for (int64_t iField = 0; iField < schema->n_children; ++iField) + { + const auto fieldSchema = schema->children[iField]; + if (strcmp(fieldSchema->name, pszGeomFieldName) == 0) + { + iGeomField = iField; + break; + } + CPLAssert(array->children[iField]->length == + array->children[0]->length); + } + // Guaranteed if CanPostFilterArrowArray() returned true + CPLAssert(iGeomField >= 0); + CPLAssert(strcmp(schema->children[iGeomField]->format, "z") == 0 || + strcmp(schema->children[iGeomField]->format, "Z") == 0); + + CPLAssert(array->children[iGeomField]->n_buffers == 3); + + std::vector abyValidityFromFilters; + const size_t nCountIntersecting = + strcmp(schema->children[iGeomField]->format, "z") == 0 + ? FillValidityArrayFromWKBArray( + array->children[iGeomField], m_sFilterEnvelope, + abyValidityFromFilters) + : FillValidityArrayFromWKBArray( + array->children[iGeomField], m_sFilterEnvelope, + abyValidityFromFilters); + const size_t nLength = + static_cast(array->children[iGeomField]->length); + // Nothing to do ? + if (nCountIntersecting == nLength) + { + // CPLDebug("OGR", "All rows match filter"); + return; + } + + array->length = nCountIntersecting; + + for (int64_t iField = 0; iField < array->n_children; ++iField) + { + const auto psSchemaField = schema->children[iField]; + const auto psArray = array->children[iField]; + const char *format = psSchemaField->format; + + if (strcmp(format, "b") == 0) + { + CompactBoolArray(psArray, abyValidityFromFilters); + } + else if (strcmp(format, "c") == 0 || strcmp(format, "C") == 0) + { + CompactPrimitiveArray(psArray, abyValidityFromFilters); + } + else if (strcmp(format, "s") == 0 || strcmp(format, "S") == 0 || + strcmp(format, "e") == 0) + { + CompactPrimitiveArray(psArray, abyValidityFromFilters); + } + else if (strcmp(format, "i") == 0 || strcmp(format, "I") == 0 || + strcmp(format, "f") == 0 || strcmp(format, "tdD") == 0 || + strcmp(format, "tts") == 0 || strcmp(format, "ttm") == 0) + { + CompactPrimitiveArray(psArray, abyValidityFromFilters); + } + else if (strcmp(format, "l") == 0 || strcmp(format, "L") == 0 || + strcmp(format, "g") == 0 || strcmp(format, "tdm") == 0 || + strcmp(format, "ttu") == 0 || strcmp(format, "ttn") == 0 || + strncmp(format, "ts", 2) == 0) + { + CompactPrimitiveArray(psArray, abyValidityFromFilters); + } + else if (strcmp(format, "z") == 0 || strcmp(format, "u") == 0) + { + CompactStringOrBinaryArray(psArray, + abyValidityFromFilters); + } + else if (strcmp(format, "Z") == 0 || strcmp(format, "U") == 0) + { + CompactStringOrBinaryArray(psArray, + abyValidityFromFilters); + } + else + { + CPLError(CE_Failure, CPLE_AppDefined, + "Unexpected error in PostFilterArrowArray(): unhandled " + "field format: %s", + format); + + array->release(array); + memset(array, 0, sizeof(*array)); + + break; + } + + CPLAssert(psArray->length == array->length); + } +} diff --git a/ogr/ogrsf_frmts/ogrsf_frmts.h b/ogr/ogrsf_frmts/ogrsf_frmts.h index 8bc8b80c3827..77955591dbce 100644 --- a/ogr/ogrsf_frmts/ogrsf_frmts.h +++ b/ogr/ogrsf_frmts/ogrsf_frmts.h @@ -158,6 +158,10 @@ class CPL_DLL OGRLayer : public GDALMajorObject CreateSchemaForWKBGeometryColumn(const OGRGeomFieldDefn *poFieldDefn, const char *pszArrowFormat = "z"); + bool CanPostFilterArrowArray(const struct ArrowSchema *schema) const; + void PostFilterArrowArray(const struct ArrowSchema *schema, + struct ArrowArray *array) const; + public: OGRLayer(); virtual ~OGRLayer(); diff --git a/perftests/bench_ogr_to_geopandas.py b/perftests/bench_ogr_to_geopandas.py index f10315f529ea..fcf8f2381ad9 100644 --- a/perftests/bench_ogr_to_geopandas.py +++ b/perftests/bench_ogr_to_geopandas.py @@ -7,6 +7,8 @@ from osgeo import ogr +ogr.UseExceptions() + def layer_as_geopandas(lyr): @@ -75,7 +77,49 @@ def layer_as_geopandas(lyr): return df +def Usage(): + print("bench_ogr_to_geopandas.py [-spat xmin ymin xmax ymax] [-where cond]") + print(" filename [layer_name]") + sys.exit(1) + + if __name__ == "__main__": - ds = ogr.Open(sys.argv[1]) - lyr = ds.GetLayer(0) + + i = 1 + filename = None + where = None + minx = None + miny = None + maxx = None + maxy = None + layer_name = None + while i < len(sys.argv): + if sys.argv[i] == "-spat": + minx = float(sys.argv[i + 1]) + miny = float(sys.argv[i + 2]) + maxx = float(sys.argv[i + 3]) + maxy = float(sys.argv[i + 4]) + i += 4 + elif sys.argv[i] == "-where": + where = sys.argv[i + 1] + i += 1 + elif sys.argv[i][0] == "-": + Usage() + elif filename is None: + filename = sys.argv[i] + elif layer_name is None: + layer_name = sys.argv[i] + else: + Usage() + i += 1 + + if not filename: + Usage() + + ds = ogr.Open(filename) + lyr = ds.GetLayer(layer_name if layer_name else 0) + if minx: + lyr.SetSpatialFilterRect(minx, miny, maxx, maxy) + if where: + lyr.SetAttributeFilter(where) print(layer_as_geopandas(lyr))