From e8827816bb8623dbfa0fa4bdb810eac851faf373 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Thu, 3 Nov 2022 16:08:11 +0100 Subject: [PATCH 01/24] Refactor LinearExpression and Constraint using composition Since there are unsafe methods in xr.Dataset that do not correctly pass on the LinearExpression or Constraint behaviour, we switch the two classes to a composition instead of inheritance scheme. --- linopy/constraints.py | 75 +++++------- linopy/expressions.py | 216 ++++++++++++++++++++------------- linopy/testing.py | 6 + test/test_constraint.py | 14 --- test/test_linear_expression.py | 45 +++---- 5 files changed, 191 insertions(+), 165 deletions(-) create mode 100644 linopy/testing.py diff --git a/linopy/constraints.py b/linopy/constraints.py index cd67c07b..391df855 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -6,6 +6,7 @@ """ import re +from deprecated import deprecated from dataclasses import dataclass from itertools import product from typing import Any, Sequence, Union @@ -21,51 +22,29 @@ from linopy import expressions, variables from linopy.common import ( _merge_inplace, - has_assigned_model, has_optimized_model, is_constant, replace_by_map, ) -class Constraint(DataArray): +@dataclass(repr=False) +class Constraint: """ - Constraint container for storing constraint labels. + Projection to a single constraint in a model The Constraint class is a subclass of xr.DataArray hence most xarray functions can be applied to it. """ - __slots__ = ("_cache", "_coords", "_indexes", "_name", "_variable", "model") - - def __init__(self, *args, **kwargs): - - # workaround until https://github.com/pydata/xarray/pull/5984 is merged - if isinstance(args[0], DataArray): - da = args[0] - args = (da.data, da.coords) - kwargs.update({"attrs": da.attrs, "name": da.name}) - - self.model = kwargs.pop("model", None) - super().__init__(*args, **kwargs) - assert self.name is not None, "Constraint data does not have a name." - - # We have to set the _reduce_method to None, in order to overwrite basic - # reduction functions as `sum`. There might be a better solution (?). - _reduce_method = None - - # Disable array function, only function defined below are supported - # and set priority higher than pandas/xarray/numpy - __array_ufunc__ = None - __array_priority__ = 10000 + name: str + model: "Model" def __repr__(self): """ Get the string representation of the constraints. """ - data_string = ( - "Constraint labels:\n" + self.to_array().__repr__().split("\n", 1)[1] - ) + data_string = "Constraint labels:\n" + self.labels.__repr__().split("\n", 1)[1] extend_line = "-" * len(self.name) return ( f"Constraint '{self.name}':\n" @@ -78,18 +57,28 @@ def _repr_html_(self): Get the html representation of the variables. """ # return self.__repr__() - data_string = self.to_array()._repr_html_() + data_string = self.labels._repr_html_() data_string = data_string.replace("xarray.DataArray", "linopy.Constraint") return data_string + @deprecated( + reason="Constraint.to_array has been replaced by using the .labels property" + ) def to_array(self): """ Convert the variable array to a xarray.DataArray. """ - return DataArray(self) + return self.labels + + @property + def labels(self): + return self.model.constraints.labels[self.name] + + @labels.setter + def labels(self, value): + raise RuntimeError("Labels are read-only") @property - @has_assigned_model def coeffs(self): """ Get the left-hand-side coefficients of the constraint. @@ -100,7 +89,6 @@ def coeffs(self): return self.model.constraints.coeffs[self.name] @coeffs.setter - @has_assigned_model def coeffs(self, value): term_dim = self.name + "_term" value = DataArray(value).broadcast_like(self.vars, exclude=[term_dim]) @@ -115,7 +103,6 @@ def coeffs(self, value): self.model.constraints.coeffs = coeffs @property - @has_assigned_model def vars(self): """ Get the left-hand-side variables of the constraint. @@ -126,7 +113,6 @@ def vars(self): return self.model.constraints.vars[self.name] @vars.setter - @has_assigned_model def vars(self, value): term_dim = self.name + "_term" value = DataArray(value).broadcast_like(self.coeffs, exclude=[term_dim]) @@ -141,7 +127,6 @@ def vars(self, value): self.model.constraints.vars = vars @property - @has_assigned_model def lhs(self): """ Get the left-hand-side linear expression of the constraint. @@ -155,7 +140,6 @@ def lhs(self): return expressions.LinearExpression(Dataset({"coeffs": coeffs, "vars": vars})) @lhs.setter - @has_assigned_model def lhs(self, value): if not isinstance(value, expressions.LinearExpression): raise TypeError("Assigned lhs must be a LinearExpression.") @@ -165,7 +149,6 @@ def lhs(self, value): self.vars = value.vars @property - @has_assigned_model def sign(self): """ Get the signs of the constraint. @@ -176,16 +159,14 @@ def sign(self): return self.model.constraints.sign[self.name] @sign.setter - @has_assigned_model @is_constant def sign(self, value): - value = DataArray(value).broadcast_like(self) + value = DataArray(value).broadcast_like(self.sign) if (value == "==").any(): raise ValueError('Sign "==" not supported, use "=" instead.') self.model.constraints.sign[self.name] = value @property - @has_assigned_model def rhs(self): """ Get the right hand side constants of the constraint. @@ -196,12 +177,11 @@ def rhs(self): return self.model.constraints.rhs[self.name] @rhs.setter - @has_assigned_model @is_constant def rhs(self, value): if isinstance(value, (variables.Variable, expressions.LinearExpression)): raise TypeError(f"Assigned rhs must be a constant, got {type(value)}).") - value = DataArray(value).broadcast_like(self) + value = DataArray(value).broadcast_like(self.rhs) self.model.constraints.rhs[self.name] = value @property @@ -219,6 +199,10 @@ def dual(self): ) return self.model.dual[self.name] + @property + def shape(self): + return self.labels.shape + @dataclass(repr=False) class Constraints: @@ -232,7 +216,7 @@ class Constraints: sign: Dataset = Dataset() rhs: Dataset = Dataset() blocks: Dataset = Dataset() - model: Any = None # Model is not defined due to circular imports + model: "Model" = None # Model is not defined due to circular imports dataset_attrs = ["labels", "coeffs", "vars", "sign", "rhs"] dataset_names = [ @@ -268,7 +252,7 @@ def __getitem__( self, names: Union[str, Sequence[str]] ) -> Union[Constraint, "Constraints"]: if isinstance(names, str): - return Constraint(self.labels[names], model=self.model) + return Constraint(names, model=self.model) return self.__class__( self.labels[names], @@ -581,7 +565,8 @@ def __init__(self, lhs, sign, rhs): """ if isinstance(rhs, (variables.Variable, expressions.LinearExpression)): raise TypeError(f"Assigned rhs must be a constant, got {type(rhs)}).") - self._lhs, self._rhs = xr.align(lhs, DataArray(rhs)) + lhs, self._rhs = xr.align(lhs.data, DataArray(rhs)) + self._lhs = expressions.LinearExpression(lhs) self._sign = DataArray(sign) @property diff --git a/linopy/expressions.py b/linopy/expressions.py index 9fd90660..ed28d015 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -6,11 +6,12 @@ This module contains definition related to affine expressions. """ +from typing import Union import functools import logging from dataclasses import dataclass from itertools import product, zip_longest -from warnings import warn +from deprecated import deprecated import numpy as np import pandas as pd @@ -18,7 +19,6 @@ from numpy import array, nan from xarray import DataArray, Dataset from xarray.core.dataarray import DataArrayCoordinates -from xarray.core.groupby import _maybe_reorder, peek_at from linopy import constraints, variables from linopy.common import as_dataarray @@ -26,11 +26,12 @@ def exprwrap(method, *default_args, **new_default_kwargs): @functools.wraps(method) - def _exprwrap(obj, *args, **kwargs): + def _exprwrap(expr, *args, **kwargs): for k, v in new_default_kwargs.items(): kwargs.setdefault(k, v) - obj = Dataset(obj) - return LinearExpression(method(obj, *default_args, *args, **kwargs)) + return expr.__class__( + method(_expr_unwrap(expr), *default_args, *args, **kwargs) + ) _exprwrap.__doc__ = f"Wrapper for the xarray {method} function for linopy.Variable" if new_default_kwargs: @@ -39,10 +40,17 @@ def _exprwrap(obj, *args, **kwargs): return _exprwrap +def _expr_unwrap(maybe_expr): + if isinstance(maybe_expr, LinearExpression): + return maybe_expr.data + + return maybe_expr + + logger = logging.getLogger(__name__) -class LinearExpression(Dataset): +class LinearExpression: """ A linear expression consisting of terms of coefficients and variables. @@ -80,7 +88,8 @@ class LinearExpression(Dataset): """ - __slots__ = ("_cache", "_coords", "_indexes", "_name", "_variable") + data: Dataset + __slots__ = ("data",) fill_value = {"vars": -1, "coeffs": np.nan} @@ -97,22 +106,14 @@ def __init__(self, data_vars=None, coords=None, attrs=None): ds["vars"] = ds.vars.fillna(-1).astype(int) (ds,) = xr.broadcast(ds) ds = ds.transpose(..., "_term") - super().__init__(ds) - # We have to set the _reduce_method to None, in order to overwrite basic - # reduction functions as `sum`. There might be a better solution (?). - _reduce_method = None - - # Disable array function, only function defined below are supported - # and set priority higher than pandas/xarray/numpy - __array_ufunc__ = None - __array_priority__ = 10000 + self.data = ds def __repr__(self): """ Get the string representation of the expression. """ - ds_string = self.to_dataset().__repr__().split("\n", 1)[1] + ds_string = self.data.__repr__().split("\n", 1)[1] ds_string = ds_string.replace("Data variables:\n", "Data:\n") nterm = getattr(self, "nterm", 0) return ( @@ -125,7 +126,7 @@ def _repr_html_(self): Get the html representation of the expression. """ # return self.__repr__() - ds_string = self.to_dataset()._repr_html_() + ds_string = self.data._repr_html_() ds_string = ds_string.replace("Data variables:\n", "Data:\n") ds_string = ds_string.replace("xarray.Dataset", "linopy.LinearExpression") return ds_string @@ -160,7 +161,7 @@ def __neg__(self): """ Get the negative of the expression. """ - return LinearExpression(self.assign(coeffs=-self.coeffs)) + return self.assign(coeffs=-self.coeffs) def __mul__(self, other): """ @@ -168,7 +169,7 @@ def __mul__(self, other): """ coeffs = other * self.coeffs assert coeffs.shape == self.coeffs.shape - return LinearExpression(self.assign(coeffs=coeffs)) + return self.assign(coeffs=coeffs) def __rmul__(self, other): """ @@ -185,13 +186,66 @@ def __ge__(self, rhs): def __eq__(self, rhs): return constraints.AnonymousConstraint(self, "=", rhs) + @deprecated(reason="Access the Dataset directly through `.data`") def to_dataset(self): """ Convert the expression to a xarray.Dataset. """ - return Dataset(self) + return self.data + + @property + def vars(self): + return self.data.vars + + @vars.setter + def vars(self, value): + self.data["vars"] = value + + @property + def coeffs(self): + return self.data.coeffs + + @coeffs.setter + def coeffs(self, value): + self.data["coeffs"] = value + + @property + def attrs(self): + return self.data.attrs + + @property + def coords(self): + return self.data.coords + + @property + def dims(self): + return self.data.dims + + @property + def indexes(self): + return self.data.indexes + + @classmethod + def _sum(cls, expr: Union["LinearExpression", Dataset], dims=None) -> Dataset: + data = _expr_unwrap(expr) + + if dims is None: + vars = DataArray(data.vars.data.ravel(), dims="_term") + coeffs = DataArray(data.coeffs.data.ravel(), dims="_term") + ds = xr.Dataset({"vars": vars, "coeffs": coeffs}) + + else: + dims = [d for d in np.atleast_1d(dims) if d != "_term"] + ds = ( + data.reset_index(dims, drop=True) + .rename(_term="_stacked_term") + .stack(_term=["_stacked_term"] + dims) + .reset_index("_term", drop=True) + ) + + return ds - def sum(self, dims=None, drop_zeros=False): + def sum(self, dims=None, drop_zeros=False) -> "LinearExpression": """ Sum the expression over all or a subset of dimensions. @@ -208,26 +262,16 @@ def sum(self, dims=None, drop_zeros=False): linopy.LinearExpression Summed expression. """ - if dims is None: - vars = DataArray(self.vars.data.ravel(), dims="_term") - coeffs = DataArray(self.coeffs.data.ravel(), dims="_term") - ds = xr.Dataset({"vars": vars, "coeffs": coeffs}) - else: - dims = [d for d in np.atleast_1d(dims) if d != "_term"] - ds = ( - self.reset_index(dims, drop=True) - .rename(_term="_stacked_term") - .stack(_term=["_stacked_term"] + dims) - .reset_index("_term", drop=True) - ) + res = self.__class__(self._sum(self, dims=dims)) if drop_zeros: - ds = ds.densify_terms() + res = res.densify_terms() - return self.__class__(ds) + return res - def from_tuples(*tuples, chunk=None): + @classmethod + def from_tuples(cls, *tuples, chunk=None): """ Create a linear expression by using tuples of coefficients and variables. @@ -282,9 +326,9 @@ def from_tuples(*tuples, chunk=None): ds_list.append(ds) if len(ds_list) > 1: - return merge(ds_list) + return merge(ds_list, cls=cls) else: - return LinearExpression(ds_list[0]) + return cls(ds_list[0]) def from_rule(model, rule, coords): """ @@ -364,7 +408,7 @@ def _from_scalarexpression_list(exprs, coords: DataArrayCoordinates): return LinearExpression(ds) - def where(self, cond, **kwargs): + def where(self, cond, other=xr.core.dtypes.NA, **kwargs): """ Filter variables based on a condition. @@ -384,9 +428,13 @@ def where(self, cond, **kwargs): linopy.LinearExpression """ # Cannot set `other` if drop=True - if not kwargs.get("drop", False) and "other" not in kwargs: - kwargs["other"] = self.fill_value - return self.__class__(DataArray.where(self, cond, **kwargs)) + if other is xr.core.dtypes.NA: + if not kwargs.get("drop", False): + other = self.fill_value + else: + other = _expr_unwrap(other) + cond = _expr_unwrap(cond) + return self.__class__(DataArray.where(self.data, cond, other=other, **kwargs)) def groupby_sum(self, group): """ @@ -407,15 +455,14 @@ def groupby_sum(self, group): if isinstance(group, pd.Series): logger.info("Converting group pandas.Series to xarray.DataArray") group = group.to_xarray() - groups = xr.Dataset.groupby(self, group) + groups = xr.Dataset.groupby(self.data, group) def func(ds): - ds = LinearExpression.sum(ds, groups._group_dim) - ds = ds.to_dataset() + ds = self._sum(ds, groups._group_dim) ds = ds.assign_coords(_term=np.arange(len(ds._term))) return ds - return LinearExpression(groups.map(func)) # .reset_index('_term') + return self.__class__(groups.map(func)) # .reset_index('_term') def rolling_sum(self, **kwargs): """ @@ -454,7 +501,7 @@ def nterm(self): """ Get the number of terms in the linear expression. """ - return len(self._term) + return len(self.data._term) @property def shape(self): @@ -483,12 +530,12 @@ def densify_terms(self): Move all non-zero term entries to the front and cut off all-zero entries in the term-axis. """ - self = self.transpose(..., "_term") + data = self.data.transpose(..., "_term") - data = self.coeffs.data - axis = data.ndim - 1 - nnz = np.nonzero(data) - nterm = (data != 0).sum(axis).max() + cdata = data.coeffs.data + axis = cdata.ndim - 1 + nnz = np.nonzero(cdata) + nterm = (cdata != 0).sum(axis).max() mod_nnz = list(nnz) mod_nnz.pop(axis) @@ -499,15 +546,15 @@ def densify_terms(self): new_index = np.array([idx[:i].count(j) for i, j in enumerate(idx)]) mod_nnz.insert(axis, new_index) - vdata = np.full_like(data, -1) - vdata[tuple(mod_nnz)] = self.vars.data[nnz] - self.vars.data = vdata + vdata = np.full_like(cdata, -1) + vdata[tuple(mod_nnz)] = data.vars.data[nnz] + data.vars.data = vdata - cdata = np.zeros_like(data) - cdata[tuple(mod_nnz)] = self.coeffs.data[nnz] - self.coeffs.data = cdata + cdata = np.zeros_like(cdata) + cdata[tuple(mod_nnz)] = data.coeffs.data[nnz] + data.coeffs.data = cdata - return self.sel(_term=slice(0, nterm)) + return self.__class__(data.sel(_term=slice(0, nterm))) def sanitize(self): """ @@ -522,13 +569,29 @@ def sanitize(self): return self.assign(vars=self.vars.fillna(-1).astype(int)) return self + def equals(self, other: "LinearExpression"): + return self.data.equals(_expr_unwrap(other)) + + def rename(self, name_dict, **renames): + # Does not return a linear expression + return self.data.rename(name_dict, **renames) + + def __iter__(self): + return self.data.__iter__() + # Wrapped function which would convert variable to dataarray + assign = exprwrap(Dataset.assign) + + assign_attrs = exprwrap(Dataset.assign_attrs) + astype = exprwrap(Dataset.astype) bfill = exprwrap(Dataset.bfill) broadcast_like = exprwrap(Dataset.broadcast_like) + chunk = exprwrap(Dataset.chunk) + coarsen = exprwrap(Dataset.coarsen) clip = exprwrap(Dataset.clip) @@ -537,6 +600,8 @@ def sanitize(self): fillna = exprwrap(Dataset.fillna, value=fill_value) + sel = exprwrap(Dataset.sel) + shift = exprwrap(Dataset.shift) reindex = exprwrap(Dataset.reindex, fill_value=fill_value) @@ -545,22 +610,6 @@ def sanitize(self): rolling = exprwrap(Dataset.rolling) - # TODO: explicitly disable `dangerous` functions - conj = property() - conjugate = property() - count = property() - cumsum = property() - cumprod = property() - cumulative_integrate = property() - curvefit = property() - diff = property() - differentiate = property() - groupby_bins = property() - integrate = property() - interp = property() - polyfit = property() - prod = property() - def _pd_series_wo_index_name(ds): if isinstance(ds, pd.Series): @@ -583,7 +632,7 @@ def _pd_dataframe_wo_axes_names(df): return False -def merge(*exprs, dim="_term"): +def merge(*exprs, dim="_term", cls=LinearExpression): """ Merge multiple linear expression together. @@ -606,19 +655,18 @@ def merge(*exprs, dim="_term"): else: exprs = list(exprs) + exprs = [e.data if isinstance(e, cls) else e for e in exprs] + if not all(len(expr._term) == len(exprs[0]._term) for expr in exprs[1:]): exprs = [expr.assign_coords(_term=np.arange(len(expr._term))) for expr in exprs] - exprs = [e.to_dataset() if isinstance(e, LinearExpression) else e for e in exprs] - fill_value = LinearExpression.fill_value + fill_value = cls.fill_value kwargs = dict(fill_value=fill_value, coords="minimal", compat="override") ds = xr.concat(exprs, dim, **kwargs) - res = LinearExpression(ds) - - if "_term" in res.coords: - res = res.reset_index("_term", drop=True) + if "_term" in ds.coords: + ds = ds.reset_index("_term", drop=True) - return res + return cls(ds) @dataclass diff --git a/linopy/testing.py b/linopy/testing.py new file mode 100644 index 00000000..1cdb6da5 --- /dev/null +++ b/linopy/testing.py @@ -0,0 +1,6 @@ +from xarray.testing import assert_equal +from .expressions import _expr_unwrap + + +def assert_linequal(a, b): + return assert_equal(_expr_unrwap(a), _expr_unwrap(b)) diff --git a/test/test_constraint.py b/test/test_constraint.py index 48d54cc3..949a779c 100644 --- a/test/test_constraint.py +++ b/test/test_constraint.py @@ -117,20 +117,6 @@ def test_constraints_accessor(): assert isinstance(m.constraints.equalities, linopy.constraints.Constraints) -def test_constraint_getter_without_model(): - data = xr.DataArray(range(10)).rename("con") - c = linopy.constraints.Constraint(data) - - with pytest.raises(AttributeError): - c.coeffs - with pytest.raises(AttributeError): - c.vars - with pytest.raises(AttributeError): - c.sign - with pytest.raises(AttributeError): - c.rhs - - def test_constraint_sanitize_zeros(): m = Model() x = m.add_variables(coords=[range(10)]) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 69f3e1eb..ff5bc009 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -14,6 +14,7 @@ from linopy import LinearExpression, Model, merge from linopy.expressions import ScalarLinearExpression +from linopy.testing import assert_linequal m = Model() @@ -43,7 +44,7 @@ def test_values(): def test_duplicated_index(): expr = m.linexpr((10, "x"), (-1, "x")) - assert (expr._term == [0, 1]).all() + assert (expr.data._term == [0, 1]).all() def test_variable_to_linexpr(): @@ -57,23 +58,23 @@ def test_variable_to_linexpr(): expr = 10 * x + y assert isinstance(expr, LinearExpression) - assert_equal(expr, m.linexpr((10, "x"), (1, "y"))) + assert_linequal(expr, m.linexpr((10, "x"), (1, "y"))) expr = x + 8 * y assert isinstance(expr, LinearExpression) - assert_equal(expr, m.linexpr((1, "x"), (8, "y"))) + assert_linequal(expr, m.linexpr((1, "x"), (8, "y"))) expr = x + y assert isinstance(expr, LinearExpression) - assert_equal(expr, m.linexpr((1, "x"), (1, "y"))) + assert_linequal(expr, m.linexpr((1, "x"), (1, "y"))) expr = x - y assert isinstance(expr, LinearExpression) - assert_equal(expr, m.linexpr((1, "x"), (-1, "y"))) + assert_linequal(expr, m.linexpr((1, "x"), (-1, "y"))) expr = -x - 8 * y assert isinstance(expr, LinearExpression) - assert_equal(expr, m.linexpr((-1, "x"), (-8, "y"))) + assert_linequal(expr, m.linexpr((-1, "x"), (-8, "y"))) expr = np.array([1, 2]) * x assert isinstance(expr, LinearExpression) @@ -155,7 +156,7 @@ def test_add(): assert res.nterm == expr.nterm + other.nterm assert (res.coords["dim_0"] == expr.coords["dim_0"]).all() assert (res.coords["dim_1"] == other.coords["dim_1"]).all() - assert res.notnull().all().to_array().all() + assert res.data.notnull().all().to_array().all() assert isinstance(x - expr, LinearExpression) assert isinstance(x + expr, LinearExpression) @@ -174,7 +175,7 @@ def test_sub(): assert res.nterm == expr.nterm + other.nterm assert (res.coords["dim_0"] == expr.coords["dim_0"]).all() assert (res.coords["dim_1"] == other.coords["dim_1"]).all() - assert res.notnull().all().to_array().all() + assert res.data.notnull().all().to_array().all() def test_sum(): @@ -182,14 +183,14 @@ def test_sum(): res = expr.sum("dim_0") assert res.size == expr.size - assert res.nterm == expr.nterm * len(expr.dim_0) + assert res.nterm == expr.nterm * len(expr.data.dim_0) res = expr.sum() assert res.size == expr.size assert res.nterm == expr.size - assert res.notnull().all().to_array().all() + assert res.data.notnull().all().to_array().all() - assert_equal(expr.sum(["dim_0", "_term"]), expr.sum("dim_0")) + assert_linequal(expr.sum(["dim_0", "_term"]), expr.sum("dim_0")) def test_where(): @@ -209,10 +210,10 @@ def test_merge(): expr2 = z.sum("dim_0") res = merge(expr1, expr2) - assert res._term.size == 6 + assert res.nterm == 6 res = merge([expr1, expr2]) - assert res._term.size == 6 + assert res.nterm == 6 # now concat with same length of terms expr1 = z.sel(dim_0=0).sum("dim_1") @@ -243,7 +244,7 @@ def test_sum_drop_zeros(): assert res.nterm == 1 coeff[1, 2] = 4 - expr["coeffs"] = coeff + expr.data["coeffs"] = coeff res = expr.sum() res = expr.sum("dim_0", drop_zeros=True) @@ -276,27 +277,27 @@ def test_groupby_sum(): groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) grouped = v.to_linexpr().groupby_sum(groups) assert "group" in grouped.dims - assert (grouped.group == [1, 2]).all() - assert grouped._term.size == 10 + assert (grouped.data.group == [1, 2]).all() + assert grouped.data._term.size == 10 # now asymetric groups which result in different nterms groups = xr.DataArray([1] * 12 + [2] * 8, coords=v.coords) grouped = v.to_linexpr().groupby_sum(groups) assert "group" in grouped.dims # first group must be full with vars - assert (grouped.sel(group=1) > 0).all() + assert (grouped.data.sel(group=1) > 0).all() # the last 4 entries of the second group must be empty, i.e. -1 - assert (grouped.sel(group=2).isel(_term=slice(None, -4)).vars >= 0).all() - assert (grouped.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() - assert grouped._term.size == 12 + assert (grouped.data.sel(group=2).isel(_term=slice(None, -4)).vars >= 0).all() + assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() + assert grouped.data._term.size == 12 def test_groupby_sum_variable(): groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) grouped = v.groupby_sum(groups) assert "group" in grouped.dims - assert (grouped.group == [1, 2]).all() - assert grouped._term.size == 10 + assert (grouped.data.group == [1, 2]).all() + assert grouped.data._term.size == 10 def test_rolling_sum(): From 5fda4901f1a1192cf468e63464e8b4d2b1656cfd Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Thu, 3 Nov 2022 16:15:59 +0100 Subject: [PATCH 02/24] Make linter happy --- linopy/constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/linopy/constraints.py b/linopy/constraints.py index 391df855..c9072aea 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -38,7 +38,7 @@ class Constraint: """ name: str - model: "Model" + model: Any def __repr__(self): """ @@ -216,7 +216,7 @@ class Constraints: sign: Dataset = Dataset() rhs: Dataset = Dataset() blocks: Dataset = Dataset() - model: "Model" = None # Model is not defined due to circular imports + model: Any = None # Model is not defined due to circular imports dataset_attrs = ["labels", "coeffs", "vars", "sign", "rhs"] dataset_names = [ From ef9d2eea0a604061edd42b1ede08f79f8335e198 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Thu, 3 Nov 2022 16:22:14 +0100 Subject: [PATCH 03/24] Small fixes introduced during clean-up --- linopy/expressions.py | 5 ++--- linopy/testing.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index ed28d015..37b4731e 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -572,9 +572,8 @@ def sanitize(self): def equals(self, other: "LinearExpression"): return self.data.equals(_expr_unwrap(other)) - def rename(self, name_dict, **renames): - # Does not return a linear expression - return self.data.rename(name_dict, **renames) + def rename(self, name_dict = None, **names) -> Dataset: + return self.data.rename(name_dict, **names) def __iter__(self): return self.data.__iter__() diff --git a/linopy/testing.py b/linopy/testing.py index 1cdb6da5..7d198ff4 100644 --- a/linopy/testing.py +++ b/linopy/testing.py @@ -3,4 +3,4 @@ def assert_linequal(a, b): - return assert_equal(_expr_unrwap(a), _expr_unwrap(b)) + return assert_equal(_expr_unwrap(a), _expr_unwrap(b)) From f903d46b440f2005a0dd3dbf6cdf9805fadc8bab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Nov 2022 15:23:55 +0000 Subject: [PATCH 04/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/constraints.py | 4 ++-- linopy/expressions.py | 6 +++--- linopy/testing.py | 3 ++- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/linopy/constraints.py b/linopy/constraints.py index c9072aea..61c863be 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -6,7 +6,6 @@ """ import re -from deprecated import deprecated from dataclasses import dataclass from itertools import product from typing import Any, Sequence, Union @@ -15,6 +14,7 @@ import numpy as np import pandas as pd import xarray as xr +from deprecated import deprecated from numpy import arange, array from scipy.sparse import coo_matrix from xarray import DataArray, Dataset @@ -31,7 +31,7 @@ @dataclass(repr=False) class Constraint: """ - Projection to a single constraint in a model + Projection to a single constraint in a model. The Constraint class is a subclass of xr.DataArray hence most xarray functions can be applied to it. diff --git a/linopy/expressions.py b/linopy/expressions.py index 37b4731e..879eddf9 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -6,16 +6,16 @@ This module contains definition related to affine expressions. """ -from typing import Union import functools import logging from dataclasses import dataclass from itertools import product, zip_longest -from deprecated import deprecated +from typing import Union import numpy as np import pandas as pd import xarray as xr +from deprecated import deprecated from numpy import array, nan from xarray import DataArray, Dataset from xarray.core.dataarray import DataArrayCoordinates @@ -572,7 +572,7 @@ def sanitize(self): def equals(self, other: "LinearExpression"): return self.data.equals(_expr_unwrap(other)) - def rename(self, name_dict = None, **names) -> Dataset: + def rename(self, name_dict=None, **names) -> Dataset: return self.data.rename(name_dict, **names) def __iter__(self): diff --git a/linopy/testing.py b/linopy/testing.py index 7d198ff4..b8f2de6d 100644 --- a/linopy/testing.py +++ b/linopy/testing.py @@ -1,5 +1,6 @@ from xarray.testing import assert_equal -from .expressions import _expr_unwrap + +from linopy.expressions import _expr_unwrap def assert_linequal(a, b): From 21ed04e9062c658787da7093faf92ce1eb751ac8 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Thu, 3 Nov 2022 16:38:21 +0100 Subject: [PATCH 05/24] Fix imports --- linopy/constraints.py | 5 ++--- linopy/expressions.py | 4 ++-- linopy/variables.py | 1 - 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/linopy/constraints.py b/linopy/constraints.py index 61c863be..126b5ba3 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -6,6 +6,7 @@ """ import re +from deprecation import deprecated from dataclasses import dataclass from itertools import product from typing import Any, Sequence, Union @@ -61,9 +62,7 @@ def _repr_html_(self): data_string = data_string.replace("xarray.DataArray", "linopy.Constraint") return data_string - @deprecated( - reason="Constraint.to_array has been replaced by using the .labels property" - ) + @deprecated(details="Use the `labels` property instead of `to_array`") def to_array(self): """ Convert the variable array to a xarray.DataArray. diff --git a/linopy/expressions.py b/linopy/expressions.py index 879eddf9..06c8472e 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -11,11 +11,11 @@ from dataclasses import dataclass from itertools import product, zip_longest from typing import Union +from deprecation import deprecated import numpy as np import pandas as pd import xarray as xr -from deprecated import deprecated from numpy import array, nan from xarray import DataArray, Dataset from xarray.core.dataarray import DataArrayCoordinates @@ -186,7 +186,7 @@ def __ge__(self, rhs): def __eq__(self, rhs): return constraints.AnonymousConstraint(self, "=", rhs) - @deprecated(reason="Access the Dataset directly through `.data`") + @deprecated(details="Use the `data` property instead of `to_dataset`") def to_dataset(self): """ Convert the expression to a xarray.Dataset. diff --git a/linopy/variables.py b/linopy/variables.py index 254b4696..7213e0d7 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -15,7 +15,6 @@ import dask import numpy as np import pandas as pd -from deprecation import deprecated from numpy import floating, inf, issubdtype from xarray import DataArray, Dataset, zeros_like From d94214d9f25e6c0a463021b090cc59666d795a88 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Thu, 3 Nov 2022 16:42:50 +0100 Subject: [PATCH 06/24] Use deprecation instead of deprecated package --- linopy/constraints.py | 3 +-- linopy/expressions.py | 2 +- test/test_model_creation.py | 3 ++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/linopy/constraints.py b/linopy/constraints.py index 126b5ba3..d9e4380b 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -6,16 +6,15 @@ """ import re -from deprecation import deprecated from dataclasses import dataclass from itertools import product from typing import Any, Sequence, Union +from deprecation import deprecated import dask import numpy as np import pandas as pd import xarray as xr -from deprecated import deprecated from numpy import arange, array from scipy.sparse import coo_matrix from xarray import DataArray, Dataset diff --git a/linopy/expressions.py b/linopy/expressions.py index 06c8472e..38decb9c 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -11,8 +11,8 @@ from dataclasses import dataclass from itertools import product, zip_longest from typing import Union -from deprecation import deprecated +from deprecation import deprecated import numpy as np import pandas as pd import xarray as xr diff --git a/test/test_model_creation.py b/test/test_model_creation.py index e038a677..e26d81a1 100755 --- a/test/test_model_creation.py +++ b/test/test_model_creation.py @@ -16,6 +16,7 @@ import xarray as xr from linopy import Model +from linopy.testing import assert_linequal # Test model functions @@ -220,7 +221,7 @@ def test_linexpr(): expr = m.linexpr((1, "x"), (10, "y")) target = 1 * x + 10 * y # assert (expr._term == ['x', 'y']).all() - assert (expr.to_dataset() == target.to_dataset()).all().to_array().all() + assert_linequal(expr, target) def test_constraint_assignment(): From f4504bd4dc3ea0338c754748c9181cb28692d09c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Nov 2022 15:43:04 +0000 Subject: [PATCH 07/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/constraints.py | 2 +- linopy/expressions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/linopy/constraints.py b/linopy/constraints.py index d9e4380b..a2a67e13 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -10,11 +10,11 @@ from itertools import product from typing import Any, Sequence, Union -from deprecation import deprecated import dask import numpy as np import pandas as pd import xarray as xr +from deprecation import deprecated from numpy import arange, array from scipy.sparse import coo_matrix from xarray import DataArray, Dataset diff --git a/linopy/expressions.py b/linopy/expressions.py index 38decb9c..8f27750d 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -12,10 +12,10 @@ from itertools import product, zip_longest from typing import Union -from deprecation import deprecated import numpy as np import pandas as pd import xarray as xr +from deprecation import deprecated from numpy import array, nan from xarray import DataArray, Dataset from xarray.core.dataarray import DataArrayCoordinates From e10c096bea6b30f15633c69b70532ba06afa97b8 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Mon, 7 Nov 2022 01:39:08 +0100 Subject: [PATCH 08/24] variables: Convert into composition pattern (WIP) --- linopy/common.py | 17 ++++++ linopy/expressions.py | 25 +++------ linopy/variables.py | 95 +++++++++++++++++++++------------- setup.py | 2 +- test/test_constraint.py | 2 +- test/test_io.py | 2 +- test/test_linear_expression.py | 2 +- 7 files changed, 88 insertions(+), 57 deletions(-) diff --git a/linopy/common.py b/linopy/common.py index 8dc4d887..e74dcedc 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -98,3 +98,20 @@ def wrapper(self, arg): return func(self, arg) return wrapper + + +def forward_as_properties(**routes): + def add_accessor(cls, item, attr): + @property + def get(self): + return getattr(getattr(self, item), attr) + + setattr(cls, attr, get) + + def deco(cls): + for item, attrs in routes.items(): + for attr in attrs: + add_accessor(cls, item, attr) + return cls + + return deco diff --git a/linopy/expressions.py b/linopy/expressions.py index 8f27750d..cce4b57e 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -21,7 +21,7 @@ from xarray.core.dataarray import DataArrayCoordinates from linopy import constraints, variables -from linopy.common import as_dataarray +from linopy.common import as_dataarray, forward_as_properties def exprwrap(method, *default_args, **new_default_kwargs): @@ -50,6 +50,7 @@ def _expr_unwrap(maybe_expr): logger = logging.getLogger(__name__) +@forward_as_properties(data=["attrs", "coords", "dims", "indexes"]) class LinearExpression: """ A linear expression consisting of terms of coefficients and variables. @@ -93,6 +94,8 @@ class LinearExpression: fill_value = {"vars": -1, "coeffs": np.nan} + __array_ufunc__ = None + def __init__(self, data_vars=None, coords=None, attrs=None): ds = Dataset(data_vars, coords, attrs) @@ -209,22 +212,6 @@ def coeffs(self): def coeffs(self, value): self.data["coeffs"] = value - @property - def attrs(self): - return self.data.attrs - - @property - def coords(self): - return self.data.coords - - @property - def dims(self): - return self.data.dims - - @property - def indexes(self): - return self.data.indexes - @classmethod def _sum(cls, expr: Union["LinearExpression", Dataset], dims=None) -> Dataset: data = _expr_unwrap(expr) @@ -311,6 +298,8 @@ def from_tuples(cls, *tuples, chunk=None): for (c, v) in tuples: if isinstance(v, variables.ScalarVariable): v = v.label + elif isinstance(v, variables.Variable): + v = v.labels v = as_dataarray(v) if isinstance(c, np.ndarray) or _pd_series_wo_index_name(c): @@ -434,7 +423,7 @@ def where(self, cond, other=xr.core.dtypes.NA, **kwargs): else: other = _expr_unwrap(other) cond = _expr_unwrap(cond) - return self.__class__(DataArray.where(self.data, cond, other=other, **kwargs)) + return self.__class__(self.data.where(cond, other=other, **kwargs)) def groupby_sum(self, group): """ diff --git a/linopy/variables.py b/linopy/variables.py index 7213e0d7..43251988 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -15,6 +15,7 @@ import dask import numpy as np import pandas as pd +from deprecation import deprecated from numpy import floating, inf, issubdtype from xarray import DataArray, Dataset, zeros_like @@ -24,16 +25,15 @@ has_assigned_model, has_optimized_model, is_constant, + forward_as_properties, ) - def varwrap(method, *default_args, **new_default_kwargs): @functools.wraps(method) - def _varwrap(obj, *args, **kwargs): + def _varwrap(var, *args, **kwargs): for k, v in new_default_kwargs.items(): kwargs.setdefault(k, v) - obj = DataArray(obj) - return Variable(method(obj, *default_args, *args, **kwargs)) + return var.__class__(method(var.labels, *default_args, *args, **kwargs)) _varwrap.__doc__ = f"Wrapper for the xarray {method} function for linopy.Variable" if new_default_kwargs: @@ -42,7 +42,19 @@ def _varwrap(obj, *args, **kwargs): return _varwrap -class Variable(DataArray): +def _var_unwrap(var): + if isinstance(var, Variable): + return var.labels + return var + + +@dataclass(repr=False) +@forward_as_properties( + labels=[ + "name" + ] +) +class Variable: """ Variable container for storing variable labels. @@ -91,28 +103,14 @@ class Variable(DataArray): Further operations like taking the negative and subtracting are supported. """ - __slots__ = ("_cache", "_coords", "_indexes", "_name", "_variable", "model") - - def __init__(self, *args, **kwargs): - - # workaround until https://github.com/pydata/xarray/pull/5984 is merged - if isinstance(args[0], DataArray): - da = args[0] - args = (da.data, da.coords) - kwargs.update({"attrs": da.attrs, "name": da.name}) - - self.model = kwargs.pop("model", None) - super().__init__(*args, **kwargs) - assert self.name is not None, "Variable data does not have a name." + labels: DataArray + model: Any = None - # We have to set the _reduce_method to None, in order to overwrite basic - # reduction functions as `sum`. There might be a better solution (?). - _reduce_method = None # Disable array function, only function defined below are supported # and set priority higher than pandas/xarray/numpy __array_ufunc__ = None - __array_priority__ = 10000 + # __array_priority__ = 10000 def __getitem__(self, keys) -> "ScalarVariable": keys = (keys,) if not isinstance(keys, tuple) else keys @@ -121,18 +119,19 @@ def __getitem__(self, keys) -> "ScalarVariable": "Set single values for each dimension in order to obtain a " "ScalarVariable. For all other purposes, use `.sel` and `.isel`." ) - if not self.ndim: + if not self.labels.ndim: return ScalarVariable(self.data.item()) - assert self.ndim == len(keys), f"expected {self.ndim} keys, got {len(keys)}." - key = dict(zip(self.dims, keys)) - selector = [self.get_index(k).get_loc(v) for k, v in key.items()] - return ScalarVariable(self.data[tuple(selector)]) + assert self.labels.ndim == len(keys), f"expected {self.labels.ndim} keys, got {len(keys)}." + key = dict(zip(self.labels.dims, keys)) + selector = [self.labels.get_index(k).get_loc(v) for k, v in key.items()] + return ScalarVariable(self.labels.data[tuple(selector)]) + @deprecated(details="Use `labels` instead of `to_array()`") def to_array(self): """ Convert the variable array to a xarray.DataArray. """ - return DataArray(self) + return self.labels def to_linexpr(self, coefficient=1): """ @@ -145,7 +144,7 @@ def __repr__(self): Get the string representation of the variables. """ data_string = ( - "Variable labels:\n" + self.to_array().__repr__().split("\n", 1)[1] + "Variable labels:\n" + self.labels.__repr__().split("\n", 1)[1] ) extend_line = "-" * len(self.name) return ( @@ -159,7 +158,7 @@ def _repr_html_(self): Get the html representation of the variables. """ # return self.__repr__() - data_string = self.to_array()._repr_html_() + data_string = self.labels._repr_html_() data_string = data_string.replace("xarray.DataArray", "linopy.Variable") return data_string @@ -278,7 +277,7 @@ def upper(self, value): The function raises an error in case no model is set as a reference. """ - value = DataArray(value).broadcast_like(self) + value = DataArray(value).broadcast_like(self.upper) self.model.variables.upper[self.name] = value @property @@ -302,7 +301,7 @@ def lower(self, value): The function raises an error in case no model is set as a reference. """ - value = DataArray(value).broadcast_like(self) + value = DataArray(value).broadcast_like(self.lower) self.model.variables.lower[self.name] = value @property @@ -360,7 +359,7 @@ def where(self, cond, other=-1, **kwargs): ------- linopy.Variable """ - return self.__class__(DataArray.where(self, cond, other, **kwargs)) + return self.__class__(self.labels.where(cond, other, **kwargs)) def sanitize(self): """ @@ -371,10 +370,31 @@ def sanitize(self): linopy.Variable """ if issubdtype(self.dtype, floating): - return self.fillna(-1).astype(int) + return self.__class__(self.labels.fillna(-1).astype(int)) return self - + + def equals(self, other): + return self.labels.equals(_var_unwrap(other)) + + @property + def attrs(self): + return self.labels.attrs + + @property + def values(self): + return self.labels.values + + @property + def shape(self): + return self.labels.shape + + @property + def size(self): + return self.labels.size + # Wrapped function which would convert variable to dataarray + assign_attrs = varwrap(DataArray.assign_attrs) + astype = varwrap(DataArray.astype) bfill = varwrap(DataArray.bfill) @@ -383,10 +403,15 @@ def sanitize(self): clip = varwrap(DataArray.clip) + compute = varwrap(DataArray.compute) + ffill = varwrap(DataArray.ffill) fillna = varwrap(DataArray.fillna) + sel = varwrap(DataArray.sel) + isel = varwrap(DataArray.isel) + shift = varwrap(DataArray.shift, fill_value=-1) roll = varwrap(DataArray.roll) diff --git a/setup.py b/setup.py index d797f997..eb5874d7 100755 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ "bottleneck", "toolz", "numexpr", - "xarray<=2022.03.", + "xarray<=2022.03.,>=0.21.0", "dask>=0.18.0", "tqdm", "deprecation", diff --git a/test/test_constraint.py b/test/test_constraint.py index 949a779c..a85374d1 100644 --- a/test/test_constraint.py +++ b/test/test_constraint.py @@ -98,7 +98,7 @@ def test_constraint_accessor_M(): assert (c.rhs == 2).all().item() c.lhs = 3 * y - assert (c.vars.squeeze() == y.data).all() + assert (c.vars.squeeze() == y.labels.data).all() assert (c.coeffs == 3).all() assert isinstance(c.lhs, linopy.LinearExpression) diff --git a/test/test_io.py b/test/test_io.py index e2b778ca..006e3712 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -41,7 +41,7 @@ def test_str_arrays_with_nans(): x = m.add_variables(4, pd.Series([8, 10]), name="x") # now expand the second dimension, expended values of x will be nan y = m.add_variables(0, pd.DataFrame([[1, 2], [3, 4], [5, 6]]), name="y") - assert m["x"].data[-1].item() == -1 + assert m["x"].values[-1] == -1 da = int_to_str(m["x"].values) assert da.dtype == object diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index ff5bc009..e0e6d24a 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -232,7 +232,7 @@ def test_merge(): def test_sum_drop_zeros(): - coeff = xr.zeros_like(z) + coeff = xr.zeros_like(z.labels) coeff[1, 0] = 3 coeff[0, 2] = 5 expr = coeff * z From e87e353a1decd210e0bac92333b99784f4bb2df1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Nov 2022 00:39:22 +0000 Subject: [PATCH 09/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/variables.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/linopy/variables.py b/linopy/variables.py index 43251988..7f6ee5f8 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -22,12 +22,13 @@ import linopy.expressions as expressions from linopy.common import ( _merge_inplace, + forward_as_properties, has_assigned_model, has_optimized_model, is_constant, - forward_as_properties, ) + def varwrap(method, *default_args, **new_default_kwargs): @functools.wraps(method) def _varwrap(var, *args, **kwargs): @@ -49,11 +50,7 @@ def _var_unwrap(var): @dataclass(repr=False) -@forward_as_properties( - labels=[ - "name" - ] -) +@forward_as_properties(labels=["name"]) class Variable: """ Variable container for storing variable labels. @@ -106,7 +103,6 @@ class Variable: labels: DataArray model: Any = None - # Disable array function, only function defined below are supported # and set priority higher than pandas/xarray/numpy __array_ufunc__ = None @@ -121,7 +117,9 @@ def __getitem__(self, keys) -> "ScalarVariable": ) if not self.labels.ndim: return ScalarVariable(self.data.item()) - assert self.labels.ndim == len(keys), f"expected {self.labels.ndim} keys, got {len(keys)}." + assert self.labels.ndim == len( + keys + ), f"expected {self.labels.ndim} keys, got {len(keys)}." key = dict(zip(self.labels.dims, keys)) selector = [self.labels.get_index(k).get_loc(v) for k, v in key.items()] return ScalarVariable(self.labels.data[tuple(selector)]) @@ -143,9 +141,7 @@ def __repr__(self): """ Get the string representation of the variables. """ - data_string = ( - "Variable labels:\n" + self.labels.__repr__().split("\n", 1)[1] - ) + data_string = "Variable labels:\n" + self.labels.__repr__().split("\n", 1)[1] extend_line = "-" * len(self.name) return ( f"Variable '{self.name}':\n" @@ -372,26 +368,26 @@ def sanitize(self): if issubdtype(self.dtype, floating): return self.__class__(self.labels.fillna(-1).astype(int)) return self - + def equals(self, other): return self.labels.equals(_var_unwrap(other)) - + @property def attrs(self): return self.labels.attrs - + @property def values(self): return self.labels.values - + @property def shape(self): return self.labels.shape - + @property def size(self): return self.labels.size - + # Wrapped function which would convert variable to dataarray assign_attrs = varwrap(DataArray.assign_attrs) From 95fc5c4abfdcd41edf3956fce896f3dc8966d5e3 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Mon, 7 Nov 2022 19:22:23 +0100 Subject: [PATCH 10/24] Add monkey_patch work-around for left multiplication --- linopy/common.py | 14 +++++++++++++- linopy/variables.py | 35 +++++++++++++---------------------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/linopy/common.py b/linopy/common.py index e74dcedc..8f784180 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -6,7 +6,7 @@ This module contains commonly used functions. """ -from functools import wraps +from functools import wraps, update_wrapper, partial import numpy as np from xarray import DataArray, apply_ufunc, merge @@ -115,3 +115,15 @@ def deco(cls): return cls return deco + + +def monkey_patch(cls, pass_unpatched_method=False): + def deco(func): + wrapped = getattr(cls, func.__name__) + if pass_unpatched_method: + func = partial(func, unpatched_method=wrapped) + update_wrapper(func, wrapped) + setattr(cls, func.__name__, func) + return func + + return deco diff --git a/linopy/variables.py b/linopy/variables.py index 7f6ee5f8..0f5f1596 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -26,6 +26,7 @@ has_assigned_model, has_optimized_model, is_constant, + monkey_patch, ) @@ -50,7 +51,7 @@ def _var_unwrap(var): @dataclass(repr=False) -@forward_as_properties(labels=["name"]) +@forward_as_properties(labels=["attrs", "name", "values", "shape", "size"]) class Variable: """ Variable container for storing variable labels. @@ -103,11 +104,6 @@ class Variable: labels: DataArray model: Any = None - # Disable array function, only function defined below are supported - # and set priority higher than pandas/xarray/numpy - __array_ufunc__ = None - # __array_priority__ = 10000 - def __getitem__(self, keys) -> "ScalarVariable": keys = (keys,) if not isinstance(keys, tuple) else keys assert all(map(np.isscalar, keys)), ( @@ -372,22 +368,6 @@ def sanitize(self): def equals(self, other): return self.labels.equals(_var_unwrap(other)) - @property - def attrs(self): - return self.labels.attrs - - @property - def values(self): - return self.labels.values - - @property - def shape(self): - return self.labels.shape - - @property - def size(self): - return self.labels.size - # Wrapped function which would convert variable to dataarray assign_attrs = varwrap(DataArray.assign_attrs) @@ -688,3 +668,14 @@ def __ge__(self, other): def __eq__(self, other): return self.to_scalar_linexpr(1).__eq__(other) + + +##### +## MONKEY PATCH DataArray __mul__ function to pass multiplication to Variable +##### + +@monkey_patch(DataArray, pass_unpatched_method=True) +def __mul__(da, other, unpatched_method): + if isinstance(other, Variable): + return NotImplemented + return unpatched_method(da, other) From 060fe25168289c8da5923ec0c56ba9adddd5072c Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Mon, 7 Nov 2022 19:56:57 +0100 Subject: [PATCH 11/24] Fix tests --- linopy/common.py | 4 ++-- linopy/expressions.py | 1 + linopy/variables.py | 9 +++++++-- test/test_variable.py | 6 +++--- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/linopy/common.py b/linopy/common.py index 8f784180..59ddef56 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -6,7 +6,7 @@ This module contains commonly used functions. """ -from functools import wraps, update_wrapper, partial +from functools import wraps, update_wrapper, partialmethod import numpy as np from xarray import DataArray, apply_ufunc, merge @@ -121,7 +121,7 @@ def monkey_patch(cls, pass_unpatched_method=False): def deco(func): wrapped = getattr(cls, func.__name__) if pass_unpatched_method: - func = partial(func, unpatched_method=wrapped) + func = partialmethod(func, unpatched_method=wrapped) update_wrapper(func, wrapped) setattr(cls, func.__name__, func) return func diff --git a/linopy/expressions.py b/linopy/expressions.py index cce4b57e..5ca7decd 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -95,6 +95,7 @@ class LinearExpression: fill_value = {"vars": -1, "coeffs": np.nan} __array_ufunc__ = None + __array_priority__ = 10000 def __init__(self, data_vars=None, coords=None, attrs=None): ds = Dataset(data_vars, coords, attrs) diff --git a/linopy/variables.py b/linopy/variables.py index 0f5f1596..89b4b42e 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -51,7 +51,9 @@ def _var_unwrap(var): @dataclass(repr=False) -@forward_as_properties(labels=["attrs", "name", "values", "shape", "size"]) +@forward_as_properties( + labels=["attrs", "coords", "indexes", "name", "shape", "size", "values"] +) class Variable: """ Variable container for storing variable labels. @@ -104,6 +106,8 @@ class Variable: labels: DataArray model: Any = None + __array_ufunc__ = None + def __getitem__(self, keys) -> "ScalarVariable": keys = (keys,) if not isinstance(keys, tuple) else keys assert all(map(np.isscalar, keys)), ( @@ -361,7 +365,7 @@ def sanitize(self): ------- linopy.Variable """ - if issubdtype(self.dtype, floating): + if issubdtype(self.labels.dtype, floating): return self.__class__(self.labels.fillna(-1).astype(int)) return self @@ -674,6 +678,7 @@ def __eq__(self, other): ## MONKEY PATCH DataArray __mul__ function to pass multiplication to Variable ##### + @monkey_patch(DataArray, pass_unpatched_method=True) def __mul__(da, other, unpatched_method): if isinstance(other, Variable): diff --git a/test/test_variable.py b/test/test_variable.py index aaac1939..6de4fc33 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -104,7 +104,7 @@ def test_variable_where(): x = m.add_variables(coords=[range(10)]) x = x.where([True] * 4 + [False] * 6) assert isinstance(x, linopy.variables.Variable) - assert x.loc[9].item() == -1 + assert x.values[9] == -1 def test_variable_shift(): @@ -112,7 +112,7 @@ def test_variable_shift(): x = m.add_variables(coords=[range(10)]) x = x.shift(dim_0=3) assert isinstance(x, linopy.variables.Variable) - assert x.loc[0].item() == -1 + assert x.values[0] == -1 def test_variable_sanitize(): @@ -122,7 +122,7 @@ def test_variable_sanitize(): x = x.where([True] * 4 + [False] * 6, np.nan) x = x.sanitize() assert isinstance(x, linopy.variables.Variable) - assert x.loc[9].item() == -1 + assert x.values[9] == -1 def test_variable_type_preservation(): From c03acd356a4dd12826cc5f42d68761b8a08961ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Nov 2022 18:57:12 +0000 Subject: [PATCH 12/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linopy/common.py b/linopy/common.py index 59ddef56..5ff3ed31 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -6,7 +6,7 @@ This module contains commonly used functions. """ -from functools import wraps, update_wrapper, partialmethod +from functools import partialmethod, update_wrapper, wraps import numpy as np from xarray import DataArray, apply_ufunc, merge From cc58dd1045d6d977f36b44bdef7038b38c3273e3 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Mon, 7 Nov 2022 20:02:15 +0100 Subject: [PATCH 13/24] test_model_creation: Fix final test --- test/test_model_creation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_model_creation.py b/test/test_model_creation.py index e26d81a1..aa085adf 100755 --- a/test/test_model_creation.py +++ b/test/test_model_creation.py @@ -376,7 +376,7 @@ def test_remove_variable(): assert "con0" not in m.constraints.labels - assert not m.objective.vars.isin(x).any() + assert not m.objective.vars.isin(x.labels).any() def test_remove_constraint(): From d1644545f4cafa919772eadac02c8c9023ecb77e Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Tue, 8 Nov 2022 18:53:10 +0100 Subject: [PATCH 14/24] Move monkey patching into dedicated module --- linopy/__init__.py | 4 ++++ linopy/common.py | 12 ------------ linopy/monkey_patch_xarray.py | 25 +++++++++++++++++++++++++ linopy/variables.py | 13 ------------- 4 files changed, 29 insertions(+), 25 deletions(-) create mode 100644 linopy/monkey_patch_xarray.py diff --git a/linopy/__init__.py b/linopy/__init__.py index 869a86a0..a640768b 100755 --- a/linopy/__init__.py +++ b/linopy/__init__.py @@ -12,3 +12,7 @@ from linopy.model import LinearExpression, Model, Variable, available_solvers from linopy.remote import RemoteHandler from linopy.version import version as __version__ + +# Note: For intercepting multiplications between xarray dataarrays, Variables and Expressions +# we need to extend their __mul__ functions with a quick special case +import linopy.monkey_patch_xarray \ No newline at end of file diff --git a/linopy/common.py b/linopy/common.py index 5ff3ed31..361c3098 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -115,15 +115,3 @@ def deco(cls): return cls return deco - - -def monkey_patch(cls, pass_unpatched_method=False): - def deco(func): - wrapped = getattr(cls, func.__name__) - if pass_unpatched_method: - func = partialmethod(func, unpatched_method=wrapped) - update_wrapper(func, wrapped) - setattr(cls, func.__name__, func) - return func - - return deco diff --git a/linopy/monkey_patch_xarray.py b/linopy/monkey_patch_xarray.py new file mode 100644 index 00000000..b485adee --- /dev/null +++ b/linopy/monkey_patch_xarray.py @@ -0,0 +1,25 @@ +from functools import partialmethod, update_wrapper +from xarray import DataArray + +from .variables import Variable +from .expressions import LinearExpression + + +def monkey_patch(cls, pass_unpatched_method=False): + def deco(func): + func_name = func.__name__ + wrapped = getattr(cls, func_name) + update_wrapper(func, wrapped) + if pass_unpatched_method: + func = partialmethod(func, unpatched_method=wrapped) + setattr(cls, func_name, func) + return func + + return deco + + +@monkey_patch(DataArray, pass_unpatched_method=True) +def __mul__(da, other, unpatched_method): + if isinstance(other, (Variable, LinearExpression)): + return NotImplemented + return unpatched_method(da, other) diff --git a/linopy/variables.py b/linopy/variables.py index 89b4b42e..de018df9 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -26,7 +26,6 @@ has_assigned_model, has_optimized_model, is_constant, - monkey_patch, ) @@ -672,15 +671,3 @@ def __ge__(self, other): def __eq__(self, other): return self.to_scalar_linexpr(1).__eq__(other) - - -##### -## MONKEY PATCH DataArray __mul__ function to pass multiplication to Variable -##### - - -@monkey_patch(DataArray, pass_unpatched_method=True) -def __mul__(da, other, unpatched_method): - if isinstance(other, Variable): - return NotImplemented - return unpatched_method(da, other) From 44fafc8457442126700c17f20546bb7f43456b5c Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 8 Nov 2022 20:57:25 +0100 Subject: [PATCH 15/24] support python3.11 --- linopy/constraints.py | 14 +++++++------- linopy/variables.py | 12 ++++++------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/linopy/constraints.py b/linopy/constraints.py index a2a67e13..4f179775 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -6,7 +6,7 @@ """ import re -from dataclasses import dataclass +from dataclasses import dataclass, field from itertools import product from typing import Any, Sequence, Union @@ -208,12 +208,12 @@ class Constraints: A constraint container used for storing multiple constraint arrays. """ - labels: Dataset = Dataset() - coeffs: Dataset = Dataset() - vars: Dataset = Dataset() - sign: Dataset = Dataset() - rhs: Dataset = Dataset() - blocks: Dataset = Dataset() + labels: Dataset = field(default_factory=Dataset) + coeffs: Dataset = field(default_factory=Dataset) + vars: Dataset = field(default_factory=Dataset) + sign: Dataset = field(default_factory=Dataset) + rhs: Dataset = field(default_factory=Dataset) + blocks: Dataset = field(default_factory=Dataset) model: Any = None # Model is not defined due to circular imports dataset_attrs = ["labels", "coeffs", "vars", "sign", "rhs"] diff --git a/linopy/variables.py b/linopy/variables.py index 89b4b42e..0b5b7215 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -7,7 +7,7 @@ import functools import re -from dataclasses import dataclass +from dataclasses import dataclass, field from distutils.log import warn from typing import Any, Sequence, Union from warnings import warn @@ -103,7 +103,7 @@ class Variable: Further operations like taking the negative and subtracting are supported. """ - labels: DataArray + labels: DataArray = field(default_factory=DataArray) model: Any = None __array_ufunc__ = None @@ -405,10 +405,10 @@ class Variables: A variables container used for storing multiple variable arrays. """ - labels: Dataset = Dataset() - lower: Dataset = Dataset() - upper: Dataset = Dataset() - blocks: Dataset = Dataset() + labels: Dataset = field(default_factory=Dataset) + lower: Dataset = field(default_factory=Dataset) + upper: Dataset = field(default_factory=Dataset) + blocks: Dataset = field(default_factory=Dataset) model: Any = None # Model is not defined due to circular imports dataset_attrs = ["labels", "lower", "upper"] From f7e134b8e4bacfcf04857e17c7bf4b81c25726e9 Mon Sep 17 00:00:00 2001 From: Fabian Date: Wed, 9 Nov 2022 12:38:51 +0100 Subject: [PATCH 16/24] variables: add locindexer --- linopy/expressions.py | 4 +++- linopy/model.py | 3 +-- linopy/variables.py | 40 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 5ca7decd..8cd067d8 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -10,7 +10,7 @@ import logging from dataclasses import dataclass from itertools import product, zip_longest -from typing import Union +from typing import Any, Hashable, Iterable, Mapping, Sequence, Union import numpy as np import pandas as pd @@ -573,6 +573,8 @@ def __iter__(self): assign_attrs = exprwrap(Dataset.assign_attrs) + assign_coords = exprwrap(Dataset.assign_coords) + astype = exprwrap(Dataset.astype) bfill = exprwrap(Dataset.bfill) diff --git a/linopy/model.py b/linopy/model.py index ce1a6e4f..1a26052b 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -306,7 +306,7 @@ def __getitem__(self, key): """ Get a model variable by the name. """ - return Variable(self.variables[key], model=self) + return self.variables[key] def check_force_dim_names(self, ds): """ @@ -534,7 +534,6 @@ def add_constraints( "Argument `sign` and `rhs` must not be None if first argument " " is an expression." ) - if isinstance(lhs, (list, tuple)): lhs = self.linexpr(*lhs) elif isinstance(lhs, (Variable, ScalarVariable, ScalarLinearExpression)): diff --git a/linopy/variables.py b/linopy/variables.py index 0b5b7215..6ad728c0 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -10,7 +10,6 @@ from dataclasses import dataclass, field from distutils.log import warn from typing import Any, Sequence, Union -from warnings import warn import dask import numpy as np @@ -18,6 +17,7 @@ from deprecation import deprecated from numpy import floating, inf, issubdtype from xarray import DataArray, Dataset, zeros_like +from xarray.core import indexing, utils import linopy.expressions as expressions from linopy.common import ( @@ -52,7 +52,17 @@ def _var_unwrap(var): @dataclass(repr=False) @forward_as_properties( - labels=["attrs", "coords", "indexes", "name", "shape", "size", "values"] + labels=[ + "attrs", + "coords", + "indexes", + "name", + "shape", + "size", + "values", + "dims", + "ndim", + ] ) class Variable: """ @@ -124,6 +134,10 @@ def __getitem__(self, keys) -> "ScalarVariable": selector = [self.labels.get_index(k).get_loc(v) for k, v in key.items()] return ScalarVariable(self.labels.data[tuple(selector)]) + @property + def loc(self): + return _LocIndexer(self) + @deprecated(details="Use `labels` instead of `to_array()`") def to_array(self): """ @@ -131,6 +145,9 @@ def to_array(self): """ return self.labels + def to_pandas(self): + return self.labels.to_pandas() + def to_linexpr(self, coefficient=1): """ Create a linear exprssion from the variables. @@ -375,6 +392,8 @@ def equals(self, other): # Wrapped function which would convert variable to dataarray assign_attrs = varwrap(DataArray.assign_attrs) + assign_coords = varwrap(DataArray.assign_coords) + astype = varwrap(DataArray.astype) bfill = varwrap(DataArray.bfill) @@ -390,15 +409,32 @@ def equals(self, other): fillna = varwrap(DataArray.fillna) sel = varwrap(DataArray.sel) + isel = varwrap(DataArray.isel) shift = varwrap(DataArray.shift, fill_value=-1) + rename = varwrap(DataArray.rename) + roll = varwrap(DataArray.roll) rolling = varwrap(DataArray.rolling) +class _LocIndexer: + __slots__ = ("variable",) + + def __init__(self, variable: Variable): + self.variable = variable + + def __getitem__(self, key) -> DataArray: + if not utils.is_dict_like(key): + # expand the indexer so we can handle Ellipsis + labels = indexing.expanded_indexer(key, self.variable.ndim) + key = dict(zip(self.variable.dims, labels)) + return self.variable.sel(key) + + @dataclass(repr=False) class Variables: """ From 7f6ffa0917696773b26bed2ca6146e634bd527aa Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 10 Nov 2022 11:50:19 +0100 Subject: [PATCH 17/24] add LinearExpressionGroupby class --- linopy/expressions.py | 65 ++++++++++++++++++++++++++++++++-- linopy/variables.py | 50 +++++++++++++++++++++----- test/test_linear_expression.py | 28 +++++++++++++++ test/test_variable.py | 1 - 4 files changed, 133 insertions(+), 11 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 8cd067d8..05af5e43 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -50,6 +50,28 @@ def _expr_unwrap(maybe_expr): logger = logging.getLogger(__name__) +@dataclass +class LinearExpressionGroupby: + """ + GroupBy object specialized to grouping LinearExpression objects. + """ + + groupby: xr.core.groupby.DatasetGroupBy + + def map(self, func, shortcut=False, args=(), **kwargs): + return LinearExpression( + self.groupby.map(func, shortcut=shortcut, args=args, **kwargs) + ) + + def sum(self, **kwargs): + def func(ds): + ds = LinearExpression._sum(ds, self.groupby._group_dim) + ds = ds.assign_coords(_term=np.arange(len(ds._term))) + return ds + + return self.map(func, **kwargs) + + @forward_as_properties(data=["attrs", "coords", "dims", "indexes"]) class LinearExpression: """ @@ -426,6 +448,43 @@ def where(self, cond, other=xr.core.dtypes.NA, **kwargs): cond = _expr_unwrap(cond) return self.__class__(self.data.where(cond, other=other, **kwargs)) + def groupby( + self, + group, + squeeze: "bool" = True, + restore_coord_dims: "bool" = None, + ): + """ + Returns a LinearExpressionGroupBy object for performing grouped + operations. + + Docstring and arguments are borrowed from `xarray.Dataset.groupby` + + Parameters + ---------- + group : str, DataArray or IndexVariable + Array whose unique values should be used to group this array. If a + string, must be the name of a variable contained in this dataset. + squeeze : bool, optional + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped + A `LinearExpressionGroupBy` containing the xarray groups and ensuring + the correct return type. + """ + ds = self.to_dataset() + groups = ds.groupby( + group=group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + ) + return LinearExpressionGroupby(groups) + def groupby_sum(self, group): """ Sum expression over groups. @@ -583,9 +642,11 @@ def __iter__(self): chunk = exprwrap(Dataset.chunk) - coarsen = exprwrap(Dataset.coarsen) + drop = exprwrap(Dataset.drop) + + drop_sel = exprwrap(Dataset.drop_sel) - clip = exprwrap(Dataset.clip) + drop_isel = exprwrap(Dataset.drop_isel) ffill = exprwrap(Dataset.ffill) diff --git a/linopy/variables.py b/linopy/variables.py index 6ad728c0..a1cc2fb2 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -14,6 +14,7 @@ import dask import numpy as np import pandas as pd +import xarray as xr from deprecation import deprecated from numpy import floating, inf, issubdtype from xarray import DataArray, Dataset, zeros_like @@ -230,6 +231,41 @@ def __ge__(self, other): def __eq__(self, other): return self.to_linexpr().__eq__(other) + def groupby( + self, + group, + squeeze: "bool" = True, + restore_coord_dims: "bool" = None, + ): + """ + Returns a LinearExpressionGroupBy object for performing grouped + operations. + + Docstring and arguments are borrowed from `xarray.Dataset.groupby` + + Parameters + ---------- + group : str, DataArray or IndexVariable + Array whose unique values should be used to group this array. If a + string, must be the name of a variable contained in this dataset. + squeeze : bool, optional + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped + A `LinearExpressionGroupBy` containing the xarray groups and ensuring + the correct return type. + """ + return self.to_linexpr().groupby( + group=group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + ) + def groupby_sum(self, group): """ Sum variable over groups. @@ -248,12 +284,6 @@ def groupby_sum(self, group): """ return self.to_linexpr().groupby_sum(group) - def group_terms(self, group): - warn( - 'The function "group_terms" was renamed to "groupby_sum" and will be remove in v0.0.10.' - ) - return self.groupby_sum(group) - def rolling_sum(self, **kwargs): """ Rolling sum of variable. @@ -400,10 +430,14 @@ def equals(self, other): broadcast_like = varwrap(DataArray.broadcast_like) - clip = varwrap(DataArray.clip) - compute = varwrap(DataArray.compute) + drop = varwrap(DataArray.drop) + + drop_sel = varwrap(DataArray.drop_sel) + + drop_isel = varwrap(DataArray.drop_isel) + ffill = varwrap(DataArray.ffill) fillna = varwrap(DataArray.fillna) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index e0e6d24a..5082cfd9 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -273,6 +273,34 @@ def test_multiindexed_expression(): assert isinstance(expr, LinearExpression) +def test_groupby(): + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).sum() + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.data._term.size == 10 + + # now asymetric groups which result in different nterms + groups = xr.DataArray([1] * 12 + [2] * 8, coords=v.coords) + grouped = expr.groupby(groups).sum() + assert "group" in grouped.dims + # first group must be full with vars + assert (grouped.data.sel(group=1) > 0).all() + # the last 4 entries of the second group must be empty, i.e. -1 + assert (grouped.data.sel(group=2).isel(_term=slice(None, -4)).vars >= 0).all() + assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() + assert grouped.data._term.size == 12 + + +def test_groupby_variable(): + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = v.groupby(groups).sum() + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.data._term.size == 10 + + def test_groupby_sum(): groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) grouped = v.to_linexpr().groupby_sum(groups) diff --git a/test/test_variable.py b/test/test_variable.py index 6de4fc33..424e7ef4 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -131,7 +131,6 @@ def test_variable_type_preservation(): assert isinstance(x.bfill("dim_0"), linopy.variables.Variable) assert isinstance(x.broadcast_like(x.to_array()), linopy.variables.Variable) - assert isinstance(x.clip(max=20), linopy.variables.Variable) assert isinstance(x.ffill("dim_0"), linopy.variables.Variable) assert isinstance(x.fillna(-1), linopy.variables.Variable) From fd27342f47e4411f7a346dbfc59fc5d1d1b3ffe7 Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 10 Nov 2022 14:51:02 +0100 Subject: [PATCH 18/24] add LinearExpressionRolling class adjust rename function --- linopy/expressions.py | 70 ++++++++++++++++++++++++++++++++-- linopy/model.py | 2 +- linopy/variables.py | 39 +++++++++++++++++-- test/test_linear_expression.py | 20 ++++++++++ test/test_variable.py | 2 +- 5 files changed, 124 insertions(+), 9 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 05af5e43..d6729309 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -51,6 +51,7 @@ def _expr_unwrap(maybe_expr): @dataclass +@forward_as_properties(groupby=["dims", "groups"]) class LinearExpressionGroupby: """ GroupBy object specialized to grouping LinearExpression objects. @@ -71,7 +72,30 @@ def func(ds): return self.map(func, **kwargs) + def roll(self, **kwargs): + return self.map(Dataset.roll, **kwargs) + +@dataclass +@forward_as_properties(rolling=["center", "dim", "obj", "rollings", "window"]) +class LinearExpressionRolling: + """ + GroupBy object specialized to grouping LinearExpression objects. + """ + + rolling: xr.core.rolling.DataArrayRolling + + def sum(self, **kwargs): + ds = ( + self.rolling.construct("_rolling_term", keep_attrs=True) + .rename(_term="_stacked_term") + .stack(_term=["_stacked_term", "_rolling_term"]) + .reset_index("_term", drop=True) + ) + return LinearExpression(ds) + + +@dataclass(repr=False) @forward_as_properties(data=["attrs", "coords", "dims", "indexes"]) class LinearExpression: """ @@ -453,7 +477,7 @@ def groupby( group, squeeze: "bool" = True, restore_coord_dims: "bool" = None, - ): + ) -> LinearExpressionGroupby: """ Returns a LinearExpressionGroupBy object for performing grouped operations. @@ -479,7 +503,7 @@ def groupby( A `LinearExpressionGroupBy` containing the xarray groups and ensuring the correct return type. """ - ds = self.to_dataset() + ds = self.data groups = ds.groupby( group=group, squeeze=squeeze, restore_coord_dims=restore_coord_dims ) @@ -513,6 +537,43 @@ def func(ds): return self.__class__(groups.map(func)) # .reset_index('_term') + def rolling( + self, + dim: "Mapping[Any, int]" = None, + min_periods: "int" = None, + center: "bool | Mapping[Any, bool]" = False, + **window_kwargs: "int", + ) -> LinearExpressionRolling: + """ + Rolling window object. + + Docstring and arguments are borrowed from `xarray.Dataset.rolling` + + Parameters + ---------- + dim : dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. + min_periods : int, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or mapping, default: False + Set the labels at the center of the window. + **window_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or window_kwargs must be provided. + + Returns + ------- + linopy.expression.LinearExpressionRolling + """ + ds = self.data + rolling = ds.rolling( + dim=dim, min_periods=min_periods, center=center, **window_kwargs + ) + return LinearExpressionRolling(rolling) + def rolling_sum(self, **kwargs): """ Rolling sum of the linear expression. @@ -621,6 +682,7 @@ def sanitize(self): def equals(self, other: "LinearExpression"): return self.data.equals(_expr_unwrap(other)) + # TODO: make this return a LinearExpression (needs refactoring of __init__) def rename(self, name_dict=None, **names) -> Dataset: return self.data.rename(name_dict, **names) @@ -658,9 +720,9 @@ def __iter__(self): reindex = exprwrap(Dataset.reindex, fill_value=fill_value) - roll = exprwrap(Dataset.roll) + rename_dims = exprwrap(Dataset.rename_dims) - rolling = exprwrap(Dataset.rolling) + roll = exprwrap(Dataset.roll) def _pd_series_wo_index_name(ds): diff --git a/linopy/model.py b/linopy/model.py index 1a26052b..4f82954b 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -565,7 +565,7 @@ def add_constraints( ), "Dimensions of mask not a subset of resulting labels dimensions." labels = labels.where(mask, -1) - lhs = lhs.rename({"_term": f"{name}_term"}) + lhs = lhs.data.rename({"_term": f"{name}_term"}) if self.chunk: lhs = lhs.chunk(self.chunk) diff --git a/linopy/variables.py b/linopy/variables.py index a1cc2fb2..96cb9d81 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -9,7 +9,7 @@ import re from dataclasses import dataclass, field from distutils.log import warn -from typing import Any, Sequence, Union +from typing import Any, Mapping, Sequence, Union import dask import numpy as np @@ -284,6 +284,41 @@ def groupby_sum(self, group): """ return self.to_linexpr().groupby_sum(group) + def rolling( + self, + dim: "Mapping[Any, int]" = None, + min_periods: "int" = None, + center: "bool | Mapping[Any, bool]" = False, + **window_kwargs: "int", + ) -> "expressions.LinearExpressionRolling": + """ + Rolling window object. + + Docstring and arguments are borrowed from `xarray.Dataset.rolling` + + Parameters + ---------- + dim : dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. + min_periods : int, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or mapping, default: False + Set the labels at the center of the window. + **window_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or window_kwargs must be provided. + + Returns + ------- + linopy.expression.LinearExpressionRolling + """ + return self.to_linexpr().rolling( + dim=dim, min_periods=min_periods, center=center, **window_kwargs + ) + def rolling_sum(self, **kwargs): """ Rolling sum of variable. @@ -452,8 +487,6 @@ def equals(self, other): roll = varwrap(DataArray.roll) - rolling = varwrap(DataArray.rolling) - class _LocIndexer: __slots__ = ("variable",) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 5082cfd9..585dfea9 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -292,6 +292,12 @@ def test_groupby(): assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() assert grouped.data._term.size == 12 + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).roll(dim_2=1) + assert grouped.nterm == 1 + assert grouped.vars[0].item() == 19 + def test_groupby_variable(): groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) @@ -328,6 +334,15 @@ def test_groupby_sum_variable(): assert grouped.data._term.size == 10 +def test_rolling(): + expr = 1 * v + rolled = expr.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + + rolled = expr.rolling(dim_2=3).sum() + assert rolled.nterm == 3 + + def test_rolling_sum(): rolled = v.to_linexpr().rolling_sum(dim_2=2) assert rolled.nterm == 2 @@ -338,6 +353,11 @@ def test_rolling_sum(): assert rolled.nterm == 6 +def test_rolling_variable(): + rolled = v.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + + def test_rolling_sum_variable(): rolled = v.rolling_sum(dim_2=2) assert rolled.nterm == 2 diff --git a/test/test_variable.py b/test/test_variable.py index 424e7ef4..16db37a8 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -130,7 +130,7 @@ def test_variable_type_preservation(): x = m.add_variables(coords=[range(10)]) assert isinstance(x.bfill("dim_0"), linopy.variables.Variable) - assert isinstance(x.broadcast_like(x.to_array()), linopy.variables.Variable) + assert isinstance(x.broadcast_like(x.labels), linopy.variables.Variable) assert isinstance(x.ffill("dim_0"), linopy.variables.Variable) assert isinstance(x.fillna(-1), linopy.variables.Variable) From 705c97067c156ea38917e00f36c844192874124d Mon Sep 17 00:00:00 2001 From: Fabian Date: Mon, 14 Nov 2022 12:15:51 +0100 Subject: [PATCH 19/24] style: make flake8 a bit more happy --- linopy/expressions.py | 2 +- linopy/variables.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index d6729309..c513f0e7 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -10,7 +10,7 @@ import logging from dataclasses import dataclass from itertools import product, zip_longest -from typing import Any, Hashable, Iterable, Mapping, Sequence, Union +from typing import Any, Mapping, Union import numpy as np import pandas as pd diff --git a/linopy/variables.py b/linopy/variables.py index 96cb9d81..9ee4d06d 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -8,13 +8,11 @@ import functools import re from dataclasses import dataclass, field -from distutils.log import warn from typing import Any, Mapping, Sequence, Union import dask import numpy as np import pandas as pd -import xarray as xr from deprecation import deprecated from numpy import floating, inf, issubdtype from xarray import DataArray, Dataset, zeros_like @@ -778,7 +776,7 @@ def __eq__(self, other): ##### -## MONKEY PATCH DataArray __mul__ function to pass multiplication to Variable +# MONKEY PATCH DataArray __mul__ function to pass multiplication to Variable ##### From b63cf0272f11fa8cdb520dd1d77db0b2d8f9ed4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Nov 2022 21:48:43 +0000 Subject: [PATCH 20/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/__init__.py | 7 +++---- linopy/monkey_patch_xarray.py | 5 +++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/linopy/__init__.py b/linopy/__init__.py index a640768b..8fb0cdd5 100755 --- a/linopy/__init__.py +++ b/linopy/__init__.py @@ -6,13 +6,12 @@ @author: fabulous """ +# Note: For intercepting multiplications between xarray dataarrays, Variables and Expressions +# we need to extend their __mul__ functions with a quick special case +import linopy.monkey_patch_xarray from linopy import model, remote from linopy.expressions import merge from linopy.io import read_netcdf from linopy.model import LinearExpression, Model, Variable, available_solvers from linopy.remote import RemoteHandler from linopy.version import version as __version__ - -# Note: For intercepting multiplications between xarray dataarrays, Variables and Expressions -# we need to extend their __mul__ functions with a quick special case -import linopy.monkey_patch_xarray \ No newline at end of file diff --git a/linopy/monkey_patch_xarray.py b/linopy/monkey_patch_xarray.py index b485adee..824a6045 100644 --- a/linopy/monkey_patch_xarray.py +++ b/linopy/monkey_patch_xarray.py @@ -1,8 +1,9 @@ from functools import partialmethod, update_wrapper + from xarray import DataArray -from .variables import Variable -from .expressions import LinearExpression +from linopy.expressions import LinearExpression +from linopy.variables import Variable def monkey_patch(cls, pass_unpatched_method=False): From 0330959324fa13fdca6b5220fb804ca7e7946973 Mon Sep 17 00:00:00 2001 From: Fabian Date: Fri, 9 Dec 2022 12:08:35 +0100 Subject: [PATCH 21/24] fix circular imports --- linopy/constraints.py | 2 +- linopy/monkey_patch_xarray.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/linopy/constraints.py b/linopy/constraints.py index 9053e218..19013a1b 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -679,7 +679,7 @@ class AnonymousScalarConstraint: (rhs) for exactly one constraint. """ - lhs: expressions.ScalarLinearExpression + lhs: "expressions.ScalarLinearExpression" sign: str rhs: float diff --git a/linopy/monkey_patch_xarray.py b/linopy/monkey_patch_xarray.py index 824a6045..b717f876 100644 --- a/linopy/monkey_patch_xarray.py +++ b/linopy/monkey_patch_xarray.py @@ -2,8 +2,7 @@ from xarray import DataArray -from linopy.expressions import LinearExpression -from linopy.variables import Variable +from linopy import expressions, variables def monkey_patch(cls, pass_unpatched_method=False): @@ -21,6 +20,6 @@ def deco(func): @monkey_patch(DataArray, pass_unpatched_method=True) def __mul__(da, other, unpatched_method): - if isinstance(other, (Variable, LinearExpression)): + if isinstance(other, (variables.Variable, expressions.LinearExpression)): return NotImplemented return unpatched_method(da, other) From d12a5bbf604e2d23004f9cf87bf4bd4153c331bc Mon Sep 17 00:00:00 2001 From: Fabian Date: Fri, 9 Dec 2022 12:33:19 +0100 Subject: [PATCH 22/24] fix imports for xarray v2022.12. --- linopy/expressions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/linopy/expressions.py b/linopy/expressions.py index daf780ba..cfc76e9f 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -15,6 +15,8 @@ import numpy as np import pandas as pd import xarray as xr +import xarray.core.groupby +import xarray.core.rolling from deprecation import deprecated from numpy import array, nan from xarray import DataArray, Dataset From de43928e98aa2b55928794f21051eecde4162110 Mon Sep 17 00:00:00 2001 From: Fabian Date: Fri, 9 Dec 2022 12:44:00 +0100 Subject: [PATCH 23/24] update release notes --- doc/release_notes.rst | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/doc/release_notes.rst b/doc/release_notes.rst index f4444d0c..bd3c8b3f 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -1,8 +1,11 @@ Release Notes ============= -.. Upcoming Release -.. ---------------- +Upcoming Release +---------------- + +* The internal data structure of linopy classes were updated to a safer design. Instead of being defined as inherited xarray classes, the class `Variable`, `LinearExpression` and `Constraint` are now dataclasses with containing the xarray objects in the data field. This allows the package to have more flexible function design and a reduced set of wrapped functions that are sensible to use in the optimization context. +* The class `Variable` and `LinearExpression` have new functions `groupby` and `rolling` imitating the corresponding xarray functions but with safe type inheritance and application of appended operations. Version 0.0.15 From 749b673083571a693648d621996a003ef710fb58 Mon Sep 17 00:00:00 2001 From: Fabian Date: Fri, 9 Dec 2022 14:06:57 +0100 Subject: [PATCH 24/24] test: increase coverage --- test/test_constraint.py | 19 +++++++++++++++++++ test/test_variable.py | 2 ++ 2 files changed, 21 insertions(+) diff --git a/test/test_constraint.py b/test/test_constraint.py index 8509e5f2..61e6ec95 100644 --- a/test/test_constraint.py +++ b/test/test_constraint.py @@ -105,6 +105,25 @@ def test_constraint_accessor(): assert c.vars.notnull().all().item() assert c.coeffs.notnull().all().item() + # Test that assigning labels raises RuntimeError + with pytest.raises(RuntimeError): + c.labels = c.labels + + # Test that assigning lhs with other type that LinearExpression raises TypeError + with pytest.raises(TypeError): + c.lhs = x + + # Test that assigning lhs with other type that LinearExpression raises TypeError + with pytest.raises(ValueError): + c.sign = "==" + + # Test that assigning a variable or linear expression to the rhs property raises a TypeError + with pytest.raises(TypeError): + c.rhs = x + + with pytest.raises(TypeError): + c.rhs = x + y + def test_constraint_accessor_M(): m = Model() diff --git a/test/test_variable.py b/test/test_variable.py index 16db37a8..4e8ae4fa 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -30,6 +30,8 @@ def test_variable_getter(): with pytest.raises(AssertionError): x[[1, 2, 3]] + assert isinstance(x.loc[[1, 2, 3]], linopy.Variable) + def test_variable_repr(): m = Model()