Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable TLS certificate validation when disableTLSCertificateValidation is set in the config file for a source. #5514

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using NuGet.Configuration;
Expand Down Expand Up @@ -56,6 +58,11 @@ private HttpHandlerResourceV3 CreateResource(PackageSource packageSource)
AutomaticDecompression = (DecompressionMethods.GZip | DecompressionMethods.Deflate),
};

if (packageSource.DisableTLSCertificateValidation)
{
clientHandler.ServerCertificateCustomValidationCallback = (HttpRequestMessage message, X509Certificate2 cert, X509Chain chain, SslPolicyErrors errors) => true;
}

#if IS_DESKTOP
if (packageSource.MaxHttpRequestsPerSource > 0)
{
Expand Down
2 changes: 1 addition & 1 deletion src/NuGet.Core/NuGet.Protocol/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -530,4 +530,4 @@ The "s" should be localized to the abbreviation for seconds.</comment>
{1} - number of vulerabilitiy files reported
{2} - max count allowed</comment>
</data>
</root>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Security;
using System.Security.Authentication;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using Moq;
using NuGet.Configuration;
using NuGet.Protocol.Core.Types;
using NuGet.Test.Server;
using NuGet.Test.Utility;
using Org.BouncyCastle.Asn1.X509;
using Xunit;

namespace NuGet.Protocol.Tests
Expand Down Expand Up @@ -124,5 +128,119 @@ static IEnumerable<DelegatingHandler> GetDelegatingHandlers(HttpMessageHandler h
}
}
}

[Fact]
public async Task TryCreate_WhenCertificateValidationIsDisabled_HandlerShouldNotBeNull()
{
// Arrange
Mock<IProxyCache> proxyCache = new();
proxyCache.Setup(pc => pc.GetProxy(It.IsAny<Uri>())).Returns((IWebProxy)null);
PackageSource packageSource = new(_testPackageSourceURL, "source")
{
DisableTLSCertificateValidation = true
};
SourceRepository sourceRepository = new(packageSource, Array.Empty<INuGetResourceProvider>());
HttpHandlerResourceV3Provider target = new(proxyCache.Object);

// Act
var result = await target.TryCreate(sourceRepository, CancellationToken.None);

// Assert
result.Item1.Should().BeTrue();
HttpHandlerResourceV3 resource = (HttpHandlerResourceV3)result.Item2;
resource.Should().NotBeNull();
HttpClientHandler clientHandler = resource.ClientHandler;

clientHandler.ServerCertificateCustomValidationCallback.Should().NotBeNull();
var callbackResult = clientHandler.ServerCertificateCustomValidationCallback.Invoke(null, null, null, SslPolicyErrors.RemoteCertificateChainErrors
& SslPolicyErrors.RemoteCertificateNameMismatch
& SslPolicyErrors.RemoteCertificateNotAvailable
& SslPolicyErrors.None);
callbackResult.Should().BeTrue();
}

[Fact]
public async Task TryCreate_WhenCertificateValidationIsEnabled_HandlerShouldBeNull()
{
// Arrange
Mock<IProxyCache> proxyCache = new();
proxyCache.Setup(pc => pc.GetProxy(It.IsAny<Uri>())).Returns((IWebProxy)null);
PackageSource packageSource = new(_testPackageSourceURL, "source")
{
DisableTLSCertificateValidation = false
};
SourceRepository sourceRepository = new(packageSource, Array.Empty<INuGetResourceProvider>());
HttpHandlerResourceV3Provider target = new(proxyCache.Object);

// Act
var result = await target.TryCreate(sourceRepository, CancellationToken.None);

// Assert
result.Item1.Should().BeTrue();
HttpHandlerResourceV3 resource = (HttpHandlerResourceV3)result.Item2;
resource.Should().NotBeNull();
HttpClientHandler clientHandler = resource.ClientHandler;

clientHandler.ServerCertificateCustomValidationCallback.Should().BeNull();
}

[Fact]
public async Task TryCreate_WhenCertificateValidationIsNotDisabled_ClientHandlerThrowsAnException()
{
// Arrange
TcpListenerServer server = new()
{
Mode = TestServerMode.InvalidTLSCertificate
};

Mock<IProxyCache> proxyCache = new();
proxyCache.Setup(pc => pc.GetProxy(It.IsAny<Uri>())).Returns((IWebProxy)null);
PackageSource packageSource = new(_testPackageSourceURL, "source");
SourceRepository sourceRepository = new(packageSource, Array.Empty<INuGetResourceProvider>());
HttpHandlerResourceV3Provider target = new(proxyCache.Object);
var result = await target.TryCreate(sourceRepository, CancellationToken.None);
HttpHandlerResourceV3 resource = (HttpHandlerResourceV3)result.Item2;
HttpClientHandler clientHandler = resource.ClientHandler;
var client = new HttpClient(clientHandler);

await server.ExecuteAsync(async uri =>
{
// Act & Assert
var exception = await Assert.ThrowsAsync<HttpRequestException>(async () => await client.GetAsync(uri));
return 0;
});
}

