Skip to content

Commit

Permalink
Add custom request header decoder API to Kestrel (#23233)
Browse files Browse the repository at this point in the history
  • Loading branch information
halter73 authored Jun 27, 2020
1 parent bfbb8b0 commit b446ab7
Show file tree
Hide file tree
Showing 20 changed files with 336 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ public KestrelServerOptions() { }
public bool DisableStringReuse { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } }
public bool EnableAltSvc { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } }
public Microsoft.AspNetCore.Server.Kestrel.Core.KestrelServerLimits Limits { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
public System.Func<string, System.Text.Encoding> RequestHeaderEncodingSelector { get { throw null; } set { } }
public Microsoft.AspNetCore.Server.Kestrel.KestrelConfigurationLoader Configure() { throw null; }
public Microsoft.AspNetCore.Server.Kestrel.KestrelConfigurationLoader Configure(Microsoft.Extensions.Configuration.IConfiguration config) { throw null; }
public Microsoft.AspNetCore.Server.Kestrel.KestrelConfigurationLoader Configure(Microsoft.Extensions.Configuration.IConfiguration config, bool reloadOnChange) { throw null; }
Expand Down
3 changes: 0 additions & 3 deletions src/Servers/Kestrel/Core/src/Internal/ConfigurationReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ internal class ConfigurationReader
private const string EndpointDefaultsKey = "EndpointDefaults";
private const string EndpointsKey = "Endpoints";
private const string UrlKey = "Url";
private const string Latin1RequestHeadersKey = "Latin1RequestHeaders";

private readonly IConfiguration _configuration;

private IDictionary<string, CertificateConfig> _certificates;
private EndpointDefaults _endpointDefaults;
private IEnumerable<EndpointConfig> _endpoints;
private bool? _latin1RequestHeaders;

public ConfigurationReader(IConfiguration configuration)
{
Expand All @@ -35,7 +33,6 @@ public ConfigurationReader(IConfiguration configuration)
public IDictionary<string, CertificateConfig> Certificates => _certificates ??= ReadCertificates();
public EndpointDefaults EndpointDefaults => _endpointDefaults ??= ReadEndpointDefaults();
public IEnumerable<EndpointConfig> Endpoints => _endpoints ??= ReadEndpoints();
public bool Latin1RequestHeaders => _latin1RequestHeaders ??= _configuration.GetValue<bool>(Latin1RequestHeadersKey);

private IDictionary<string, CertificateConfig> ReadCertificates()
{
Expand Down

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ public void Reset()
ConnectionIdFeature = ConnectionId;

HttpRequestHeaders.Reset();
HttpRequestHeaders.UseLatin1 = ServerOptions.Latin1RequestHeaders;
HttpRequestHeaders.EncodingSelector = ServerOptions.RequestHeaderEncodingSelector;
HttpRequestHeaders.ReuseHeaderValues = !ServerOptions.DisableStringReuse;
HttpResponseHeaders.Reset();
RequestHeaders = HttpRequestHeaders;
Expand Down Expand Up @@ -532,7 +532,7 @@ public void OnTrailer(ReadOnlySpan<byte> name, ReadOnlySpan<byte> value)
}

string key = name.GetHeaderName();
var valueStr = value.GetRequestHeaderStringNonNullCharacters(ServerOptions.Latin1RequestHeaders);
var valueStr = value.GetRequestHeaderString(key, HttpRequestHeaders.EncodingSelector);
RequestTrailers.Append(key, valueStr);
}

Expand Down
40 changes: 32 additions & 8 deletions src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestHeaders.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
using System.Buffers.Text;
using System.Collections;
using System.Collections.Generic;
using System.Globalization;
using System.Runtime.CompilerServices;
using System.Text;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;
Expand All @@ -17,12 +19,12 @@ internal sealed partial class HttpRequestHeaders : HttpHeaders
private long _previousBits = 0;

public bool ReuseHeaderValues { get; set; }
public bool UseLatin1 { get; set; }
public Func<string, Encoding> EncodingSelector { get; set; }

public HttpRequestHeaders(bool reuseHeaderValues = true, bool useLatin1 = false)
public HttpRequestHeaders(bool reuseHeaderValues = true, Func<string, Encoding> encodingSelector = null)
{
ReuseHeaderValues = reuseHeaderValues;
UseLatin1 = useLatin1;
EncodingSelector = encodingSelector ?? KestrelServerOptions.DefaultRequestHeaderEncodingSelector;
}

public void OnHeadersComplete()
Expand Down Expand Up @@ -87,7 +89,30 @@ private void AppendContentLength(ReadOnlySpan<byte> value)
parsed < 0 ||
consumed != value.Length)
{
KestrelBadHttpRequestException.Throw(RequestRejectionReason.InvalidContentLength, value.GetRequestHeaderStringNonNullCharacters(UseLatin1));
KestrelBadHttpRequestException.Throw(RequestRejectionReason.InvalidContentLength, value.GetRequestHeaderString(HeaderNames.ContentLength, EncodingSelector));
}

