Skip to content

Commit

Permalink
reimplement socket duplication, fix dotnet#1760
Browse files Browse the repository at this point in the history
  • Loading branch information
antonfirsov committed Jan 17, 2020
1 parent 704aa77 commit 52dc552
Show file tree
Hide file tree
Showing 16 changed files with 555 additions and 82 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using System;
using System.Net.Sockets;
using System.Runtime.InteropServices;

internal static partial class Interop
{
internal static partial class Winsock
{
[StructLayout(LayoutKind.Sequential, CharSet=CharSet.Auto)]
internal struct WSAProtocolChain
{
internal int ChainLen;
[MarshalAs(UnmanagedType.ByValArray, SizeConst=7)]
internal uint[] ChainEntries;
}

[StructLayout(LayoutKind.Sequential, CharSet=CharSet.Auto)]
internal struct WSAProtocolInfo
{
internal uint ServiceFlags1;
internal uint ServiceFlags2;
internal uint ServiceFlags3;
internal uint ServiceFlags4;
internal uint ProviderFlags;
internal Guid ProviderId;
internal uint CatalogEntryId;
internal WSAProtocolChain ProtocolChain;
internal int Version;
internal AddressFamily AddressFamily;
internal int MaxSockAddr;
internal int MinSockAddr;
internal SocketType SocketType;
internal ProtocolType ProtocolType;
internal int ProtocolMaxOffset;
internal int NetworkByteOrder;
internal int SecurityScheme;
internal uint MessageSize;
internal uint ProviderReserved;
[MarshalAs(UnmanagedType.ByValTStr, SizeConst = 256)]
internal string ProtocolName;

public static readonly int Size = Marshal.SizeOf(typeof(WSAProtocolInfo));
}

[DllImport(Interop.Libraries.Ws2_32, SetLastError = true)]
internal static extern unsafe SocketError WSADuplicateSocket(
[In] SafeHandle socketHandle,
[In] uint targetProcessId,
[In] byte* pinnedBuffer
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,18 @@ public static void ForceNonBlocking(this Socket socket, bool force)
socket.Blocking = true;
}
}

public static (Socket, Socket) CreateConnectedSocketPair()
{
using Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
listener.Listen(1);

Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
client.Connect(listener.LocalEndPoint);
Socket server = listener.Accept();

return (client, server);
}
}
}
3 changes: 3 additions & 0 deletions src/libraries/System.Net.Sockets/src/Resources/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,7 @@
<data name="net_sockets_valuetaskmisuse" xml:space="preserve">
<value>A ValueTask returned from an asynchronous socket operation was consumed concurrently. ValueTasks must only ever be awaited once. (Id: {0}).</value>
</data>
<data name="net_sockets_invalid_socketinformation" xml:space="preserve">
<value>The specified value for the socket information is invalid.</value>
</data>
</root>
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<AssemblyName>System.Net.Sockets</AssemblyName>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
Expand Down Expand Up @@ -197,6 +197,9 @@
<Compile Include="$(CommonPath)Interop\Windows\WinSock\Interop.WSAConnect.cs">
<Link>Common\Interop\Windows\WinSock\Interop.WSAConnect.cs</Link>
</Compile>
<Compile Include="..\..\Common\src\Interop\Windows\WinSock\Interop.WSADuplicateSocket.cs">
<Link>Common\Interop\Windows\WinSock\Interop.WSADuplicateSocket.cs</Link>
</Compile>
<Compile Include="$(CommonPath)Interop\Windows\WinSock\Interop.WSAGetOverlappedResult.cs">
<Link>Common\Interop\Windows\WinSock\Interop.WSAGetOverlappedResult.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ internal ThreadPoolBoundHandle GetOrAllocateThreadPoolBoundHandle(bool trySkipCo
catch (Exception exception) when (!ExceptionCheck.IsFatal(exception))
{
bool closed = IsClosed;
bool alreadyBound = !IsInvalid && !IsClosed && (exception is ArgumentException);
CloseAsIs(abortive: false);
if (closed)
{
Expand All @@ -67,6 +68,12 @@ internal ThreadPoolBoundHandle GetOrAllocateThreadPoolBoundHandle(bool trySkipCo
// instead propagate as an ObjectDisposedException.
ThrowSocketDisposedException(exception);
}

if (alreadyBound)
{
throw new InvalidOperationException("Asynchronous operations are not allowed on this socket. It's handle might have been previously bound to a Thread Pool / IOCP port.");
}

throw;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@ namespace System.Net.Sockets
{
public partial class Socket
{
public Socket(SocketInformation socketInformation)
{
//
// This constructor works in conjunction with DuplicateAndClose, which is not supported.
// See comments in DuplicateAndClose.
//
throw new PlatformNotSupportedException(SR.net_sockets_duplicateandclose_notsupported);
}

public SocketInformation DuplicateAndClose(int targetProcessId)
{
//
// DuplicateAndClose is not supported on Unix, since passing FD-s between processes
// should involve Unix Domain Sockets. This programming model is fundamentally different,
// and incompatible with the design of SocketInformation API-s.
//
throw new PlatformNotSupportedException(SR.net_sockets_duplicateandclose_notsupported);
}

partial void ValidateForMultiConnect(bool isMultiEndpoint)
{
// ValidateForMultiConnect is called before any {Begin}Connect{Async} call,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,85 @@ public partial class Socket

internal void ReplaceHandleIfNecessaryAfterFailedConnect() { /* nop on Windows */ }

public Socket(SocketInformation socketInformation)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this);

InitializeSockets();

SocketError errorCode = SocketPal.CreateSocket(socketInformation, out _handle,
ref _addressFamily, ref _socketType, ref _protocolType);

if (errorCode != SocketError.Success)
{
Debug.Assert(_handle.IsInvalid);

if (errorCode == SocketError.InvalidArgument)
{
throw new ArgumentException(SR.net_sockets_invalid_socketinformation, nameof(socketInformation));
}

// Failed to create the socket, throw.
throw new SocketException((int)errorCode);
}

if (_handle.IsInvalid)
{
throw new SocketException();
}

if (_addressFamily != AddressFamily.InterNetwork && _addressFamily != AddressFamily.InterNetworkV6)
{
throw new NotSupportedException(SR.net_invalidversion);
}

_isConnected = socketInformation.GetOption(SocketInformationOptions.Connected);
_willBlock = !socketInformation.GetOption(SocketInformationOptions.NonBlocking);
InternalSetBlocking(_willBlock);
_isListening = socketInformation.GetOption(SocketInformationOptions.Listening);

IPAddress tempAddress = _addressFamily == AddressFamily.InterNetwork ? IPAddress.Any : IPAddress.IPv6Any;
IPEndPoint ep = new IPEndPoint(tempAddress, 0);

Internals.SocketAddress socketAddress = IPEndPointExtensions.Serialize(ep);
errorCode = SocketPal.GetSockName(_handle, socketAddress.Buffer, ref socketAddress.InternalSize);
if (errorCode == SocketError.Success)
{
try
{
_rightEndPoint = ep.Create(socketAddress);
}
catch
{
}
}

if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
}

public SocketInformation DuplicateAndClose(int targetProcessId)
{
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, targetProcessId);

ThrowIfDisposed();

SocketError errorCode = SocketPal.DuplicateSocket(_handle, targetProcessId, out SocketInformation info);

if (errorCode != SocketError.Success)
{
throw new SocketException((int)errorCode);
}

info.SetOption(SocketInformationOptions.Connected, Connected);
info.SetOption(SocketInformationOptions.NonBlocking, !Blocking);
info.SetOption(SocketInformationOptions.Listening, _isListening);

Close(-1);

if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
return info;
}

private void EnsureDynamicWinsockMethods()
{
if (_dynamicWinsockMethods == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ public Socket(AddressFamily addressFamily, SocketType socketType, ProtocolType p
if (errorCode != SocketError.Success)
{
Debug.Assert(_handle.IsInvalid);

// Failed to create the socket, throw.
throw new SocketException((int)errorCode);
}
Expand All @@ -106,15 +105,6 @@ public Socket(AddressFamily addressFamily, SocketType socketType, ProtocolType p
if (NetEventSource.IsEnabled) NetEventSource.Exit(this);
}

public Socket(SocketInformation socketInformation)
{
//
// This constructor works in conjunction with DuplicateAndClose, which is not supported.
// See comments in DuplicateAndClose.
//
throw new PlatformNotSupportedException(SR.net_sockets_duplicateandclose_notsupported);
}

// Called by the class to create a socket to accept an incoming request.
private Socket(SafeSocketHandle fd)
{
Expand Down Expand Up @@ -2043,16 +2033,7 @@ private bool CanUseConnectEx(EndPoint remoteEP)
(_rightEndPoint != null || remoteEP.GetType() == typeof(IPEndPoint));
}

public SocketInformation DuplicateAndClose(int targetProcessId)
{
//
// On Windows, we cannot duplicate a socket that is bound to an IOCP. In this implementation, we *only*
// support IOCPs, so this will not work.
//
// On Unix, duplication of a socket into an arbitrary process is not supported at all.
//
throw new PlatformNotSupportedException(SR.net_sockets_duplicateandclose_notsupported);
}


internal IAsyncResult UnsafeBeginConnect(EndPoint remoteEP, AsyncCallback callback, object state, bool flowContext = false)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,24 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.Serialization;

namespace System.Net.Sockets
{
public struct SocketInformation
{
public byte[] ProtocolInformation { get; set; }
public SocketInformationOptions Options { get; set; }

internal void SetOption(SocketInformationOptions option, bool value)
{
if (value) Options |= option;
else Options &= ~option;
}

internal bool GetOption(SocketInformationOptions option)
{
return ((Options & option) != 0);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;

Expand Down Expand Up @@ -49,6 +50,43 @@ public static SocketError CreateSocket(AddressFamily addressFamily, SocketType s
return socket.IsInvalid ? GetLastSocketError() : SocketError.Success;
}

public static unsafe SocketError CreateSocket(
SocketInformation socketInformation,
out SafeSocketHandle socket,
ref AddressFamily addressFamily,
ref SocketType socketType,
ref ProtocolType protocolType)
{
if (socketInformation.ProtocolInformation == null || socketInformation.ProtocolInformation.Length < Interop.Winsock.WSAProtocolInfo.Size)
{
throw new ArgumentException(SR.net_sockets_invalid_socketinformation, nameof(socketInformation));
}

fixed (byte* pinnedBuffer = socketInformation.ProtocolInformation)
{
IntPtr handle = Interop.Winsock.WSASocketW(
(AddressFamily)(-1),
(SocketType)(-1),
(ProtocolType)(-1),
(IntPtr)pinnedBuffer, 0, Interop.Winsock.SocketConstructorFlags.WSA_FLAG_OVERLAPPED);

socket = new SafeSocketHandle(handle, ownsHandle: true);
if (NetEventSource.IsEnabled) NetEventSource.Info(null, socket);

if (socket.IsInvalid)
{
return GetLastSocketError();
}

Interop.Winsock.WSAProtocolInfo protocolInfo = Marshal.PtrToStructure<Interop.Winsock.WSAProtocolInfo>((IntPtr)pinnedBuffer);
addressFamily = protocolInfo.AddressFamily;
socketType = protocolInfo.SocketType;
protocolType = protocolInfo.ProtocolType;

return SocketError.Success;
}
}

public static SocketError SetBlocking(SafeSocketHandle handle, bool shouldBlock, out bool willBlock)
{
int intBlocking = shouldBlock ? 0 : -1;
Expand Down Expand Up @@ -1246,5 +1284,18 @@ internal static SocketError Disconnect(Socket socket, SafeSocketHandle handle, b

return errorCode;
}

internal static unsafe SocketError DuplicateSocket(SafeSocketHandle handle, int targetProcessId, out SocketInformation socketInformation)
{
socketInformation = new SocketInformation
{
ProtocolInformation = new byte[Interop.Winsock.WSAProtocolInfo.Size]
};

fixed (byte* pinnedBuffer = socketInformation.ProtocolInformation)
{
return Interop.Winsock.WSADuplicateSocket(handle, (uint)targetProcessId, pinnedBuffer);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,6 @@ public void SupportsIPv6_MatchesOSSupportsIPv6()
#pragma warning restore
}

[Fact]
public void UseOnlyOverlappedIO_AlwaysFalse()
{
using (var s = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
{
Assert.Equal(AddressFamily.InterNetwork, s.AddressFamily);
Assert.Equal(SocketType.Stream, s.SocketType);
Assert.Equal(ProtocolType.Tcp, s.ProtocolType);

Assert.False(s.UseOnlyOverlappedIO);
s.UseOnlyOverlappedIO = true;
Assert.False(s.UseOnlyOverlappedIO);
}
}

[Fact]
public void IOControl_FIONREAD_Success()
{
Expand Down
Loading

0 comments on commit 52dc552

Please sign in to comment.