Skip to content

Commit

Permalink
Fix issue with kw_index_list and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
larsevj committed Jan 6, 2025
1 parent 00792c6 commit e831f0c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
4 changes: 1 addition & 3 deletions python/resdata/grid/rd_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,9 +1173,7 @@ def contains_active(self, active_index):
return self._contains_active(active_index)

def kw_index_list(self, rd_kw, force_active):
c_ptr = self._get_kw_index_list(rd_kw, force_active)
index_list = IntVector.createCReference(c_ptr, self)
return index_list
return self._get_kw_index_list(rd_kw, force_active)

@property
def name(self):
Expand Down
34 changes: 34 additions & 0 deletions python/tests/rd_tests/test_rd_kw.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random

from resdata import ResDataType, ResdataTypeEnum, FileMode
from resdata.grid import ResdataRegion, GridGenerator
from resdata.resfile import ResdataKW, ResdataFile, FortIO, openFortIO


Expand Down Expand Up @@ -645,6 +646,39 @@ def test_imul():
kw2 *= "a"


def test_assign():
kw1 = ResdataKW("KW1", 5, ResDataType.RD_INT)
kw2 = ResdataKW("KW2", 6, ResDataType.RD_INT)
kw3 = ResdataKW("KW3", 5, ResDataType.RD_FLOAT)
for i in range(len(kw1)):
kw1[i] = 1
with pytest.raises(TypeError, match="Type / size mismatch"):
kw2.assign(kw1)
with pytest.raises(TypeError, match="Type / size mismatch"):
kw3.assign(kw1)
with pytest.raises(TypeError, match="Type mismatch"):
kw2.assign("a")
with pytest.raises(TypeError, match="Only muliplication with scalar supported"):
kw3.assign("a")


def test_apply():
kw1 = ResdataKW("KW1", 5, ResDataType.RD_INT)
kw2 = ResdataKW("KW2", 6, ResDataType.RD_INT)
kw3 = ResdataKW("KW3", 5, ResDataType.RD_FLOAT)
kw1.assign(1)
kw1.apply(lambda x: x + 1)
assert list(kw1) == [2] * 5
kw2.assign(5)
kw2.apply(lambda x, y: x + y, arg=5)
assert list(kw2) == [10] * 6
grid = GridGenerator.create_rectangular(dims=(5, 1, 1), dV=(1, 1, 1))
region = ResdataRegion(grid, True)
kw3.assign(3.0)
kw3.apply(lambda x: x + 1.0, mask=region)
assert list(kw3) == [4.0]*5


def test_get_ptr_data():
assert ResdataKW("KW1", 10, ResDataType.RD_INT).get_data_ptr()
assert ResdataKW("KW1", 10, ResDataType.RD_FLOAT).get_data_ptr()
Expand Down

0 comments on commit e831f0c

Please sign in to comment.