Skip to content

Commit

Permalink
Provide limited support for serializing unions (#211)
Browse files Browse the repository at this point in the history
Since discriminated unions have yet to ship as an actual language
feature, this change supports a workaround used by StaticCS to encode a
discriminated union. In fact, JsonValue is one such DU, but it uses an
"untagged" representation, which is not currently auto-generated.

The default representation is the same one used by serde.rs, which is
referred to as "externally tagged." Other encodings are possible, but
not yet implemented.
  • Loading branch information
agocke authored Jan 3, 2025
1 parent 3c353c8 commit ef64b0b
Show file tree
Hide file tree
Showing 250 changed files with 4,338 additions and 1,361 deletions.
1 change: 1 addition & 0 deletions src/generator/ConfigOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ internal record TypeOptions
public bool DenyUnknownMembers { get; init; } = false;
public MemberFormat MemberFormat { get; init; } = MemberFormat.CamelCase;
public bool SerializeNull { get; init; } = false;
public string? Rename { get; init; } = null;
}

internal readonly record struct MemberOptions()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

namespace Serde
{
internal class DeserializeImplGenerator
internal class DeserializeImplGen
{
internal static (List<MemberDeclarationSyntax>, BaseListSyntax) GenerateDeserializeImpl(
internal static (List<MemberDeclarationSyntax>, BaseListSyntax) GenDeserialize(
TypeDeclContext typeDeclContext,
GeneratorExecutionContext context,
ITypeSymbol receiverType,
Expand All @@ -25,6 +25,19 @@ internal static (List<MemberDeclarationSyntax>, BaseListSyntax) GenerateDeserial
var typeFqn = receiverType.ToDisplayString();
TypeSyntax typeSyntax = ParseTypeName(typeFqn);

if (receiverType.IsAbstract)
{
var memberDecl = ParseMemberDeclaration(GenUnionDeserializeMethod((INamedTypeSymbol)receiverType))!;
List<BaseTypeSyntax> unionBase = [
// `Serde.IDeserialize<'typeName'>
SimpleBaseType(QualifiedName(IdentifierName("Serde"), GenericName(
Identifier("IDeserialize"),
TypeArgumentList(SeparatedList(new[] { typeSyntax }))
))),
];
return ([ memberDecl ], BaseList(SeparatedList(unionBase)));
}

// Generate members for IDeserialize.Deserialize implementation
var members = new List<MemberDeclarationSyntax>();
List<BaseTypeSyntax> bases = [
Expand Down Expand Up @@ -57,6 +70,72 @@ static IDeserialize<{typeFqn}> IDeserializeProvider<{typeFqn}>.DeserializeInstan
return (members, baseList);
}

/// <summary>
/// Generates the method body for deserializing a union.
/// Code looks like:
/// <code>
/// static T IDeserialize&lt;T&gt;Deserialize(IDeserializer deserializer)
/// {
/// var serdeInfo = SerdeInfoProvider.GetInfo{T}();
/// var de = deserializer.ReadType(serdeInfo);
/// int index;
/// if ((index = de.TryReadIndex(serdeInfo, out var errorName)) == IDeserializeType.IndexNotFound)
/// {
/// throw new InvalidDeserializeValueException($"Unexpected value: {errorName}");
/// }
/// return index switch
/// {
/// {index} => deserializer.Deserialize({union member}),
/// ...
/// _ => throw new InvalidDeserializeValueException($"Unexpected index: {index}")
/// };
/// }
/// </code>
/// </summary>
private static string GenUnionDeserializeMethod(INamedTypeSymbol type)
{
Debug.Assert(type.IsAbstract);

var members = SymbolUtilities.GetDUTypeMembers(type);
var typeFqn = type.ToDisplayString();
var assignedVarType = members.Length switch {
(<= 8) => "byte",
(<= 16) => "ushort",
(<= 32) => "uint",
(<= 64) => "ulong",
_ => throw new InvalidOperationException("Too many members in type")
};
var membersBuilder = new StringBuilder();
for (int i = 0; i < members.Length; i++)
{
var m = members[i];
membersBuilder.AppendLine($"{i} => de.ReadValue<{m.ToDisplayString()}, {SerdeInfoGenerator.GetUnionProxyName(m)}>({i}),");
}

var src = $$"""
{{typeFqn}} IDeserialize<{{typeFqn}}>.Deserialize(IDeserializer deserializer)
{
var serdeInfo = global::Serde.SerdeInfoProvider.GetInfo<{{typeFqn}}>();
var de = deserializer.ReadType(serdeInfo);
int index;
if ((index = de.TryReadIndex(serdeInfo, out var errorName)) == IDeserializeType.IndexNotFound)
{
throw Serde.DeserializeException.UnknownMember(errorName!, serdeInfo);
}
{{typeFqn}} _l_result = index switch {
{{membersBuilder}}
_ => throw new InvalidOperationException($"Unexpected index: {index}")
};
if ((index = de.TryReadIndex(serdeInfo, out _)) != IDeserializeType.EndOfType)
{
throw Serde.DeserializeException.ExpectedEndOfType(index);
}
return _l_result;
}
""";
return src;
}

/// <summary>
/// Generates the method body for deserializing an enum.
/// Code looks like:
Expand Down
2 changes: 2 additions & 0 deletions src/generator/Diagnostics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ internal enum DiagId
ERR_MissingPrimaryCtor = 4,
ERR_CantFindNestedWrapper = 5,
ERR_WrapperDoesntImplementInterface = 6,
ERR_CantImplementAbstract = 7,
}

internal static class Diagnostics
Expand All @@ -29,6 +30,7 @@ internal static class Diagnostics
ERR_MissingPrimaryCtor => nameof(ERR_MissingPrimaryCtor),
ERR_CantFindNestedWrapper => nameof(ERR_CantFindNestedWrapper),
ERR_WrapperDoesntImplementInterface => nameof(ERR_WrapperDoesntImplementInterface),
ERR_CantImplementAbstract => nameof(ERR_CantImplementAbstract),
};

public static Diagnostic CreateDiagnostic(DiagId id, Location location, params object[] args)
Expand Down
11 changes: 6 additions & 5 deletions src/generator/GenerationOutput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Immutable;
using System.Diagnostics;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;

namespace Serde;

Expand All @@ -14,11 +15,11 @@ namespace Serde;
public readonly record struct GenerationOutput
{
public ImmutableArray<Diagnostic> Diagnostics { get;}
public ImmutableSortedSet<(string FileName, string Content)> Sources { get; }
public ImmutableSortedSet<(string FileName, SourceBuilder Content)> Sources { get; }

public GenerationOutput(
IEnumerable<Diagnostic> diagnostics,
IEnumerable<(string fileName, string content)> sources)
IEnumerable<(string fileName, SourceBuilder content)> sources)
{
var diagSet = new HashSet<Diagnostic>();
var diagBuilder = ImmutableArray.CreateBuilder<Diagnostic>();
Expand All @@ -29,7 +30,7 @@ public GenerationOutput(
diagBuilder.Add(diag);
}
}
var outputBuilder = ImmutableSortedSet.CreateBuilder<(string, string)>(SerdeInfoImplComparer.Instance);
var outputBuilder = ImmutableSortedSet.CreateBuilder<(string, SourceBuilder)>(SerdeInfoImplComparer.Instance);
foreach (var source in sources)
{
outputBuilder.Add(source);
Expand Down Expand Up @@ -62,12 +63,12 @@ public override int GetHashCode()
/// Generally compares the source and content of SerdeInfo implementations. For ISerdeInfoProvider
/// implementations, in particular, we only care about the source name.
/// </summary>
private sealed class SerdeInfoImplComparer : IComparer<(string SrcName, string Content)>
private sealed class SerdeInfoImplComparer : IComparer<(string SrcName, SourceBuilder Content)>
{
private SerdeInfoImplComparer() { }
public static readonly SerdeInfoImplComparer Instance = new SerdeInfoImplComparer();

public int Compare((string SrcName, string Content) x, (string SrcName, string Content) y)
public int Compare((string SrcName, SourceBuilder Content) x, (string SrcName, SourceBuilder Content) y)
{
if (x.SrcName.EndsWith(".ISerdeInfoProvider") && y.SrcName.EndsWith(".ISerdeInfoProvider"))
{
Expand Down
Loading

0 comments on commit ef64b0b

Please sign in to comment.