[Fact]
public async Task TryCreate_WhenCertificateValidationIsDisabled_ClientHandlerDoesNotThrowAnException()
{
// Arrange
TcpListenerServer server = new()
{
Mode = TestServerMode.InvalidTLSCertificate
};
Mock<IProxyCache> proxyCache = new();
proxyCache.Setup(pc => pc.GetProxy(It.IsAny<Uri>())).Returns((IWebProxy)null);
PackageSource packageSource = new(_testPackageSourceURL, "source")
{
DisableTLSCertificateValidation = true
};
SourceRepository sourceRepository = new(packageSource, Array.Empty<INuGetResourceProvider>());
HttpHandlerResourceV3Provider target = new(proxyCache.Object);
var result = await target.TryCreate(sourceRepository, CancellationToken.None);
HttpHandlerResourceV3 resource = (HttpHandlerResourceV3)result.Item2;
HttpClientHandler clientHandler = resource.ClientHandler;
var client = new HttpClient(clientHandler);

await server.ExecuteAsync(async uri =>
{
// Act
var response = await client.GetAsync(uri);

// Assert
Assert.True(response.IsSuccessStatusCode);
return 0;
});
}
}
}
3 changes: 2 additions & 1 deletion test/TestUtilities/Test.Utility/TestServer/ITestServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ public enum TestServerMode
ConnectFailure,
ServerProtocolViolation,
NameResolutionFailure,
SlowResponseBody
SlowResponseBody,
InvalidTLSCertificate,
}

public interface ITestServer
Expand Down
65 changes: 64 additions & 1 deletion test/TestUtilities/Test.Utility/TestServer/TcpListenerServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Net.Security;
using System.Security.Authentication;

namespace NuGet.Test.Server
{
public class TcpListenerServer : ITestServer
{
private X509Certificate2 _tlsCertificate;
public async Task<T> ExecuteAsync<T>(Func<string, Task<T>> action)
{
Func<TcpListener, CancellationToken, Task> startServer;
Expand All @@ -25,6 +30,9 @@ public async Task<T> ExecuteAsync<T>(Func<string, Task<T>> action)
case TestServerMode.SlowResponseBody:
startServer = StartSlowResponseBody;
break;
case TestServerMode.InvalidTLSCertificate:
startServer = StartInvalidTlsCertificateServer;
break;

default:
throw new InvalidOperationException($"The mode {Mode} is not supported by this server.");
Expand All @@ -39,7 +47,16 @@ public async Task<T> ExecuteAsync<T>(Func<string, Task<T>> action)
var tcpListener = new TcpListener(IPAddress.Loopback, port);
tcpListener.Start();
var serverTask = startServer(tcpListener, serverCts.Token);
var address = $"http://localhost:{port}/";
string address;

if (Mode == TestServerMode.InvalidTLSCertificate)
{
address = $"https://localhost:{port}/";
}
else
{
address = $"http://localhost:{port}/";
}

// execute the caller's action
var result = await action(address);
Expand All @@ -53,6 +70,52 @@ public async Task<T> ExecuteAsync<T>(Func<string, Task<T>> action)
CancellationToken.None);
}

public TcpListenerServer()
{
_tlsCertificate = GenerateSelfSignedCertificate();
}

private static X509Certificate2 GenerateSelfSignedCertificate()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GenerateSelfSignedCertificate method has same logic in this class and also SelfSignedCertificateMockServer. If possible, please consider removing the code duplication in a follow-up PR.

{
using (var rsa = RSA.Create(2048))
{
var request = new CertificateRequest("cn=test", rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1);
var start = DateTime.UtcNow;
var end = DateTime.UtcNow.AddYears(1);
var cert = request.CreateSelfSigned(start, end);
var certBytes = cert.Export(X509ContentType.Pfx, "password");

return new X509Certificate2(certBytes, "password", X509KeyStorageFlags.Exportable);
}
}

private async Task StartInvalidTlsCertificateServer(TcpListener tcpListener, CancellationToken token)
{
while (!token.IsCancellationRequested)
{
using (var client = await Task.Run(tcpListener.AcceptTcpClientAsync, token))
using (var sslStream = new SslStream(client.GetStream(), false))
{
sslStream.AuthenticateAsServer(_tlsCertificate, clientCertificateRequired: false, SslProtocols.Tls12, checkCertificateRevocation: true);
using (var reader = new StreamReader(sslStream, Encoding.ASCII, false, 1))
using (var writer = new StreamWriter(sslStream, Encoding.ASCII, 1, false))
{
while (!string.IsNullOrEmpty(reader.ReadLine()))
{
}

string content = "{}";
writer.WriteLine("HTTP/1.1 200 OK");
writer.WriteLine($"Date: {DateTimeOffset.UtcNow:R}");
writer.WriteLine($"Content-Length: {content.Length}");
writer.WriteLine("Content-Type: application/json");
writer.WriteLine();
writer.WriteLine(content);
}
}
}
}

public TestServerMode Mode { get; set; } = TestServerMode.ServerProtocolViolation;
public TimeSpan SleepDuration { get; set; } = TimeSpan.FromSeconds(110);

Expand Down