Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to configure websocket keep-alive #2352

Merged
merged 5 commits into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions iothub/device/src/AmqpTransportSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ public TimeSpan OpenTimeout
set => SetOpenTimeout(value);
}

/// <summary>
/// A keep-alive for when using web sockets.
drwill-ms marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
/// <seealso href="https://docs.microsoft.com/dotnet/api/system.net.websockets.clientwebsocketoptions.keepaliveinterval"/>
public TimeSpan? WebSocketKeepAlive { get; set; }

/// <summary>
/// The pre-fetch count
/// </summary>
Expand Down
47 changes: 7 additions & 40 deletions iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ public AmqpConnectionHolder(IDeviceIdentity deviceIdentity)
_deviceIdentity = deviceIdentity;
_amqpIotConnector = new AmqpIotConnector(deviceIdentity.AmqpTransportSettings, deviceIdentity.IotHubConnectionString.HostName);
if (Logging.IsEnabled)
{
Logging.Associate(this, _deviceIdentity, nameof(_deviceIdentity));
}
}

public AmqpUnit CreateAmqpUnit(
Expand All @@ -43,9 +41,7 @@ public AmqpUnit CreateAmqpUnit(
Action onUnitDisconnected)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, deviceIdentity, nameof(CreateAmqpUnit));
}

var amqpUnit = new AmqpUnit(
deviceIdentity,
Expand All @@ -59,20 +55,17 @@ public AmqpUnit CreateAmqpUnit(
{
_amqpUnits.Add(amqpUnit);
}

if (Logging.IsEnabled)
{
Logging.Exit(this, deviceIdentity, nameof(CreateAmqpUnit));
}

return amqpUnit;
}

private void OnConnectionClosed(object o, EventArgs args)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, o, nameof(OnConnectionClosed));
}

if (_amqpIotConnection != null && ReferenceEquals(_amqpIotConnection, o))
{
Expand All @@ -87,25 +80,21 @@ private void OnConnectionClosed(object o, EventArgs args)
unit.OnConnectionDisconnected();
}
}

if (Logging.IsEnabled)
{
Logging.Exit(this, o, nameof(OnConnectionClosed));
}
}

public void Shutdown()
{
if (Logging.IsEnabled)
{
Logging.Enter(this, _amqpIotConnection, nameof(Shutdown));
}

_amqpAuthenticationRefresher?.StopLoop();
_amqpIotConnection?.SafeClose();

if (Logging.IsEnabled)
{
Logging.Exit(this, _amqpIotConnection, nameof(Shutdown));
}
}

public void Dispose()
Expand All @@ -122,9 +111,7 @@ private void Dispose(bool disposing)
}

if (Logging.IsEnabled)
{
Logging.Info(this, disposing, nameof(Dispose));
}

if (disposing)
{
Expand All @@ -144,38 +131,29 @@ private void Dispose(bool disposing)
public async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, deviceIdentity, nameof(CreateRefresherAsync));
}

AmqpIotConnection amqpIotConnection = await EnsureConnectionAsync(cancellationToken).ConfigureAwait(false);
IAmqpAuthenticationRefresher amqpAuthenticator = await amqpIotConnection
.CreateRefresherAsync(deviceIdentity, cancellationToken)
.ConfigureAwait(false);

if (Logging.IsEnabled)
{
Logging.Exit(this, deviceIdentity, nameof(CreateRefresherAsync));
}

return amqpAuthenticator;
}

public async Task<AmqpIotSession> OpenSessionAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, deviceIdentity, nameof(OpenSessionAsync));
}

AmqpIotConnection amqpIotConnection = await EnsureConnectionAsync(cancellationToken).ConfigureAwait(false);
AmqpIotSession amqpIotSession = await amqpIotConnection.OpenSessionAsync(cancellationToken).ConfigureAwait(false);
if (Logging.IsEnabled)
{
Logging.Associate(amqpIotConnection, amqpIotSession, nameof(OpenSessionAsync));
}

if (Logging.IsEnabled)
{
Logging.Exit(this, deviceIdentity, nameof(OpenSessionAsync));
}

Expand All @@ -185,9 +163,7 @@ public async Task<AmqpIotSession> OpenSessionAsync(IDeviceIdentity deviceIdentit
public async Task<AmqpIotConnection> EnsureConnectionAsync(CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, nameof(EnsureConnectionAsync));
}