_contentLength = parsed;
}

[MethodImpl(MethodImplOptions.NoInlining)]
private void AppendContentLengthCustomEncoding(ReadOnlySpan<byte> value, Encoding customEncoding)
{
if (_contentLength.HasValue)
{
KestrelBadHttpRequestException.Throw(RequestRejectionReason.MultipleContentLengths);
}

// long.MaxValue = 9223372036854775807 (19 chars)
Span<char> decodedChars = stackalloc char[20];
var numChars = customEncoding.GetChars(value, decodedChars);
long parsed = -1;

if (numChars > 19 ||
!long.TryParse(decodedChars.Slice(0, numChars), NumberStyles.Integer, CultureInfo.InvariantCulture, out parsed) ||
parsed < 0)
{
KestrelBadHttpRequestException.Throw(RequestRejectionReason.InvalidContentLength, value.GetRequestHeaderString(HeaderNames.ContentLength, EncodingSelector));
}

_contentLength = parsed;
Expand All @@ -108,11 +133,10 @@ private bool AddValueUnknown(string key, StringValues value)
}

[MethodImpl(MethodImplOptions.NoInlining)]
private unsafe void AppendUnknownHeaders(ReadOnlySpan<byte> name, string valueString)
private unsafe void AppendUnknownHeaders(string name, string valueString)
{
string key = name.GetHeaderName();
Unknown.TryGetValue(key, out var existing);
Unknown[key] = AppendValue(existing, valueString);
Unknown.TryGetValue(name, out var existing);
Unknown[name] = AppendValue(existing, valueString);
}

public Enumerator GetEnumerator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ internal static partial class HttpUtilities
private const ulong _http10VersionLong = 3471766442030158920; // GetAsciiStringAsLong("HTTP/1.0"); const results in better codegen
private const ulong _http11VersionLong = 3543824036068086856; // GetAsciiStringAsLong("HTTP/1.1"); const results in better codegen

private static readonly UTF8EncodingSealed HeaderValueEncoding = new UTF8EncodingSealed();
private static readonly UTF8EncodingSealed DefaultRequestHeaderEncoding = new UTF8EncodingSealed();
private static readonly SpanAction<char, IntPtr> _getHeaderName = GetHeaderName;
private static readonly SpanAction<char, IntPtr> _getAsciiStringNonNullCharacters = GetAsciiStringNonNullCharacters;

Expand Down Expand Up @@ -120,11 +120,8 @@ public static unsafe string GetAsciiStringNonNullCharacters(this ReadOnlySpan<by
}
}

public static string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte> span)
=> GetAsciiOrUTF8StringNonNullCharacters((ReadOnlySpan<byte>)span);

public static string GetAsciiOrUTF8StringNonNullCharacters(this ReadOnlySpan<byte> span)
=> StringUtilities.GetAsciiOrUTF8StringNonNullCharacters(span, HeaderValueEncoding);
=> StringUtilities.GetAsciiOrUTF8StringNonNullCharacters(span, DefaultRequestHeaderEncoding);

