From 0b80962246d302ac4257820791578b7757d947c3 Mon Sep 17 00:00:00 2001 From: eirannejad Date: Tue, 19 Mar 2024 11:41:36 -0700 Subject: [PATCH 1/2] Fixed calling str and repr overrides on base classes --- src/runtime/TypeManager.cs | 13 ++++++++++++- src/runtime/Types/ReflectedClrType.cs | 28 +++++++++++++++------------ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/runtime/TypeManager.cs b/src/runtime/TypeManager.cs index 45133b3a7..d224fe913 100644 --- a/src/runtime/TypeManager.cs +++ b/src/runtime/TypeManager.cs @@ -294,7 +294,7 @@ static void InitializeCoreFields(PyType type) Util.WriteIntPtr(type, TypeOffset.tp_itemsize, IntPtr.Zero); } - internal static void InitializeClass(PyType type, ClassBase impl, Type clrType) + internal static void InitializeClass(PyType type, ClassBase impl, Type clrType, bool reflected = false) { // we want to do this after the slot stuff above in case the class itself implements a slot method SlotsHolder slotsHolder = CreateSlotsHolder(type); @@ -307,6 +307,17 @@ internal static void InitializeClass(PyType type, ClassBase impl, Type clrType) // that the type of the new type must PyType_Type at the time we // call this, else PyType_Ready will skip some slot initialization. + // NOTE: + // if this is a reflected type, lets keep these slots empty so + // python copies the parent class (also dotnet) slots into this + // type. ReflectedCLRType handles overriding these slots if + // type has override methods e.g. __str__ or __repr__ overrides + if (reflected) + { + Util.WriteIntPtr(type, TypeOffset.tp_str, IntPtr.Zero); + Util.WriteIntPtr(type, TypeOffset.tp_repr, IntPtr.Zero); + } + if (!type.IsReady && Runtime.PyType_Ready(type) != 0) { throw PythonException.ThrowLastAsClrException(); diff --git a/src/runtime/Types/ReflectedClrType.cs b/src/runtime/Types/ReflectedClrType.cs index 984f68f88..61db3633d 100644 --- a/src/runtime/Types/ReflectedClrType.cs +++ b/src/runtime/Types/ReflectedClrType.cs @@ -24,7 +24,7 @@ internal ReflectedClrType(BorrowedReference original) : base(original) { } /// /// Returned might be partially initialized. /// - public static ReflectedClrType GetOrCreate(Type type) + public static ReflectedClrType GetOrCreate(Type type, bool reflected = false) { if (ClassManager.cache.TryGetValue(type, out ReflectedClrType pyType)) { @@ -41,7 +41,7 @@ public static ReflectedClrType GetOrCreate(Type type) // Now we force initialize the Python type object to reflect the given // managed type, filling the Python type slots with thunks that // point to the managed methods providing the implementation. - TypeManager.InitializeClass(pyType, impl, type); + TypeManager.InitializeClass(pyType, impl, type, reflected); return pyType; } @@ -80,7 +80,7 @@ internal static NewReference CreateSubclass(ClassBase baseType, IEnumerable Date: Tue, 19 Mar 2024 11:41:41 -0700 Subject: [PATCH 2/2] updated tests --- tests/test_method_getattr.py | 78 +++++++++++---- tests/test_method_getattr_self.py | 29 +++--- tests/test_method_getattribute.py | 78 +++++++++++---- tests/test_method_getitem.py | 58 +++++++++++ ...der.py => test_method_getitem_len_iter.py} | 0 tests/test_method_repr.py | 98 +++++++++++-------- tests/test_method_str.py | 96 +++++++++++------- 7 files changed, 301 insertions(+), 136 deletions(-) create mode 100644 tests/test_method_getitem.py rename tests/{test_method_dunder.py => test_method_getitem_len_iter.py} (100%) diff --git a/tests/test_method_getattr.py b/tests/test_method_getattr.py index fc2ca42a2..ede3d5163 100644 --- a/tests/test_method_getattr.py +++ b/tests/test_method_getattr.py @@ -3,32 +3,27 @@ IBASE = System.Collections.IEnumerable -class S1(BASE): - def __getattr__(self, name): - return super().__getattr__(name) - - def __getattr__(self, name): - if name == 'answer1': - return 1 - return super().__getattr__(name) - - -class S2(S1): - def __getattr__(self, name): - if name == 'answer2': - return 2 - return super().__getattr__(name) +def test_method_getattr(): + class S1(BASE): + def __getattr__(self, name): + if name == 'answer1': + return 1 + return super().__getattr__(name) -class S3(S2): - def __getattr__(self, name): - if name == 'answer3': - return 3 - return super().__getattr__(name) + class S2(S1): + def __getattr__(self, name): + if name == 'answer2': + return 2 + return super().__getattr__(name) + class S3(S2): + def __getattr__(self, name): + if name == 'answer3': + return 3 + return super().__getattr__(name) -def test_method_getattr(): s1 = S1() assert s1.answer1 == 1 try: @@ -44,3 +39,44 @@ def test_method_getattr(): assert s2.answer1 == 1 assert s2.answer2 == 2 assert s3.answer3 == 3 + + +def test_method_getattr_missing_middle(): + class S1(BASE): + def __getattr__(self, name): + if name == 'answer1': + return 1 + return super().__getattr__(name) + + + class S2(S1): + pass + + + class S3(S2): + def __getattr__(self, name): + if name == 'answer3': + return 3 + return super().__getattr__(name) + + s1 = S1() + assert s1.answer1 == 1 + try: + assert s1.answer2 == 2 + except AttributeError: + pass + + s2 = S2() + assert s2.answer1 == 1 + try: + assert s2.answer2 == 2 + except AttributeError: + pass + + s3 = S3() + assert s3.answer1 == 1 + try: + assert s3.answer2 == 2 + except AttributeError: + pass + assert s3.answer3 == 3 diff --git a/tests/test_method_getattr_self.py b/tests/test_method_getattr_self.py index 6aa07cfd9..bcaf5b400 100644 --- a/tests/test_method_getattr_self.py +++ b/tests/test_method_getattr_self.py @@ -13,18 +13,14 @@ def __getattr__(self, key): return super().__getattr__() -class T2(System.Object): +class T2(T1): def __init__(self, value) -> None: - self.answer = value - - def __getattribute__(self, __name: str) -> Any: - if __name == 'answer': - return 43 - return super().__getattribute__(__name) + super().__init__(value) + self.answer1 = value + 1 def __getattr__(self, key): - if key == 'answer': - return self.answer + if key == 'answer1': + return self.answer1 return super().__getattr__() @@ -37,15 +33,12 @@ def __getattr__(self, key): return super().__getattr__() -class T4(System.Object): - def __getattribute__(self, __name: str) -> Any: - if __name == 'answer': - return 45 - return super().__getattribute__(__name) +class T4(T3): + answer1 = 45 def __getattr__(self, key): - if key == 'answer': - return self.answer + if key == 'answer1': + return self.answer1 return super().__getattr__() @@ -54,10 +47,10 @@ def test_method_getattr_self(): assert t1.answer == 42 t2 = T2(42) - assert t2.answer == 43 + assert t2.answer1 == 43 t3 = T3() assert t3.answer == 44 t4 = T4() - assert t4.answer == 45 + assert t4.answer1 == 45 diff --git a/tests/test_method_getattribute.py b/tests/test_method_getattribute.py index 76d537b26..58e755172 100644 --- a/tests/test_method_getattribute.py +++ b/tests/test_method_getattribute.py @@ -3,32 +3,28 @@ IBASE = System.Collections.IEnumerable -class S1(BASE): - def __getattribute__(self, name): - return super().__getattribute__(name) - - def __getattr__(self, name): - if name == 'answer1': - return 1 - return super().__getattr__(name) - +def test_method_getattribute(): + class S1(BASE): + def __getattribute__(self, name): + if name == 'answer1': + return 1 + return super().__getattribute__(name) -class S2(S1): - def __getattribute__(self, name): - if name == 'answer2': - return 2 - return super().__getattribute__(name) + class S2(S1): + def __getattribute__(self, name): + if name == 'answer2': + return 2 + return super().__getattribute__(name) -class S3(S2): - def __getattr__(self, name): - if name == 'answer3': - return 3 - return super().__getattr__(name) + class S3(S2): + def __getattribute__(self, name): + if name == 'answer3': + return 3 + return super().__getattribute__(name) -def test_method_getattribute(): s1 = S1() assert s1.answer1 == 1 try: @@ -44,3 +40,45 @@ def test_method_getattribute(): assert s2.answer1 == 1 assert s2.answer2 == 2 assert s3.answer3 == 3 + + +def test_method_getattribute_missing_middle(): + class S1(BASE): + def __getattribute__(self, name): + if name == 'answer1': + return 1 + return super().__getattribute__(name) + + + class S2(S1): + pass + + + class S3(S2): + def __getattribute__(self, name): + if name == 'answer3': + return 3 + return super().__getattribute__(name) + + + s1 = S1() + assert s1.answer1 == 1 + try: + assert s1.answer2 == 2 + except AttributeError: + pass + + s2 = S2() + assert s2.answer1 == 1 + try: + assert s2.answer2 == 2 + except AttributeError: + pass + + s3 = S3() + assert s2.answer1 == 1 + try: + assert s3.answer2 == 2 + except AttributeError: + pass + assert s3.answer3 == 3 diff --git a/tests/test_method_getitem.py b/tests/test_method_getitem.py new file mode 100644 index 000000000..7419abc5e --- /dev/null +++ b/tests/test_method_getitem.py @@ -0,0 +1,58 @@ +BASE = object + +try: + import System + BASE = System.Object +except: + pass + + +class T1(BASE): + def __getitem__(self, key): + if key == 'answer1': + return 1 + return super().__getitem__(key) + + +class T2(T1): + def __getitem__(self, key): + if key == 'answer2': + return 2 + return super().__getitem__(key) + + +class T3(T2): + pass + + +class T4(T3): + def __getitem__(self, key): + if key == 'answer4': + return 4 + return super().__getitem__(key) + + + +def test_method_getattr_self(): + t1 = T1() + assert t1['answer1'] == 1 + + t2 = T2() + assert t2['answer2'] == 2 + + t3 = T3() + assert t3['answer2'] == 2 + + t4 = T4() + assert t4['answer4'] == 4 + + assert t4['answer2'] == 2 + + try: + assert t4['answer3'] == 3 + except AttributeError: + pass + + +if __name__ == '__main__': + test_method_getattr_self() diff --git a/tests/test_method_dunder.py b/tests/test_method_getitem_len_iter.py similarity index 100% rename from tests/test_method_dunder.py rename to tests/test_method_getitem_len_iter.py diff --git a/tests/test_method_repr.py b/tests/test_method_repr.py index 6c3da3a17..9d931900f 100644 --- a/tests/test_method_repr.py +++ b/tests/test_method_repr.py @@ -3,20 +3,18 @@ IBASE = System.Collections.IEnumerable -# =========================================================================== -class S1(BASE): - def __repr__(self): - s = super().__repr__() - return f"S1 {s}" +def test_repr_overload(): + class S1(BASE): + def __repr__(self): + s = super().__repr__() + return f"S1 {s}" -class S2(S1): - def __repr__(self): - s = super().__repr__() - return f"S2 {s}" + class S2(S1): + def __repr__(self): + s = super().__repr__() + return f"S2 {s}" - -def test_str_overload(): s1 = S1() assert repr(s1).startswith('S1