Skip to content

Commit

Permalink
add LinearExpressionRolling class
Browse files Browse the repository at this point in the history
adjust rename function
  • Loading branch information
FabianHofmann committed Nov 14, 2022
1 parent cfb4427 commit daa118f
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 9 deletions.
70 changes: 66 additions & 4 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _expr_unwrap(maybe_expr):


@dataclass
@forward_as_properties(groupby=["dims", "groups"])
class LinearExpressionGroupby:
"""
GroupBy object specialized to grouping LinearExpression objects.
Expand All @@ -71,7 +72,30 @@ def func(ds):

return self.map(func, **kwargs)

def roll(self, **kwargs):
return self.map(Dataset.roll, **kwargs)


@dataclass
@forward_as_properties(rolling=["center", "dim", "obj", "rollings", "window"])
class LinearExpressionRolling:
"""
GroupBy object specialized to grouping LinearExpression objects.
"""

rolling: xr.core.rolling.DataArrayRolling

def sum(self, **kwargs):
ds = (
self.rolling.construct("_rolling_term", keep_attrs=True)
.rename(_term="_stacked_term")
.stack(_term=["_stacked_term", "_rolling_term"])
.reset_index("_term", drop=True)
)
return LinearExpression(ds)


@dataclass(repr=False)
@forward_as_properties(data=["attrs", "coords", "dims", "indexes"])
class LinearExpression:
"""
Expand Down Expand Up @@ -453,7 +477,7 @@ def groupby(
group,
squeeze: "bool" = True,
restore_coord_dims: "bool" = None,
):
) -> LinearExpressionGroupby:
"""
Returns a LinearExpressionGroupBy object for performing grouped
operations.
Expand All @@ -479,7 +503,7 @@ def groupby(
A `LinearExpressionGroupBy` containing the xarray groups and ensuring
the correct return type.
"""
ds = self.to_dataset()
ds = self.data
groups = ds.groupby(
group=group, squeeze=squeeze, restore_coord_dims=restore_coord_dims
)
Expand Down Expand Up @@ -513,6 +537,43 @@ def func(ds):

return self.__class__(groups.map(func)) # .reset_index('_term')

def rolling(
self,
dim: "Mapping[Any, int]" = None,
min_periods: "int" = None,
center: "bool | Mapping[Any, bool]" = False,
**window_kwargs: "int",
) -> LinearExpressionRolling:
"""
Rolling window object.
Docstring and arguments are borrowed from `xarray.Dataset.rolling`
Parameters
----------
dim : dict, optional
Mapping from the dimension name to create the rolling iterator
along (e.g. `time`) to its moving window size.
min_periods : int, default: None
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : bool or mapping, default: False
Set the labels at the center of the window.
**window_kwargs : optional
The keyword arguments form of ``dim``.
One of dim or window_kwargs must be provided.
Returns
-------
linopy.expression.LinearExpressionRolling
"""
ds = self.data
rolling = ds.rolling(
dim=dim, min_periods=min_periods, center=center, **window_kwargs
)
return LinearExpressionRolling(rolling)

def rolling_sum(self, **kwargs):
"""
Rolling sum of the linear expression.
Expand Down Expand Up @@ -621,6 +682,7 @@ def sanitize(self):
def equals(self, other: "LinearExpression"):
return self.data.equals(_expr_unwrap(other))

# TODO: make this return a LinearExpression (needs refactoring of __init__)
def rename(self, name_dict=None, **names) -> Dataset:
return self.data.rename(name_dict, **names)

Expand Down Expand Up @@ -658,9 +720,9 @@ def __iter__(self):

reindex = exprwrap(Dataset.reindex, fill_value=fill_value)

roll = exprwrap(Dataset.roll)
rename_dims = exprwrap(Dataset.rename_dims)

rolling = exprwrap(Dataset.rolling)
roll = exprwrap(Dataset.roll)


def _pd_series_wo_index_name(ds):
Expand Down
2 changes: 1 addition & 1 deletion linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def add_constraints(
), "Dimensions of mask not a subset of resulting labels dimensions."
labels = labels.where(mask, -1)

lhs = lhs.rename({"_term": f"{name}_term"})
lhs = lhs.data.rename({"_term": f"{name}_term"})

if self.chunk:
lhs = lhs.chunk(self.chunk)
Expand Down
39 changes: 36 additions & 3 deletions linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
from dataclasses import dataclass, field
from distutils.log import warn
from typing import Any, Sequence, Union
from typing import Any, Mapping, Sequence, Union

import dask
import numpy as np
Expand Down Expand Up @@ -284,6 +284,41 @@ def groupby_sum(self, group):
"""
return self.to_linexpr().groupby_sum(group)

def rolling(
self,
dim: "Mapping[Any, int]" = None,
min_periods: "int" = None,
center: "bool | Mapping[Any, bool]" = False,
**window_kwargs: "int",
) -> "expressions.LinearExpressionRolling":
"""
Rolling window object.
Docstring and arguments are borrowed from `xarray.Dataset.rolling`
Parameters
----------
dim : dict, optional
Mapping from the dimension name to create the rolling iterator
along (e.g. `time`) to its moving window size.
min_periods : int, default: None
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : bool or mapping, default: False
Set the labels at the center of the window.
**window_kwargs : optional
The keyword arguments form of ``dim``.
One of dim or window_kwargs must be provided.
Returns
-------
linopy.expression.LinearExpressionRolling
"""
return self.to_linexpr().rolling(
dim=dim, min_periods=min_periods, center=center, **window_kwargs
)

def rolling_sum(self, **kwargs):
"""
Rolling sum of variable.
Expand Down Expand Up @@ -452,8 +487,6 @@ def equals(self, other):

roll = varwrap(DataArray.roll)

rolling = varwrap(DataArray.rolling)


class _LocIndexer:
__slots__ = ("variable",)
Expand Down
20 changes: 20 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ def test_groupby():
assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all()
assert grouped.data._term.size == 12

expr = 1 * v
groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords)
grouped = expr.groupby(groups).roll(dim_2=1)
assert grouped.nterm == 1
assert grouped.vars[0].item() == 19


def test_groupby_variable():
groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords)
Expand Down Expand Up @@ -328,6 +334,15 @@ def test_groupby_sum_variable():
assert grouped.data._term.size == 10


def test_rolling():
expr = 1 * v
rolled = expr.rolling(dim_2=2).sum()
assert rolled.nterm == 2

rolled = expr.rolling(dim_2=3).sum()
assert rolled.nterm == 3


def test_rolling_sum():
rolled = v.to_linexpr().rolling_sum(dim_2=2)
assert rolled.nterm == 2
Expand All @@ -338,6 +353,11 @@ def test_rolling_sum():
assert rolled.nterm == 6


def test_rolling_variable():
rolled = v.rolling(dim_2=2).sum()
assert rolled.nterm == 2


def test_rolling_sum_variable():
rolled = v.rolling_sum(dim_2=2)
assert rolled.nterm == 2
Expand Down
2 changes: 1 addition & 1 deletion test/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_variable_type_preservation():
x = m.add_variables(coords=[range(10)])

assert isinstance(x.bfill("dim_0"), linopy.variables.Variable)
assert isinstance(x.broadcast_like(x.to_array()), linopy.variables.Variable)
assert isinstance(x.broadcast_like(x.labels), linopy.variables.Variable)
assert isinstance(x.ffill("dim_0"), linopy.variables.Variable)
assert isinstance(x.fillna(-1), linopy.variables.Variable)

Expand Down

0 comments on commit daa118f

Please sign in to comment.