Skip to content

Commit

Permalink
Merge pull request #246 from leuchtum/change_dims_dim_keyword
Browse files Browse the repository at this point in the history
Renamed kw `dims` of `LinearExpression.sum` & `Variable.sum` to `dim`
  • Loading branch information
FabianHofmann authored May 3, 2024
2 parents f182b7f + a17a0aa commit e69787b
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 21 deletions.
4 changes: 4 additions & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Upcoming Version

* Linopy now supports python 3.12.

**Deprecations**

* The argument `dims` in the `.sum` function of variables and expressions was deprecated in favor of the `dim` argument. This aligns the argument name with the xarray convention.

Version 0.3.8
-------------

Expand Down
6 changes: 3 additions & 3 deletions examples/transport-tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@
"metadata": {},
"outputs": [],
"source": [
"x.sum(dims=\"Markets\") <= a"
"x.sum(dim=\"Markets\") <= a"
]
},
{
Expand All @@ -320,10 +320,10 @@
"# demand(j) satisfy demand at market j ;\n",
"# demand(j) .. sum(i, x(i,j)) =g= b(j);\n",
"\n",
"con = x.sum(dims=\"Markets\") <= a\n",
"con = x.sum(dim=\"Markets\") <= a\n",
"con1 = m.add_constraints(con, name=\"Observe supply limit at plant i\")\n",
"\n",
"con = x.sum(dims=\"Canning Plants\") >= b\n",
"con = x.sum(dim=\"Canning Plants\") >= b\n",
"con2 = m.add_constraints(con, name=\"Satisfy demand at market j\")"
]
},
Expand Down
40 changes: 26 additions & 14 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass, field
from itertools import product, zip_longest
from typing import Any, Mapping, Optional, Union
from warnings import warn

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -187,7 +188,7 @@ def sum(self, use_fallback=False, **kwargs):
coords = Coordinates.from_pandas_multiindex(idx, group_dim)
ds = self.data.assign_coords(coords)
ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value)
ds = LinearExpression._sum(ds, dims=GROUPED_TERM_DIM)
ds = LinearExpression._sum(ds, dim=GROUPED_TERM_DIM)