private static unsafe void GetAsciiStringNonNullCharacters(Span<char> buffer, IntPtr state)
{
Expand All @@ -139,8 +136,34 @@ private static unsafe void GetAsciiStringNonNullCharacters(Span<char> buffer, In
}
}

public static string GetRequestHeaderStringNonNullCharacters(this ReadOnlySpan<byte> span, bool useLatin1) =>
useLatin1 ? span.GetLatin1StringNonNullCharacters() : span.GetAsciiOrUTF8StringNonNullCharacters(HeaderValueEncoding);
public static string GetRequestHeaderString(this ReadOnlySpan<byte> span, string name, Func<string, Encoding> encodingSelector)
{
if (ReferenceEquals(KestrelServerOptions.DefaultRequestHeaderEncodingSelector, encodingSelector))
{
return span.GetAsciiOrUTF8StringNonNullCharacters(DefaultRequestHeaderEncoding);
}

var encoding = encodingSelector(name);

if (encoding is null)
{
return span.GetAsciiOrUTF8StringNonNullCharacters(DefaultRequestHeaderEncoding);
}

if (ReferenceEquals(encoding, Encoding.Latin1))
{
return span.GetLatin1StringNonNullCharacters();
}

try
{
return encoding.GetString(span);
}
catch (DecoderFallbackException ex)
{
throw new InvalidOperationException(ex.Message, ex);
}
}

