Skip to content

Commit

Permalink
improved equality checks and added test
Browse files Browse the repository at this point in the history
  • Loading branch information
eirannejad committed Jan 17, 2024
1 parent fe2dbe8 commit 7f8b1b6
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 140 deletions.
149 changes: 71 additions & 78 deletions src/runtime/MethodBinder.Solver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ larger number of parameters.
N │ │ │ │┼┼┼│ │ │ │ │ │ │
C ├───────────────┤ │ │┼┼┼│ │ ├────────┤ │ ├──────────────┤
E │ │ │ │┼┼┼│ │ │ │ │ │ │
│ STATIC │ │ │┼┼┼│ │ │ PARAMS │ │ │ GENERIC │
│ STATIC │ │ │┼┼┼│ │ │ resvd. │ │ │ GENERIC │
R │ │ │ │┼┼┼│ │ │ │ │ │ │
A ├───────────────┤ │ │┼┼┼│ │ ├────────┤ │ ├──────────────┤
N │ │ │ │┼┼┼│ │ │ │ │ │ │
G │ STATIC<T> │ │ │┼┼┼│ │ │ resvd. │ │ │ CAST/CONVERT │
G │ STATIC<T> │ │ │┼┼┼│ │ │ resvd. │ │ │ CONVERT/CAST
E │ │ └─► │┼┼┼│ └── │ │ └── │ │
└───────────────┘ └───┘ └────────┘ └──────────────┘
FUNCTION MAX ARGS TYPE MATCH
Expand Down Expand Up @@ -364,12 +364,7 @@ public bool TryGetArguments(object? instance,
return true;
}

public void SetArgs(ref ArgProvider prov)
{
ExtractParameters(ref prov, computeDist: false);
}

public uint SetArgsAndGetDistance(ref ArgProvider prov)
public uint AssignArguments(ref ArgProvider prov)
{
uint distance = 0;

Expand All @@ -381,7 +376,7 @@ public uint SetArgsAndGetDistance(ref ArgProvider prov)
// NOTE:
// if method contains generic parameters, the distance
// compute logic will take that into consideration.
// but methods can be generic with no generic paramerers,
// but methods can be generic with no generic parameters,
// e.g. Foo<T>(int, float).
// this ensures these methods are furthur away from
// non-generic instance methods with matching parameters
Expand All @@ -393,17 +388,16 @@ public uint SetArgsAndGetDistance(ref ArgProvider prov)
(uint)Method.GetGenericArguments().Length);
}

distance += ExtractParameters(ref prov, computeDist: true);

#if UNIT_TEST_DEBUG
Debug.WriteLine($"{Method} -> {distance}");
#endif

distance += ExtractArguments(ref prov);

return distance;
}