AmqpIotConnection amqpIotConnection = null;
IAmqpAuthenticationRefresher amqpAuthenticationRefresher = null;
Expand All @@ -205,18 +181,15 @@ public async Task<AmqpIotConnection> EnsureConnectionAsync(CancellationToken can
if (_amqpIotConnection == null || _amqpIotConnection.IsClosing())
{
if (Logging.IsEnabled)
{
Logging.Info(this, "Creating new AmqpConnection", nameof(EnsureConnectionAsync));
}

// Create AmqpConnection
amqpIotConnection = await _amqpIotConnector.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);

if (_deviceIdentity.AuthenticationModel == AuthenticationModel.SasGrouped)
{
if (Logging.IsEnabled)
{
Logging.Info(this, "Creating connection wide AmqpAuthenticationRefresher", nameof(EnsureConnectionAsync));
}

amqpAuthenticationRefresher = new AmqpAuthenticationRefresher(_deviceIdentity, amqpIotConnection.GetCbsLink());
await amqpAuthenticationRefresher.InitLoopAsync(cancellationToken).ConfigureAwait(false);
Expand All @@ -226,9 +199,7 @@ public async Task<AmqpIotConnection> EnsureConnectionAsync(CancellationToken can
_amqpAuthenticationRefresher = amqpAuthenticationRefresher;
_amqpIotConnection.Closed += OnConnectionClosed;
if (Logging.IsEnabled)
{
Logging.Associate(this, _amqpIotConnection, nameof(_amqpIotConnection));
}
}
else
{
Expand All @@ -245,20 +216,17 @@ public async Task<AmqpIotConnection> EnsureConnectionAsync(CancellationToken can
{
_lock.Release();
}

if (Logging.IsEnabled)
{
Logging.Exit(this, nameof(EnsureConnectionAsync));
}

return amqpIotConnection;
}

public void RemoveAmqpUnit(AmqpUnit amqpUnit)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, amqpUnit, nameof(RemoveAmqpUnit));
}

lock (_unitsLock)
{
Expand All @@ -269,10 +237,9 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
Shutdown();
}
}

if (Logging.IsEnabled)
{
Logging.Exit(this, amqpUnit, nameof(RemoveAmqpUnit));
}
}

