diff --git a/python/resdata/grid/rd_region.py b/python/resdata/grid/rd_region.py index 69ad6e24b..18aca7f69 100644 --- a/python/resdata/grid/rd_region.py +++ b/python/resdata/grid/rd_region.py @@ -1173,9 +1173,7 @@ def contains_active(self, active_index): return self._contains_active(active_index) def kw_index_list(self, rd_kw, force_active): - c_ptr = self._get_kw_index_list(rd_kw, force_active) - index_list = IntVector.createCReference(c_ptr, self) - return index_list + return self._get_kw_index_list(rd_kw, force_active) @property def name(self): diff --git a/python/tests/rd_tests/test_rd_kw.py b/python/tests/rd_tests/test_rd_kw.py index 7adb45cef..e857f262f 100644 --- a/python/tests/rd_tests/test_rd_kw.py +++ b/python/tests/rd_tests/test_rd_kw.py @@ -8,6 +8,7 @@ import random from resdata import ResDataType, ResdataTypeEnum, FileMode +from resdata.grid import ResdataRegion, GridGenerator from resdata.resfile import ResdataKW, ResdataFile, FortIO, openFortIO @@ -645,6 +646,39 @@ def test_imul(): kw2 *= "a" +def test_assign(): + kw1 = ResdataKW("KW1", 5, ResDataType.RD_INT) + kw2 = ResdataKW("KW2", 6, ResDataType.RD_INT) + kw3 = ResdataKW("KW3", 5, ResDataType.RD_FLOAT) + for i in range(len(kw1)): + kw1[i] = 1 + with pytest.raises(TypeError, match="Type / size mismatch"): + kw2.assign(kw1) + with pytest.raises(TypeError, match="Type / size mismatch"): + kw3.assign(kw1) + with pytest.raises(TypeError, match="Type mismatch"): + kw2.assign("a") + with pytest.raises(TypeError, match="Only muliplication with scalar supported"): + kw3.assign("a") + + +def test_apply(): + kw1 = ResdataKW("KW1", 5, ResDataType.RD_INT) + kw2 = ResdataKW("KW2", 6, ResDataType.RD_INT) + kw3 = ResdataKW("KW3", 5, ResDataType.RD_FLOAT) + kw1.assign(1) + kw1.apply(lambda x: x + 1) + assert list(kw1) == [2] * 5 + kw2.assign(5) + kw2.apply(lambda x, y: x + y, arg=5) + assert list(kw2) == [10] * 6 + grid = GridGenerator.create_rectangular(dims=(5, 1, 1), dV=(1, 1, 1)) + region = ResdataRegion(grid, True) + kw3.assign(3.0) + kw3.apply(lambda x: x + 1.0, mask=region) + assert list(kw3) == [4.0] * 5 + + def test_get_ptr_data(): assert ResdataKW("KW1", 10, ResDataType.RD_INT).get_data_ptr() assert ResdataKW("KW1", 10, ResDataType.RD_FLOAT).get_data_ptr() diff --git a/python/tests/rd_tests/test_region.py b/python/tests/rd_tests/test_region.py index 840c70e74..46d0a1ec5 100644 --- a/python/tests/rd_tests/test_region.py +++ b/python/tests/rd_tests/test_region.py @@ -517,3 +517,10 @@ def test_get_set_name(full_region): def test_contains_active(full_region): assert full_region.contains_active(0) + + +def test_kw_index_list(grid, full_region): + kw_int = ResdataKW("INT", grid.get_global_size(), ResDataType.RD_INT) + kw_float = ResdataKW("FLOAT", grid.get_global_size(), ResDataType.RD_FLOAT) + full_region.kw_index_list(kw_int, False) + full_region.kw_index_list(kw_float, True)