Skip to content

Commit

Permalink
GetNextArrowArray() implementations: automatically adjust batch size …
Browse files Browse the repository at this point in the history
…of list/string/binary arrays do not saturate their capacity (2 billion elements)
  • Loading branch information
rouault committed Oct 29, 2023
1 parent 4de5a34 commit 342fc84
Show file tree
Hide file tree
Showing 10 changed files with 814 additions and 213 deletions.
71 changes: 71 additions & 0 deletions autotest/ogr/ogr_gpkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9515,3 +9515,74 @@ def test_ogr_gpkg_sql_exact_spatial_filter_for_feature_count(tmp_vsimem):
lyr = ds.GetLayer(0)
lyr.SetSpatialFilterRect(0.1, 0.2, 0.15, 0.3)
assert lyr.GetFeatureCount() == 1


###############################################################################


@pytest.mark.parametrize("too_big_field", ["huge_string", "huge_binary", "geometry"])
def test_ogr_gpkg_arrow_stream_huge_array(tmp_vsimem, too_big_field):
pytest.importorskip("osgeo.gdal_array")
pytest.importorskip("numpy")

filename = tmp_vsimem / "test_ogr_gpkg_arrow_stream_huge_array.gpkg"
ds = gdal.GetDriverByName("GPKG").Create(filename, 0, 0, 0, gdal.GDT_Unknown)
lyr = ds.CreateLayer("foo")
lyr.CreateField(ogr.FieldDefn("huge_string", ogr.OFTString))
lyr.CreateField(ogr.FieldDefn("huge_binary", ogr.OFTBinary))
for i in range(50):
f = ogr.Feature(lyr.GetLayerDefn())
if too_big_field == "huge_string":
if i > 10:
f["huge_string"] = "x" * 10000
elif too_big_field == "huge_binary":
if i > 10:
f["huge_binary"] = b"x" * 10000
else:
geom = ogr.Geometry(ogr.wkbLineString)
geom.SetPoint_2D(500, 0, 0)
f.SetGeometry(geom)
lyr.CreateFeature(f)
ds.ExecuteSQL(
"CREATE VIEW my_view AS SELECT fid AS my_fid, geom AS my_geom, huge_string, huge_binary FROM foo"
)
ds.ExecuteSQL(
"INSERT INTO gpkg_contents (table_name, identifier, data_type, srs_id) VALUES ( 'my_view', 'my_view', 'features', 0 )"
)
ds.ExecuteSQL(
"INSERT INTO gpkg_geometry_columns (table_name, column_name, geometry_type_name, srs_id, z, m) values ('my_view', 'my_geom', 'GEOMETRY', 0, 0, 0)"
)
ds = None

ds = ogr.Open(filename)
for lyr_name in ["foo", "my_view"]:
lyr = ds.GetLayer(lyr_name)

with gdaltest.config_option("OGR_ARROW_MEM_LIMIT", "20000", thread_local=False):
stream = lyr.GetArrowStreamAsNumPy(
["INCLUDE_FID=YES", "MAX_FEATURES_IN_BATCH=10", "USE_MASKED_ARRAYS=NO"]
)
batch_count = 0
got_fids = []
for batch in stream:
batch_count += 1
for fid in batch[lyr.GetFIDColumn()]:
got_fids.append(fid)
assert got_fids == [i + 1 for i in range(50)]
assert batch_count == (25 if too_big_field == "geometry" else 21), lyr_name
del stream

with ds.ExecuteSQL("SELECT * FROM foo") as sql_lyr:
with gdaltest.config_option("OGR_ARROW_MEM_LIMIT", "20000", thread_local=False):
stream = sql_lyr.GetArrowStreamAsNumPy(
["INCLUDE_FID=YES", "MAX_FEATURES_IN_BATCH=10", "USE_MASKED_ARRAYS=NO"]
)
batch_count = 0
got_fids = []
for batch in stream:
batch_count += 1
for fid in batch[sql_lyr.GetFIDColumn()]:
got_fids.append(fid)
assert got_fids == [i + 1 for i in range(50)]
assert batch_count == (25 if too_big_field == "geometry" else 21), lyr_name
del stream
174 changes: 174 additions & 0 deletions autotest/ogr/ogr_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,180 @@ def test_ogr_mem_arrow_stream_numpy():
assert len(batches) == 0


###############################################################################


@pytest.mark.parametrize(
"limited_field",
[
"str",
"strlist",
"int32list",
"int64list",
"float64list",
"boollist",
"binary",
"binary_fixed_width",
"geometry",
],
)
def test_ogr_mem_arrow_stream_numpy_memlimit(limited_field):
pytest.importorskip("osgeo.gdal_array")
pytest.importorskip("numpy")

ds = ogr.GetDriverByName("Memory").CreateDataSource("")
lyr = ds.CreateLayer("foo")

field = ogr.FieldDefn("str", ogr.OFTString)
lyr.CreateField(field)

