Skip to content

Commit

Permalink
ArrowArray: implement fast 'FID IN (...)' / 'FID = ...' attribute fil…
Browse files Browse the repository at this point in the history
…ter in generic GetNextArrowArray(), and use it for FlatGeoBuf one too (when it has a spatial index) (fixes OSGeo#8590)
  • Loading branch information
rouault committed Oct 21, 2023
1 parent fada29f commit bc5aadd
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 14 deletions.
24 changes: 20 additions & 4 deletions autotest/ogr/ogr_flatgeobuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ def test_ogr_flatgeobuf_arrow_stream_numpy(layer_creation_options):
for options in ([], ["MAX_FEATURES_IN_BATCH=1"]):
# Test attribute filter
lyr.SetAttributeFilter("int16 = -123")
stream = lyr.GetArrowStreamAsNumPy()
stream = lyr.GetArrowStreamAsNumPy(options)
batches = [batch for batch in stream]
lyr.SetAttributeFilter(None)
assert len(batches) == 1
Expand All @@ -1087,7 +1087,7 @@ def test_ogr_flatgeobuf_arrow_stream_numpy(layer_creation_options):

# Test spatial filter
lyr.SetSpatialFilterRect(0, 0, 10, 10)
stream = lyr.GetArrowStreamAsNumPy()
stream = lyr.GetArrowStreamAsNumPy(options)
batches = [batch for batch in stream]
lyr.SetSpatialFilter(None)
assert len(batches) == 1
Expand All @@ -1098,7 +1098,7 @@ def test_ogr_flatgeobuf_arrow_stream_numpy(layer_creation_options):
# Test attribute + spatial filter: no result
lyr.SetAttributeFilter("int16 = -123")
lyr.SetSpatialFilterRect(0, 0, 10, 10)
stream = lyr.GetArrowStreamAsNumPy()
stream = lyr.GetArrowStreamAsNumPy(options)
batches = [batch for batch in stream]
lyr.SetAttributeFilter(None)
lyr.SetSpatialFilter(None)
Expand All @@ -1107,7 +1107,7 @@ def test_ogr_flatgeobuf_arrow_stream_numpy(layer_creation_options):
# Test attribute + spatial filter: result
lyr.SetAttributeFilter("int16 = -123")
lyr.SetSpatialFilterRect(-1, 2, -1, 2)
stream = lyr.GetArrowStreamAsNumPy()
stream = lyr.GetArrowStreamAsNumPy(options)
batches = [batch for batch in stream]
lyr.SetAttributeFilter(None)
lyr.SetSpatialFilter(None)
Expand All @@ -1116,6 +1116,22 @@ def test_ogr_flatgeobuf_arrow_stream_numpy(layer_creation_options):
assert batches[0]["OGC_FID"][0] == 1
assert batches[0]["int16"][0] == -123

# Test fast FID filtering
lyr.SetAttributeFilter("FID IN (1, -2, 0)")
stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
batches = [batch for batch in stream]
lyr.SetAttributeFilter(None)
assert len(batches) == 1
batch = batches[0]
assert len(batch["OGC_FID"]) == 2
assert set(batch["OGC_FID"]) == set([0, 1])

lyr.SetAttributeFilter("FID = 2")
stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
batches = [batch for batch in stream]
lyr.SetAttributeFilter(None)
assert len(batches) == 0

# Test ignored fields
assert lyr.SetIgnoredFields(["OGR_GEOMETRY", "int16"]) == ogr.OGRERR_NONE
stream = lyr.GetArrowStreamAsNumPy(options=["INCLUDE_FID=NO"])
Expand Down
17 changes: 17 additions & 0 deletions autotest/ogr/ogr_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,23 @@ def test_ogr_mem_arrow_stream_numpy():
assert batch["binary"][1] == b"\xDE\xAD"
assert len(batch["wkb_geometry"][1]) == 21

# Test fast FID filtering
lyr.SetAttributeFilter("FID IN (1, -2, 0)")
stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
batches = [batch for batch in stream]
lyr.SetAttributeFilter(None)
assert len(batches) == 1
batch = batches[0]
assert len(batch["OGC_FID"]) == 2
assert batch["OGC_FID"][0] == 1
assert batch["OGC_FID"][1] == 0

lyr.SetAttributeFilter("FID = 2")
stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
batches = [batch for batch in stream]
lyr.SetAttributeFilter(None)
assert len(batches) == 0


###############################################################################
# Test optimization to save memory on string fields with huge strings compared
Expand Down
18 changes: 13 additions & 5 deletions ogr/ogrsf_frmts/flatgeobuf/ogrflatgeobuflayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,10 +800,9 @@ OGRFeature *OGRFlatGeobufLayer::GetFeature(GIntBig nFeatureId)
}
else
{
if (static_cast<uint64_t>(nFeatureId) >= m_featuresCount)
if (nFeatureId < 0 ||
static_cast<uint64_t>(nFeatureId) >= m_featuresCount)
{
CPLError(CE_Failure, CPLE_AppDefined,
"Requested feature id is out of bounds");
return nullptr;
}
ResetReading();
Expand Down Expand Up @@ -1374,7 +1373,8 @@ OGRErr OGRFlatGeobufLayer::parseFeature(OGRFeature *poFeature)
int OGRFlatGeobufLayer::GetNextArrowArray(struct ArrowArrayStream *stream,
struct ArrowArray *out_array)
{
if (CPLTestBool(
if (!m_poSharedArrowArrayStreamPrivateData->m_anQueriedFIDs.empty() ||
CPLTestBool(
CPLGetConfigOption("OGR_FLATGEOBUF_STREAM_BASE_IMPL", "NO")))
{
return OGRLayer::GetNextArrowArray(stream, out_array);
Expand Down Expand Up @@ -1418,6 +1418,8 @@ int OGRFlatGeobufLayer::GetNextArrowArray(struct ArrowArrayStream *stream,
sHelper.nMaxBatchSize = 0;
}

const GIntBig nFeatureIdxStart = m_featuresPos;

while (iFeat < sHelper.nMaxBatchSize)
{
bEOFOrError = true;
Expand Down Expand Up @@ -1893,7 +1895,13 @@ int OGRFlatGeobufLayer::GetNextArrowArray(struct ArrowArrayStream *stream,
// Spatial filter already evaluated
auto poFilterGeomBackup = m_poFilterGeom;
m_poFilterGeom = nullptr;
PostFilterArrowArray(&schema, out_array, nullptr);
CPLStringList aosOptions;
if (!m_poFilterGeom)
{
aosOptions.SetNameValue("BASE_SEQUENTIAL_FID",
CPLSPrintf(CPL_FRMT_GIB, nFeatureIdxStart));
}
PostFilterArrowArray(&schema, out_array, aosOptions.List());
schema.release(&schema);
m_poFilterGeom = poFilterGeomBackup;
}
Expand Down
69 changes: 64 additions & 5 deletions ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1591,15 +1591,47 @@ int OGRLayer::GetNextArrowArray(struct ArrowArrayStream *stream,

out_array->release = OGRLayerDefaultReleaseArray;

for (int i = 0; i < nMaxBatchSize; i++)
if (!m_poSharedArrowArrayStreamPrivateData->m_anQueriedFIDs.empty())
{
auto poFeature = std::unique_ptr<OGRFeature>(GetNextFeature());
if (!poFeature)
if (m_poSharedArrowArrayStreamPrivateData->m_iQueriedFIDS == 0)
{
CPLDebug("OGR", "Using fast FID filtering");
}
while (
apoFeatures.size() < static_cast<size_t>(nMaxBatchSize) &&
m_poSharedArrowArrayStreamPrivateData->m_iQueriedFIDS <
m_poSharedArrowArrayStreamPrivateData->m_anQueriedFIDs.size())
{
const auto nFID =
m_poSharedArrowArrayStreamPrivateData->m_anQueriedFIDs
[m_poSharedArrowArrayStreamPrivateData->m_iQueriedFIDS];
auto poFeature = std::unique_ptr<OGRFeature>(GetFeature(nFID));
++m_poSharedArrowArrayStreamPrivateData->m_iQueriedFIDS;
if (poFeature && (m_poFilterGeom == nullptr ||
FilterGeometry(poFeature->GetGeomFieldRef(
m_iGeomFieldFilter))))
{
apoFeatures.emplace_back(std::move(poFeature));
}
}
if (m_poSharedArrowArrayStreamPrivateData->m_iQueriedFIDS ==
m_poSharedArrowArrayStreamPrivateData->m_anQueriedFIDs.size())
{
poPrivate->poShared->m_bEOF = true;
break;
}
apoFeatures.emplace_back(std::move(poFeature));
}
else
{
for (int i = 0; i < nMaxBatchSize; i++)
{
auto poFeature = std::unique_ptr<OGRFeature>(GetNextFeature());
if (!poFeature)
{
poPrivate->poShared->m_bEOF = true;
break;
}
apoFeatures.emplace_back(std::move(poFeature));
}
}
if (apoFeatures.empty())
{
Expand Down Expand Up @@ -2077,6 +2109,33 @@ bool OGRLayer::GetArrowStream(struct ArrowArrayStream *out_stream,
m_poSharedArrowArrayStreamPrivateData->m_poLayer = this;
}
m_poSharedArrowArrayStreamPrivateData->m_bArrowArrayStreamInProgress = true;

// Special case for "FID = constant", or "FID IN (constant1, ...., constantN)"
m_poSharedArrowArrayStreamPrivateData->m_anQueriedFIDs.clear();
m_poSharedArrowArrayStreamPrivateData->m_iQueriedFIDS = 0;
if (m_poAttrQuery)
{
swq_expr_node *poNode =
static_cast<swq_expr_node *>(m_poAttrQuery->GetSWQExpr());
if (poNode->eNodeType == SNT_OPERATION &&
(poNode->nOperation == SWQ_IN || poNode->nOperation == SWQ_EQ) &&
poNode->papoSubExpr[0]->eNodeType == SNT_COLUMN &&
poNode->papoSubExpr[0]->field_index ==
GetLayerDefn()->GetFieldCount() + SPF_FID &&
TestCapability(OLCRandomRead))
{
for (int i = 1; i < poNode->nSubExprCount; ++i)
{
if (poNode->papoSubExpr[i]->eNodeType == SNT_CONSTANT &&
poNode->papoSubExpr[i]->field_type == SWQ_INTEGER64)
{
m_poSharedArrowArrayStreamPrivateData->m_anQueriedFIDs
.push_back(poNode->papoSubExpr[i]->int_value);
}
}
}
}

auto poPrivateData = new ArrowArrayStreamPrivateDataSharedDataWrapper();
poPrivateData->poShared = m_poSharedArrowArrayStreamPrivateData;
out_stream->private_data = poPrivateData;
Expand Down
2 changes: 2 additions & 0 deletions ogr/ogrsf_frmts/ogrsf_frmts.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class CPL_DLL OGRLayer : public GDALMajorObject
bool m_bArrowArrayStreamInProgress = false;
bool m_bEOF = false;
OGRLayer *m_poLayer = nullptr;
std::vector<GIntBig> m_anQueriedFIDs{};
size_t m_iQueriedFIDS = 0;
};
std::shared_ptr<ArrowArrayStreamPrivateData>
m_poSharedArrowArrayStreamPrivateData{};
Expand Down

0 comments on commit bc5aadd

Please sign in to comment.