Skip to content

Commit

Permalink
Added support for casting python lists to non-generic IList
Browse files Browse the repository at this point in the history
  • Loading branch information
eirannejad committed Feb 10, 2023
1 parent 1729698 commit 44e8597
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 7 deletions.
10 changes: 5 additions & 5 deletions src/runtime/Codecs/ListDecoder.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;

namespace Python.Runtime.Codecs
{
public class ListDecoder : IPyObjectDecoder
{
private static bool IsList(Type targetType)
{
if (!targetType.IsGenericType)
return false;

return targetType.GetGenericTypeDefinition() == typeof(IList<>);
return (targetType.IsGenericType && targetType.GetGenericTypeDefinition() == typeof(IList<>))
|| (targetType == typeof(IList));
}

private static bool IsList(PyType objectType)
Expand All @@ -32,7 +32,7 @@ public bool TryDecode<T>(PyObject pyObj, out T value)
{
if (pyObj == null) throw new ArgumentNullException(nameof(pyObj));

var elementType = typeof(T).GetGenericArguments()[0];
var elementType = typeof(T).IsGenericType ? typeof(T).GetGenericArguments()[0] : typeof(object);
Type collectionType = typeof(CollectionWrappers.ListWrapper<>).MakeGenericType(elementType);

var instance = Activator.CreateInstance(collectionType, new[] { pyObj });
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/Codecs/PyObjectConversions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ public static class PyObjectConversions
static readonly DecoderGroup decoders = new();
static readonly EncoderGroup encoders = new();

static PyObjectConversions()
{
ListDecoder.Register();
}

/// <summary>
/// Registers specified encoder (marshaller)
/// <para>Python.NET will pick suitable encoder/decoder registered first</para>
Expand Down
1 change: 0 additions & 1 deletion src/runtime/CollectionWrappers/IterableWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ public IEnumerator<T> GetEnumerator()
iterObject = PyIter.GetIter(pyObject);
}

using var _ = iterObject;
while (true)
{
using var GIL = Py.GIL();
Expand Down
92 changes: 91 additions & 1 deletion src/runtime/CollectionWrappers/ListWrapper.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using System;
using System.Collections;
using System.Collections.Generic;

namespace Python.Runtime.CollectionWrappers
{
internal class ListWrapper<T> : SequenceWrapper<T>, IList<T>
internal class ListWrapper<T> : SequenceWrapper<T>, IList<T>, IList
{
public ListWrapper(PyObject pyObj) : base(pyObj)
{
Expand Down Expand Up @@ -53,5 +54,94 @@ public void RemoveAt(int index)
if (result == false)
Runtime.CheckExceptionOccurred();
}

public class InvalidTypeException : Exception
{
public InvalidTypeException() : base($"value is not of type {typeof(T)}") { }
}

#region IList
object? IList.this[int index]
{
get
{
return this[index];
}
set
{
if (value is T tvalue)
this[index] = tvalue;
else
throw new InvalidTypeException();
}
}

bool IList.IsFixedSize => false;

bool IList.IsReadOnly => false;

int ICollection.Count => Count;

bool ICollection.IsSynchronized => false;

object? ICollection.SyncRoot => null;

int IList.Add(object value)
{
if (value is T tvalue)
{
Add(tvalue);
return IndexOf(tvalue);
}

throw new InvalidTypeException();
}

void IList.Clear() => Clear();

bool IList.Contains(object value)
{
if (value is T tvalue)
return Contains(tvalue);

throw new InvalidTypeException();
}

int IList.IndexOf(object value)
{
if (value is T tvalue)
return indexOf(tvalue);

throw new InvalidTypeException();
}

void IList.Insert(int index, object value)
{
if (value is T tvalue)
{
Insert(index, tvalue);
return;
}

throw new InvalidTypeException();
}

void IList.Remove(object value)
{
if (value is T tvalue)
{
Remove(tvalue);
return;
}

throw new InvalidTypeException();
}

void IList.RemoveAt(int index) => RemoveAt(index);

void ICollection.CopyTo(Array array, int index) => throw new InvalidTypeException();

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
#endregion
}
}

0 comments on commit 44e8597

Please sign in to comment.