Skip to content

Commit

Permalink
gdalattachpct.py: fix it when output file is a VRT (fixes #9513)
Browse files Browse the repository at this point in the history
  • Loading branch information
rouault committed Mar 20, 2024
1 parent 8878a23 commit 9e4699d
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 24 deletions.
119 changes: 119 additions & 0 deletions autotest/pyscripts/test_gdalattachpct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#!/usr/bin/env pytest
###############################################################################
# $Id$
#
# Project: GDAL/OGR Test Suite
# Purpose: gdalattachpct.py testing
# Author: Even Rouault <even dot rouault at spatialys dot com>
#
###############################################################################
# Copyright (c) 2024, Even Rouault <even dot rouault at spatialys dot com>
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
###############################################################################

import pytest
import test_py_scripts

from osgeo import gdal

pytestmark = [
pytest.mark.skipif(
test_py_scripts.get_py_script("gdalattachpct") is None,
reason="gdalattachpct.py not available",
),
]


@pytest.fixture()
def script_path():
return test_py_scripts.get_py_script("gdalattachpct")


@pytest.fixture()
def palette_file(tmp_path):
palette_filename = str(tmp_path / "pal.txt")
with open(palette_filename, "wt") as f:
f.write("0 1 2 3\n")
f.write("255 254 253 252\n")
return palette_filename


###############################################################################
# Basic test


def test_gdalattachpct_basic(script_path, tmp_path, palette_file):

src_filename = str(tmp_path / "src.tif")
ds = gdal.GetDriverByName("GTiff").Create(src_filename, 1, 1)
ds.GetRasterBand(1).Fill(1)
ds = None

out_filename = str(tmp_path / "dst.tif")

test_py_scripts.run_py_script(
script_path,
"gdalattachpct",
f" {palette_file} {src_filename} {out_filename}",
)

ds = gdal.Open(out_filename)
assert ds.GetDriver().ShortName == "GTiff"
assert ds.GetRasterBand(1).GetColorInterpretation() == gdal.GCI_PaletteIndex
ct = ds.GetRasterBand(1).GetColorTable()
assert ct
assert ct.GetCount() == 256
assert ct.GetColorEntry(0) == (1, 2, 3, 255)
assert ct.GetColorEntry(255) == (254, 253, 252, 255)
assert ds.GetRasterBand(1).Checksum() == 1


###############################################################################
# Test outputing to VRT


def test_gdalattachpct_vrt_output(script_path, tmp_path, palette_file):

src_filename = str(tmp_path / "src.tif")
ds = gdal.GetDriverByName("GTiff").Create(src_filename, 1, 1)
ds.GetRasterBand(1).Fill(1)
ds = None

out_filename = str(tmp_path / "dst.vrt")

test_py_scripts.run_py_script(
script_path,
"gdalattachpct",
f" {palette_file} {src_filename} {out_filename}",
)

ds = gdal.Open(out_filename)
assert ds.GetDriver().ShortName == "VRT"
assert ds.GetRasterBand(1).GetColorInterpretation() == gdal.GCI_PaletteIndex
ct = ds.GetRasterBand(1).GetColorTable()
assert ct
assert ct.GetCount() == 256
assert ct.GetColorEntry(0) == (1, 2, 3, 255)
assert ct.GetColorEntry(255) == (254, 253, 252, 255)
assert ds.GetRasterBand(1).Checksum() == 1

# Check source file is not altered
ds = gdal.Open(src_filename)
assert ds.GetRasterBand(1).GetColorTable() is None
56 changes: 32 additions & 24 deletions swig/python/gdal-utils/osgeo_utils/gdalattachpct.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,39 +111,47 @@ def doit(
print("No color table on file ", pct_filename)
return None, 1

# =============================================================================
# Create a MEM clone of the source file.
# =============================================================================

src_ds = open_ds(src_filename)

mem_ds = gdal.GetDriverByName("MEM").CreateCopy("mem", src_ds)

# =============================================================================
# Assign the color table in memory.
# =============================================================================

mem_ds.GetRasterBand(1).SetRasterColorTable(ct)
mem_ds.GetRasterBand(1).SetRasterColorInterpretation(gdal.GCI_PaletteIndex)

# =============================================================================
# Write the dataset to the output file.
# =============================================================================

# Figure out destination driver
if not driver_name:
driver_name = GetOutputDriverFor(dst_filename)

dst_driver = gdal.GetDriverByName(driver_name)
if dst_driver is None:
print('"%s" driver not registered.' % driver_name)
print(f'"{driver_name}" driver not registered.')
return None, 1

src_ds = open_ds(src_filename)
if src_ds is None:
print(f"Cannot open {src_filename}")
return None, 1

if driver_name.upper() == "MEM":
out_ds = mem_ds
if driver_name.upper() == "VRT":
# For VRT, create the VRT first from the source dataset, so it
# correctly referes to it
out_ds = dst_driver.CreateCopy(dst_filename or "", src_ds)
if out_ds is None:
print(f"Cannot create {dst_filename}")
return None, 1

# And now assign the color table to the VRT
out_ds.GetRasterBand(1).SetRasterColorTable(ct)
out_ds.GetRasterBand(1).SetRasterColorInterpretation(gdal.GCI_PaletteIndex)
else:
out_ds = dst_driver.CreateCopy(dst_filename or "", mem_ds)
# Create a MEM clone of the source file.
mem_ds = gdal.GetDriverByName("MEM").CreateCopy("mem", src_ds)

# Assign the color table in memory.
mem_ds.GetRasterBand(1).SetRasterColorTable(ct)
mem_ds.GetRasterBand(1).SetRasterColorInterpretation(gdal.GCI_PaletteIndex)

# Write the dataset to the output file.
if driver_name.upper() == "MEM":
out_ds = mem_ds
else:
out_ds = dst_driver.CreateCopy(dst_filename or "", mem_ds)

mem_ds = None

mem_ds = None
src_ds = None

return out_ds, 0
Expand Down

0 comments on commit 9e4699d

Please sign in to comment.