Skip to content

Commit

Permalink
Merge branch 'release/9.0' into release/9.0-staging
Browse files Browse the repository at this point in the history
  • Loading branch information
carlossanlop authored Dec 3, 2024
2 parents 624638d + cc58c28 commit 38709f8
Show file tree
Hide file tree
Showing 37 changed files with 1,736 additions and 838 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ namespace System.Formats.Nrbf
public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRecord
{
internal ArrayRecord() { }
public virtual long FlattenedLength { get { throw null; } }
public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } }
public abstract System.ReadOnlySpan<int> Lengths { get; }
public int Rank { get { throw null; } }
Expand Down
2 changes: 1 addition & 1 deletion src/libraries/System.Formats.Nrbf/src/PACKAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ There are more than a dozen different serialization [record types](https://learn
- `PrimitiveTypeRecord<T>` derives from the non-generic [PrimitiveTypeRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord), which also exposes a [Value](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord.value) property. But on the base class, the value is returned as `object` (which introduces boxing for value types).
- [ClassRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.classrecord): describes all `class` and `struct` besides the aforementioned primitive types.
- [ArrayRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.arrayrecord): describes all array records, including jagged and multi-dimensional arrays.
- [`SZArrayRecord<T>`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `ClassRecord`.
- [`SZArrayRecord<T>`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `SerializationRecord`.

```csharp
SerializationRecord rootObject = NrbfDecoder.Decode(payload); // payload is a Stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ internal enum AllowedRecordTypes : uint
ArraySingleString = 1 << SerializationRecordType.ArraySingleString,

Nulls = ObjectNull | ObjectNullMultiple256 | ObjectNullMultiple,
Arrays = ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray,

/// <summary>
/// Any .NET object (a primitive, a reference type, a reference or single null).
/// </summary>
AnyObject = MemberPrimitiveTyped
| ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray
| Arrays
| ClassWithId | ClassWithMembersAndTypes | SystemClassWithMembersAndTypes
| BinaryObjectString
| MemberReference
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
using System.Diagnostics.CodeAnalysis;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.Serialization;

namespace System.Formats.Nrbf;

Expand All @@ -27,12 +30,6 @@ private protected ArrayRecord(ArrayInfo arrayInfo)
/// <value>A buffer of integers that represent the number of elements in every dimension.</value>
public abstract ReadOnlySpan<int> Lengths { get; }

/// <summary>
/// When overridden in a derived class, gets the total number of all elements in every dimension.
/// </summary>
/// <value>A number that represent the total number of all elements in every dimension.</value>
public virtual long FlattenedLength => ArrayInfo.FlattenedLength;

/// <summary>
/// Gets the rank of the array.
/// </summary>
Expand Down Expand Up @@ -118,4 +115,86 @@ private void HandleNext(object value, NextInfo info, int size)
}

internal abstract (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType();

internal static void Populate(List<SerializationRecord> source, Array destination, int[] lengths, AllowedRecordTypes allowedRecordTypes, bool allowNulls)
{
int[] indices = new int[lengths.Length];
nuint numElementsWritten = 0; // only for debugging; not used in release builds

foreach (SerializationRecord record in source)
{
object? value = GetActualValue(record, allowedRecordTypes, out int incrementCount);
if (value is not null)
{
// null is a default element for all array of reference types, so we don't call SetValue for nulls.
destination.SetValue(value, indices);
Debug.Assert(incrementCount == 1, "IncrementCount other than 1 is allowed only for null records.");
}
else if (!allowNulls)
{
ThrowHelper.ThrowArrayContainedNulls();
}

while (incrementCount > 0)
{
incrementCount--;
numElementsWritten++;
int dimension = indices.Length - 1;
while (dimension >= 0)
{
indices[dimension]++;
if (indices[dimension] < lengths[dimension])
{
break;
}
indices[dimension] = 0;
dimension--;
}

if (dimension < 0)
{
break;
}
}
}

Debug.Assert(numElementsWritten == (uint)source.Count, "We should have traversed the entirety of the source records collection.");
Debug.Assert(numElementsWritten == (ulong)destination.LongLength, "We should have traversed the entirety of the destination array.");
}

private static object? GetActualValue(SerializationRecord record, AllowedRecordTypes allowedRecordTypes, out int repeatCount)
{
repeatCount = 1;

if (record is NullsRecord nullsRecord)
{
repeatCount = nullsRecord.NullCount;
return null;
}
else if (record.RecordType == SerializationRecordType.MemberReference)
{
record = ((MemberReferenceRecord)record).GetReferencedRecord();
}

if (allowedRecordTypes == AllowedRecordTypes.BinaryObjectString)
{
if (record is not BinaryObjectStringRecord stringRecord)
{
throw new SerializationException(SR.Serialization_InvalidReference);
}

return stringRecord.Value;
}
else if (allowedRecordTypes == AllowedRecordTypes.Arrays)
{
if (record is not ArrayRecord arrayRecord)
{
throw new SerializationException(SR.Serialization_InvalidReference);
}

return arrayRecord;
}

return record;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Formats.Nrbf.Utils;
using System.Linq;
using System.Reflection.Metadata;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;

namespace System.Formats.Nrbf
{
internal sealed class ArrayRectangularPrimitiveRecord<T> : ArrayRecord where T : unmanaged
{
private readonly int[] _lengths;
private readonly IReadOnlyList<T> _values;
private TypeName? _typeName;

internal ArrayRectangularPrimitiveRecord(ArrayInfo arrayInfo, int[] lengths, IReadOnlyList<T> values) : base(arrayInfo)
{
_lengths = lengths;
_values = values;
ValuesToRead = 0; // there is nothing to read anymore
}

public override ReadOnlySpan<int> Lengths => _lengths;

public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;

public override TypeName TypeName
=> _typeName ??= TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.GetPrimitiveType<T>()).MakeArrayTypeName(Rank);

internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException();

private protected override void AddValue(object value) => throw new InvalidOperationException();

[RequiresDynamicCode("May call Array.CreateInstance().")]
private protected override Array Deserialize(Type arrayType, bool allowNulls)
{
Array result =
#if NET9_0_OR_GREATER
Array.CreateInstanceFromArrayType(arrayType, _lengths);
#else
Array.CreateInstance(typeof(T), _lengths);
#endif
int[] indices = new int[_lengths.Length];
nuint numElementsWritten = 0; // only for debugging; not used in release builds

for (int i = 0; i < _values.Count; i++)
{
result.SetValue(_values[i], indices);
numElementsWritten++;

int dimension = indices.Length - 1;
while (dimension >= 0)
{
indices[dimension]++;
if (indices[dimension] < Lengths[dimension])
{
break;
}
indices[dimension] = 0;
dimension--;
}

if (dimension < 0)
{
break;
}
}

Debug.Assert(numElementsWritten == (uint)_values.Count, "We should have traversed the entirety of the source values collection.");
Debug.Assert(numElementsWritten == (ulong)result.LongLength, "We should have traversed the entirety of the destination array.");

return result;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace System.Formats.Nrbf;
/// <remarks>
/// ArraySingleObject records are described in <see href="https://learn.microsoft.com/openspecs/windows_protocols/ms-nrbf/982b2f50-6367-402a-aaf2-44ee96e2a5e0">[MS-NRBF] 2.4.3.2</see>.
/// </remarks>
internal sealed class ArraySingleObjectRecord : SZArrayRecord<object?>
internal sealed class ArraySingleObjectRecord : SZArrayRecord<SerializationRecord>
{
private ArraySingleObjectRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = [];
internal ArraySingleObjectRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = [];

public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleObject;

Expand All @@ -27,25 +27,26 @@ public override TypeName TypeName
private List<SerializationRecord> Records { get; }

/// <inheritdoc/>
public override object?[] GetArray(bool allowNulls = true)
=> (object?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));
public override SerializationRecord?[] GetArray(bool allowNulls = true)
=> (SerializationRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));

private object?[] ToArray(bool allowNulls)
private SerializationRecord?[] ToArray(bool allowNulls)
{
object?[] values = new object?[Length];
SerializationRecord?[] values = new SerializationRecord?[Length];

int valueIndex = 0;
for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++)
{
SerializationRecord record = Records[recordIndex];

int nullCount = record is NullsRecord nullsRecord ? nullsRecord.NullCount : 0;
if (nullCount == 0)
if (record is MemberReferenceRecord referenceRecord)
{
// "new object[] { <SELF> }" is special cased because it allows for storing reference to itself.
values[valueIndex++] = record is MemberReferenceRecord referenceRecord && referenceRecord.Reference.Equals(Id)
? values // a reference to self, and a way to get StackOverflow exception ;)
: record.GetValue();
record = referenceRecord.GetReferencedRecord();
}

if (record is not NullsRecord nullsRecord)
{
values[valueIndex++] = record;
continue;
}

Expand All @@ -54,6 +55,7 @@ public override TypeName TypeName
ThrowHelper.ThrowArrayContainedNulls();
}

int nullCount = nullsRecord.NullCount;
do
{
values[valueIndex++] = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public override T[] GetArray(bool allowNulls = true)

internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int count)
{
if (count == 0)
{
return Array.Empty<T>(); // Empty arrays are allowed.
}

// For decimals, the input is provided as strings, so we can't compute the required size up-front.
if (typeof(T) == typeof(decimal))
{
Expand All @@ -71,18 +76,15 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
// allocations to be proportional to the amount of data present in the input stream,
// which is a sufficient defense against DoS.

long requiredBytes = count;
if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
{
// We can't assume DateTime as represented by the runtime is 8 bytes.
// The only assumption we can make is that it's 8 bytes on the wire.
requiredBytes *= 8;
}
else if (typeof(T) != typeof(char))
{
requiredBytes *= Unsafe.SizeOf<T>();
}
// We can't assume DateTime as represented by the runtime is 8 bytes.
// The only assumption we can make is that it's 8 bytes on the wire.
int sizeOfT = typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan)
? 8
: typeof(T) != typeof(char)
? Unsafe.SizeOf<T>()
: 1;

long requiredBytes = (long)count * sizeOfT;
bool? isDataAvailable = reader.IsDataAvailable(requiredBytes);
if (!isDataAvailable.HasValue)
{
Expand Down Expand Up @@ -110,26 +112,49 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c

// It's safe to pre-allocate, as we have ensured there is enough bytes in the stream.
T[] result = new T[count];
Span<byte> resultAsBytes = MemoryMarshal.AsBytes<T>(result);
#if NET
reader.BaseStream.ReadExactly(resultAsBytes);

// MemoryMarshal.AsBytes can fail for inputs that need more than int.MaxValue bytes.
// To avoid OverflowException, we read the data in chunks.
int MaxChunkLength =
#if !DEBUG
int.MaxValue / sizeOfT;
#else
byte[] bytes = ArrayPool<byte>.Shared.Rent((int)Math.Min(requiredBytes, 256_000));
// Let's use a different value for non-release builds to ensure this code path
// is covered with tests without the need of decoding enormous payloads.
8_000;
#endif

while (!resultAsBytes.IsEmpty)
#if !NET
byte[] rented = ArrayPool<byte>.Shared.Rent((int)Math.Min(requiredBytes, 256_000));
#endif

Span<T> valuesToRead = result.AsSpan();
while (!valuesToRead.IsEmpty)
{
int bytesRead = reader.Read(bytes, 0, Math.Min(resultAsBytes.Length, bytes.Length));
if (bytesRead <= 0)
int sliceSize = Math.Min(valuesToRead.Length, MaxChunkLength);

Span<byte> resultAsBytes = MemoryMarshal.AsBytes<T>(valuesToRead.Slice(0, sliceSize));
#if NET
reader.BaseStream.ReadExactly(resultAsBytes);
#else
while (!resultAsBytes.IsEmpty)
{
ArrayPool<byte>.Shared.Return(bytes);
ThrowHelper.ThrowEndOfStreamException();
}
int bytesRead = reader.Read(rented, 0, Math.Min(resultAsBytes.Length, rented.Length));
if (bytesRead <= 0)
{
ArrayPool<byte>.Shared.Return(rented);
ThrowHelper.ThrowEndOfStreamException();
}

bytes.AsSpan(0, bytesRead).CopyTo(resultAsBytes);
resultAsBytes = resultAsBytes.Slice(bytesRead);
rented.AsSpan(0, bytesRead).CopyTo(resultAsBytes);
resultAsBytes = resultAsBytes.Slice(bytesRead);
}
#endif
valuesToRead = valuesToRead.Slice(sliceSize);
}

ArrayPool<byte>.Shared.Return(bytes);
#if !NET
ArrayPool<byte>.Shared.Return(rented);
#endif

if (!BitConverter.IsLittleEndian)
Expand Down Expand Up @@ -176,7 +201,7 @@ internal static IReadOnlyList<T> DecodePrimitiveTypes(BinaryReader reader, int c
{
// See DontCastBytesToBooleans test to see what could go wrong.
bool[] booleans = (bool[])(object)result;
resultAsBytes = MemoryMarshal.AsBytes<T>(result);
Span<byte> resultAsBytes = MemoryMarshal.AsBytes<T>(result);
for (int i = 0; i < booleans.Length; i++)
{
// We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this.
Expand Down
Loading

0 comments on commit 38709f8

Please sign in to comment.