Skip to content

Commit

Permalink
Changed torch to torch_expression.
Browse files Browse the repository at this point in the history
  • Loading branch information
AmitSolomonPrinceton committed Sep 6, 2024
1 parent 554a9a6 commit febf20a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 45 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exp = x-y+2*z
tch_x = torch.arange(1, n+1)
tch_y = torch.arange(0, n)

tch_exp = TorchExpression(exp).tch_exp #tch_exp implements x-y+2*z, where x and y are torch.Tensor.
tch_exp = TorchExpression(exp).torch_expression #tch_exp implements x-y+2*z, where x and y are torch.Tensor.
tch_res = tch_exp(tch_x, tch_y) #Contains a torch.Tensor [7.0]*n
```

Expand Down
24 changes: 12 additions & 12 deletions cvxtorch/torch_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,21 @@ class TorchExpression():
this dtype.
"""
@property
def torch(self):
return self._torch
def torch_expression(self):
return self._torch_expression

@property
def variables_dictionary(self):
return self._variables_dictionary

def __init__(self, expr: Expression | Constraint, provided_vars_list:list = [],
implemented_only: bool=True, dtype: torch.dtype = torch.float64):
implemented_only: bool=True, dtype: torch_expression.dtype = torch_expression.float64):
self.implemented_only = implemented_only
self._torch, self._variables_dictionary = self._gen_torch_exp(expr=expr,
self._torch_expression, self._variables_dictionary = self._gen_torch_exp(expr=expr,
provided_vars_list=provided_vars_list, dtype=dtype)

def _gen_torch_exp(self, expr, provided_vars_list: list = [],
dtype: torch.dtype = torch.float64) -> tuple[callable, VariablesDict]:
dtype: torch_expression.dtype = torch_expression.float64) -> tuple[callable, VariablesDict]:
"""
This is a helper function selects the correct gen_torch_exp based on the type of expr.
Expand All @@ -124,7 +124,7 @@ def _gen_torch_exp(self, expr, provided_vars_list: list = [],
raise ValueError(f"Unsupported expression type: {type(expr)}.")

def _gen_torch_exp_expr(self, expr: Expression, provided_vars_list: list = [],
dtype: torch.dtype = torch.float64) -> tuple[callable, VariablesDict]:
dtype: torch_expression.dtype = torch_expression.float64) -> tuple[callable, VariablesDict]:
"""
This is a helper function that generates a torch expression for an Expression.
Expand Down Expand Up @@ -287,45 +287,45 @@ def inner(self, expr, provided_vars_list: list = [], dtype: torch.dtype = torch

@_gen_torch_exp_dec
def _gen_torch_exp_leaf(self, expr: Leaf, provided_vars_list: list = [],
dtype: torch.dtype = torch.float64) -> tuple[callable, VariablesDict]:
dtype: torch_expression.dtype = torch_expression.float64) -> tuple[callable, VariablesDict]:
"""
This is a helper function that generates a torch expression for a leaf.
"""
return AddExpression([expr]) #This is an easy way to convert a leaf into an expression.

@_gen_torch_exp_dec
def _gen_torch_exp_constraint(self, expr: Constraint, provided_vars_list: list = [],
dtype: torch.dtype = torch.float64) -> tuple[callable, VariablesDict]:
dtype: torch_expression.dtype = torch_expression.float64) -> tuple[callable, VariablesDict]:
""" This function generates a torch expression (args[0]-args[1]).
The order of the arguments is as it appears in args[0]-args[1] (from left to right)
"""
return expr.args[0]-expr.args[1]

@_gen_torch_exp_dec
def _gen_torch_exp_nonpos(self, expr: NonPos, provided_vars_list: list = [],
dtype: torch.dtype = torch.float64) -> tuple[callable, VariablesDict]:
dtype: torch_expression.dtype = torch_expression.float64) -> tuple[callable, VariablesDict]:
"""
This is a helper function that generates a torch expression for a NonPos constraint.
"""
return expr.args[0]<=0

@_gen_torch_exp_dec
def _gen_torch_exp_nonneg(self, expr: NonNeg, provided_vars_list: list = [],
dtype: torch.dtype = torch.float64) -> tuple[callable, VariablesDict]:
dtype: torch_expression.dtype = torch_expression.float64) -> tuple[callable, VariablesDict]:
"""
This is a helper function that generates a torch expression for a NonNeg constraint.
"""
return expr.args[0]>=0

@_gen_torch_exp_dec
def _gen_torch_exp_zero(self, expr: Zero, provided_vars_list: list = [],
dtype: torch.dtype = torch.float64) -> tuple[callable, VariablesDict]:
dtype: torch_expression.dtype = torch_expression.float64) -> tuple[callable, VariablesDict]:
"""
This is a helper function that generates a torch expression for a Zero constraint.
"""
return expr.args[0]==0

def apply_torch_numeric(self, expr: Expression, values: list[torch.Tensor]) -> torch.Tensor:
def apply_torch_numeric(self, expr: Expression, values: list[torch_expression.Tensor]) -> torch_expression.Tensor:
"""
This function returns self.torch_numeric(values) if it exists,
and self.numeric(values) otherwise.
Expand Down
64 changes: 32 additions & 32 deletions tests/test_gen_torch_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def test_exp(self):
exp5 = w-x
exp6 = X@Y.T

torch_exp1 = TorchExpression(exp1).torch
torch_exp2 = TorchExpression(exp2).torch
torch_exp3 = TorchExpression(exp3).torch
torch_exp4 = TorchExpression(exp4).torch
torch_exp5 = TorchExpression(exp5).torch
torch_exp6 = TorchExpression(exp6).torch
torch_exp1 = TorchExpression(exp1).torch_expression
torch_exp2 = TorchExpression(exp2).torch_expression
torch_exp3 = TorchExpression(exp3).torch_expression
torch_exp4 = TorchExpression(exp4).torch_expression
torch_exp5 = TorchExpression(exp5).torch_expression
torch_exp6 = TorchExpression(exp6).torch_expression

test1 = torch_exp1(5*torch.ones(n), torch.tensor([1.,2.,3.]))
test2 = torch_exp2(1*torch.ones(n), torch.tensor([1.,2.,3.]))
Expand Down Expand Up @@ -82,13 +82,13 @@ def test_constraint(self) -> None:
constraint7 = X@Y.T <= 0
X@Y.T

exp1 = TorchExpression(constraint1).torch
exp2 = TorchExpression(constraint2).torch
exp3 = TorchExpression(constraint3).torch
exp4 = TorchExpression(constraint4).torch
exp5 = TorchExpression(constraint5).torch
exp6 = TorchExpression(constraint6).torch
exp7 = TorchExpression(constraint7).torch
exp1 = TorchExpression(constraint1).torch_expression
exp2 = TorchExpression(constraint2).torch_expression
exp3 = TorchExpression(constraint3).torch_expression
exp4 = TorchExpression(constraint4).torch_expression
exp5 = TorchExpression(constraint5).torch_expression
exp6 = TorchExpression(constraint6).torch_expression
exp7 = TorchExpression(constraint7).torch_expression

x_test = torch.tensor([1,2,3], dtype=float)
z_test = torch.zeros(m, dtype=float)
Expand Down Expand Up @@ -145,18 +145,18 @@ def test_gen_torch_exp(self):
exp10 = self.c
exp11 = self.x+2*self.w+3*self.c

torch_exp1 = TorchExpression(exp1).torch
torch_exp2 = TorchExpression(exp2).torch
torch_exp3 = TorchExpression(exp3).torch
torch_exp4 = TorchExpression(exp4).torch
torch_exp5 = TorchExpression(exp5).torch
torch_exp6 = TorchExpression(exp6).torch
torch_exp7 = TorchExpression(exp7).torch
torch_exp8 = TorchExpression(exp8).torch
torch_exp9 = TorchExpression(exp9).torch
torch_exp10 = TorchExpression(exp10).torch
torch_exp11_unordered = TorchExpression(exp11).torch
torch_exp11 = TorchExpression(exp11, provided_vars_list=[self.w, self.x]).torch
torch_exp1 = TorchExpression(exp1).torch_expression
torch_exp2 = TorchExpression(exp2).torch_expression
torch_exp3 = TorchExpression(exp3).torch_expression
torch_exp4 = TorchExpression(exp4).torch_expression
torch_exp5 = TorchExpression(exp5).torch_expression
torch_exp6 = TorchExpression(exp6).torch_expression
torch_exp7 = TorchExpression(exp7).torch_expression
torch_exp8 = TorchExpression(exp8).torch_expression
torch_exp9 = TorchExpression(exp9).torch_expression
torch_exp10 = TorchExpression(exp10).torch_expression
torch_exp11_unordered = TorchExpression(exp11).torch_expression
torch_exp11 = TorchExpression(exp11, provided_vars_list=[self.w, self.x]).torch_expression

test1 = torch_exp1(5*torch.ones(self.n, dtype=torch.float64),
torch.tensor([1.,2.,3.], dtype=torch.float64))
Expand Down Expand Up @@ -202,8 +202,8 @@ def setUp(self) -> None:
self.t2_exp = (self.a*self.t1+self.b*self.t2+self.c.value).float()

def test_nonpos(self) -> None:
tch_exp1 = TorchExpression(NonPos(self.exp1)).torch
tch_exp2 = TorchExpression(NonPos(self.exp2), dtype=torch.float32).torch
tch_exp1 = TorchExpression(NonPos(self.exp1)).torch_expression
tch_exp2 = TorchExpression(NonPos(self.exp2), dtype=torch.float32).torch_expression

test1 = tch_exp1(self.t1)
test2 = tch_exp2(self.t1, self.t2)
Expand All @@ -212,8 +212,8 @@ def test_nonpos(self) -> None:
torch.testing.assert_close(test2, self.t2_exp)

def test_nonneg(self) -> None:
tch_exp1 = TorchExpression(NonNeg(self.exp1)).torch
tch_exp2 = TorchExpression(NonNeg(self.exp2), dtype=torch.float32).torch
tch_exp1 = TorchExpression(NonNeg(self.exp1)).torch_expression
tch_exp2 = TorchExpression(NonNeg(self.exp2), dtype=torch.float32).torch_expression

test1 = tch_exp1(self.t1)
test2 = tch_exp2(self.t1, self.t2)
Expand All @@ -222,8 +222,8 @@ def test_nonneg(self) -> None:
torch.testing.assert_close(test2, -self.t2_exp)

def test_zero(self) -> None:
tch_exp1 = TorchExpression(Zero(self.exp1)).torch
tch_exp2 = TorchExpression(Zero(self.exp2), dtype=torch.float32).torch
tch_exp1 = TorchExpression(Zero(self.exp1)).torch_expression
tch_exp2 = TorchExpression(Zero(self.exp2), dtype=torch.float32).torch_expression

test1 = tch_exp1(self.t1)
test2 = tch_exp2(self.t1, self.t2)
Expand All @@ -239,7 +239,7 @@ def setUp(self) -> None:
def test_dtypes(self):
for dtype in [torch.float64, torch.float32, torch.int64, torch.int32, torch.int16,\
torch.int8]:
tch_exp = TorchExpression(self.c, dtype=dtype).torch
tch_exp = TorchExpression(self.c, dtype=dtype).torch_expression
test = tch_exp()
self.assertTrue(torch.all(test==torch.Tensor([self.n])).all())
self.assertTrue(test.dtype==dtype)

0 comments on commit febf20a

Please sign in to comment.