field = ogr.FieldDefn("bool", ogr.OFTInteger)
field.SetSubType(ogr.OFSTBoolean)
lyr.CreateField(field)

field = ogr.FieldDefn("int16", ogr.OFTInteger)
field.SetSubType(ogr.OFSTInt16)
lyr.CreateField(field)

field = ogr.FieldDefn("int32", ogr.OFTInteger)
lyr.CreateField(field)

field = ogr.FieldDefn("int64", ogr.OFTInteger64)
lyr.CreateField(field)

field = ogr.FieldDefn("float32", ogr.OFTReal)
field.SetSubType(ogr.OFSTFloat32)
lyr.CreateField(field)

field = ogr.FieldDefn("float64", ogr.OFTReal)
lyr.CreateField(field)

field = ogr.FieldDefn("date", ogr.OFTDate)
lyr.CreateField(field)

field = ogr.FieldDefn("time", ogr.OFTTime)
lyr.CreateField(field)

field = ogr.FieldDefn("datetime", ogr.OFTDateTime)
lyr.CreateField(field)

field = ogr.FieldDefn("binary", ogr.OFTBinary)
lyr.CreateField(field)

field = ogr.FieldDefn("binary_fixed_width", ogr.OFTBinary)
field.SetWidth(50)
lyr.CreateField(field)

field = ogr.FieldDefn("strlist", ogr.OFTStringList)
lyr.CreateField(field)

field = ogr.FieldDefn("boollist", ogr.OFTIntegerList)
field.SetSubType(ogr.OFSTBoolean)
lyr.CreateField(field)

field = ogr.FieldDefn("int16list", ogr.OFTIntegerList)
field.SetSubType(ogr.OFSTInt16)
lyr.CreateField(field)

field = ogr.FieldDefn("int32list", ogr.OFTIntegerList)
lyr.CreateField(field)

field = ogr.FieldDefn("int64list", ogr.OFTInteger64List)
lyr.CreateField(field)

field = ogr.FieldDefn("float32list", ogr.OFTRealList)
field.SetSubType(ogr.OFSTFloat32)
lyr.CreateField(field)

field = ogr.FieldDefn("float64list", ogr.OFTRealList)
lyr.CreateField(field)

f = ogr.Feature(lyr.GetLayerDefn())
f.SetField("bool", 1)
f.SetField("int16", -12345)
f.SetField("int32", 12345678)
f.SetField("int64", 12345678901234)
f.SetField("float32", 1.25)
f.SetField("float64", 1.250123)
f.SetField("str", "abc")
f.SetField("date", "2022-05-31")
f.SetField("time", "12:34:56.789")
f.SetField("datetime", "2022-05-31T12:34:56.789Z")
f.SetField("boollist", "[False,True]")
f.SetField("int16list", "[-12345,12345]")
f.SetField("int32list", "[-12345678,12345678]")
f.SetField("int64list", "[-12345678901234,12345678901234]")
f.SetField("float32list", "[-1.25,1.25]")
f.SetField("float64list", "[-1.250123,1.250123]")
f.SetField("strlist", '["abc","defghi"]')
f.SetField("binary", b"\xDE\xAD")
f.SetField("binary_fixed_width", b"\xDE\xAD" * 25)
f.SetGeometryDirectly(ogr.CreateGeometryFromWkt("POINT(1 2)"))
lyr.CreateFeature(f)

f = ogr.Feature(lyr.GetLayerDefn())
lyr.CreateFeature(f)

f = ogr.Feature(lyr.GetLayerDefn())
if limited_field == "str":
f["str"] = "x" * 100
elif limited_field == "strlist":
f["strlist"] = ["x" * 100]
elif limited_field == "int32list":
f["int32list"] = [0] * 100
elif limited_field == "int64list":
f["int64list"] = [0] * 100
elif limited_field == "float64list":
f["float64list"] = [0] * 100
elif limited_field == "boollist":
f["boollist"] = [False] * 100
elif limited_field == "binary":
f["binary"] = b"x" * 100
elif limited_field == "binary_fixed_width":
f.SetField("binary_fixed_width", b"\xDE\xAD" * 25)
elif limited_field == "geometry":
g = ogr.Geometry(ogr.wkbLineString)
sizeof_first_point = 21
sizeof_linestring_preamble = 9
sizeof_coord_pair = 2 * 8
g.SetPoint_2D(
(100 - sizeof_linestring_preamble - sizeof_first_point)
// sizeof_coord_pair,
0,
0,
)
f.SetGeometry(g)
lyr.CreateFeature(f)

with gdaltest.config_option("OGR_ARROW_MEM_LIMIT", "100", thread_local=False):
stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
batches = [batch for batch in stream]
assert len(batches) == 2
assert [x for x in batches[0]["OGC_FID"]] == [0, 1]
assert [x for x in batches[1]["OGC_FID"]] == [2]

