Skip to content

Commit

Permalink
Merge branch 'master-rhino-8.x' into rhino-8.x
Browse files Browse the repository at this point in the history
  • Loading branch information
eirannejad committed Mar 19, 2024
2 parents ad76cee + b48dbc4 commit f4a0d92
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 149 deletions.
13 changes: 12 additions & 1 deletion src/runtime/TypeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ static void InitializeCoreFields(PyType type)
Util.WriteIntPtr(type, TypeOffset.tp_itemsize, IntPtr.Zero);
}

internal static void InitializeClass(PyType type, ClassBase impl, Type clrType)
internal static void InitializeClass(PyType type, ClassBase impl, Type clrType, bool reflected = false)
{
// we want to do this after the slot stuff above in case the class itself implements a slot method
SlotsHolder slotsHolder = CreateSlotsHolder(type);
Expand All @@ -307,6 +307,17 @@ internal static void InitializeClass(PyType type, ClassBase impl, Type clrType)
// that the type of the new type must PyType_Type at the time we
// call this, else PyType_Ready will skip some slot initialization.

// NOTE:
// if this is a reflected type, lets keep these slots empty so
// python copies the parent class (also dotnet) slots into this
// type. ReflectedCLRType handles overriding these slots if
// type has override methods e.g. __str__ or __repr__ overrides
if (reflected)
{
Util.WriteIntPtr(type, TypeOffset.tp_str, IntPtr.Zero);
Util.WriteIntPtr(type, TypeOffset.tp_repr, IntPtr.Zero);
}

if (!type.IsReady && Runtime.PyType_Ready(type) != 0)
{
throw PythonException.ThrowLastAsClrException();
Expand Down
28 changes: 16 additions & 12 deletions src/runtime/Types/ReflectedClrType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ internal ReflectedClrType(BorrowedReference original) : base(original) { }
/// <remarks>
/// Returned <see cref="ReflectedClrType"/> might be partially initialized.
/// </remarks>
public static ReflectedClrType GetOrCreate(Type type)
public static ReflectedClrType GetOrCreate(Type type, bool reflected = false)
{
if (ClassManager.cache.TryGetValue(type, out ReflectedClrType pyType))
{
Expand All @@ -41,7 +41,7 @@ public static ReflectedClrType GetOrCreate(Type type)
// Now we force initialize the Python type object to reflect the given
// managed type, filling the Python type slots with thunks that
// point to the managed methods providing the implementation.
TypeManager.InitializeClass(pyType, impl, type);
TypeManager.InitializeClass(pyType, impl, type, reflected);

return pyType;
}
Expand Down Expand Up @@ -80,7 +80,7 @@ internal static NewReference CreateSubclass(ClassBase baseType, IEnumerable<Type
assembly);

ClassManager.cache.Remove(subType);
ReflectedClrType pyTypeObj = GetOrCreate(subType);
ReflectedClrType pyTypeObj = GetOrCreate(subType, reflected: true);

// by default the class dict will have all the C# methods in it, but as this is a
// derived class we want the python overrides in there instead if they exist.
Expand All @@ -99,6 +99,8 @@ internal static NewReference CreateSubclass(ClassBase baseType, IEnumerable<Type
var tp_getattro_default = typeof(ReflectedClrType).GetMethod(nameof(ReflectedClrType.tp_getattro), tbFlags);
Util.WriteIntPtr(pyTypeObj, TypeOffset.tp_getattro, Interop.GetThunk(tp_getattro_default).Address);

// NOTE:
// lets include wrappers that would call python override methods
using var clsDict = new PyDict(dict);
using var keys = clsDict.Keys();
foreach (PyObject pyKey in keys)
Expand Down Expand Up @@ -127,12 +129,14 @@ internal static NewReference CreateSubclass(ClassBase baseType, IEnumerable<Type
Util.WriteIntPtr(pyTypeObj, TypeOffset.tp_iter, Interop.GetThunk(tp_iter).Address);
}

// include a default tp_str for reflected type so str() can call the overridden __str__
if (keyStr.StartsWith(nameof(PyIdentifier.__str__)))
{
var tp_str = typeof(ReflectedClrType).GetMethod(nameof(ReflectedClrType.tp_str), tbFlags);
Util.WriteIntPtr(pyTypeObj, TypeOffset.tp_str, Interop.GetThunk(tp_str).Address);
}

// include a default tp_repr for reflected type so repr() can call the overridden __repr__
if (keyStr.StartsWith(nameof(PyIdentifier.__repr__)))
{
var tp_repr = typeof(ReflectedClrType).GetMethod(nameof(ReflectedClrType.tp_repr), tbFlags);
Expand Down Expand Up @@ -269,8 +273,8 @@ public static NewReference tp_getattro(BorrowedReference ob, BorrowedReference k

using var objPyRepr = Runtime.PyObject_Repr(ob);
using var keyPyRepr = Runtime.PyObject_Repr(key);
string objRepr = Runtime.GetManagedString(objPyRepr.Borrow());
string keyRepr = Runtime.GetManagedString(keyPyRepr.Borrow());
string? objRepr = Runtime.GetManagedString(objPyRepr.Borrow());
string? keyRepr = Runtime.GetManagedString(keyPyRepr.Borrow());
Exceptions.SetError(Exceptions.AttributeError, $"'{objRepr}' object has no attribute '{keyRepr}'");
}

Expand All @@ -292,9 +296,9 @@ static bool TryCallBoundMethod0(BorrowedReference ob, string keyName, out NewRef
return false;
}

using var method = Runtime.PyObject_GenericGetAttr(ob, getAttrKey);
using var args = Runtime.PyTuple_New(0);
result = Runtime.PyObject_Call(method.Borrow(), args.Borrow(), null);
using var args = Runtime.PyTuple_New(1);
Runtime.PyTuple_SetItem(args.Borrow(), 0, ob);
result = Runtime.PyObject_Call(getattr, args.Borrow(), null);
return true;
}

Expand All @@ -310,10 +314,10 @@ static bool TryCallBoundMethod1(BorrowedReference ob, BorrowedReference key, str
return false;
}

using var method = Runtime.PyObject_GenericGetAttr(ob, getAttrKey);
using var args = Runtime.PyTuple_New(1);
Runtime.PyTuple_SetItem(args.Borrow(), 0, key);
result = Runtime.PyObject_Call(method.Borrow(), args.Borrow(), null);
using var args = Runtime.PyTuple_New(2);
Runtime.PyTuple_SetItem(args.Borrow(), 0, ob);
Runtime.PyTuple_SetItem(args.Borrow(), 1, key);
result = Runtime.PyObject_Call(getattr, args.Borrow(), null);
return true;
}

Expand Down
78 changes: 57 additions & 21 deletions tests/test_method_getattr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,27 @@
IBASE = System.Collections.IEnumerable


class S1(BASE):
def __getattr__(self, name):
return super().__getattr__(name)

def __getattr__(self, name):
if name == 'answer1':
return 1
return super().__getattr__(name)


class S2(S1):
def __getattr__(self, name):
if name == 'answer2':
return 2
return super().__getattr__(name)
def test_method_getattr():
class S1(BASE):
def __getattr__(self, name):
if name == 'answer1':
return 1
return super().__getattr__(name)


class S3(S2):
def __getattr__(self, name):
if name == 'answer3':
return 3
return super().__getattr__(name)
class S2(S1):
def __getattr__(self, name):
if name == 'answer2':
return 2
return super().__getattr__(name)


class S3(S2):
def __getattr__(self, name):
if name == 'answer3':
return 3
return super().__getattr__(name)

def test_method_getattr():
s1 = S1()
assert s1.answer1 == 1
try:
Expand All @@ -44,3 +39,44 @@ def test_method_getattr():
assert s2.answer1 == 1
assert s2.answer2 == 2
assert s3.answer3 == 3


def test_method_getattr_missing_middle():
class S1(BASE):
def __getattr__(self, name):
if name == 'answer1':
return 1
return super().__getattr__(name)


class S2(S1):
pass


class S3(S2):
def __getattr__(self, name):
if name == 'answer3':
return 3
return super().__getattr__(name)

s1 = S1()
assert s1.answer1 == 1
try:
assert s1.answer2 == 2
except AttributeError:
pass

s2 = S2()
assert s2.answer1 == 1
try:
assert s2.answer2 == 2
except AttributeError:
pass

s3 = S3()
assert s3.answer1 == 1
try:
assert s3.answer2 == 2
except AttributeError:
pass
assert s3.answer3 == 3
29 changes: 11 additions & 18 deletions tests/test_method_getattr_self.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@ def __getattr__(self, key):
return super().__getattr__()


class T2(System.Object):
class T2(T1):
def __init__(self, value) -> None:
self.answer = value

def __getattribute__(self, __name: str) -> Any:
if __name == 'answer':
return 43
return super().__getattribute__(__name)
super().__init__(value)
self.answer1 = value + 1

def __getattr__(self, key):
if key == 'answer':
return self.answer
if key == 'answer1':
return self.answer1
return super().__getattr__()


Expand All @@ -37,15 +33,12 @@ def __getattr__(self, key):
return super().__getattr__()


class T4(System.Object):
def __getattribute__(self, __name: str) -> Any:
if __name == 'answer':
return 45
return super().__getattribute__(__name)
class T4(T3):
answer1 = 45

def __getattr__(self, key):
if key == 'answer':
return self.answer
if key == 'answer1':
return self.answer1
return super().__getattr__()


Expand All @@ -54,10 +47,10 @@ def test_method_getattr_self():
assert t1.answer == 42

t2 = T2(42)
assert t2.answer == 43
assert t2.answer1 == 43

t3 = T3()
assert t3.answer == 44

t4 = T4()
assert t4.answer == 45
assert t4.answer1 == 45
78 changes: 58 additions & 20 deletions tests/test_method_getattribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,28 @@
IBASE = System.Collections.IEnumerable


class S1(BASE):
def __getattribute__(self, name):
return super().__getattribute__(name)

def __getattr__(self, name):
if name == 'answer1':
return 1
return super().__getattr__(name)

def test_method_getattribute():
class S1(BASE):
def __getattribute__(self, name):
if name == 'answer1':
return 1
return super().__getattribute__(name)

class S2(S1):
def __getattribute__(self, name):
if name == 'answer2':
return 2
return super().__getattribute__(name)

class S2(S1):
def __getattribute__(self, name):
if name == 'answer2':
return 2
return super().__getattribute__(name)

class S3(S2):
def __getattr__(self, name):
if name == 'answer3':
return 3
return super().__getattr__(name)

class S3(S2):
def __getattribute__(self, name):
if name == 'answer3':
return 3
return super().__getattribute__(name)


def test_method_getattribute():
s1 = S1()
assert s1.answer1 == 1
try:
Expand All @@ -44,3 +40,45 @@ def test_method_getattribute():
assert s2.answer1 == 1
assert s2.answer2 == 2
assert s3.answer3 == 3


def test_method_getattribute_missing_middle():
class S1(BASE):
def __getattribute__(self, name):
if name == 'answer1':
return 1
return super().__getattribute__(name)


class S2(S1):
pass


class S3(S2):
def __getattribute__(self, name):
if name == 'answer3':
return 3
return super().__getattribute__(name)


s1 = S1()
assert s1.answer1 == 1
try:
assert s1.answer2 == 2
except AttributeError:
pass

s2 = S2()
assert s2.answer1 == 1
try:
assert s2.answer2 == 2
except AttributeError:
pass

s3 = S3()
assert s2.answer1 == 1
try:
assert s3.answer2 == 2
except AttributeError:
pass
assert s3.answer3 == 3
Loading

0 comments on commit f4a0d92

Please sign in to comment.