Skip to content

Commit

Permalink
add LinearExpressionRolling class
Browse files Browse the repository at this point in the history
adjust rename function
  • Loading branch information
FabianHofmann committed Nov 11, 2022
1 parent cfb4427 commit c07cda8
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 6 deletions.
70 changes: 66 additions & 4 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _expr_unwrap(maybe_expr):


@dataclass
@forward_as_properties(groupby=["dims", "groups"])
class LinearExpressionGroupby:
"""
GroupBy object specialized to grouping LinearExpression objects.
Expand All @@ -71,7 +72,30 @@ def func(ds):

return self.map(func, **kwargs)

def roll(self, **kwargs):
return self.map(Dataset.roll, **kwargs)


@dataclass
@forward_as_properties(rolling=["center", "dim", "obj", "rollings", "window"])
class LinearExpressionRolling:
"""
GroupBy object specialized to grouping LinearExpression objects.
"""

rolling: xr.core.rolling.DataArrayRolling

def sum(self, **kwargs):
ds = (
self.rolling.construct("_rolling_term", keep_attrs=True)
.rename(_term="_stacked_term")
.stack(_term=["_stacked_term", "_rolling_term"])
.reset_index("_term", drop=True)
)
return LinearExpression(ds)


@dataclass(repr=False)
@forward_as_properties(data=["attrs", "coords", "dims", "indexes"])
class LinearExpression:
"""
Expand Down Expand Up @@ -453,7 +477,7 @@ def groupby(
group,
squeeze: "bool" = True,
restore_coord_dims: "bool" = None,
):
) -> LinearExpressionGroupby:
"""
Returns a LinearExpressionGroupBy object for performing grouped
operations.
Expand All @@ -479,7 +503,7 @@ def groupby(
A `LinearExpressionGroupBy` containing the xarray groups and ensuring
the correct return type.
"""
ds = self.to_dataset()
ds = self.data
groups = ds.groupby(
group=group, squeeze=squeeze, restore_coord_dims=restore_coord_dims
)
Expand Down Expand Up @@ -513,6 +537,43 @@ def func(ds):

return self.__class__(groups.map(func)) # .reset_index('_term')

def rolling(
self,
dim: "Mapping[Any, int]" = None,
min_periods: "int" = None,
center: "bool | Mapping[Any, bool]" = False,
**window_kwargs: "int",
) -> LinearExpressionRolling:
"""
Rolling window object.
Docstring and arguments are borrowed from `xarray.Dataset.rolling`
Parameters
----------
dim : dict, optional
Mapping from the dimension name to create the rolling iterator
along (e.g. `time`) to its moving window size.
min_periods : int, default: None
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : bool or mapping, default: False
Set the labels at the center of the window.
**window_kwargs : optional
The keyword arguments form of ``dim``.
One of dim or window_kwargs must be provided.
Returns
-------
linopy.expression.LinearExpressionRolling
"""
ds = self.data
rolling = ds.rolling(
dim=dim, min_periods=min_periods, center=center, **window_kwargs
)
return LinearExpressionRolling(rolling)

def rolling_sum(self, **kwargs):
"""
Rolling sum of the linear expression.
Expand Down Expand Up @@ -621,6 +682,7 @@ def sanitize(self):
def equals(self, other: "LinearExpression"):
return self.data.equals(_expr_unwrap(other))

# TODO: make this return a LinearExpression (needs refactoring of __init__)
def rename(self, name_dict=None, **names) -> Dataset:
return self.data.rename(name_dict, **names)

Expand Down Expand Up @@ -658,9 +720,9 @@ def __iter__(self):

reindex = exprwrap(Dataset.reindex, fill_value=fill_value)

roll = exprwrap(Dataset.roll)
rename_dims = exprwrap(Dataset.rename_dims)

rolling = exprwrap(Dataset.rolling)
roll = exprwrap(Dataset.roll)


def _pd_series_wo_index_name(ds):
Expand Down
2 changes: 1 addition & 1 deletion linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def add_constraints(
), "Dimensions of mask not a subset of resulting labels dimensions."
labels = labels.where(mask, -1)

lhs = lhs.rename({"_term": f"{name}_term"})
lhs = lhs.data.rename({"_term": f"{name}_term"})

if self.chunk:
lhs = lhs.chunk(self.chunk)
Expand Down
15 changes: 15 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ def test_groupby():
assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all()
assert grouped.data._term.size == 12

expr = 1 * v
groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords)
grouped = expr.groupby(groups).roll(dim_2=1)
assert grouped.nterm == 1
assert grouped.vars[0].item() == 19


def test_groupby_variable():
groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords)
Expand Down Expand Up @@ -328,6 +334,15 @@ def test_groupby_sum_variable():
assert grouped.data._term.size == 10


def test_rolling():
expr = 1 * v
rolled = expr.rolling(dim_2=2).sum()
assert rolled.nterm == 2

rolled = expr.rolling(dim_2=3).sum()
assert rolled.nterm == 3


def test_rolling_sum():
rolled = v.to_linexpr().rolling_sum(dim_2=2)
assert rolled.nterm == 2
Expand Down
2 changes: 1 addition & 1 deletion test/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_variable_type_preservation():
x = m.add_variables(coords=[range(10)])

assert isinstance(x.bfill("dim_0"), linopy.variables.Variable)
assert isinstance(x.broadcast_like(x.to_array()), linopy.variables.Variable)
assert isinstance(x.broadcast_like(x.labels), linopy.variables.Variable)
assert isinstance(x.ffill("dim_0"), linopy.variables.Variable)
assert isinstance(x.fillna(-1), linopy.variables.Variable)

Expand Down

0 comments on commit c07cda8

Please sign in to comment.