with gdaltest.config_option("OGR_ARROW_MEM_LIMIT", "1", thread_local=False):

stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
with gdal.quiet_errors():
gdal.ErrorReset()
batches = [batch for batch in stream]
assert (
gdal.GetLastErrorMsg()
== "Too large feature: not even a single feature can be returned"
)
assert len(batches) == 0

stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
with gdaltest.enable_exceptions():
with pytest.raises(
Exception,
match="Too large feature: not even a single feature can be returned",
):
batches = [batch for batch in stream]
assert len(batches) == 0


###############################################################################
# Test optimization to save memory on string fields with huge strings compared
# to the average size
Expand Down
30 changes: 29 additions & 1 deletion ogr/ogrsf_frmts/flatgeobuf/ogrflatgeobuflayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,7 @@ int OGRFlatGeobufLayer::GetNextArrowArray(struct ArrowArrayStream *stream,

const GIntBig nFeatureIdxStart = m_featuresPos;

const uint32_t nMemLimit = OGRArrowArrayHelper::GetMemLimit();
while (iFeat < sHelper.nMaxBatchSize)
{
bEOFOrError = true;
Expand Down Expand Up @@ -1538,6 +1539,20 @@ int OGRFlatGeobufLayer::GetNextArrowArray(struct ArrowArrayStream *stream,

const int iArrowField = sHelper.mapOGRGeomFieldToArrowField[0];
const size_t nWKBSize = poOGRGeometry->WkbSize();

if (iFeat > 0)
{
auto psArray = out_array->children[iArrowField];
auto panOffsets = static_cast<int32_t *>(
const_cast<void *>(psArray->buffers[1]));
const uint32_t nCurLength =
static_cast<uint32_t>(panOffsets[iFeat]);
if (nWKBSize <= nMemLimit && nWKBSize > nMemLimit - nCurLength)
{
goto after_loop;
}
}

GByte *outPtr =
sHelper.GetPtrForStringOrBinary(iArrowField, iFeat, nWKBSize);
if (outPtr == nullptr)
Expand Down Expand Up @@ -1794,6 +1809,19 @@ int OGRFlatGeobufLayer::GetNextArrowArray(struct ArrowArrayStream *stream,
}
if (!isIgnored)
{
if (iFeat > 0)
{
auto panOffsets = static_cast<int32_t *>(
const_cast<void *>(psArray->buffers[1]));
const uint32_t nCurLength =
static_cast<uint32_t>(panOffsets[iFeat]);
if (len <= nMemLimit &&
len > nMemLimit - nCurLength)
{
goto after_loop;
}
}

GByte *outPtr = sHelper.GetPtrForStringOrBinary(
iArrowField, iFeat, len);
if (outPtr == nullptr)
Expand Down Expand Up @@ -1880,7 +1908,7 @@ int OGRFlatGeobufLayer::GetNextArrowArray(struct ArrowArrayStream *stream,
m_featuresPos++;
bEOFOrError = false;
}

after_loop:
if (bEOFOrError)
m_bEOF = true;

Expand Down
25 changes: 24 additions & 1 deletion ogr/ogrsf_frmts/generic/ograrrowarrayhelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,32 @@
//! @cond Doxygen_Suppress

/************************************************************************/
/* GetMaxFeaturesInBatch() */
/* GetMemLimit() */
/************************************************************************/

/*static*/ uint32_t OGRArrowArrayHelper::GetMemLimit()
{
uint32_t nMemLimit =
static_cast<uint32_t>(std::numeric_limits<int32_t>::max());
// Just for tests
const char *pszOGR_ARROW_MEM_LIMIT =
CPLGetConfigOption("OGR_ARROW_MEM_LIMIT", nullptr);
if (pszOGR_ARROW_MEM_LIMIT)
nMemLimit = atoi(pszOGR_ARROW_MEM_LIMIT);
else
{
const uint64_t nUsableRAM = CPLGetUsablePhysicalRAM();
if (nUsableRAM > 0 && nUsableRAM / 4 < nMemLimit)
nMemLimit = static_cast<uint32_t>(nUsableRAM / 4);
}
return nMemLimit;
}

/************************************************************************/
/* GetMaxFeaturesInBatch() */
/************************************************************************/

/* static */
int OGRArrowArrayHelper::GetMaxFeaturesInBatch(
const CPLStringList &aosArrowArrayStreamOptions)
{
Expand Down
2 changes: 2 additions & 0 deletions ogr/ogrsf_frmts/generic/ograrrowarrayhelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class CPL_DLL OGRArrowArrayHelper
int64_t *panFIDValues = nullptr;
struct ArrowArray *m_out_array = nullptr;

static uint32_t GetMemLimit();

static int
GetMaxFeaturesInBatch(const CPLStringList &aosArrowArrayStreamOptions);

Expand Down
Loading

0 comments on commit 342fc84

Please sign in to comment.