if int_map is not None:
index = ds.indexes["group"].map({v: k for k, v in int_map.items()})
Expand Down Expand Up @@ -276,7 +277,7 @@ class LinearExpression:
Summation over dimensions
>>> type(expr.sum(dims="dim_0"))
>>> type(expr.sum(dim="dim_0"))
<class 'linopy.expressions.LinearExpression'>
"""

Expand Down Expand Up @@ -671,44 +672,55 @@ def solution(self):
return sol.rename("solution")

@classmethod
def _sum(cls, expr: Union["LinearExpression", Dataset], dims=None) -> Dataset:
def _sum(cls, expr: Union["LinearExpression", Dataset], dim=None) -> Dataset:
data = _expr_unwrap(expr)

if dims is None:
if dim is None:
vars = DataArray(data.vars.data.ravel(), dims=TERM_DIM)
coeffs = DataArray(data.coeffs.data.ravel(), dims=TERM_DIM)
const = data.const.sum()
ds = xr.Dataset({"vars": vars, "coeffs": coeffs, "const": const})
else:
dims = [d for d in np.atleast_1d(dims) if d != TERM_DIM]
dim = [d for d in np.atleast_1d(dim) if d != TERM_DIM]
ds = (
data[["coeffs", "vars"]]
.reset_index(dims, drop=True)
.reset_index(dim, drop=True)
.rename({TERM_DIM: STACKED_TERM_DIM})
.stack({TERM_DIM: [STACKED_TERM_DIM] + dims}, create_index=False)
.stack({TERM_DIM: [STACKED_TERM_DIM] + dim}, create_index=False)
)
ds["const"] = data.const.sum(dims)
ds["const"] = data.const.sum(dim)

return ds

def sum(self, dims=None, drop_zeros=False) -> "LinearExpression":
def sum(self, dim=None, drop_zeros=False, **kwargs) -> "LinearExpression":
"""
Sum the expression over all or a subset of dimensions.
This stack all terms of the dimensions, that are summed over, together.
Parameters
----------
dims : str/list, optional
dim : str/list, optional
Dimension(s) to sum over. The default is None which results in all
dimensions.
dims : str/list, optional
Deprecated. Use ``dim`` instead.
Returns
-------
linopy.LinearExpression
Summed expression.
"""
res = self.__class__(self._sum(self, dims=dims), self.model)
if dim is None and "dims" in kwargs:
dim = kwargs.pop("dims")
warn(
"The `dims` argument in `.sum` is deprecated. Use `dim` instead.",
DeprecationWarning,
)
if kwargs:
raise ValueError(f"Unknown keyword argument(s): {kwargs}")

res = self.__class__(self._sum(self, dim=dim), self.model)

if drop_zeros:
res = res.densify_terms()
Expand Down Expand Up @@ -1379,10 +1391,10 @@ def solution(self):
return sol.rename("solution")

@classmethod
def _sum(cls, expr: "QuadraticExpression", dims=None) -> Dataset:
def _sum(cls, expr: "QuadraticExpression", dim=None) -> Dataset:
data = _expr_unwrap(expr)
dims = dims or list(set(data.dims) - set(HELPER_DIMS))
return LinearExpression._sum(expr, dims)
dim = dim or list(set(data.dims) - set(HELPER_DIMS))
return LinearExpression._sum(expr, dim)

def to_constraint(self, sign, rhs):
raise NotImplementedError(
Expand Down
17 changes: 14 additions & 3 deletions linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def mask_func(data):
)
return df

def sum(self, dims=None):
def sum(self, dim=None, **kwargs):
"""
Sum the variables over all or a subset of dimensions.
Expand All @@ -735,16 +735,27 @@ def sum(self, dims=None):
Parameters
----------
dims : str/list, optional
dim : str/list, optional
Dimension(s) to sum over. The default is None which results in all
dimensions.
dims : str/list, optional
Deprecated. Use ``dim`` instead.
Returns
-------
linopy.LinearExpression
Summed expression.
"""
return self.to_linexpr().sum(dims)
if dim is None and "dims" in kwargs:
dim = kwargs.pop("dims")
warn(
"The `dims` argument is deprecated. Use `dim` instead.",
DeprecationWarning,
)
if kwargs:
raise ValueError(f"Unknown keyword argument(s): {kwargs}")

return self.to_linexpr().sum(dim)

def diff(self, dim, n=1):
"""
Expand Down
10 changes: 10 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,16 @@ def test_linear_expression_sum_drop_zeros(z):
assert res.nterm == 2


def test_linear_expression_sum_warn_using_dims(z):
with pytest.warns(DeprecationWarning):
(1 * z).sum(dims="dim_0")


def test_linear_expression_sum_warn_unknown_kwargs(z):
with pytest.raises(ValueError):
(1 * z).sum(unknown_kwarg="dim_0")


def test_linear_expression_multiplication(x, y, z):
expr = 10 * x + y + z
mexpr = expr * 10
Expand Down
12 changes: 11 additions & 1 deletion test/test_quadratic_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_quadratic_expression_rsubtraction(x, y):
def test_quadratic_expression_sum(x, y):
expr = x * y + x + 5

summed_expr = expr.sum(dims="dim_0")
summed_expr = expr.sum(dim="dim_0")
assert isinstance(summed_expr, QuadraticExpression)
assert not summed_expr.coord_dims

Expand All @@ -158,6 +158,16 @@ def test_quadratic_expression_sum(x, y):
assert not summed_expr_all.coord_dims


def test_quadratic_expression_sum_warn_using_dims(x):
with pytest.warns(DeprecationWarning):
(x**2).sum(dims="dim_0")


def test_quadratic_expression_sum_warn_unknown_kwargs(x):
with pytest.raises(ValueError):
(x**2).sum(unknown_kwarg="dim_0")


def test_quadratic_expression_wrong_multiplication(x, y):
with pytest.raises(TypeError):
x * x * y
Expand Down
10 changes: 10 additions & 0 deletions test/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ def test_variable_sum(x):
assert res.nterm == 10


def test_variable_sum_warn_using_dims(x):
with pytest.warns(DeprecationWarning):
x.sum(dims="first")


def test_variable_sum_warn_unknown_kwargs(x):
with pytest.raises(ValueError):
x.sum(unknown_kwarg="first")


def test_variable_where(x):
x = x.where([True] * 4 + [False] * 6)
assert isinstance(x, linopy.variables.Variable)
Expand Down

0 comments on commit e69787b

Please sign in to comment.