From 7c8fa4a7594fe0c333ce6180adecd50f6d1f43b9 Mon Sep 17 00:00:00 2001 From: jesszzzz Date: Thu, 16 Jan 2025 12:25:17 -0500 Subject: [PATCH] Fix bug when deep copying full config with missing parent (#3009) --- hydra/_internal/instantiate/_instantiate2.py | 25 +++++++++++++------- tests/instantiate/__init__.py | 14 +++++++++++ tests/instantiate/test_instantiate.py | 25 ++++++++++++++++++++ 3 files changed, 55 insertions(+), 9 deletions(-) diff --git a/hydra/_internal/instantiate/_instantiate2.py b/hydra/_internal/instantiate/_instantiate2.py index c5adb7e110..bc27918274 100644 --- a/hydra/_internal/instantiate/_instantiate2.py +++ b/hydra/_internal/instantiate/_instantiate2.py @@ -151,17 +151,24 @@ def _deep_copy_full_config(subconfig: Any) -> Any: return copy.deepcopy(subconfig) full_key = subconfig._get_full_key(None) - if full_key: - full_config_copy = copy.deepcopy(subconfig._get_root()) - if OmegaConf.is_list(subconfig._get_parent()): - # OmegaConf has a bug where _get_full_key doesn't add [] if the parent - # is a list, eg. instead of foo[0], it'll return foo0 - index = subconfig._key() - full_key = full_key[: -len(str(index))] + f"[{index}]" - return OmegaConf.select(full_config_copy, full_key) - else: + if not full_key: + return copy.deepcopy(subconfig) + + if OmegaConf.is_list(subconfig._get_parent()): + # OmegaConf has a bug where _get_full_key doesn't add [] if the parent + # is a list, eg. instead of foo[0], it'll return foo0 + index = subconfig._key() + full_key = full_key[: -len(str(index))] + f"[{index}]" + root = subconfig._get_root() + full_key = full_key.replace(root._get_full_key(None) or "", "", 1) + if OmegaConf.select(root, full_key) is not subconfig: + # The parent chain and full key are not consistent so don't + # try to copy the full config return copy.deepcopy(subconfig) + full_config_copy = copy.deepcopy(root) + return OmegaConf.select(full_config_copy, full_key) + def instantiate( config: Any, diff --git a/tests/instantiate/__init__.py b/tests/instantiate/__init__.py index 632b27696f..e4afaec733 100644 --- a/tests/instantiate/__init__.py +++ b/tests/instantiate/__init__.py @@ -8,6 +8,7 @@ from omegaconf import MISSING, DictConfig, ListConfig from hydra.types import TargetConf +from hydra.utils import instantiate from tests.instantiate.module_shadowed_by_function import a_function module_shadowed_by_function = a_function @@ -418,6 +419,19 @@ class NestedConf: b: Any = field(default_factory=lambda: User(name="b", age=2)) +class TargetWithInstantiateInInit: + def __init__( + self, user_config: Optional[DictConfig], user: Optional[User] = None + ) -> None: + if user: + self.user = user + else: + self.user = instantiate(user_config) + + def __eq__(self, other: Any) -> bool: + return self.user.__eq__(other.user) + + def recisinstance(got: Any, expected: Any) -> bool: """Compare got with expected type, recursively on dict and list.""" if not isinstance(got, type(expected)): diff --git a/tests/instantiate/test_instantiate.py b/tests/instantiate/test_instantiate.py index 10a67b1c14..6a89979074 100644 --- a/tests/instantiate/test_instantiate.py +++ b/tests/instantiate/test_instantiate.py @@ -43,6 +43,7 @@ SimpleClassNonPrimitiveConf, SimpleClassPrimitiveConf, SimpleDataClass, + TargetWithInstantiateInInit, Tree, TreeConf, UntypedPassthroughClass, @@ -571,6 +572,30 @@ def test_none_cases( OmegaConf.create({"unique_id": 5}), id="interpolation_from_parent_with_interpolation", ), + param( + DictConfig( + { + "username": "test_user", + "node": { + "_target_": "tests.instantiate.TargetWithInstantiateInInit", + "_recursive_": False, + "user_config": { + "_target_": "tests.instantiate.User", + "name": "${foo_b.username}", + "age": 40, + }, + }, + "foo_b": { + "username": "${username}", + }, + } + ), + {}, + TargetWithInstantiateInInit( + user_config=None, user=User(name="test_user", age=40) + ), + id="target_with_instantiate_in_init", + ), ], ) def test_interpolation_accessing_parent(