From daa118fa16fa32095b4f1722ff1fbae5edfccb3c Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 10 Nov 2022 14:51:02 +0100 Subject: [PATCH] add LinearExpressionRolling class adjust rename function --- linopy/expressions.py | 70 ++++++++++++++++++++++++++++++++-- linopy/model.py | 2 +- linopy/variables.py | 39 +++++++++++++++++-- test/test_linear_expression.py | 20 ++++++++++ test/test_variable.py | 2 +- 5 files changed, 124 insertions(+), 9 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 05af5e43..d6729309 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -51,6 +51,7 @@ def _expr_unwrap(maybe_expr): @dataclass +@forward_as_properties(groupby=["dims", "groups"]) class LinearExpressionGroupby: """ GroupBy object specialized to grouping LinearExpression objects. @@ -71,7 +72,30 @@ def func(ds): return self.map(func, **kwargs) + def roll(self, **kwargs): + return self.map(Dataset.roll, **kwargs) + +@dataclass +@forward_as_properties(rolling=["center", "dim", "obj", "rollings", "window"]) +class LinearExpressionRolling: + """ + GroupBy object specialized to grouping LinearExpression objects. + """ + + rolling: xr.core.rolling.DataArrayRolling + + def sum(self, **kwargs): + ds = ( + self.rolling.construct("_rolling_term", keep_attrs=True) + .rename(_term="_stacked_term") + .stack(_term=["_stacked_term", "_rolling_term"]) + .reset_index("_term", drop=True) + ) + return LinearExpression(ds) + + +@dataclass(repr=False) @forward_as_properties(data=["attrs", "coords", "dims", "indexes"]) class LinearExpression: """ @@ -453,7 +477,7 @@ def groupby( group, squeeze: "bool" = True, restore_coord_dims: "bool" = None, - ): + ) -> LinearExpressionGroupby: """ Returns a LinearExpressionGroupBy object for performing grouped operations. @@ -479,7 +503,7 @@ def groupby( A `LinearExpressionGroupBy` containing the xarray groups and ensuring the correct return type. """ - ds = self.to_dataset() + ds = self.data groups = ds.groupby( group=group, squeeze=squeeze, restore_coord_dims=restore_coord_dims ) @@ -513,6 +537,43 @@ def func(ds): return self.__class__(groups.map(func)) # .reset_index('_term') + def rolling( + self, + dim: "Mapping[Any, int]" = None, + min_periods: "int" = None, + center: "bool | Mapping[Any, bool]" = False, + **window_kwargs: "int", + ) -> LinearExpressionRolling: + """ + Rolling window object. + + Docstring and arguments are borrowed from `xarray.Dataset.rolling` + + Parameters + ---------- + dim : dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. + min_periods : int, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or mapping, default: False + Set the labels at the center of the window. + **window_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or window_kwargs must be provided. + + Returns + ------- + linopy.expression.LinearExpressionRolling + """ + ds = self.data + rolling = ds.rolling( + dim=dim, min_periods=min_periods, center=center, **window_kwargs + ) + return LinearExpressionRolling(rolling) + def rolling_sum(self, **kwargs): """ Rolling sum of the linear expression. @@ -621,6 +682,7 @@ def sanitize(self): def equals(self, other: "LinearExpression"): return self.data.equals(_expr_unwrap(other)) + # TODO: make this return a LinearExpression (needs refactoring of __init__) def rename(self, name_dict=None, **names) -> Dataset: return self.data.rename(name_dict, **names) @@ -658,9 +720,9 @@ def __iter__(self): reindex = exprwrap(Dataset.reindex, fill_value=fill_value) - roll = exprwrap(Dataset.roll) + rename_dims = exprwrap(Dataset.rename_dims) - rolling = exprwrap(Dataset.rolling) + roll = exprwrap(Dataset.roll) def _pd_series_wo_index_name(ds): diff --git a/linopy/model.py b/linopy/model.py index 1a26052b..4f82954b 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -565,7 +565,7 @@ def add_constraints( ), "Dimensions of mask not a subset of resulting labels dimensions." labels = labels.where(mask, -1) - lhs = lhs.rename({"_term": f"{name}_term"}) + lhs = lhs.data.rename({"_term": f"{name}_term"}) if self.chunk: lhs = lhs.chunk(self.chunk) diff --git a/linopy/variables.py b/linopy/variables.py index a1cc2fb2..96cb9d81 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -9,7 +9,7 @@ import re from dataclasses import dataclass, field from distutils.log import warn -from typing import Any, Sequence, Union +from typing import Any, Mapping, Sequence, Union import dask import numpy as np @@ -284,6 +284,41 @@ def groupby_sum(self, group): """ return self.to_linexpr().groupby_sum(group) + def rolling( + self, + dim: "Mapping[Any, int]" = None, + min_periods: "int" = None, + center: "bool | Mapping[Any, bool]" = False, + **window_kwargs: "int", + ) -> "expressions.LinearExpressionRolling": + """ + Rolling window object. + + Docstring and arguments are borrowed from `xarray.Dataset.rolling` + + Parameters + ---------- + dim : dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. + min_periods : int, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or mapping, default: False + Set the labels at the center of the window. + **window_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or window_kwargs must be provided. + + Returns + ------- + linopy.expression.LinearExpressionRolling + """ + return self.to_linexpr().rolling( + dim=dim, min_periods=min_periods, center=center, **window_kwargs + ) + def rolling_sum(self, **kwargs): """ Rolling sum of variable. @@ -452,8 +487,6 @@ def equals(self, other): roll = varwrap(DataArray.roll) - rolling = varwrap(DataArray.rolling) - class _LocIndexer: __slots__ = ("variable",) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 5082cfd9..585dfea9 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -292,6 +292,12 @@ def test_groupby(): assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() assert grouped.data._term.size == 12 + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).roll(dim_2=1) + assert grouped.nterm == 1 + assert grouped.vars[0].item() == 19 + def test_groupby_variable(): groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) @@ -328,6 +334,15 @@ def test_groupby_sum_variable(): assert grouped.data._term.size == 10 +def test_rolling(): + expr = 1 * v + rolled = expr.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + + rolled = expr.rolling(dim_2=3).sum() + assert rolled.nterm == 3 + + def test_rolling_sum(): rolled = v.to_linexpr().rolling_sum(dim_2=2) assert rolled.nterm == 2 @@ -338,6 +353,11 @@ def test_rolling_sum(): assert rolled.nterm == 6 +def test_rolling_variable(): + rolled = v.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + + def test_rolling_sum_variable(): rolled = v.rolling_sum(dim_2=2) assert rolled.nterm == 2 diff --git a/test/test_variable.py b/test/test_variable.py index 424e7ef4..16db37a8 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -130,7 +130,7 @@ def test_variable_type_preservation(): x = m.add_variables(coords=[range(10)]) assert isinstance(x.bfill("dim_0"), linopy.variables.Variable) - assert isinstance(x.broadcast_like(x.to_array()), linopy.variables.Variable) + assert isinstance(x.broadcast_like(x.labels), linopy.variables.Variable) assert isinstance(x.ffill("dim_0"), linopy.variables.Variable) assert isinstance(x.fillna(-1), linopy.variables.Variable)