internal IDeviceIdentity GetDeviceIdentityOfAuthenticationProvider()
Expand Down
9 changes: 6 additions & 3 deletions iothub/device/src/Transport/AmqpIot/AmqpIotConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ internal AmqpIotConnector(AmqpTransportSettings amqpTransportSettings, string ho

public async Task<AmqpIotConnection> OpenConnectionAsync(CancellationToken cancellationToken)
{
Logging.Enter(this, nameof(OpenConnectionAsync));
if (Logging.IsEnabled)
Logging.Enter(this, nameof(OpenConnectionAsync));

var amqpTransportProvider = new AmqpTransportProvider();
amqpTransportProvider.Versions.Add(s_amqpVersion_1_0_0);
Expand All @@ -63,7 +64,8 @@ public async Task<AmqpIotConnection> OpenConnectionAsync(CancellationToken cance
amqpConnection.Closed += amqpIotConnection.AmqpConnectionClosed;
await amqpConnection.OpenAsync(cancellationToken).ConfigureAwait(false);

Logging.Exit(this, $"{nameof(OpenConnectionAsync)}");
if (Logging.IsEnabled)
Logging.Exit(this, $"{nameof(OpenConnectionAsync)}");

return amqpIotConnection;
}
Expand All @@ -75,7 +77,8 @@ public async Task<AmqpIotConnection> OpenConnectionAsync(CancellationToken cance
}
finally
{
Logging.Exit(this, nameof(OpenConnectionAsync));
if (Logging.IsEnabled)
Logging.Exit(this, nameof(OpenConnectionAsync));
}
}

Expand Down
34 changes: 25 additions & 9 deletions iothub/device/src/Transport/AmqpIot/AmqpIotTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ public void Dispose()

internal async Task<TransportBase> InitializeAsync(CancellationToken cancellationToken)
{
Logging.Enter(this, nameof(InitializeAsync));
if (Logging.IsEnabled)
Logging.Enter(this, nameof(InitializeAsync));

TransportBase transport;

Expand All @@ -96,7 +97,8 @@ internal async Task<TransportBase> InitializeAsync(CancellationToken cancellatio
default:
throw new InvalidOperationException("AmqpTransportSettings must specify WebSocketOnly or TcpOnly");
}
Logging.Exit(this, nameof(InitializeAsync));
if (Logging.IsEnabled)
Logging.Exit(this, nameof(InitializeAsync));

return transport;
}
Expand All @@ -107,7 +109,8 @@ private async Task<TransportBase> CreateClientWebSocketTransportAsync(Cancellati
{
cancellationToken.ThrowIfCancellationRequested();

Logging.Enter(this, nameof(CreateClientWebSocketTransportAsync));
if (Logging.IsEnabled)
Logging.Enter(this, nameof(CreateClientWebSocketTransportAsync));

string additionalQueryParams = "";
var websocketUri = new Uri($"{WebSocketConstants.Scheme}{_hostName}:{WebSocketConstants.SecurePort}{WebSocketConstants.UriSuffix}{additionalQueryParams}");
Expand Down Expand Up @@ -139,7 +142,8 @@ private async Task<TransportBase> CreateClientWebSocketTransportAsync(Cancellati
}
finally
{
Logging.Exit(this, $"{nameof(CreateClientWebSocketTransportAsync)}");
if (Logging.IsEnabled)
Logging.Exit(this, $"{nameof(CreateClientWebSocketTransportAsync)}");
}
}

Expand All @@ -161,7 +165,8 @@ private async Task<ClientWebSocket> CreateClientWebSocketAsync(Uri websocketUri,
{
try
{
Logging.Enter(this, nameof(CreateClientWebSocketAsync));
if (Logging.IsEnabled)
Logging.Enter(this, nameof(CreateClientWebSocketAsync));

var websocket = new ClientWebSocket();

Expand All @@ -177,13 +182,22 @@ private async Task<ClientWebSocket> CreateClientWebSocketAsync(Uri websocketUri,
{
// Configure proxy server
websocket.Options.Proxy = webProxy;
Logging.Info(this, $"{nameof(CreateClientWebSocketAsync)} Setting ClientWebSocket.Options.Proxy");
if (Logging.IsEnabled)
Logging.Info(this, $"{nameof(CreateClientWebSocketAsync)} Set ClientWebSocket.Options.Proxy to {webProxy}");
}
}
catch (PlatformNotSupportedException)
{
// .NET Core 2.0 doesn't support proxy. Ignore this setting.
Logging.Error(this, $"{nameof(CreateClientWebSocketAsync)} PlatformNotSupportedException thrown as .NET Core 2.0 doesn't support proxy");
if (Logging.IsEnabled)
Logging.Error(this, $"{nameof(CreateClientWebSocketAsync)} PlatformNotSupportedException thrown as .NET Core 2.0 doesn't support proxy");
}

if (_amqpTransportSettings.WebSocketKeepAlive.HasValue)
drwill-ms marked this conversation as resolved.
Show resolved Hide resolved
{
websocket.Options.KeepAliveInterval = _amqpTransportSettings.WebSocketKeepAlive.Value;
if (Logging.IsEnabled)
Logging.Info(this, $"{nameof(CreateClientWebSocketAsync)} Set websocket keep-alive to {_amqpTransportSettings.WebSocketKeepAlive}");
}

if (_amqpTransportSettings.ClientCertificate != null)
Expand All @@ -196,7 +210,8 @@ private async Task<ClientWebSocket> CreateClientWebSocketAsync(Uri websocketUri,
if (_amqpTransportSettings.RemoteCertificateValidationCallback != null)
{
websocket.Options.RemoteCertificateValidationCallback = _amqpTransportSettings.RemoteCertificateValidationCallback;
Logging.Info(this, $"{nameof(CreateClientWebSocketAsync)} Setting RemoteCertificateValidationCallback");
if (Logging.IsEnabled)
Logging.Info(this, $"{nameof(CreateClientWebSocketAsync)} Setting RemoteCertificateValidationCallback");
}
#endif
await websocket.ConnectAsync(websocketUri, cancellationToken).ConfigureAwait(false);
Expand All @@ -205,7 +220,8 @@ private async Task<ClientWebSocket> CreateClientWebSocketAsync(Uri websocketUri,
}
finally
{
Logging.Exit(this, nameof(CreateClientWebSocketAsync));
if (Logging.IsEnabled)
Logging.Exit(this, nameof(CreateClientWebSocketAsync));
}
}

Expand Down
4 changes: 2 additions & 2 deletions iothub/device/src/Transport/Mqtt/ClientWebSocketChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ private async Task<int> DoReadBytesAsync(IByteBuffer byteBuffer)
try
{
WebSocketReceiveResult receiveResult = await _webSocket
.ReceiveAsync(new ArraySegment<byte>(byteBuffer.Array, byteBuffer.ArrayOffset + byteBuffer.WriterIndex, byteBuffer.WritableBytes), CancellationToken.None)
.ConfigureAwait(false);
.ReceiveAsync(new ArraySegment<byte>(byteBuffer.Array, byteBuffer.ArrayOffset + byteBuffer.WriterIndex, byteBuffer.WritableBytes), CancellationToken.None)
.ConfigureAwait(false);
if (receiveResult.MessageType == WebSocketMessageType.Text)
{
throw new ProtocolViolationException("Mqtt over WS message cannot be in text");
Expand Down
4 changes: 3 additions & 1 deletion iothub/device/src/Transport/Mqtt/MqttIotHubAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ private async void ScheduleCheckConnectTimeoutAsync(IChannelHandlerContext conte

try
{
await context.Channel.EventLoop.ScheduleAsync(s_checkConnAckTimeoutCallback, context, _mqttTransportSettings.ConnectArrivalTimeout).ConfigureAwait(true);
await context.Channel.EventLoop
.ScheduleAsync(s_checkConnAckTimeoutCallback, context, _mqttTransportSettings.ConnectArrivalTimeout)
.ConfigureAwait(true);
}
catch (Exception ex) when (!ex.IsFatal())
{
Expand Down
Loading