diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index e1891bef916f4..3ba24e90cf101 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -677,7 +677,6 @@ public ValueTask SendToAsync(ReadOnlyMemory buffer, SocketFlags socke Debug.Assert(saea.BufferList == null); saea.SetBuffer(MemoryMarshal.AsMemory(buffer)); saea.SocketFlags = socketFlags; - saea._socketAddress = null; saea.RemoteEndPoint = remoteEP; saea.WrapExceptionsForNetworkStream = false; return saea.SendToAsync(this, cancellationToken); @@ -709,8 +708,17 @@ public ValueTask SendToAsync(ReadOnlyMemory buffer, SocketFlags socke saea.SetBuffer(MemoryMarshal.AsMemory(buffer)); saea.SocketFlags = socketFlags; saea._socketAddress = socketAddress; + saea.RemoteEndPoint = null; saea.WrapExceptionsForNetworkStream = false; - return saea.SendToAsync(this, cancellationToken); + try + { + return saea.SendToAsync(this, cancellationToken); + } + finally + { + // detach user provided SA so we do not accidentally stomp on it later. + saea._socketAddress = null; + } } /// diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 11b8674d681f3..a8c95005154c9 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -3095,14 +3095,22 @@ private bool SendToAsync(SocketAsyncEventArgs e, CancellationToken cancellationT ArgumentNullException.ThrowIfNull(e); EndPoint? endPointSnapshot = e.RemoteEndPoint; - if (e._socketAddress == null) + + // RemoteEndPoint should be set unless somebody used SendTo with their own SA. + // In that case RemoteEndPoint will be null and we take provided SA as given. + if (endPointSnapshot == null && e._socketAddress == null) { - if (endPointSnapshot == null) - { - throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e)); - } + throw new ArgumentException(SR.Format(SR.InvalidNullArgument, "e.RemoteEndPoint"), nameof(e)); + } - // Prepare SocketAddress + if (e._socketAddress != null && endPointSnapshot is IPEndPoint ipep && e._socketAddress.Family == endPointSnapshot?.AddressFamily) + { + // we have matching SocketAddress. Since this is only used internally, it is ok to overwrite it without + ipep.Serialize(e._socketAddress.Buffer.Span); + } + else if (endPointSnapshot != null) + { + // Prepare new SocketAddress e._socketAddress = Serialize(ref endPointSnapshot); } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs index e94d862571a0f..78dd22e5eda7b 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs @@ -923,7 +923,12 @@ internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags case SocketAsyncOperation.ReceiveFrom: // Deal with incoming address. UpdateReceivedSocketAddress(_socketAddress!); - if (_remoteEndPoint != null && !SocketAddressExtensions.Equals(_socketAddress!, _remoteEndPoint)) + if (_remoteEndPoint == null) + { + // detach user provided SA as it was updated in place. + _socketAddress = null; + } + else if (!SocketAddressExtensions.Equals(_socketAddress!, _remoteEndPoint)) { try { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs index bf0ad14658869..7a3c33b64bf79 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendTo.cs @@ -173,6 +173,35 @@ public void SendToAsync_NullAsyncEventArgs_Throws_ArgumentNullException() public sealed class SendTo_Task : SendTo { public SendTo_Task(ITestOutputHelper output) : base(output) { } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SendTo_DifferentEP_Success(bool ipv4) + { + IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; + IPEndPoint remoteEp = new IPEndPoint(address, 0); + + using Socket receiver1 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket receiver2 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket sender = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + + receiver1.BindToAnonymousPort(address); + receiver2.BindToAnonymousPort(address); + + byte[] sendBuffer = new byte[32]; + var receiveInternalBuffer = new byte[sendBuffer.Length]; + ArraySegment receiveBuffer = new ArraySegment(receiveInternalBuffer, 0, receiveInternalBuffer.Length); + + + await sender.SendToAsync(sendBuffer, SocketFlags.None, receiver1.LocalEndPoint); + SocketReceiveFromResult result = await ReceiveFromAsync(receiver1, receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout); + Assert.Equal(sendBuffer.Length, result.ReceivedBytes); + + await sender.SendToAsync(sendBuffer, SocketFlags.None, receiver2.LocalEndPoint); + result = await ReceiveFromAsync(receiver2, receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout); + Assert.Equal(sendBuffer.Length, result.ReceivedBytes); + } } public sealed class SendTo_CancellableTask : SendTo diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs index ded34276f322f..3d865cb864570 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs @@ -895,5 +895,52 @@ void CreateSocketAsyncEventArgs() // separated out so that JIT doesn't extend li return cwt.Count() == 0; // validate that the cwt becomes empty }, 30_000)); } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SendTo_DifferentEP_Success(bool ipv4) + { + IPAddress address = ipv4 ? IPAddress.Loopback : IPAddress.IPv6Loopback; + IPEndPoint remoteEp = new IPEndPoint(address, 0); + + using Socket receiver1 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket receiver2 = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + using Socket sender = new Socket(address.AddressFamily, SocketType.Dgram, ProtocolType.Udp); + + receiver1.BindToAnonymousPort(address); + receiver2.BindToAnonymousPort(address); + + byte[] sendBuffer = new byte[32]; + var receiveInternalBuffer = new byte[sendBuffer.Length]; + ArraySegment receiveBuffer = new ArraySegment(receiveInternalBuffer, 0, receiveInternalBuffer.Length); + + using SocketAsyncEventArgs saea = new SocketAsyncEventArgs(); + ManualResetEventSlim mres = new ManualResetEventSlim(false); + + saea.SetBuffer(sendBuffer); + saea.RemoteEndPoint = receiver1.LocalEndPoint; + saea.Completed += delegate { mres.Set(); }; + if (sender.SendToAsync(saea)) + { + // did not finish synchronously. + mres.Wait(); + } + + SocketReceiveFromResult result = await receiver1.ReceiveFromAsync(receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout); + Assert.Equal(sendBuffer.Length, result.ReceivedBytes); + mres.Reset(); + + + saea.RemoteEndPoint = receiver2.LocalEndPoint; + if (sender.SendToAsync(saea)) + { + // did not finish synchronously. + mres.Wait(); + } + + result = await receiver2.ReceiveFromAsync(receiveBuffer, remoteEp).WaitAsync(TestSettings.PassingTestTimeout); + Assert.Equal(sendBuffer.Length, result.ReceivedBytes); + } } }