Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ClientDisconnectedSource #48188

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions src/Workspaces/Remote/Core/RemoteCallback.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using MessagePack;
using Microsoft.CodeAnalysis.ErrorReporting;
using Nerdbank.Streams;
using Newtonsoft.Json;
using Roslyn.Utilities;
using StreamJsonRpc;

Expand All @@ -30,12 +27,9 @@ internal readonly struct RemoteCallback<T>
{
private readonly T _callback;

public readonly CancellationTokenSource ClientDisconnectedSource;

public RemoteCallback(T callback, CancellationTokenSource clientDisconnectedSource)
public RemoteCallback(T callback)
{
_callback = callback;
ClientDisconnectedSource = clientDisconnectedSource;
}

public async ValueTask InvokeAsync(Func<T, CancellationToken, ValueTask> invocation, CancellationToken cancellationToken)
Expand All @@ -46,7 +40,7 @@ public async ValueTask InvokeAsync(Func<T, CancellationToken, ValueTask> invocat
}
catch (Exception exception) when (ReportUnexpectedException(exception, cancellationToken))
{
throw OnUnexpectedException(cancellationToken);
throw OnUnexpectedException(exception, cancellationToken);
}
}

Expand All @@ -58,7 +52,7 @@ public async ValueTask<TResult> InvokeAsync<TResult>(Func<T, CancellationToken,
}
catch (Exception exception) when (ReportUnexpectedException(exception, cancellationToken))
{
throw OnUnexpectedException(cancellationToken);
throw OnUnexpectedException(exception, cancellationToken);
}
}

Expand All @@ -76,7 +70,7 @@ public async ValueTask<TResult> InvokeAsync<TResult>(
}
catch (Exception exception) when (ReportUnexpectedException(exception, cancellationToken))
{
throw OnUnexpectedException(cancellationToken);
throw OnUnexpectedException(exception, cancellationToken);
}
}

