Skip to content

Commit

Permalink
VRTDerivedRasterBand: Support Int8, (U)Int64 with Python pixel functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston authored and github-actions[bot] committed Feb 20, 2024
1 parent 7dfe769 commit 429f481
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
45 changes: 45 additions & 0 deletions autotest/gdrivers/vrtderived.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,51 @@ def identity(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize
_validate(xml)


###############################################################################


@pytest.mark.parametrize("dtype", range(1, gdal.GDT_TypeCount))
def test_vrt_derived_dtype(tmp_vsimem, dtype):
pytest.importorskip("numpy")

input_fname = tmp_vsimem / "input.tif"

nx = 1
ny = 1

with gdal.GetDriverByName("GTiff").Create(
input_fname, nx, ny, 1, eType=gdal.GDT_Int8
) as input_ds:
input_ds.GetRasterBand(1).Fill(1)
gt = input_ds.GetGeoTransform()

vrt_xml = f"""
<VRTDataset rasterXSize="{nx}" rasterYSize="{ny}">
<GeoTransform>{', '.join([str(x) for x in gt])}</GeoTransform>
<VRTRasterBand dataType="{gdal.GetDataTypeName(dtype)}" band="1" subClass="VRTDerivedRasterBand">
<PixelFunctionLanguage>Python</PixelFunctionLanguage>
<PixelFunctionType>identity</PixelFunctionType>
<PixelFunctionCode><![CDATA[
def identity(in_ar, out_ar, *args, **kwargs):
out_ar[:] = in_ar[0]
]]>
</PixelFunctionCode>
<SimpleSource>
<SourceFilename relativeToVRT="0">{input_fname}</SourceFilename>
<SourceBand>1</SourceBand>
<SrcRect xOff="0" yOff="0" xSize="{nx}" ySize="{ny}" />
<DstRect xOff="0" yOff="0" xSize="{nx}" ySize="{ny}" />
</SimpleSource>
</VRTRasterBand></VRTDataset>"""

with gdal.config_option("GDAL_VRT_ENABLE_PYTHON", "YES"):
with gdal.Open(vrt_xml) as vrt_ds:
arr = vrt_ds.ReadAsArray()
if dtype not in {gdal.GDT_CInt16, gdal.GDT_CInt32}:
assert arr[0, 0] == 1
assert vrt_ds.GetRasterBand(1).DataType == dtype


###############################################################################
# Cleanup.

Expand Down
12 changes: 11 additions & 1 deletion frmts/vrt/vrtderivedrasterband.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ static PyObject *GDALCreateNumpyArray(PyObject *pCreateArray, void *pBuffer,
case GDT_Byte:
pszDataType = "uint8";
break;
case GDT_Int8:
pszDataType = "int8";
break;
case GDT_UInt16:
pszDataType = "uint16";
break;
Expand All @@ -100,6 +103,12 @@ static PyObject *GDALCreateNumpyArray(PyObject *pCreateArray, void *pBuffer,
case GDT_Int32:
pszDataType = "int32";
break;
case GDT_Int64:
pszDataType = "int64";
break;
case GDT_UInt64:
pszDataType = "uint64";
break;
case GDT_Float32:
pszDataType = "float32";
break;
Expand All @@ -116,7 +125,8 @@ static PyObject *GDALCreateNumpyArray(PyObject *pCreateArray, void *pBuffer,
case GDT_CFloat64:
pszDataType = "complex128";
break;
default:
case GDT_Unknown:
case GDT_TypeCount:
CPLAssert(FALSE);
break;
}
Expand Down

0 comments on commit 429f481

Please sign in to comment.