diff --git a/doc/release_notes.rst b/doc/release_notes.rst index f4444d0c..bd3c8b3f 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -1,8 +1,11 @@ Release Notes ============= -.. Upcoming Release -.. ---------------- +Upcoming Release +---------------- + +* The internal data structure of linopy classes were updated to a safer design. Instead of being defined as inherited xarray classes, the class `Variable`, `LinearExpression` and `Constraint` are now dataclasses with containing the xarray objects in the data field. This allows the package to have more flexible function design and a reduced set of wrapped functions that are sensible to use in the optimization context. +* The class `Variable` and `LinearExpression` have new functions `groupby` and `rolling` imitating the corresponding xarray functions but with safe type inheritance and application of appended operations. Version 0.0.15 diff --git a/linopy/__init__.py b/linopy/__init__.py index 869a86a0..8fb0cdd5 100755 --- a/linopy/__init__.py +++ b/linopy/__init__.py @@ -6,6 +6,9 @@ @author: fabulous """ +# Note: For intercepting multiplications between xarray dataarrays, Variables and Expressions +# we need to extend their __mul__ functions with a quick special case +import linopy.monkey_patch_xarray from linopy import model, remote from linopy.expressions import merge from linopy.io import read_netcdf diff --git a/linopy/common.py b/linopy/common.py index 8dc4d887..361c3098 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -6,7 +6,7 @@ This module contains commonly used functions. """ -from functools import wraps +from functools import partialmethod, update_wrapper, wraps import numpy as np from xarray import DataArray, apply_ufunc, merge @@ -98,3 +98,20 @@ def wrapper(self, arg): return func(self, arg) return wrapper + + +def forward_as_properties(**routes): + def add_accessor(cls, item, attr): + @property + def get(self): + return getattr(getattr(self, item), attr) + + setattr(cls, attr, get) + + def deco(cls): + for item, attrs in routes.items(): + for attr in attrs: + add_accessor(cls, item, attr) + return cls + + return deco diff --git a/linopy/constraints.py b/linopy/constraints.py index e9c0d609..19013a1b 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd import xarray as xr +from deprecation import deprecated from numpy import arange, array from scipy.sparse import coo_matrix from xarray import DataArray, Dataset @@ -21,51 +22,29 @@ from linopy import expressions, variables from linopy.common import ( _merge_inplace, - has_assigned_model, has_optimized_model, is_constant, replace_by_map, ) -class Constraint(DataArray): +@dataclass(repr=False) +class Constraint: """ - Constraint container for storing constraint labels. + Projection to a single constraint in a model. The Constraint class is a subclass of xr.DataArray hence most xarray functions can be applied to it. """ - __slots__ = ("_cache", "_coords", "_indexes", "_name", "_variable", "model") - - def __init__(self, *args, **kwargs): - - # workaround until https://github.com/pydata/xarray/pull/5984 is merged - if isinstance(args[0], DataArray): - da = args[0] - args = (da.data, da.coords) - kwargs.update({"attrs": da.attrs, "name": da.name}) - - self.model = kwargs.pop("model", None) - super().__init__(*args, **kwargs) - assert self.name is not None, "Constraint data does not have a name." - - # We have to set the _reduce_method to None, in order to overwrite basic - # reduction functions as `sum`. There might be a better solution (?). - _reduce_method = None - - # Disable array function, only function defined below are supported - # and set priority higher than pandas/xarray/numpy - __array_ufunc__ = None - __array_priority__ = 10000 + name: str + model: Any def __repr__(self): """ Get the string representation of the constraints. """ - data_string = ( - "Constraint labels:\n" + self.to_array().__repr__().split("\n", 1)[1] - ) + data_string = "Constraint labels:\n" + self.labels.__repr__().split("\n", 1)[1] extend_line = "-" * len(self.name) return ( f"Constraint '{self.name}':\n" @@ -78,18 +57,26 @@ def _repr_html_(self): Get the html representation of the variables. """ # return self.__repr__() - data_string = self.to_array()._repr_html_() + data_string = self.labels._repr_html_() data_string = data_string.replace("xarray.DataArray", "linopy.Constraint") return data_string + @deprecated(details="Use the `labels` property instead of `to_array`") def to_array(self): """ Convert the variable array to a xarray.DataArray. """ - return DataArray(self) + return self.labels + + @property + def labels(self): + return self.model.constraints.labels[self.name] + + @labels.setter + def labels(self, value): + raise RuntimeError("Labels are read-only") @property - @has_assigned_model def coeffs(self): """ Get the left-hand-side coefficients of the constraint. @@ -100,7 +87,6 @@ def coeffs(self): return self.model.constraints.coeffs[self.name] @coeffs.setter - @has_assigned_model def coeffs(self, value): term_dim = self.name + "_term" value = DataArray(value).broadcast_like(self.vars, exclude=[term_dim]) @@ -115,7 +101,6 @@ def coeffs(self, value): self.model.constraints.coeffs = coeffs @property - @has_assigned_model def vars(self): """ Get the left-hand-side variables of the constraint. @@ -126,7 +111,6 @@ def vars(self): return self.model.constraints.vars[self.name] @vars.setter - @has_assigned_model def vars(self, value): term_dim = self.name + "_term" value = DataArray(value).broadcast_like(self.coeffs, exclude=[term_dim]) @@ -141,7 +125,6 @@ def vars(self, value): self.model.constraints.vars = vars @property - @has_assigned_model def lhs(self): """ Get the left-hand-side linear expression of the constraint. @@ -155,7 +138,6 @@ def lhs(self): return expressions.LinearExpression(Dataset({"coeffs": coeffs, "vars": vars})) @lhs.setter - @has_assigned_model def lhs(self, value): if not isinstance(value, expressions.LinearExpression): raise TypeError("Assigned lhs must be a LinearExpression.") @@ -165,7 +147,6 @@ def lhs(self, value): self.vars = value.vars @property - @has_assigned_model def sign(self): """ Get the signs of the constraint. @@ -176,16 +157,14 @@ def sign(self): return self.model.constraints.sign[self.name] @sign.setter - @has_assigned_model @is_constant def sign(self, value): - value = DataArray(value).broadcast_like(self) + value = DataArray(value).broadcast_like(self.sign) if (value == "==").any(): raise ValueError('Sign "==" not supported, use "=" instead.') self.model.constraints.sign[self.name] = value @property - @has_assigned_model def rhs(self): """ Get the right hand side constants of the constraint. @@ -196,12 +175,11 @@ def rhs(self): return self.model.constraints.rhs[self.name] @rhs.setter - @has_assigned_model @is_constant def rhs(self, value): if isinstance(value, (variables.Variable, expressions.LinearExpression)): raise TypeError(f"Assigned rhs must be a constant, got {type(value)}).") - value = DataArray(value).broadcast_like(self) + value = DataArray(value).broadcast_like(self.rhs) self.model.constraints.rhs[self.name] = value @property @@ -219,6 +197,10 @@ def dual(self): ) return self.model.dual[self.name] + @property + def shape(self): + return self.labels.shape + @dataclass(repr=False) class Constraints: @@ -268,7 +250,7 @@ def __getitem__( self, names: Union[str, Sequence[str]] ) -> Union[Constraint, "Constraints"]: if isinstance(names, str): - return Constraint(self.labels[names], model=self.model) + return Constraint(names, model=self.model) return self.__class__( self.labels[names], @@ -581,7 +563,8 @@ def __init__(self, lhs, sign, rhs): """ if isinstance(rhs, (variables.Variable, expressions.LinearExpression)): raise TypeError(f"Assigned rhs must be a constant, got {type(rhs)}).") - self._lhs, self._rhs = xr.align(lhs, DataArray(rhs)) + lhs, self._rhs = xr.align(lhs.data, DataArray(rhs)) + self._lhs = expressions.LinearExpression(lhs) self._sign = DataArray(sign) @property @@ -696,7 +679,7 @@ class AnonymousScalarConstraint: (rhs) for exactly one constraint. """ - lhs: expressions.ScalarLinearExpression + lhs: "expressions.ScalarLinearExpression" sign: str rhs: float diff --git a/linopy/expressions.py b/linopy/expressions.py index bb40e12e..cfc76e9f 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -10,27 +10,30 @@ import logging from dataclasses import dataclass from itertools import product, zip_longest -from warnings import warn +from typing import Any, Mapping, Union import numpy as np import pandas as pd import xarray as xr +import xarray.core.groupby +import xarray.core.rolling +from deprecation import deprecated from numpy import array, nan from xarray import DataArray, Dataset from xarray.core.dataarray import DataArrayCoordinates -from xarray.core.groupby import _maybe_reorder, peek_at from linopy import constraints, variables -from linopy.common import as_dataarray +from linopy.common import as_dataarray, forward_as_properties def exprwrap(method, *default_args, **new_default_kwargs): @functools.wraps(method) - def _exprwrap(obj, *args, **kwargs): + def _exprwrap(expr, *args, **kwargs): for k, v in new_default_kwargs.items(): kwargs.setdefault(k, v) - obj = Dataset(obj) - return LinearExpression(method(obj, *default_args, *args, **kwargs)) + return expr.__class__( + method(_expr_unwrap(expr), *default_args, *args, **kwargs) + ) _exprwrap.__doc__ = f"Wrapper for the xarray {method} function for linopy.Variable" if new_default_kwargs: @@ -39,10 +42,64 @@ def _exprwrap(obj, *args, **kwargs): return _exprwrap +def _expr_unwrap(maybe_expr): + if isinstance(maybe_expr, LinearExpression): + return maybe_expr.data + + return maybe_expr + + logger = logging.getLogger(__name__) -class LinearExpression(Dataset): +@dataclass +@forward_as_properties(groupby=["dims", "groups"]) +class LinearExpressionGroupby: + """ + GroupBy object specialized to grouping LinearExpression objects. + """ + + groupby: xr.core.groupby.DatasetGroupBy + + def map(self, func, shortcut=False, args=(), **kwargs): + return LinearExpression( + self.groupby.map(func, shortcut=shortcut, args=args, **kwargs) + ) + + def sum(self, **kwargs): + def func(ds): + ds = LinearExpression._sum(ds, self.groupby._group_dim) + ds = ds.assign_coords(_term=np.arange(len(ds._term))) + return ds + + return self.map(func, **kwargs) + + def roll(self, **kwargs): + return self.map(Dataset.roll, **kwargs) + + +@dataclass +@forward_as_properties(rolling=["center", "dim", "obj", "rollings", "window"]) +class LinearExpressionRolling: + """ + GroupBy object specialized to grouping LinearExpression objects. + """ + + rolling: xr.core.rolling.DataArrayRolling + + def sum(self, **kwargs): + ds = ( + self.rolling.construct("_rolling_term", keep_attrs=True) + .rename(_term="_stacked_term") + .stack(_term=["_stacked_term", "_rolling_term"]) + .reset_index("_term", drop=True) + ) + return LinearExpression(ds) + + +@dataclass(repr=False) +@forward_as_properties(data=["attrs", "coords", "dims", "indexes"]) +class LinearExpression: """ A linear expression consisting of terms of coefficients and variables. @@ -80,10 +137,14 @@ class LinearExpression(Dataset): """ - __slots__ = ("_cache", "_coords", "_indexes", "_name", "_variable") + data: Dataset + __slots__ = ("data",) fill_value = {"vars": -1, "coeffs": np.nan} + __array_ufunc__ = None + __array_priority__ = 10000 + def __init__(self, data_vars=None, coords=None, attrs=None): ds = Dataset(data_vars, coords, attrs) @@ -97,22 +158,14 @@ def __init__(self, data_vars=None, coords=None, attrs=None): ds["vars"] = ds.vars.fillna(-1).astype(int) (ds,) = xr.broadcast(ds) ds = ds.transpose(..., "_term") - super().__init__(ds) - - # We have to set the _reduce_method to None, in order to overwrite basic - # reduction functions as `sum`. There might be a better solution (?). - _reduce_method = None - # Disable array function, only function defined below are supported - # and set priority higher than pandas/xarray/numpy - __array_ufunc__ = None - __array_priority__ = 10000 + self.data = ds def __repr__(self): """ Get the string representation of the expression. """ - ds_string = self.to_dataset().__repr__().split("\n", 1)[1] + ds_string = self.data.__repr__().split("\n", 1)[1] ds_string = ds_string.replace("Data variables:\n", "Data:\n") nterm = getattr(self, "nterm", 0) return ( @@ -125,7 +178,7 @@ def _repr_html_(self): Get the html representation of the expression. """ # return self.__repr__() - ds_string = self.to_dataset()._repr_html_() + ds_string = self.data._repr_html_() ds_string = ds_string.replace("Data variables:\n", "Data:\n") ds_string = ds_string.replace("xarray.Dataset", "linopy.LinearExpression") return ds_string @@ -160,7 +213,7 @@ def __neg__(self): """ Get the negative of the expression. """ - return LinearExpression(self.assign(coeffs=-self.coeffs)) + return self.assign(coeffs=-self.coeffs) def __mul__(self, other): """ @@ -174,7 +227,7 @@ def __mul__(self, other): ) coeffs = other * self.coeffs assert coeffs.shape == self.coeffs.shape - return LinearExpression(self.assign(coeffs=coeffs)) + return self.assign(coeffs=coeffs) def __rmul__(self, other): """ @@ -203,13 +256,50 @@ def __ge__(self, rhs): def __eq__(self, rhs): return constraints.AnonymousConstraint(self, "=", rhs) + @deprecated(details="Use the `data` property instead of `to_dataset`") def to_dataset(self): """ Convert the expression to a xarray.Dataset. """ - return Dataset(self) + return self.data + + @property + def vars(self): + return self.data.vars + + @vars.setter + def vars(self, value): + self.data["vars"] = value + + @property + def coeffs(self): + return self.data.coeffs + + @coeffs.setter + def coeffs(self, value): + self.data["coeffs"] = value + + @classmethod + def _sum(cls, expr: Union["LinearExpression", Dataset], dims=None) -> Dataset: + data = _expr_unwrap(expr) + + if dims is None: + vars = DataArray(data.vars.data.ravel(), dims="_term") + coeffs = DataArray(data.coeffs.data.ravel(), dims="_term") + ds = xr.Dataset({"vars": vars, "coeffs": coeffs}) + + else: + dims = [d for d in np.atleast_1d(dims) if d != "_term"] + ds = ( + data.reset_index(dims, drop=True) + .rename(_term="_stacked_term") + .stack(_term=["_stacked_term"] + dims) + .reset_index("_term", drop=True) + ) - def sum(self, dims=None, drop_zeros=False): + return ds + + def sum(self, dims=None, drop_zeros=False) -> "LinearExpression": """ Sum the expression over all or a subset of dimensions. @@ -226,26 +316,16 @@ def sum(self, dims=None, drop_zeros=False): linopy.LinearExpression Summed expression. """ - if dims is None: - vars = DataArray(self.vars.data.ravel(), dims="_term") - coeffs = DataArray(self.coeffs.data.ravel(), dims="_term") - ds = xr.Dataset({"vars": vars, "coeffs": coeffs}) - else: - dims = [d for d in np.atleast_1d(dims) if d != "_term"] - ds = ( - self.reset_index(dims, drop=True) - .rename(_term="_stacked_term") - .stack(_term=["_stacked_term"] + dims) - .reset_index("_term", drop=True) - ) + res = self.__class__(self._sum(self, dims=dims)) if drop_zeros: - ds = ds.densify_terms() + res = res.densify_terms() - return self.__class__(ds) + return res - def from_tuples(*tuples, chunk=None): + @classmethod + def from_tuples(cls, *tuples, chunk=None): """ Create a linear expression by using tuples of coefficients and variables. @@ -285,6 +365,8 @@ def from_tuples(*tuples, chunk=None): for (c, v) in tuples: if isinstance(v, variables.ScalarVariable): v = v.label + elif isinstance(v, variables.Variable): + v = v.labels v = as_dataarray(v) if isinstance(c, np.ndarray) or _pd_series_wo_index_name(c): @@ -300,9 +382,9 @@ def from_tuples(*tuples, chunk=None): ds_list.append(ds) if len(ds_list) > 1: - return merge(ds_list) + return merge(ds_list, cls=cls) else: - return LinearExpression(ds_list[0]) + return cls(ds_list[0]) def from_rule(model, rule, coords): """ @@ -384,7 +466,7 @@ def _from_scalarexpression_list(exprs, coords: DataArrayCoordinates): return LinearExpression(ds) - def where(self, cond, **kwargs): + def where(self, cond, other=xr.core.dtypes.NA, **kwargs): """ Filter variables based on a condition. @@ -404,9 +486,50 @@ def where(self, cond, **kwargs): linopy.LinearExpression """ # Cannot set `other` if drop=True - if not kwargs.get("drop", False) and "other" not in kwargs: - kwargs["other"] = self.fill_value - return self.__class__(DataArray.where(self, cond, **kwargs)) + if other is xr.core.dtypes.NA: + if not kwargs.get("drop", False): + other = self.fill_value + else: + other = _expr_unwrap(other) + cond = _expr_unwrap(cond) + return self.__class__(self.data.where(cond, other=other, **kwargs)) + + def groupby( + self, + group, + squeeze: "bool" = True, + restore_coord_dims: "bool" = None, + ) -> LinearExpressionGroupby: + """ + Returns a LinearExpressionGroupBy object for performing grouped + operations. + + Docstring and arguments are borrowed from `xarray.Dataset.groupby` + + Parameters + ---------- + group : str, DataArray or IndexVariable + Array whose unique values should be used to group this array. If a + string, must be the name of a variable contained in this dataset. + squeeze : bool, optional + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped + A `LinearExpressionGroupBy` containing the xarray groups and ensuring + the correct return type. + """ + ds = self.data + groups = ds.groupby( + group=group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + ) + return LinearExpressionGroupby(groups) def groupby_sum(self, group): """ @@ -427,15 +550,51 @@ def groupby_sum(self, group): if isinstance(group, pd.Series): logger.info("Converting group pandas.Series to xarray.DataArray") group = group.to_xarray() - groups = xr.Dataset.groupby(self, group) + groups = xr.Dataset.groupby(self.data, group) def func(ds): - ds = LinearExpression.sum(ds, groups._group_dim) - ds = ds.to_dataset() + ds = self._sum(ds, groups._group_dim) ds = ds.assign_coords(_term=np.arange(len(ds._term))) return ds - return LinearExpression(groups.map(func)) # .reset_index('_term') + return self.__class__(groups.map(func)) # .reset_index('_term') + + def rolling( + self, + dim: "Mapping[Any, int]" = None, + min_periods: "int" = None, + center: "bool | Mapping[Any, bool]" = False, + **window_kwargs: "int", + ) -> LinearExpressionRolling: + """ + Rolling window object. + + Docstring and arguments are borrowed from `xarray.Dataset.rolling` + + Parameters + ---------- + dim : dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. + min_periods : int, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or mapping, default: False + Set the labels at the center of the window. + **window_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or window_kwargs must be provided. + + Returns + ------- + linopy.expression.LinearExpressionRolling + """ + ds = self.data + rolling = ds.rolling( + dim=dim, min_periods=min_periods, center=center, **window_kwargs + ) + return LinearExpressionRolling(rolling) def rolling_sum(self, **kwargs): """ @@ -474,7 +633,7 @@ def nterm(self): """ Get the number of terms in the linear expression. """ - return len(self._term) + return len(self.data._term) @property def shape(self): @@ -503,12 +662,12 @@ def densify_terms(self): Move all non-zero term entries to the front and cut off all-zero entries in the term-axis. """ - self = self.transpose(..., "_term") + data = self.data.transpose(..., "_term") - data = self.coeffs.data - axis = data.ndim - 1 - nnz = np.nonzero(data) - nterm = (data != 0).sum(axis).max() + cdata = data.coeffs.data + axis = cdata.ndim - 1 + nnz = np.nonzero(cdata) + nterm = (cdata != 0).sum(axis).max() mod_nnz = list(nnz) mod_nnz.pop(axis) @@ -519,15 +678,15 @@ def densify_terms(self): new_index = np.array([idx[:i].count(j) for i, j in enumerate(idx)]) mod_nnz.insert(axis, new_index) - vdata = np.full_like(data, -1) - vdata[tuple(mod_nnz)] = self.vars.data[nnz] - self.vars.data = vdata + vdata = np.full_like(cdata, -1) + vdata[tuple(mod_nnz)] = data.vars.data[nnz] + data.vars.data = vdata - cdata = np.zeros_like(data) - cdata[tuple(mod_nnz)] = self.coeffs.data[nnz] - self.coeffs.data = cdata + cdata = np.zeros_like(cdata) + cdata[tuple(mod_nnz)] = data.coeffs.data[nnz] + data.coeffs.data = cdata - return self.sel(_term=slice(0, nterm)) + return self.__class__(data.sel(_term=slice(0, nterm))) def sanitize(self): """ @@ -542,44 +701,50 @@ def sanitize(self): return self.assign(vars=self.vars.fillna(-1).astype(int)) return self + def equals(self, other: "LinearExpression"): + return self.data.equals(_expr_unwrap(other)) + + # TODO: make this return a LinearExpression (needs refactoring of __init__) + def rename(self, name_dict=None, **names) -> Dataset: + return self.data.rename(name_dict, **names) + + def __iter__(self): + return self.data.__iter__() + # Wrapped function which would convert variable to dataarray + assign = exprwrap(Dataset.assign) + + assign_attrs = exprwrap(Dataset.assign_attrs) + + assign_coords = exprwrap(Dataset.assign_coords) + astype = exprwrap(Dataset.astype) bfill = exprwrap(Dataset.bfill) broadcast_like = exprwrap(Dataset.broadcast_like) - coarsen = exprwrap(Dataset.coarsen) + chunk = exprwrap(Dataset.chunk) + + drop = exprwrap(Dataset.drop) - clip = exprwrap(Dataset.clip) + drop_sel = exprwrap(Dataset.drop_sel) + + drop_isel = exprwrap(Dataset.drop_isel) ffill = exprwrap(Dataset.ffill) fillna = exprwrap(Dataset.fillna, value=fill_value) + sel = exprwrap(Dataset.sel) + shift = exprwrap(Dataset.shift) reindex = exprwrap(Dataset.reindex, fill_value=fill_value) - roll = exprwrap(Dataset.roll) + rename_dims = exprwrap(Dataset.rename_dims) - rolling = exprwrap(Dataset.rolling) - - # TODO: explicitly disable `dangerous` functions - conj = property() - conjugate = property() - count = property() - cumsum = property() - cumprod = property() - cumulative_integrate = property() - curvefit = property() - diff = property() - differentiate = property() - groupby_bins = property() - integrate = property() - interp = property() - polyfit = property() - prod = property() + roll = exprwrap(Dataset.roll) def _pd_series_wo_index_name(ds): @@ -603,7 +768,7 @@ def _pd_dataframe_wo_axes_names(df): return False -def merge(*exprs, dim="_term"): +def merge(*exprs, dim="_term", cls=LinearExpression): """ Merge multiple linear expression together. @@ -626,19 +791,18 @@ def merge(*exprs, dim="_term"): else: exprs = list(exprs) + exprs = [e.data if isinstance(e, cls) else e for e in exprs] + if not all(len(expr._term) == len(exprs[0]._term) for expr in exprs[1:]): exprs = [expr.assign_coords(_term=np.arange(len(expr._term))) for expr in exprs] - exprs = [e.to_dataset() if isinstance(e, LinearExpression) else e for e in exprs] - fill_value = LinearExpression.fill_value + fill_value = cls.fill_value kwargs = dict(fill_value=fill_value, coords="minimal", compat="override") ds = xr.concat(exprs, dim, **kwargs) - res = LinearExpression(ds) - - if "_term" in res.coords: - res = res.reset_index("_term", drop=True) + if "_term" in ds.coords: + ds = ds.reset_index("_term", drop=True) - return res + return cls(ds) @dataclass diff --git a/linopy/model.py b/linopy/model.py index ce1a6e4f..4f82954b 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -306,7 +306,7 @@ def __getitem__(self, key): """ Get a model variable by the name. """ - return Variable(self.variables[key], model=self) + return self.variables[key] def check_force_dim_names(self, ds): """ @@ -534,7 +534,6 @@ def add_constraints( "Argument `sign` and `rhs` must not be None if first argument " " is an expression." ) - if isinstance(lhs, (list, tuple)): lhs = self.linexpr(*lhs) elif isinstance(lhs, (Variable, ScalarVariable, ScalarLinearExpression)): @@ -566,7 +565,7 @@ def add_constraints( ), "Dimensions of mask not a subset of resulting labels dimensions." labels = labels.where(mask, -1) - lhs = lhs.rename({"_term": f"{name}_term"}) + lhs = lhs.data.rename({"_term": f"{name}_term"}) if self.chunk: lhs = lhs.chunk(self.chunk) diff --git a/linopy/monkey_patch_xarray.py b/linopy/monkey_patch_xarray.py new file mode 100644 index 00000000..b717f876 --- /dev/null +++ b/linopy/monkey_patch_xarray.py @@ -0,0 +1,25 @@ +from functools import partialmethod, update_wrapper + +from xarray import DataArray + +from linopy import expressions, variables + + +def monkey_patch(cls, pass_unpatched_method=False): + def deco(func): + func_name = func.__name__ + wrapped = getattr(cls, func_name) + update_wrapper(func, wrapped) + if pass_unpatched_method: + func = partialmethod(func, unpatched_method=wrapped) + setattr(cls, func_name, func) + return func + + return deco + + +@monkey_patch(DataArray, pass_unpatched_method=True) +def __mul__(da, other, unpatched_method): + if isinstance(other, (variables.Variable, expressions.LinearExpression)): + return NotImplemented + return unpatched_method(da, other) diff --git a/linopy/testing.py b/linopy/testing.py new file mode 100644 index 00000000..b8f2de6d --- /dev/null +++ b/linopy/testing.py @@ -0,0 +1,7 @@ +from xarray.testing import assert_equal + +from linopy.expressions import _expr_unwrap + + +def assert_linequal(a, b): + return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) diff --git a/linopy/variables.py b/linopy/variables.py index 895f19c0..2fb34df7 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -8,9 +8,7 @@ import functools import re from dataclasses import dataclass, field -from distutils.log import warn -from typing import Any, Sequence, Union -from warnings import warn +from typing import Any, Mapping, Sequence, Union import dask import numpy as np @@ -18,10 +16,12 @@ from deprecation import deprecated from numpy import floating, inf, issubdtype from xarray import DataArray, Dataset, zeros_like +from xarray.core import indexing, utils import linopy.expressions as expressions from linopy.common import ( _merge_inplace, + forward_as_properties, has_assigned_model, has_optimized_model, is_constant, @@ -30,11 +30,10 @@ def varwrap(method, *default_args, **new_default_kwargs): @functools.wraps(method) - def _varwrap(obj, *args, **kwargs): + def _varwrap(var, *args, **kwargs): for k, v in new_default_kwargs.items(): kwargs.setdefault(k, v) - obj = DataArray(obj) - return Variable(method(obj, *default_args, *args, **kwargs)) + return var.__class__(method(var.labels, *default_args, *args, **kwargs)) _varwrap.__doc__ = f"Wrapper for the xarray {method} function for linopy.Variable" if new_default_kwargs: @@ -43,7 +42,27 @@ def _varwrap(obj, *args, **kwargs): return _varwrap -class Variable(DataArray): +def _var_unwrap(var): + if isinstance(var, Variable): + return var.labels + return var + + +@dataclass(repr=False) +@forward_as_properties( + labels=[ + "attrs", + "coords", + "indexes", + "name", + "shape", + "size", + "values", + "dims", + "ndim", + ] +) +class Variable: """ Variable container for storing variable labels. @@ -92,28 +111,10 @@ class Variable(DataArray): Further operations like taking the negative and subtracting are supported. """ - __slots__ = ("_cache", "_coords", "_indexes", "_name", "_variable", "model") - - def __init__(self, *args, **kwargs): - - # workaround until https://github.com/pydata/xarray/pull/5984 is merged - if isinstance(args[0], DataArray): - da = args[0] - args = (da.data, da.coords) - kwargs.update({"attrs": da.attrs, "name": da.name}) + labels: DataArray = field(default_factory=DataArray) + model: Any = None - self.model = kwargs.pop("model", None) - super().__init__(*args, **kwargs) - assert self.name is not None, "Variable data does not have a name." - - # We have to set the _reduce_method to None, in order to overwrite basic - # reduction functions as `sum`. There might be a better solution (?). - _reduce_method = None - - # Disable array function, only function defined below are supported - # and set priority higher than pandas/xarray/numpy __array_ufunc__ = None - __array_priority__ = 10000 def __getitem__(self, keys) -> "ScalarVariable": keys = (keys,) if not isinstance(keys, tuple) else keys @@ -122,18 +123,28 @@ def __getitem__(self, keys) -> "ScalarVariable": "Set single values for each dimension in order to obtain a " "ScalarVariable. For all other purposes, use `.sel` and `.isel`." ) - if not self.ndim: + if not self.labels.ndim: return ScalarVariable(self.data.item()) - assert self.ndim == len(keys), f"expected {self.ndim} keys, got {len(keys)}." - key = dict(zip(self.dims, keys)) - selector = [self.get_index(k).get_loc(v) for k, v in key.items()] - return ScalarVariable(self.data[tuple(selector)]) + assert self.labels.ndim == len( + keys + ), f"expected {self.labels.ndim} keys, got {len(keys)}." + key = dict(zip(self.labels.dims, keys)) + selector = [self.labels.get_index(k).get_loc(v) for k, v in key.items()] + return ScalarVariable(self.labels.data[tuple(selector)]) + @property + def loc(self): + return _LocIndexer(self) + + @deprecated(details="Use `labels` instead of `to_array()`") def to_array(self): """ Convert the variable array to a xarray.DataArray. """ - return DataArray(self) + return self.labels + + def to_pandas(self): + return self.labels.to_pandas() def to_linexpr(self, coefficient=1): """ @@ -147,9 +158,7 @@ def __repr__(self): """ Get the string representation of the variables. """ - data_string = ( - "Variable labels:\n" + self.to_array().__repr__().split("\n", 1)[1] - ) + data_string = "Variable labels:\n" + self.labels.__repr__().split("\n", 1)[1] extend_line = "-" * len(self.name) return ( f"Variable '{self.name}':\n" @@ -162,7 +171,7 @@ def _repr_html_(self): Get the html representation of the variables. """ # return self.__repr__() - data_string = self.to_array()._repr_html_() + data_string = self.labels._repr_html_() data_string = data_string.replace("xarray.DataArray", "linopy.Variable") return data_string @@ -245,6 +254,41 @@ def __ge__(self, other): def __eq__(self, other): return self.to_linexpr().__eq__(other) + def groupby( + self, + group, + squeeze: "bool" = True, + restore_coord_dims: "bool" = None, + ): + """ + Returns a LinearExpressionGroupBy object for performing grouped + operations. + + Docstring and arguments are borrowed from `xarray.Dataset.groupby` + + Parameters + ---------- + group : str, DataArray or IndexVariable + Array whose unique values should be used to group this array. If a + string, must be the name of a variable contained in this dataset. + squeeze : bool, optional + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped + A `LinearExpressionGroupBy` containing the xarray groups and ensuring + the correct return type. + """ + return self.to_linexpr().groupby( + group=group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + ) + def groupby_sum(self, group): """ Sum variable over groups. @@ -263,11 +307,40 @@ def groupby_sum(self, group): """ return self.to_linexpr().groupby_sum(group) - def group_terms(self, group): - warn( - 'The function "group_terms" was renamed to "groupby_sum" and will be remove in v0.0.10.' + def rolling( + self, + dim: "Mapping[Any, int]" = None, + min_periods: "int" = None, + center: "bool | Mapping[Any, bool]" = False, + **window_kwargs: "int", + ) -> "expressions.LinearExpressionRolling": + """ + Rolling window object. + + Docstring and arguments are borrowed from `xarray.Dataset.rolling` + + Parameters + ---------- + dim : dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. + min_periods : int, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or mapping, default: False + Set the labels at the center of the window. + **window_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or window_kwargs must be provided. + + Returns + ------- + linopy.expression.LinearExpressionRolling + """ + return self.to_linexpr().rolling( + dim=dim, min_periods=min_periods, center=center, **window_kwargs ) - return self.groupby_sum(group) def rolling_sum(self, **kwargs): """ @@ -305,7 +378,7 @@ def upper(self, value): The function raises an error in case no model is set as a reference. """ - value = DataArray(value).broadcast_like(self) + value = DataArray(value).broadcast_like(self.upper) self.model.variables.upper[self.name] = value @property @@ -329,7 +402,7 @@ def lower(self, value): The function raises an error in case no model is set as a reference. """ - value = DataArray(value).broadcast_like(self) + value = DataArray(value).broadcast_like(self.lower) self.model.variables.lower[self.name] = value @property @@ -387,7 +460,7 @@ def where(self, cond, other=-1, **kwargs): ------- linopy.Variable """ - return self.__class__(DataArray.where(self, cond, other, **kwargs)) + return self.__class__(self.labels.where(cond, other, **kwargs)) def sanitize(self): """ @@ -397,28 +470,59 @@ def sanitize(self): ------- linopy.Variable """ - if issubdtype(self.dtype, floating): - return self.fillna(-1).astype(int) + if issubdtype(self.labels.dtype, floating): + return self.__class__(self.labels.fillna(-1).astype(int)) return self + def equals(self, other): + return self.labels.equals(_var_unwrap(other)) + # Wrapped function which would convert variable to dataarray + assign_attrs = varwrap(DataArray.assign_attrs) + + assign_coords = varwrap(DataArray.assign_coords) + astype = varwrap(DataArray.astype) bfill = varwrap(DataArray.bfill) broadcast_like = varwrap(DataArray.broadcast_like) - clip = varwrap(DataArray.clip) + compute = varwrap(DataArray.compute) + + drop = varwrap(DataArray.drop) + + drop_sel = varwrap(DataArray.drop_sel) + + drop_isel = varwrap(DataArray.drop_isel) ffill = varwrap(DataArray.ffill) fillna = varwrap(DataArray.fillna) + sel = varwrap(DataArray.sel) + + isel = varwrap(DataArray.isel) + shift = varwrap(DataArray.shift, fill_value=-1) + rename = varwrap(DataArray.rename) + roll = varwrap(DataArray.roll) - rolling = varwrap(DataArray.rolling) + +class _LocIndexer: + __slots__ = ("variable",) + + def __init__(self, variable: Variable): + self.variable = variable + + def __getitem__(self, key) -> DataArray: + if not utils.is_dict_like(key): + # expand the indexer so we can handle Ellipsis + labels = indexing.expanded_indexer(key, self.variable.ndim) + key = dict(zip(self.variable.dims, labels)) + return self.variable.sel(key) @dataclass(repr=False) diff --git a/test/test_constraint.py b/test/test_constraint.py index 9e497c14..61e6ec95 100644 --- a/test/test_constraint.py +++ b/test/test_constraint.py @@ -105,6 +105,25 @@ def test_constraint_accessor(): assert c.vars.notnull().all().item() assert c.coeffs.notnull().all().item() + # Test that assigning labels raises RuntimeError + with pytest.raises(RuntimeError): + c.labels = c.labels + + # Test that assigning lhs with other type that LinearExpression raises TypeError + with pytest.raises(TypeError): + c.lhs = x + + # Test that assigning lhs with other type that LinearExpression raises TypeError + with pytest.raises(ValueError): + c.sign = "==" + + # Test that assigning a variable or linear expression to the rhs property raises a TypeError + with pytest.raises(TypeError): + c.rhs = x + + with pytest.raises(TypeError): + c.rhs = x + y + def test_constraint_accessor_M(): m = Model() @@ -125,7 +144,7 @@ def test_constraint_accessor_M(): assert (c.rhs == 2).all().item() c.lhs = 3 * y - assert (c.vars.squeeze() == y.data).all() + assert (c.vars.squeeze() == y.labels.data).all() assert (c.coeffs == 3).all() assert isinstance(c.lhs, linopy.LinearExpression) @@ -144,20 +163,6 @@ def test_constraints_accessor(): assert isinstance(m.constraints.equalities, linopy.constraints.Constraints) -def test_constraint_getter_without_model(): - data = xr.DataArray(range(10)).rename("con") - c = linopy.constraints.Constraint(data) - - with pytest.raises(AttributeError): - c.coeffs - with pytest.raises(AttributeError): - c.vars - with pytest.raises(AttributeError): - c.sign - with pytest.raises(AttributeError): - c.rhs - - def test_constraint_sanitize_zeros(): m = Model() x = m.add_variables(coords=[range(10)]) diff --git a/test/test_io.py b/test/test_io.py index e2b778ca..006e3712 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -41,7 +41,7 @@ def test_str_arrays_with_nans(): x = m.add_variables(4, pd.Series([8, 10]), name="x") # now expand the second dimension, expended values of x will be nan y = m.add_variables(0, pd.DataFrame([[1, 2], [3, 4], [5, 6]]), name="y") - assert m["x"].data[-1].item() == -1 + assert m["x"].values[-1] == -1 da = int_to_str(m["x"].values) assert da.dtype == object diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index eec4cfaa..1fc563dc 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -14,6 +14,7 @@ from linopy import LinearExpression, Model, merge from linopy.expressions import ScalarLinearExpression +from linopy.testing import assert_linequal m = Model() @@ -43,7 +44,7 @@ def test_values(): def test_duplicated_index(): expr = m.linexpr((10, "x"), (-1, "x")) - assert (expr._term == [0, 1]).all() + assert (expr.data._term == [0, 1]).all() def test_variable_to_linexpr(): @@ -63,23 +64,23 @@ def test_variable_to_linexpr(): expr = 10 * x + y assert isinstance(expr, LinearExpression) - assert_equal(expr, m.linexpr((10, "x"), (1, "y"))) + assert_linequal(expr, m.linexpr((10, "x"), (1, "y"))) expr = x + 8 * y assert isinstance(expr, LinearExpression) - assert_equal(expr, m.linexpr((1, "x"), (8, "y"))) + assert_linequal(expr, m.linexpr((1, "x"), (8, "y"))) expr = x + y assert isinstance(expr, LinearExpression) - assert_equal(expr, m.linexpr((1, "x"), (1, "y"))) + assert_linequal(expr, m.linexpr((1, "x"), (1, "y"))) expr = x - y assert isinstance(expr, LinearExpression) - assert_equal(expr, m.linexpr((1, "x"), (-1, "y"))) + assert_linequal(expr, m.linexpr((1, "x"), (-1, "y"))) expr = -x - 8 * y assert isinstance(expr, LinearExpression) - assert_equal(expr, m.linexpr((-1, "x"), (-8, "y"))) + assert_linequal(expr, m.linexpr((-1, "x"), (-8, "y"))) expr = np.array([1, 2]) * x assert isinstance(expr, LinearExpression) @@ -201,7 +202,7 @@ def test_add(): assert res.nterm == expr.nterm + other.nterm assert (res.coords["dim_0"] == expr.coords["dim_0"]).all() assert (res.coords["dim_1"] == other.coords["dim_1"]).all() - assert res.notnull().all().to_array().all() + assert res.data.notnull().all().to_array().all() assert isinstance(x - expr, LinearExpression) assert isinstance(x + expr, LinearExpression) @@ -220,7 +221,7 @@ def test_sub(): assert res.nterm == expr.nterm + other.nterm assert (res.coords["dim_0"] == expr.coords["dim_0"]).all() assert (res.coords["dim_1"] == other.coords["dim_1"]).all() - assert res.notnull().all().to_array().all() + assert res.data.notnull().all().to_array().all() def test_sum(): @@ -228,14 +229,14 @@ def test_sum(): res = expr.sum("dim_0") assert res.size == expr.size - assert res.nterm == expr.nterm * len(expr.dim_0) + assert res.nterm == expr.nterm * len(expr.data.dim_0) res = expr.sum() assert res.size == expr.size assert res.nterm == expr.size - assert res.notnull().all().to_array().all() + assert res.data.notnull().all().to_array().all() - assert_equal(expr.sum(["dim_0", "_term"]), expr.sum("dim_0")) + assert_linequal(expr.sum(["dim_0", "_term"]), expr.sum("dim_0")) def test_where(): @@ -255,10 +256,10 @@ def test_merge(): expr2 = z.sum("dim_0") res = merge(expr1, expr2) - assert res._term.size == 6 + assert res.nterm == 6 res = merge([expr1, expr2]) - assert res._term.size == 6 + assert res.nterm == 6 # now concat with same length of terms expr1 = z.sel(dim_0=0).sum("dim_1") @@ -277,7 +278,7 @@ def test_merge(): def test_sum_drop_zeros(): - coeff = xr.zeros_like(z) + coeff = xr.zeros_like(z.labels) coeff[1, 0] = 3 coeff[0, 2] = 5 expr = coeff * z @@ -289,7 +290,7 @@ def test_sum_drop_zeros(): assert res.nterm == 1 coeff[1, 2] = 4 - expr["coeffs"] = coeff + expr.data["coeffs"] = coeff res = expr.sum() res = expr.sum("dim_0", drop_zeros=True) @@ -332,31 +333,74 @@ def test_multiindexed_expression(): assert isinstance(expr, LinearExpression) +def test_groupby(): + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).sum() + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.data._term.size == 10 + + # now asymetric groups which result in different nterms + groups = xr.DataArray([1] * 12 + [2] * 8, coords=v.coords) + grouped = expr.groupby(groups).sum() + assert "group" in grouped.dims + # first group must be full with vars + assert (grouped.data.sel(group=1) > 0).all() + # the last 4 entries of the second group must be empty, i.e. -1 + assert (grouped.data.sel(group=2).isel(_term=slice(None, -4)).vars >= 0).all() + assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() + assert grouped.data._term.size == 12 + + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).roll(dim_2=1) + assert grouped.nterm == 1 + assert grouped.vars[0].item() == 19 + + +def test_groupby_variable(): + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = v.groupby(groups).sum() + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.data._term.size == 10 + + def test_groupby_sum(): groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) grouped = v.to_linexpr().groupby_sum(groups) assert "group" in grouped.dims - assert (grouped.group == [1, 2]).all() - assert grouped._term.size == 10 + assert (grouped.data.group == [1, 2]).all() + assert grouped.data._term.size == 10 # now asymetric groups which result in different nterms groups = xr.DataArray([1] * 12 + [2] * 8, coords=v.coords) grouped = v.to_linexpr().groupby_sum(groups) assert "group" in grouped.dims # first group must be full with vars - assert (grouped.sel(group=1) > 0).all() + assert (grouped.data.sel(group=1) > 0).all() # the last 4 entries of the second group must be empty, i.e. -1 - assert (grouped.sel(group=2).isel(_term=slice(None, -4)).vars >= 0).all() - assert (grouped.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() - assert grouped._term.size == 12 + assert (grouped.data.sel(group=2).isel(_term=slice(None, -4)).vars >= 0).all() + assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() + assert grouped.data._term.size == 12 def test_groupby_sum_variable(): groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) grouped = v.groupby_sum(groups) assert "group" in grouped.dims - assert (grouped.group == [1, 2]).all() - assert grouped._term.size == 10 + assert (grouped.data.group == [1, 2]).all() + assert grouped.data._term.size == 10 + + +def test_rolling(): + expr = 1 * v + rolled = expr.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + + rolled = expr.rolling(dim_2=3).sum() + assert rolled.nterm == 3 def test_rolling_sum(): @@ -369,6 +413,11 @@ def test_rolling_sum(): assert rolled.nterm == 6 +def test_rolling_variable(): + rolled = v.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + + def test_rolling_sum_variable(): rolled = v.rolling_sum(dim_2=2) assert rolled.nterm == 2 diff --git a/test/test_model_creation.py b/test/test_model_creation.py index e038a677..aa085adf 100755 --- a/test/test_model_creation.py +++ b/test/test_model_creation.py @@ -16,6 +16,7 @@ import xarray as xr from linopy import Model +from linopy.testing import assert_linequal # Test model functions @@ -220,7 +221,7 @@ def test_linexpr(): expr = m.linexpr((1, "x"), (10, "y")) target = 1 * x + 10 * y # assert (expr._term == ['x', 'y']).all() - assert (expr.to_dataset() == target.to_dataset()).all().to_array().all() + assert_linequal(expr, target) def test_constraint_assignment(): @@ -375,7 +376,7 @@ def test_remove_variable(): assert "con0" not in m.constraints.labels - assert not m.objective.vars.isin(x).any() + assert not m.objective.vars.isin(x.labels).any() def test_remove_constraint(): diff --git a/test/test_variable.py b/test/test_variable.py index aaac1939..4e8ae4fa 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -30,6 +30,8 @@ def test_variable_getter(): with pytest.raises(AssertionError): x[[1, 2, 3]] + assert isinstance(x.loc[[1, 2, 3]], linopy.Variable) + def test_variable_repr(): m = Model() @@ -104,7 +106,7 @@ def test_variable_where(): x = m.add_variables(coords=[range(10)]) x = x.where([True] * 4 + [False] * 6) assert isinstance(x, linopy.variables.Variable) - assert x.loc[9].item() == -1 + assert x.values[9] == -1 def test_variable_shift(): @@ -112,7 +114,7 @@ def test_variable_shift(): x = m.add_variables(coords=[range(10)]) x = x.shift(dim_0=3) assert isinstance(x, linopy.variables.Variable) - assert x.loc[0].item() == -1 + assert x.values[0] == -1 def test_variable_sanitize(): @@ -122,7 +124,7 @@ def test_variable_sanitize(): x = x.where([True] * 4 + [False] * 6, np.nan) x = x.sanitize() assert isinstance(x, linopy.variables.Variable) - assert x.loc[9].item() == -1 + assert x.values[9] == -1 def test_variable_type_preservation(): @@ -130,8 +132,7 @@ def test_variable_type_preservation(): x = m.add_variables(coords=[range(10)]) assert isinstance(x.bfill("dim_0"), linopy.variables.Variable) - assert isinstance(x.broadcast_like(x.to_array()), linopy.variables.Variable) - assert isinstance(x.clip(max=20), linopy.variables.Variable) + assert isinstance(x.broadcast_like(x.labels), linopy.variables.Variable) assert isinstance(x.ffill("dim_0"), linopy.variables.Variable) assert isinstance(x.fillna(-1), linopy.variables.Variable)