Expand All @@ -87,7 +81,7 @@ public async ValueTask<TResult> InvokeAsync<TResult>(
// 3) Remote exception - an exception was thrown by the callee
// 4) Cancelation
//
private bool ReportUnexpectedException(Exception exception, CancellationToken cancellationToken)
private static bool ReportUnexpectedException(Exception exception, CancellationToken cancellationToken)
{
if (exception is IOException)
{
Expand All @@ -99,14 +93,10 @@ private bool ReportUnexpectedException(Exception exception, CancellationToken ca
{
if (cancellationToken.IsCancellationRequested)
{
// Cancellation was requested and expected
return false;
}

// It is not guaranteed that RPC only throws OCE when our token is signaled.
// Signal the cancelation source that our token is linked to and throw new cancellation
// exception in OnUnexpectedException.
ClientDisconnectedSource.Cancel();

return true;
}

Expand All @@ -118,20 +108,24 @@ private bool ReportUnexpectedException(Exception exception, CancellationToken ca
// as any observation of ConnectionLostException indicates a bug (e.g. https://github.com/microsoft/vs-streamjsonrpc/issues/549).
if (exception is ConnectionLostException)
{
ClientDisconnectedSource.Cancel();

return true;
}

// Indicates bug on client side or in serialization, report NFW and propagate the exception.
return FatalError.ReportWithoutCrashAndPropagate(exception);
}

private static Exception OnUnexpectedException(CancellationToken cancellationToken)
private static Exception OnUnexpectedException(Exception exception, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

// If this is hit the cancellation token passed to the service implementation did not use the correct token.
if (exception is ConnectionLostException)
{
throw new OperationCanceledException(exception.Message, exception);
}

// If this is hit the cancellation token passed to the service implementation did not use the correct token,
// and the resulting exception was not a ConnectionLostException.
return ExceptionUtilities.Unreachable;
}
}
Expand Down
14 changes: 10 additions & 4 deletions src/Workspaces/Remote/Core/RemoteHostAssetSerialization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ static void WriteAsset(ObjectWriter writer, ISerializerService serializer, Check
}
}

public static ValueTask<ImmutableArray<(Checksum, object)>> ReadDataAsync(PipeReader pipeReader, int scopeId, ISet<Checksum> checksums, ISerializerService serializerService, CancellationToken cancellationToken)
public static async ValueTask<ImmutableArray<(Checksum, object)>> ReadDataAsync(PipeReader pipeReader, int scopeId, ISet<Checksum> checksums, ISerializerService serializerService, CancellationToken cancellationToken)
{
// Workaround for ObjectReader not supporting async reading.
// Unless we read from the RPC stream asynchronously and with cancallation support we might hang when the server cancels.
Expand All @@ -98,7 +98,7 @@ static void WriteAsset(ObjectWriter writer, ISerializerService serializer, Check
Exception? exception = null;

// start a task on a thread pool thread copying from the RPC pipe to a local pipe:
Task.Run(async () =>
var copyTask = Task.Run(async () =>
{
try
{
Expand All @@ -113,20 +113,26 @@ static void WriteAsset(ObjectWriter writer, ISerializerService serializer, Check
await localPipe.Writer.CompleteAsync(exception).ConfigureAwait(false);
await pipeReader.CompleteAsync(exception).ConfigureAwait(false);
}
}, cancellationToken).Forget();
}, cancellationToken);

// blocking read from the local pipe on the current thread:
try
{
using var stream = localPipe.Reader.AsStream(leaveOpen: false);
return new(ReadData(stream, scopeId, checksums, serializerService, cancellationToken));
return ReadData(stream, scopeId, checksums, serializerService, cancellationToken);
}
catch (EndOfStreamException)
{
cancellationToken.ThrowIfCancellationRequested();

throw exception ?? ExceptionUtilities.Unreachable;
}
finally
{
// Make sure to complete the copy and pipes before returning, otherwise the caller could complete the
// reader and/or writer while they are still in use.
await copyTask.ConfigureAwait(false);
}
}

public static ImmutableArray<(Checksum, object)> ReadData(Stream stream, int scopeId, ISet<Checksum> checksums, ISerializerService serializerService, CancellationToken cancellationToken)
Expand Down
8 changes: 5 additions & 3 deletions src/Workspaces/Remote/Core/SolutionAssetProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public async ValueTask GetAssetsAsync(PipeWriter pipeWriter, int scopeId, Checks
// (non-contiguous) memory allocated for the underlying buffers. The amount of memory is bounded by the total size of the serialized assets.
var localPipe = new Pipe(RemoteHostAssetSerialization.PipeOptionsWithUnlimitedWriterBuffer);

Task.Run(() =>
var task1 = Task.Run(() =>
{
try
{
Expand All @@ -71,12 +71,14 @@ public async ValueTask GetAssetsAsync(PipeWriter pipeWriter, int scopeId, Checks
{
// no-op
}
}, cancellationToken).Forget();
}, cancellationToken);

// Complete RPC once we send the initial piece of data and start waiting for the writer to send more,
// so the client can start reading from the stream. Once CopyPipeDataAsync completes the pipeWriter
// the corresponding client-side pipeReader will complete and the data transfer will be finished.
CopyPipeDataAsync().Forget();
var task2 = CopyPipeDataAsync();

await Task.WhenAll(task1, task2).ConfigureAwait(false);

async Task CopyPipeDataAsync()
{
Expand Down
8 changes: 3 additions & 5 deletions src/Workspaces/Remote/ServiceHub/Host/SolutionAssetSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ namespace Microsoft.CodeAnalysis.Remote
internal sealed class SolutionAssetSource : IAssetSource
{
private readonly ServiceBrokerClient _client;
private readonly CancellationTokenSource _clientDisconnectedSource;

public SolutionAssetSource(ServiceBrokerClient client, CancellationTokenSource clientDisconnectedSource)
public SolutionAssetSource(ServiceBrokerClient client)
{
_client = client;
_clientDisconnectedSource = clientDisconnectedSource;
}

public async ValueTask<ImmutableArray<(Checksum, object)>> GetAssetsAsync(int scopeId, ISet<Checksum> checksums, ISerializerService serializerService, CancellationToken cancellationToken)
Expand All @@ -35,7 +33,7 @@ public SolutionAssetSource(ServiceBrokerClient client, CancellationTokenSource c
using var provider = await _client.GetProxyAsync<ISolutionAssetProvider>(SolutionAssetProvider.ServiceDescriptor, cancellationToken).ConfigureAwait(false);
Contract.ThrowIfNull(provider.Proxy);

return await new RemoteCallback<ISolutionAssetProvider>(provider.Proxy, _clientDisconnectedSource).InvokeAsync(
return await new RemoteCallback<ISolutionAssetProvider>(provider.Proxy).InvokeAsync(
(proxy, pipeWriter, cancellationToken) => proxy.GetAssetsAsync(pipeWriter, scopeId, checksums.ToArray(), cancellationToken),
(pipeReader, cancellationToken) => RemoteHostAssetSerialization.ReadDataAsync(pipeReader, scopeId, checksums, serializerService, cancellationToken),
cancellationToken).ConfigureAwait(false);
Expand All @@ -49,7 +47,7 @@ public async ValueTask<bool> IsExperimentEnabledAsync(string experimentName, Can
using var provider = await _client.GetProxyAsync<ISolutionAssetProvider>(SolutionAssetProvider.ServiceDescriptor, cancellationToken).ConfigureAwait(false);
Contract.ThrowIfNull(provider.Proxy);

return await new RemoteCallback<ISolutionAssetProvider>(provider.Proxy, _clientDisconnectedSource).InvokeAsync(
return await new RemoteCallback<ISolutionAssetProvider>(provider.Proxy).InvokeAsync(
(self, cancellationToken) => provider.Proxy.IsExperimentEnabledAsync(experimentName, cancellationToken),
cancellationToken).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.ServiceHub.Framework;
using Microsoft.ServiceHub.Framework.Services;
Expand Down Expand Up @@ -74,7 +73,7 @@ internal TService Create(
var serviceHubTraceSource = (TraceSource)hostProvidedServices.GetService(typeof(TraceSource));
var serverConnection = descriptor.WithTraceSource(serviceHubTraceSource).ConstructRpcConnection(pipe);

var args = new ServiceConstructionArguments(hostProvidedServices, serviceBroker, new CancellationTokenSource());
var args = new ServiceConstructionArguments(hostProvidedServices, serviceBroker);
var service = CreateService(args, descriptor, serverConnection, serviceActivationOptions.ClientRpcTarget);

serverConnection.AddLocalRpcTarget(service);
Expand Down Expand Up @@ -106,7 +105,7 @@ protected sealed override TService CreateService(
{
Contract.ThrowIfNull(descriptor.ClientInterface);
var callback = (TCallback)(clientRpcTarget ?? serverConnection.ConstructRpcClient(descriptor.ClientInterface));
return CreateService(arguments, new RemoteCallback<TCallback>(callback, arguments.ClientDisconnectedSource));
return CreateService(arguments, new RemoteCallback<TCallback>(callback));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#nullable enable

using System;
using System.Threading;
using Microsoft.ServiceHub.Framework;

namespace Microsoft.CodeAnalysis.Remote
Expand All @@ -16,13 +15,11 @@ internal readonly struct ServiceConstructionArguments
{
public readonly IServiceProvider ServiceProvider;
public readonly IServiceBroker ServiceBroker;
public readonly CancellationTokenSource ClientDisconnectedSource;

public ServiceConstructionArguments(IServiceProvider serviceProvider, IServiceBroker serviceBroker, CancellationTokenSource clientDisconnectedSource)
public ServiceConstructionArguments(IServiceProvider serviceProvider, IServiceBroker serviceBroker)
{
ServiceProvider = serviceProvider;
ServiceBroker = serviceBroker;
ClientDisconnectedSource = clientDisconnectedSource;
}
}
}
Expand Down
17 changes: 1 addition & 16 deletions src/Workspaces/Remote/ServiceHub/Services/BrokeredServiceBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@

using System;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.ErrorReporting;
using Microsoft.ServiceHub.Framework;
using Microsoft.ServiceHub.Framework.Services;
using Nerdbank.Streams;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.Remote
Expand All @@ -27,7 +23,6 @@ internal abstract partial class BrokeredServiceBase : IDisposable
protected readonly RemoteWorkspaceManager WorkspaceManager;

protected readonly SolutionAssetSource SolutionAssetSource;
protected readonly CancellationTokenSource ClientDisconnectedSource;
protected readonly ServiceBrokerClient ServiceBrokerClient;

// test data are only available when running tests:
Expand All @@ -48,8 +43,7 @@ protected BrokeredServiceBase(in ServiceConstructionArguments arguments)
ServiceBrokerClient = new ServiceBrokerClient(arguments.ServiceBroker);
#pragma warning restore

SolutionAssetSource = new SolutionAssetSource(ServiceBrokerClient, arguments.ClientDisconnectedSource);
ClientDisconnectedSource = arguments.ClientDisconnectedSource;
SolutionAssetSource = new SolutionAssetSource(ServiceBrokerClient);
}

public void Dispose()
Expand All @@ -71,7 +65,6 @@ protected Task<Solution> GetSolutionAsync(PinnedSolutionInfo solutionInfo, Cance
protected async ValueTask<T> RunServiceAsync<T>(Func<CancellationToken, ValueTask<T>> implementation, CancellationToken cancellationToken)
{
WorkspaceManager.SolutionAssetCache.UpdateLastActivityTime();
using var _ = LinkToken(ref cancellationToken);

try
{
Expand All @@ -86,7 +79,6 @@ protected async ValueTask<T> RunServiceAsync<T>(Func<CancellationToken, ValueTas
protected async ValueTask RunServiceAsync(Func<CancellationToken, ValueTask> implementation, CancellationToken cancellationToken)
{
WorkspaceManager.SolutionAssetCache.UpdateLastActivityTime();
using var _ = LinkToken(ref cancellationToken);

try
{
Expand All @@ -97,12 +89,5 @@ protected async ValueTask RunServiceAsync(Func<CancellationToken, ValueTask> imp
throw ExceptionUtilities.Unreachable;
}
}

private CancellationTokenSource? LinkToken(ref CancellationToken cancellationToken)
{
var source = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, ClientDisconnectedSource.Token);
cancellationToken = source.Token;
return source;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,16 @@ public RemoteDesignerAttributeIncrementalAnalyzer(Workspace workspace, RemoteCal

protected override async ValueTask ReportProjectRemovedAsync(ProjectId projectId, CancellationToken cancellationToken)
{
// cancel whenever the analyzer runner cancels or the client disconnects and the request is canceled:
using var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _callback.ClientDisconnectedSource.Token);

await _callback.InvokeAsync(
(callback, cancellationToken) => callback.OnProjectRemovedAsync(projectId, cancellationToken),
linkedSource.Token).ConfigureAwait(false);
cancellationToken).ConfigureAwait(false);
}

protected override async ValueTask ReportDesignerAttributeDataAsync(List<DesignerAttributeData> data, CancellationToken cancellationToken)
{
// cancel whenever the analyzer runner cancels or the client disconnects and the request is canceled:
using var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _callback.ClientDisconnectedSource.Token);

await _callback.InvokeAsync(
(callback, cancellationToken) => callback.ReportDesignerAttributeDataAsync(data.ToImmutableArray(), cancellationToken),
linkedSource.Token).ConfigureAwait(false);
cancellationToken).ConfigureAwait(false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.ProjectTelemetry;
using Microsoft.CodeAnalysis.SolutionCrawler;
using StreamJsonRpc;

namespace Microsoft.CodeAnalysis.Remote
{
Expand Down Expand Up @@ -67,12 +66,9 @@ public override async Task AnalyzeProjectAsync(Project project, bool semanticsCh
_projectToData[projectId] = info;
}

// cancel whenever the analyzer runner cancels or the client disconnects and the request is canceled:
using var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _callback.ClientDisconnectedSource.Token);

await _callback.InvokeAsync(
(callback, cancellationToken) => callback.ReportProjectTelemetryDataAsync(info, cancellationToken),
linkedSource.Token).ConfigureAwait(false);
cancellationToken).ConfigureAwait(false);
}

public override Task RemoveProjectAsync(ProjectId projectId, CancellationToken cancellationToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,9 @@ public RemoteTodoCommentsIncrementalAnalyzer(RemoteCallback<ITodoCommentsListene

protected override async ValueTask ReportTodoCommentDataAsync(DocumentId documentId, ImmutableArray<TodoCommentData> data, CancellationToken cancellationToken)
{
// cancel whenever the analyzer runner cancels or the client disconnects and the request is canceled:
using var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _callback.ClientDisconnectedSource.Token);

await _callback.InvokeAsync(
(callback, cancellationToken) => callback.ReportTodoCommentDataAsync(documentId, data, cancellationToken),
linkedSource.Token).ConfigureAwait(false);
cancellationToken).ConfigureAwait(false);
}
}
}