Skip to content

Commit

Permalink
fix #57
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianHofmann committed Nov 9, 2022
1 parent a7e6a29 commit 11f180d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 5 deletions.
25 changes: 25 additions & 0 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def __mul__(self, other):
"""
Multiply the expr by a factor.
"""
if isinstance(other, (LinearExpression, variables.Variable)):
raise TypeError(
"unsupported operand type(s) for *: "
f"{type(self)} and {type(other)}. "
"Non-linear expressions are not yet supported."
)
coeffs = other * self.coeffs
assert coeffs.shape == self.coeffs.shape
return LinearExpression(self.assign(coeffs=coeffs))
Expand All @@ -176,6 +182,18 @@ def __rmul__(self, other):
"""
return self.__mul__(other)

def __div__(self, other):
if isinstance(other, (LinearExpression, variables.Variable)):
raise TypeError(
"unsupported operand type(s) for /: "
f"{type(self)} and {type(other)}"
"Non-linear expressions are not yet supported."
)
return self.__mul__(1 / other)

def __truediv__(self, other):
return self.__div__(other)

def __le__(self, rhs):
return constraints.AnonymousConstraint(self, "<=", rhs)

Expand Down Expand Up @@ -672,8 +690,15 @@ def __rmul__(self, other):
return self.__mul__(other)

def __div__(self, other):
if not isinstance(other, (int, np.integer, float)):
raise TypeError(
"unsupported operand type(s) for /: " f"{type(self)} and {type(other)}"
)
return self.__mul__(1 / other)

def __truediv__(self, other):
return self.__div__(other)

def __le__(self, other):
if not isinstance(other, (int, np.integer, float)):
raise TypeError(
Expand Down
39 changes: 34 additions & 5 deletions linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def to_linexpr(self, coefficient=1):
"""
Create a linear exprssion from the variables.
"""
if isinstance(coefficient, (expressions.LinearExpression, Variable)):
raise TypeError(f"unsupported type of coefficient: {type(coefficient)}")
return expressions.LinearExpression.from_tuples((coefficient, self))

def __repr__(self):
Expand Down Expand Up @@ -170,17 +172,41 @@ def __neg__(self):
"""
return self.to_linexpr(-1)

def __mul__(self, coefficient):
def __mul__(self, other):
"""
Multiply variables with a coefficient.
"""
return self.to_linexpr(coefficient)
if isinstance(other, (expressions.LinearExpression, Variable)):
raise TypeError(
"unsupported operand type(s) for *: "
f"{type(self)} and {type(other)}. "
"Non-linear expressions are not yet supported."
)
return self.to_linexpr(other)

def __rmul__(self, coefficient):
def __rmul__(self, other):
"""
Right-multiply variables with a coefficient.
"""
return self.to_linexpr(coefficient)
return self.to_linexpr(other)

def __div__(self, other):
"""
Divide variables with a coefficient.
"""
if isinstance(other, (expressions.LinearExpression, Variable)):
raise TypeError(
"unsupported operand type(s) for /: "
f"{type(self)} and {type(other)}. "
"Non-linear expressions are not yet supported."
)
return self.to_linexpr(1 / other)

def __truediv__(self, coefficient):
"""
True divide variables with a coefficient.
"""
return self.__div__(coefficient)

def __add__(self, other):
"""
Expand Down Expand Up @@ -635,7 +661,7 @@ class ScalarVariable:
coords: dict = None

def to_scalar_linexpr(self, coeff=1):
if not isinstance(coeff, (int, float)):
if not isinstance(coeff, (int, np.integer, float)):
raise TypeError(f"Coefficient must be a numeric value, got {type(coeff)}.")
return expressions.ScalarLinearExpression((coeff,), (self.label,))

Expand All @@ -660,6 +686,9 @@ def __rmul__(self, coeff):
def __div__(self, coeff):
return self.to_scalar_linexpr(1 / coeff)

def __truediv__(self, coeff):
return self.to_scalar_linexpr(1 / coeff)

def __le__(self, other):
return self.to_scalar_linexpr(1).__le__(other)

Expand Down
32 changes: 32 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def test_variable_to_linexpr():
expr = x * 1
assert isinstance(expr, LinearExpression)

expr = x / 1
assert isinstance(expr, LinearExpression)

expr = x / 1.0
assert isinstance(expr, LinearExpression)

expr = 10 * x + y
assert isinstance(expr, LinearExpression)
assert_equal(expr, m.linexpr((10, "x"), (1, "y")))
Expand Down Expand Up @@ -94,6 +100,18 @@ def test_variable_to_linexpr():
with pytest.raises(TypeError):
x - 10

with pytest.raises(TypeError):
x * x

with pytest.raises(TypeError):
x / x

with pytest.raises(TypeError):
x * (1 * x)

with pytest.raises(TypeError):
x / (1 * x)


def test_variable_to_linexpr_with_array():
expr = np.arange(20) * v
Expand Down Expand Up @@ -261,6 +279,20 @@ def test_mul():
mexpr = 10 * expr
assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 100).item()

mexpr = expr / 100
assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 1 / 10).item()

mexpr = expr / 100.0
assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 1 / 10).item()

with pytest.raises(TypeError):
expr = 10 * x + y + z
expr * x

with pytest.raises(TypeError):
expr = 10 * x + y + z
expr / x


def test_sanitize():
expr = 10 * x + y + z
Expand Down

0 comments on commit 11f180d

Please sign in to comment.