Skip to content

Commit

Permalink
fix: make sure last slice is included in iterate slices (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianHofmann authored Dec 23, 2024
1 parent 3fb5b6c commit b558f55
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
2 changes: 2 additions & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Release Notes
Upcoming Version
----------------


* IMPORTANT BUGFIX: The last slice of constraints was not correctly written to LP files in case the constraint size was not a multiple of the slice size. This is fixed now.
* Solution files that following a different naming scheme of variables and constraints using more than on initial letter in the prefix (e.g. `col123`, `row456`) are now supported.

Version 0.4.3
Expand Down
6 changes: 3 additions & 3 deletions linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def iterate_slices(
return

# number of slices
n_slices = max(size // slice_size, 1)
n_slices = max((size + slice_size - 1) // slice_size, 1)

# leading dimension (the dimension with the largest size)
sizes = {dim: ds.sizes[dim] for dim in slice_dims}
Expand All @@ -533,12 +533,12 @@ def iterate_slices(
if size_of_leading_dim < n_slices:
n_slices = size_of_leading_dim

chunk_size = ds.sizes[leading_dim] // n_slices
chunk_size = (ds.sizes[leading_dim] + n_slices - 1) // n_slices

# Iterate over the Cartesian product of slice indices
for i in range(n_slices):
start = i * chunk_size
end = start + chunk_size
end = min(start + chunk_size, size_of_leading_dim)
slice_dict = {leading_dim: slice(start, end)}
yield ds.isel(slice_dict)

Expand Down
30 changes: 30 additions & 0 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,20 @@ def test_iterate_slices_slice_size_none():
assert ds.equals(s)


def test_iterate_slices_includes_last_slice():
ds = xr.Dataset(
{"var": (("x"), np.random.rand(10))}, # noqa: NPY002
coords={"x": np.arange(10)},
)
slices = list(iterate_slices(ds, slice_size=3, slice_dims=["x"]))
assert len(slices) == 4 # 10 slices for dimension 'x' with size 10
total_elements = sum(s.sizes["x"] for s in slices)
assert total_elements == ds.sizes["x"] # Ensure all elements are included
for s in slices:
assert isinstance(s, xr.Dataset)
assert set(s.dims) == set(ds.dims)


def test_iterate_slices_empty_slice_dims():
ds = xr.Dataset(
{"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002
Expand All @@ -542,6 +556,22 @@ def test_iterate_slices_invalid_slice_dims():
list(iterate_slices(ds, slice_size=50, slice_dims=["z"]))


def test_iterate_slices_empty_dataset():
ds = xr.Dataset(
{"var": (("x", "y"), np.array([]).reshape(0, 0))}, coords={"x": [], "y": []}
)
slices = list(iterate_slices(ds, slice_size=10, slice_dims=["x"]))
assert len(slices) == 1
assert ds.equals(slices[0])


def test_iterate_slices_single_element():
ds = xr.Dataset({"var": (("x", "y"), np.array([[1]]))}, coords={"x": [0], "y": [0]})
slices = list(iterate_slices(ds, slice_size=1, slice_dims=["x"]))
assert len(slices) == 1
assert ds.equals(slices[0])


def test_get_dims_with_index_levels():
# Create test data

Expand Down

0 comments on commit b558f55

Please sign in to comment.