Skip to content

Commit

Permalink
revert adding deepcopy override. Add test for nested deepcopy
Browse files Browse the repository at this point in the history
  • Loading branch information
allenanie committed Jan 18, 2025
1 parent 9e2db27 commit 85a4528
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 15 deletions.
12 changes: 1 addition & 11 deletions opto/trace/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,4 @@ def _set(self, new_parameters):
parameters_dict[k]._set(v)
else: # if the parameter does not exist
assert k not in self.__dict__
setattr(self, k, v)

def __deepcopy__(self, memo):
""" Custom deepcopy behavior for Module. """
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if '__TRACE_RESERVED_' not in k:
setattr(result, k, copy.deepcopy(v, memo))
return result
setattr(self, k, v)
52 changes: 48 additions & 4 deletions tests/unit_tests/test_class_method.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from opto import trace
from copy import deepcopy
from copy import deepcopy, copy

@trace.model
class Model:
Expand Down Expand Up @@ -98,8 +98,52 @@ def test_case_model_copy():
assert len(y1.parents) == 3 # since it's trainable
assert len(y2.parents) == 3

def printout_deecopy_modules():
pass
def test_case_model_nested_copy():
m1 = Model()
m3 = deepcopy(m1)
m2 = deepcopy(m3)

# Make sure the parameters are different
try:
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node
except AttributeError:
# These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed.
pass

# The hidden nodes are defined now
assert len(m1.parameters()) == 1
assert len(m2.parameters()) == 1

# Make sure the parameters are different
assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now
assert m1.parameters()[0] is not m2.parameters()[0]

# check that the reserved node is the returned parameter
assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0]
assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0]

# each instance has a version different from the class' version
assert m1.forward is not m2.forward
assert m1.forward is not Model.forward
assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter

y1 = m1.forward(1)
y2 = m2.forward(2)

from opto.trace.utils import contain
# self is not duplicated
assert contain(y1.parents, m1.__TRACE_RESERVED_self_node)
assert contain(y2.parents, m2.__TRACE_RESERVED_self_node)
# assert m1.__TRACE_RESERVED_self_node in y1.parents
# assert m1.__TRACE_RESERVED_self_node in y2.parents
assert contain(y1.parents, m1.forward.parameter)
assert contain(y2.parents, m2.forward.parameter)

# assert m1.forward.parameter in y1.parents
# assert m1.forward.parameter in y2.parents
assert len(y1.parents) == 3 # since it's trainable
assert len(y2.parents) == 3

test_case_two_models()
test_case_model_copy()
test_case_model_copy()
test_case_model_nested_copy()

0 comments on commit 85a4528

Please sign in to comment.