public static string GetAsciiStringEscaped(this ReadOnlySpan<byte> span, int maxChars)
{
Expand Down
2 changes: 0 additions & 2 deletions src/Servers/Kestrel/Core/src/KestrelConfigurationLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,6 @@ public void Load()

ConfigurationReader = new ConfigurationReader(Configuration);

Options.Latin1RequestHeaders = ConfigurationReader.Latin1RequestHeaders;

LoadDefaultCert(ConfigurationReader);

foreach (var endpoint in ConfigurationReader.Endpoints)
Expand Down
16 changes: 13 additions & 3 deletions src/Servers/Kestrel/Core/src/KestrelServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,19 @@ public class KestrelServer : IServer

private IDisposable _configChangedRegistration;

public KestrelServer(IOptions<KestrelServerOptions> options, IEnumerable<IConnectionListenerFactory> transportFactories, ILoggerFactory loggerFactory)
public KestrelServer(
IOptions<KestrelServerOptions> options,
IEnumerable<IConnectionListenerFactory> transportFactories,
ILoggerFactory loggerFactory)
: this(transportFactories, null, CreateServiceContext(options, loggerFactory))
{
}

public KestrelServer(IOptions<KestrelServerOptions> options, IEnumerable<IConnectionListenerFactory> transportFactories, IEnumerable<IMultiplexedConnectionListenerFactory> multiplexedFactories, ILoggerFactory loggerFactory)
public KestrelServer(
IOptions<KestrelServerOptions> options,
IEnumerable<IConnectionListenerFactory> transportFactories,
IEnumerable<IMultiplexedConnectionListenerFactory> multiplexedFactories,
ILoggerFactory loggerFactory)
: this(transportFactories, multiplexedFactories, CreateServiceContext(options, loggerFactory))
{
}
Expand All @@ -52,7 +59,10 @@ internal KestrelServer(IEnumerable<IConnectionListenerFactory> transportFactorie
}

// For testing
internal KestrelServer(IEnumerable<IConnectionListenerFactory> transportFactories, IEnumerable<IMultiplexedConnectionListenerFactory> multiplexedFactories, ServiceContext serviceContext)
internal KestrelServer(
IEnumerable<IConnectionListenerFactory> transportFactories,
IEnumerable<IMultiplexedConnectionListenerFactory> multiplexedFactories,
ServiceContext serviceContext)
{
if (transportFactories == null)
{
Expand Down
40 changes: 27 additions & 13 deletions src/Servers/Kestrel/Core/src/KestrelServerOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Net;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using Microsoft.AspNetCore.Certificates.Generation;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
Expand All @@ -22,6 +23,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core
/// </summary>
public class KestrelServerOptions
{
// internal to fast-path header decoding when RequestHeaderEncodingSelector is unchanged.
internal static readonly Func<string, Encoding> DefaultRequestHeaderEncodingSelector = _ => null;

private Func<string, Encoding> _requestHeaderEncodingSelector = DefaultRequestHeaderEncodingSelector;

// The following two lists configure the endpoints that Kestrel should listen to. If both lists are empty, the "urls" config setting (e.g. UseUrls) is used.
internal List<ListenOptions> CodeBackedListenOptions { get; } = new List<ListenOptions>();
internal List<ListenOptions> ConfigurationBackedListenOptions { get; } = new List<ListenOptions>();
Expand Down Expand Up @@ -65,6 +71,24 @@ public class KestrelServerOptions
/// </remarks>
public bool DisableStringReuse { get; set; } = false;

/// <summary>
/// Controls whether to return the AltSvcHeader from on an HTTP/2 or lower response for HTTP/3
/// </summary>
/// <remarks>
/// Defaults to false.
/// </remarks>
public bool EnableAltSvc { get; set; } = false;

/// <summary>
/// Gets or sets a callback that returns the <see cref="Encoding"/> to decode the value for the specified request header name,
/// or <see langword="null"/> to use the default <see cref="UTF8Encoding"/>.
/// </summary>
public Func<string, Encoding> RequestHeaderEncodingSelector
{
get => _requestHeaderEncodingSelector;
set => _requestHeaderEncodingSelector = value ?? throw new ArgumentNullException(nameof(value));
}

/// <summary>
/// Enables the Listen options callback to resolve and use services registered by the application during startup.
/// Typically initialized by UseKestrel()"/>.
Expand All @@ -78,15 +102,10 @@ public class KestrelServerOptions

/// <summary>
/// Provides a configuration source where endpoints will be loaded from on server start.
/// The default is null.
/// The default is <see langword="null"/>.
/// </summary>
public KestrelConfigurationLoader ConfigurationLoader { get; set; }

/// <summary>
/// Controls whether to return the AltSvcHeader from on an HTTP/2 or lower response for HTTP/3
/// </summary>
public bool EnableAltSvc { get; set; } = false;

/// <summary>
/// A default configuration action for all endpoints. Use for Listen, configuration, the default url, and URLs.
/// </summary>
Expand All @@ -107,11 +126,6 @@ public class KestrelServerOptions
/// </summary>
internal bool IsDevCertLoaded { get; set; }

/// <summary>
/// Treat request headers as Latin-1 or ISO/IEC 8859-1 instead of UTF-8.
/// </summary>
internal bool Latin1RequestHeaders { get; set; }

/// <summary>
/// Specifies a configuration Action to run for each newly created endpoint. Calling this again will replace
/// the prior action.
Expand Down Expand Up @@ -159,7 +173,7 @@ private void EnsureDefaultCert()
if (DefaultCertificate == null && !IsDevCertLoaded)
{
IsDevCertLoaded = true; // Only try once
var logger = ApplicationServices.GetRequiredService<ILogger<KestrelServer>>();
var logger = ApplicationServices!.GetRequiredService<ILogger<KestrelServer>>();
try
{
DefaultCertificate = CertificateManager.Instance.ListCertificates(StoreName.My, StoreLocation.CurrentUser, isValid: true)
Expand Down Expand Up @@ -220,7 +234,7 @@ private void EnsureDefaultCert()
/// </summary>
/// <param name="config">The configuration section for Kestrel.</param>
/// <param name="reloadOnChange">
/// If <see langword="true" />, Kestrel will dynamically update endpoint bindings when configuration changes.
/// If <see langword="true"/>, Kestrel will dynamically update endpoint bindings when configuration changes.
/// This will only reload endpoints defined in the "Endpoints" section of your <paramref name="config"/>. Endpoints defined in code will not be reloaded.
/// </param>
/// <returns>A <see cref="KestrelConfigurationLoader"/> for further endpoint configuration.</returns>
Expand Down
Loading

0 comments on commit b446ab7

Please sign in to comment.