From 85a4528b0b518ca76c0c9b74090ddcdbaa176b6b Mon Sep 17 00:00:00 2001 From: windweller Date: Sat, 18 Jan 2025 11:45:48 -0800 Subject: [PATCH] revert adding deepcopy override. Add test for nested deepcopy --- opto/trace/modules.py | 12 +------ tests/unit_tests/test_class_method.py | 52 ++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/opto/trace/modules.py b/opto/trace/modules.py index 36b0d912..ae7a8913 100644 --- a/opto/trace/modules.py +++ b/opto/trace/modules.py @@ -61,14 +61,4 @@ def _set(self, new_parameters): parameters_dict[k]._set(v) else: # if the parameter does not exist assert k not in self.__dict__ - setattr(self, k, v) - - def __deepcopy__(self, memo): - """ Custom deepcopy behavior for Module. """ - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if '__TRACE_RESERVED_' not in k: - setattr(result, k, copy.deepcopy(v, memo)) - return result \ No newline at end of file + setattr(self, k, v) \ No newline at end of file diff --git a/tests/unit_tests/test_class_method.py b/tests/unit_tests/test_class_method.py index 33748f10..2bf62600 100644 --- a/tests/unit_tests/test_class_method.py +++ b/tests/unit_tests/test_class_method.py @@ -1,5 +1,5 @@ from opto import trace -from copy import deepcopy +from copy import deepcopy, copy @trace.model class Model: @@ -98,8 +98,52 @@ def test_case_model_copy(): assert len(y1.parents) == 3 # since it's trainable assert len(y2.parents) == 3 -def printout_deecopy_modules(): - pass +def test_case_model_nested_copy(): + m1 = Model() + m3 = deepcopy(m1) + m2 = deepcopy(m3) + + # Make sure the parameters are different + try: + assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node + except AttributeError: + # These secrets attributes are not defined yet. They will only be defined after the bundled method is accessed. + pass + + # The hidden nodes are defined now + assert len(m1.parameters()) == 1 + assert len(m2.parameters()) == 1 + + # Make sure the parameters are different + assert m1.__TRACE_RESERVED_self_node is not m2.__TRACE_RESERVED_self_node # they are defined now + assert m1.parameters()[0] is not m2.parameters()[0] + + # check that the reserved node is the returned parameter + assert getattr(m1, '__TRACE_RESERVED_bundle_Model.forward').parameter is m1.parameters()[0] + assert getattr(m2, '__TRACE_RESERVED_bundle_Model.forward').parameter is m2.parameters()[0] + + # each instance has a version different from the class' version + assert m1.forward is not m2.forward + assert m1.forward is not Model.forward + assert m2.forward.parameter == Model.forward.parameter == m1.forward.parameter + + y1 = m1.forward(1) + y2 = m2.forward(2) + + from opto.trace.utils import contain + # self is not duplicated + assert contain(y1.parents, m1.__TRACE_RESERVED_self_node) + assert contain(y2.parents, m2.__TRACE_RESERVED_self_node) + # assert m1.__TRACE_RESERVED_self_node in y1.parents + # assert m1.__TRACE_RESERVED_self_node in y2.parents + assert contain(y1.parents, m1.forward.parameter) + assert contain(y2.parents, m2.forward.parameter) + + # assert m1.forward.parameter in y1.parents + # assert m1.forward.parameter in y2.parents + assert len(y1.parents) == 3 # since it's trainable + assert len(y2.parents) == 3 test_case_two_models() -test_case_model_copy() \ No newline at end of file +test_case_model_copy() +test_case_model_nested_copy() \ No newline at end of file