uint ExtractParameters(ref ArgProvider prov,
bool computeDist)
uint ExtractArguments(ref ArgProvider prov)
{
uint distance = 0;

Expand Down Expand Up @@ -438,7 +432,7 @@ uint ExtractParameters(ref ArgProvider prov,
item = prov.GetKWArg(slot.Key);
if (item != null)
{
PyObject value = new PyObject(item);
PyObject value = new(item);

// NOTE:
// if this param is a capturing params[], expect
Expand All @@ -453,13 +447,8 @@ uint ExtractParameters(ref ArgProvider prov,
slot.Value = value;
}

if (computeDist)
{
slot.Distance =
GetTypeDistance(ref prov, item, slot);

distance += slot.Distance;
}
slot.Distance = GetDistance(ref prov, item, slot);
distance += slot.Distance;

continue;
}
Expand Down Expand Up @@ -494,10 +483,11 @@ uint ExtractParameters(ref ArgProvider prov,

// compute distance on first arg
// that is being captured by params []
if (computeDist && ai == argidx)
if (ai == argidx)
{
distance +=
GetTypeDistance(ref prov, item, slot);
slot.Distance =
GetDistance(ref prov, item, slot);
distance += slot.Distance;
}

argidx++;
Expand All @@ -512,7 +502,7 @@ uint ExtractParameters(ref ArgProvider prov,
// a default distance for this param slot
else if (argidx > prov.ArgsCount)
{
distance += computeDist ? ARG_GROUP_SIZE : 0;
distance += ARG_GROUP_SIZE;
}

continue;
Expand All @@ -528,14 +518,9 @@ uint ExtractParameters(ref ArgProvider prov,
if (item != null)
{
slot.Value = new PyObject(item);

if (computeDist)
{
slot.Distance =
GetTypeDistance(ref prov, item, slot);

distance += slot.Distance;
}
slot.Distance =
GetDistance(ref prov, item, slot);
distance += slot.Distance;

argidx++;
continue;
Expand Down Expand Up @@ -580,7 +565,7 @@ static bool TryBind(MethodBase[] methods,
spec = default;
error = default;

ArgProvider provider = new ArgProvider(args, kwargs);
ArgProvider provider = new(args, kwargs);

// Find any method that could accept this many args and kwargs
int index = 0;
Expand Down Expand Up @@ -637,26 +622,14 @@ static bool TryBindByValue(int count,
return false;
}

if (count == 1)
{
spec = specs[0];
spec!.SetArgs(ref prov);
return true;
}

uint ambigCount = 0;
MethodBase?[] ambigMethods = new MethodBase?[count];
uint closest = uint.MaxValue;
for (int sidx = 0; sidx < count; sidx++)
{
BindSpec mspec = specs[sidx]!;

uint distance = mspec!.SetArgsAndGetDistance(ref prov);

if (distance == ARG_GROUP_SIZE)
{
continue;
}
uint distance = mspec!.AssignArguments(ref prov);

// NOTE:
// if method has the exact same distance,
Expand Down Expand Up @@ -854,16 +827,16 @@ static bool TryBindByCount(MethodBase method,
{
argSpecs = new BindParam[]
{
new BindParam(mparams[0], BindParamKind.Default),
new BindParam(mparams[1], BindParamKind.Self),
new BindParam(mparams[0], BindParamKind.Default),
new BindParam(mparams[1], BindParamKind.Self),
};
}
else
{
argSpecs = new BindParam[]
{
new BindParam(mparams[0], BindParamKind.Self),
new BindParam(mparams[1], BindParamKind.Default),
new BindParam(mparams[0], BindParamKind.Self),
new BindParam(mparams[1], BindParamKind.Default),
};
}
}
Expand Down Expand Up @@ -1074,37 +1047,54 @@ static HashSet<string> GetKeys(BorrowedReference kwargs)

static readonly uint TOTAL_MAX_DIST = uint.MaxValue;
static readonly uint FUNC_GROUP_SIZE = TOTAL_MAX_DIST / 4;
static readonly uint ARGS_MAX_DIST = FUNC_GROUP_SIZE;
static readonly uint ARG_GROUP_SIZE = FUNC_GROUP_SIZE / MAX_ARGS;
static readonly uint ARG_MAX_DIST = ARG_GROUP_SIZE;
static readonly uint TYPE_GROUP_SIZE = ARG_GROUP_SIZE / 4;
static readonly uint MATCH_GROUP_SIZE = TYPE_GROUP_SIZE / 4;

static uint GetTypeDistance(ref ArgProvider prov,
BorrowedReference from, BindParam to)
static readonly uint MATCH_MAX_DIST = MATCH_GROUP_SIZE;
static readonly uint CONVERT_MATCH_THRESHOLD = MATCH_GROUP_SIZE * 3;

// NOTE:
// this method computes a distance between the given python arg
// and the expected type iin target parameter slot.
// However in many cases when given arg is a python object,
// the final clr type of arg is unknown. therefore we return the
// max distance for these and let the arg converter attempt
// to convert the type to the expected type later.
static uint GetDistance(ref ArgProvider prov,
BorrowedReference from, BindParam to)
{
Type toType = to.Type;

if (to.Kind == BindParamKind.Params
&& Runtime.PySequence_Check(from))
{
uint argsCount = (uint)Runtime.PyTuple_Size(from);
uint argsCount = (uint)Runtime.PySequence_Size(from);
if (argsCount > 0)
{
BorrowedReference item = Runtime.PyTuple_GetItem(from, 0);
if (item != null)
using var iterObj = Runtime.PyObject_GetIter(from);
using var item = Runtime.PyIter_Next(iterObj.Borrow());
if (!item.IsNull()
&& prov.GetCLRType(item.Borrow()) is Type argType)
{
if (prov.GetCLRType(item) is Type argType)
{
return GetTypeDistance(argType, toType);
}
return GetTypeDistance(argType, toType);
}
}
}
else if (prov.GetCLRType(from) is Type argType)
{
return GetTypeDistance(argType, toType);
}
else if (from == null
|| Runtime.None == from
|| toType == typeof(object)
|| toType == typeof(PyObject))
{
return 0;
}

return GetTypeDistance(typeof(object), toType);
return ARG_MAX_DIST;
}

static uint GetTypeDistance(Type from, Type to)
Expand Down Expand Up @@ -1136,14 +1126,14 @@ static uint GetTypeDistance(Type from, Type to)

if (from.IsArray != to.IsArray)
{
distance = ARG_GROUP_SIZE;
distance = ARG_MAX_DIST;
goto computed;
}

if ((from.IsArray && to.IsArray)
&& (from.GetElementType() != to.GetElementType()))
{
distance = ARG_GROUP_SIZE;
distance = ARG_MAX_DIST;
goto computed;
}

Expand All @@ -1162,16 +1152,9 @@ static uint GetTypeDistance(Type from, Type to)
goto computed;
}

// cast/convert match
// convert/cast match
distance += MATCH_GROUP_SIZE;
if (TryGetPrecedence(from, out uint fromPrec)
&& TryGetPrecedence(to, out uint toPrec))
{
distance += GetConvertTypeDistance(fromPrec, toPrec);
goto computed;
}

distance = ARG_GROUP_SIZE;
distance += GetConvertTypeDistance(from, to);

computed:
#if METHODBINDER_SOLVER_NEW_CACHE_DIST
Expand Down Expand Up @@ -1207,7 +1190,7 @@ static uint GetDerivedTypeDistance(Type from, Type to)
Type t = from;
while (t != null
&& t != to
&& depth < MATCH_GROUP_SIZE)
&& depth < MATCH_MAX_DIST)
{
depth++;
t = t.BaseType;
Expand All @@ -1218,9 +1201,15 @@ static uint GetDerivedTypeDistance(Type from, Type to)

// zero when types are equal.
// 0 <= x < MATCH_MAX_DIST
static uint GetConvertTypeDistance(uint from, uint to)
static uint GetConvertTypeDistance(Type from, Type to)
{
return (uint)Math.Abs((int)to - (int)from);
if (TryGetPrecedence(from, out uint fromPrec)
&& TryGetPrecedence(to, out uint toPrec))
{
return (uint)Math.Abs((int)toPrec - (int)fromPrec);
}

return MATCH_MAX_DIST;
}

static bool TryGetPrecedence(Type of, out uint predecence)
Expand All @@ -1232,6 +1221,11 @@ static bool TryGetPrecedence(Type of, out uint predecence)
return false;
}

if (!of.IsPrimitive)
{
return false;
}

if (of.IsArray)
{
return TryGetPrecedence(of.GetElementType(), out predecence);
Expand All @@ -1253,8 +1247,7 @@ static bool TryGetPrecedence(Type of, out uint predecence)
{
// 0-9
case TypeCode.Object:
predecence = 1;
return true;
return false;

// 10-19
case TypeCode.UInt64:
Expand Down
Loading

0 comments on commit 7f8b1b6

Please sign in to comment.