Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove contravariance from ISerialize<T> #152

Merged
merged 2 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions src/generator/Generator.Serialize.Generic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.InteropServices.ComTypes;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -164,7 +165,7 @@ private static (MemberDeclarationSyntax[], BaseListSyntax) GenerateSerializeGene
Identifier("name"),
argumentList: null,
EqualsValueClause(SwitchExpression(receiverExpr, SeparatedList(cases)))) }))));
var wrapper = TryGetPrimitiveWrapper(enumType.EnumUnderlyingType!, SerdeUsage.Serialize)!;
var wrapper = TryGetPrimitiveWrapper(enumType.EnumUnderlyingType!, SerdeUsage.Serialize).Unwrap().Wrapper;
statements.Add(ExpressionStatement(InvocationExpression(
QualifiedName(IdentifierName("serializer"), IdentifierName("SerializeEnumValue")),
ArgumentList(SeparatedList(new[] {
Expand Down Expand Up @@ -205,8 +206,8 @@ private static (MemberDeclarationSyntax[], BaseListSyntax) GenerateSerializeGene
{
// Generate statements of the form `type.SerializeField<FieldType, Serialize>("FieldName", receiver.FieldValue)`
var memberExpr = MakeMemberAccessExpr(m, receiverExpr);
var serializeImpl = MakeSerializeType(m, context, memberExpr, inProgress);
if (serializeImpl is null)
var typeAndWrapperOpt = MakeSerializeType(m, context, memberExpr, inProgress);
if (typeAndWrapperOpt is not {} typeAndWrapper)
{
// No built-in handling and doesn't implement ISerialize, error
context.ReportDiagnostic(CreateDiagnostic(
Expand All @@ -218,7 +219,7 @@ private static (MemberDeclarationSyntax[], BaseListSyntax) GenerateSerializeGene
}
else
{
statements.Add(MakeSerializeFieldStmt(m, memberExpr, serializeImpl, receiverExpr));
statements.Add(MakeSerializeFieldStmt(m, memberExpr, typeAndWrapper, receiverExpr));
}
}

Expand Down Expand Up @@ -260,7 +261,7 @@ private static (MemberDeclarationSyntax[], BaseListSyntax) GenerateSerializeGene
static ExpressionStatementSyntax MakeSerializeFieldStmt(
DataMemberSymbol member,
ExpressionSyntax value,
TypeSyntax serializeType,
TypeAndWrapper typeAndWrapper,
ExpressionSyntax receiver)
{
var arguments = new List<ExpressionSyntax>() {
Expand All @@ -270,8 +271,8 @@ static ExpressionStatementSyntax MakeSerializeFieldStmt(
value,
};
var typeArgs = new List<TypeSyntax>() {
member.Type.ToFqnSyntax(),
serializeType
typeAndWrapper.Type,
typeAndWrapper.Wrapper
};

string methodName;
Expand Down Expand Up @@ -341,7 +342,7 @@ static ExpressionStatementSyntax MakeSerializeFieldStmt(
/// implements ISerialize. SerdeDn provides wrappers for primitives and common types in the
/// framework. If found, we generate and initialize the wrapper.
/// </summary>
private static TypeSyntax? MakeSerializeType(
private static TypeAndWrapper? MakeSerializeType(
DataMemberSymbol member,
GeneratorExecutionContext context,
ExpressionSyntax memberExpr,
Expand All @@ -350,23 +351,18 @@ static ExpressionStatementSyntax MakeSerializeFieldStmt(
// 1. Check for an explicit wrapper
if (TryGetExplicitWrapper(member, context, SerdeUsage.Serialize, inProgress) is {} wrapper)
{
return wrapper;
return new(member.Type.ToFqnSyntax(), wrapper);
}

// 2. Check for a direct implementation of ISerialize
if (ImplementsSerde(member.Type, context, SerdeUsage.Serialize))
{
return GenericName(Identifier("IdWrap"), TypeArgumentList(SeparatedList(new[] { member.Type.ToFqnSyntax() })));
return new(member.Type.ToFqnSyntax(),
GenericName(Identifier("IdWrap"), TypeArgumentList(SeparatedList(new[] { member.Type.ToFqnSyntax() }))));
}

// 3. A wrapper that implements ISerialize
var wrapperType = TryGetAnyWrapper(member.Type, context, SerdeUsage.Serialize, inProgress);
if (wrapperType is not null)
{
return wrapperType;
}

return null;
return TryGetAnyWrapper(member.Type, context, SerdeUsage.Serialize, inProgress);
}

/// <summary>
Expand Down Expand Up @@ -553,9 +549,9 @@ private static bool ImplementsSerde(ITypeSymbol memberType, GeneratorExecutionCo

// Otherwise we'll need to wrap the element type as well e.g.,
// ArrayWrap<`elemType`, `elemTypeWrapper`>
var wrapper = TryGetAnyWrapper(elemType, context, usage, inProgress);
var typeAndWrapper = TryGetAnyWrapper(elemType, context, usage, inProgress);

if (wrapper is null)
if (typeAndWrapper is not (_, var wrapper))
{
// Could not find a wrapper
return null;
Expand All @@ -575,7 +571,7 @@ private static bool ImplementsSerde(ITypeSymbol memberType, GeneratorExecutionCo
return wrapperSyntax;
}

private static TypeSyntax? TryGetAnyWrapper(
private static TypeAndWrapper? TryGetAnyWrapper(
ITypeSymbol elemType,
GeneratorExecutionContext context,
SerdeUsage usage,
Expand All @@ -592,16 +588,16 @@ private static bool ImplementsSerde(ITypeSymbol memberType, GeneratorExecutionCo
allTypes = parent.Name + allTypes;
}
var wrapperName = $"{allTypes}Wrap";
return IdentifierName(wrapperName);
return new(elemType.ToFqnSyntax(), IdentifierName(wrapperName));
}
var nameSyntax = TryGetPrimitiveWrapper(elemType, usage)
var typeAndWrapper = TryGetPrimitiveWrapper(elemType, usage)
?? TryGetEnumWrapper(elemType, usage)
?? TryGetCompoundWrapper(elemType, context, usage, inProgress);
if (nameSyntax is null)
if (typeAndWrapper is null)
{
return null;
}
return nameSyntax;
return typeAndWrapper;
}


Expand Down Expand Up @@ -654,7 +650,7 @@ namespace Serde
}

// If the target is a core type, we can wrap it
private static TypeSyntax? TryGetPrimitiveWrapper(ITypeSymbol type, SerdeUsage usage)
private static TypeAndWrapper? TryGetPrimitiveWrapper(ITypeSymbol type, SerdeUsage usage)
{
if (type.NullableAnnotation == NullableAnnotation.Annotated)
{
Expand All @@ -678,10 +674,10 @@ namespace Serde
SpecialType.System_Decimal => "DecimalWrap",
_ => null
};
return name is null ? null : IdentifierName(name);
return name is null ? null : new(type.ToFqnSyntax(), IdentifierName(name));
}

private static TypeSyntax? TryGetEnumWrapper(ITypeSymbol type, SerdeUsage usage)
private static TypeAndWrapper? TryGetEnumWrapper(ITypeSymbol type, SerdeUsage usage)
{
if (type.TypeKind is not TypeKind.Enum)
{
Expand All @@ -704,7 +700,7 @@ namespace Serde
? containing + "." + wrapperName
: "global::" + wrapperName;

return SyntaxFactory.ParseTypeName(wrapperFqn);
return new(type.ToFqnSyntax(), ParseTypeName(wrapperFqn));
}

private static bool HasGenerateAttribute(ITypeSymbol memberType, SerdeUsage usage)
Expand Down Expand Up @@ -735,17 +731,18 @@ private static bool HasGenerateAttribute(ITypeSymbol memberType, SerdeUsage usag
return false;
}

private static TypeSyntax? TryGetCompoundWrapper(ITypeSymbol type, GeneratorExecutionContext context, SerdeUsage usage, ImmutableList<ITypeSymbol> inProgress)
private static TypeAndWrapper? TryGetCompoundWrapper(ITypeSymbol type, GeneratorExecutionContext context, SerdeUsage usage, ImmutableList<ITypeSymbol> inProgress)
{
return type switch
(TypeSyntax?, TypeSyntax?)? valueTypeAndWrapper = type switch
{
{ OriginalDefinition.SpecialType: SpecialType.System_Nullable_T } =>
MakeWrappedExpression(
(null,
MakeWrappedExpression(
context.Compilation.GetTypeByMetadataName("Serde.NullableWrap+" + GetImplName(usage) + "`2")!,
ImmutableArray.Create(((INamedTypeSymbol)type).TypeArguments[0]),
context,
usage,
inProgress),
inProgress)),

// This is rather subtle. One might think that we would want to use a
// NullableRefWrapper for any reference type that could contain null. In fact, we
Expand All @@ -759,27 +756,36 @@ private static bool HasGenerateAttribute(ITypeSymbol memberType, SerdeUsage usag
// ISerialize, and therefore the substitution to provide the appropriate nullable
// wrapper.
{ IsReferenceType: true, NullableAnnotation: NullableAnnotation.Annotated} =>
MakeWrappedExpression(
(null,
MakeWrappedExpression(
context.Compilation.GetTypeByMetadataName("Serde.NullableRefWrap+" + GetImplName(usage) + "`2")!,
ImmutableArray.Create(type.WithNullableAnnotation(NullableAnnotation.NotAnnotated)),
context,
usage,
inProgress),
inProgress)),

IArrayTypeSymbol and { IsSZArray: true, Rank: 1, ElementType: { } elemType }
=> MakeWrappedExpression(
=> (null,
MakeWrappedExpression(
context.Compilation.GetTypeByMetadataName("Serde.ArrayWrap+" + GetImplName(usage) + "`2")!,
ImmutableArray.Create(elemType),
context,
usage,
inProgress),
inProgress)),

INamedTypeSymbol t when TryGetWrapperName(t, context, usage) is { } tuple
=> MakeWrappedExpression(
tuple.WrapperType, tuple.Args, context, usage, inProgress),
INamedTypeSymbol t when TryGetWrapperName(t, context, usage) is (var ValueType, (var WrapperType, var Args))
=> (ValueType,
MakeWrappedExpression(
WrapperType, Args, context, usage, inProgress)),

_ => null,
};
return valueTypeAndWrapper switch {
null => null,
(null, {} wrapper) => new(type.ToFqnSyntax(), wrapper),
({ } value, { } wrapper) => new(value, wrapper),
(_, null) => throw ExceptionUtilities.Unreachable
};
}

private static string GetImplName(SerdeUsage usage) => usage switch
Expand All @@ -789,19 +795,20 @@ INamedTypeSymbol t when TryGetWrapperName(t, context, usage) is { } tuple
_ => throw ExceptionUtilities.Unreachable
};

private static (INamedTypeSymbol WrapperType, ImmutableArray<ITypeSymbol> Args)? TryGetWrapperName(
private static (TypeSyntax MemberType, (INamedTypeSymbol WrapperType, ImmutableArray<ITypeSymbol> Args))? TryGetWrapperName(
ITypeSymbol typeSymbol,
GeneratorExecutionContext context,
SerdeUsage usage)
{
if (typeSymbol.NullableAnnotation == NullableAnnotation.Annotated)
{
var nullableRefWrap = context.Compilation.GetTypeByMetadataName("Serde.NullableRefWrap+" + GetImplName(usage) + "`1")!;
return (nullableRefWrap, ImmutableArray.Create(typeSymbol.WithNullableAnnotation(NullableAnnotation.NotAnnotated)));
return (typeSymbol.ToFqnSyntax(),
(nullableRefWrap, ImmutableArray.Create(typeSymbol.WithNullableAnnotation(NullableAnnotation.NotAnnotated))));
}
if (typeSymbol is INamedTypeSymbol named && TryGetWellKnownType(named, context) is {} wk)
{
return (ToWrapper(wk, context.Compilation, usage), named.TypeArguments);
return (typeSymbol.ToFqnSyntax(), (ToWrapper(wk, context.Compilation, usage), named.TypeArguments));
}

// Check if it implements well-known interfaces
Expand All @@ -813,13 +820,15 @@ private static (INamedTypeSymbol WrapperType, ImmutableArray<ITypeSymbol> Args)?
if (impl.OriginalDefinition.Equals(iface, SymbolEqualityComparer.Default) &&
ToWrapper(TryGetWellKnownType(iface, context), context.Compilation, usage) is { } wrap)
{
return (wrap, impl.TypeArguments);
return (impl.ToFqnSyntax(), (wrap, impl.TypeArguments));
}
}
}
return null;
}

private readonly record struct TypeAndWrapper(TypeSyntax Type, TypeSyntax Wrapper);

[return: NotNullIfNotNull(nameof(wk))]
internal static INamedTypeSymbol? ToWrapper(WellKnownType? wk, Compilation comp, SerdeUsage usage)
{
Expand Down
3 changes: 3 additions & 0 deletions src/generator/Utilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Serde
{
Expand Down Expand Up @@ -32,6 +33,8 @@ public static bool IsSorted<T>(this ReadOnlySpan<T> span, IComparer<T> comparer)

internal static class Utilities
{
public static T Unwrap<T>(this T? value) where T : struct => value!.Value;

public static string Concat(this string recv, string other)
{
return recv + other;
Expand Down
2 changes: 1 addition & 1 deletion src/serde/ISerialize.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public interface ISerialize
void Serialize(ISerializer serializer);
}

public interface ISerialize<in T> : ISerialize
public interface ISerialize<T> : ISerialize
{
void Serialize(T value, ISerializer serializer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ partial class C : Serde.ISerialize<C>
void ISerialize<C>.Serialize(C value, ISerializer serializer)
{
var type = serializer.SerializeType("C", 1);
type.SerializeField<R, Serde.IDictWrap.SerializeImpl<string, StringWrap, int, Int32Wrap>>("rDictionary", value.RDictionary);
type.SerializeField<System.Collections.Generic.IDictionary<string, int>, Serde.IDictWrap.SerializeImpl<string, StringWrap, int, Int32Wrap>>("rDictionary", value.RDictionary);
type.End();
}
}
2 changes: 1 addition & 1 deletion test/Serde.Test/JsonSerializerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public void NullableString()
string? s = null;
var js = Serde.Json.JsonSerializer.Serialize<string?, NullableRefWrap.SerializeImpl<string, StringWrap>>(s);
Assert.Equal("null", js);
js = Serde.Json.JsonSerializer.Serialize(JsonValue.Null.Instance);
js = Serde.Json.JsonSerializer.Serialize<JsonValue>(JsonValue.Null.Instance);
Assert.Equal("null", js);
}

Expand Down
Loading