Skip to content

Commit

Permalink
Allow users to configure websocket keep-alive (Azure#2352)
Browse files Browse the repository at this point in the history
  • Loading branch information
David R. Williamson authored Apr 26, 2022
1 parent 41e8eb4 commit e98b4db
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 62 deletions.
6 changes: 6 additions & 0 deletions iothub/device/src/AmqpTransportSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ public TimeSpan OpenTimeout
set => SetOpenTimeout(value);
}

/// <summary>
/// A keep-alive for the transport layer in sending ping/pong control frames when using web sockets.
/// </summary>
/// <seealso href="https://docs.microsoft.com/dotnet/api/system.net.websockets.clientwebsocketoptions.keepaliveinterval"/>
public TimeSpan? WebSocketKeepAlive { get; set; }

/// <summary>
/// The pre-fetch count
/// </summary>
Expand Down
47 changes: 7 additions & 40 deletions iothub/device/src/Transport/Amqp/AmqpConnectionHolder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ public AmqpConnectionHolder(IDeviceIdentity deviceIdentity)
_deviceIdentity = deviceIdentity;
_amqpIotConnector = new AmqpIotConnector(deviceIdentity.AmqpTransportSettings, deviceIdentity.IotHubConnectionString.HostName);
if (Logging.IsEnabled)
{
Logging.Associate(this, _deviceIdentity, nameof(_deviceIdentity));
}
}

public AmqpUnit CreateAmqpUnit(
Expand All @@ -43,9 +41,7 @@ public AmqpUnit CreateAmqpUnit(
Action onUnitDisconnected)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, deviceIdentity, nameof(CreateAmqpUnit));
}

var amqpUnit = new AmqpUnit(
deviceIdentity,
Expand All @@ -59,20 +55,17 @@ public AmqpUnit CreateAmqpUnit(
{
_amqpUnits.Add(amqpUnit);
}

if (Logging.IsEnabled)
{
Logging.Exit(this, deviceIdentity, nameof(CreateAmqpUnit));
}

return amqpUnit;
}

private void OnConnectionClosed(object o, EventArgs args)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, o, nameof(OnConnectionClosed));
}

if (_amqpIotConnection != null && ReferenceEquals(_amqpIotConnection, o))
{
Expand All @@ -87,25 +80,21 @@ private void OnConnectionClosed(object o, EventArgs args)
unit.OnConnectionDisconnected();
}
}

if (Logging.IsEnabled)
{
Logging.Exit(this, o, nameof(OnConnectionClosed));
}
}

public void Shutdown()
{
if (Logging.IsEnabled)
{
Logging.Enter(this, _amqpIotConnection, nameof(Shutdown));
}

_amqpAuthenticationRefresher?.StopLoop();
_amqpIotConnection?.SafeClose();

if (Logging.IsEnabled)
{
Logging.Exit(this, _amqpIotConnection, nameof(Shutdown));
}
}

public void Dispose()
Expand All @@ -122,9 +111,7 @@ private void Dispose(bool disposing)
}

if (Logging.IsEnabled)
{
Logging.Info(this, disposing, nameof(Dispose));
}

if (disposing)
{
Expand All @@ -144,38 +131,29 @@ private void Dispose(bool disposing)
public async Task<IAmqpAuthenticationRefresher> CreateRefresherAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, deviceIdentity, nameof(CreateRefresherAsync));
}

AmqpIotConnection amqpIotConnection = await EnsureConnectionAsync(cancellationToken).ConfigureAwait(false);
IAmqpAuthenticationRefresher amqpAuthenticator = await amqpIotConnection
.CreateRefresherAsync(deviceIdentity, cancellationToken)
.ConfigureAwait(false);

if (Logging.IsEnabled)
{
Logging.Exit(this, deviceIdentity, nameof(CreateRefresherAsync));
}

return amqpAuthenticator;
}

public async Task<AmqpIotSession> OpenSessionAsync(IDeviceIdentity deviceIdentity, CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, deviceIdentity, nameof(OpenSessionAsync));
}

AmqpIotConnection amqpIotConnection = await EnsureConnectionAsync(cancellationToken).ConfigureAwait(false);
AmqpIotSession amqpIotSession = await amqpIotConnection.OpenSessionAsync(cancellationToken).ConfigureAwait(false);
if (Logging.IsEnabled)
{
Logging.Associate(amqpIotConnection, amqpIotSession, nameof(OpenSessionAsync));
}

if (Logging.IsEnabled)
{
Logging.Exit(this, deviceIdentity, nameof(OpenSessionAsync));
}

Expand All @@ -185,9 +163,7 @@ public async Task<AmqpIotSession> OpenSessionAsync(IDeviceIdentity deviceIdentit
public async Task<AmqpIotConnection> EnsureConnectionAsync(CancellationToken cancellationToken)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, nameof(EnsureConnectionAsync));
}

AmqpIotConnection amqpIotConnection = null;
IAmqpAuthenticationRefresher amqpAuthenticationRefresher = null;
Expand All @@ -205,18 +181,15 @@ public async Task<AmqpIotConnection> EnsureConnectionAsync(CancellationToken can
if (_amqpIotConnection == null || _amqpIotConnection.IsClosing())
{
if (Logging.IsEnabled)
{
Logging.Info(this, "Creating new AmqpConnection", nameof(EnsureConnectionAsync));
}

// Create AmqpConnection
amqpIotConnection = await _amqpIotConnector.OpenConnectionAsync(cancellationToken).ConfigureAwait(false);

if (_deviceIdentity.AuthenticationModel == AuthenticationModel.SasGrouped)
{
if (Logging.IsEnabled)
{
Logging.Info(this, "Creating connection wide AmqpAuthenticationRefresher", nameof(EnsureConnectionAsync));
}

amqpAuthenticationRefresher = new AmqpAuthenticationRefresher(_deviceIdentity, amqpIotConnection.GetCbsLink());
await amqpAuthenticationRefresher.InitLoopAsync(cancellationToken).ConfigureAwait(false);
Expand All @@ -226,9 +199,7 @@ public async Task<AmqpIotConnection> EnsureConnectionAsync(CancellationToken can
_amqpAuthenticationRefresher = amqpAuthenticationRefresher;
_amqpIotConnection.Closed += OnConnectionClosed;
if (Logging.IsEnabled)
{
Logging.Associate(this, _amqpIotConnection, nameof(_amqpIotConnection));
}
}
else
{
Expand All @@ -245,20 +216,17 @@ public async Task<AmqpIotConnection> EnsureConnectionAsync(CancellationToken can
{
_lock.Release();
}

if (Logging.IsEnabled)
{
Logging.Exit(this, nameof(EnsureConnectionAsync));
}

return amqpIotConnection;
}

public void RemoveAmqpUnit(AmqpUnit amqpUnit)
{
if (Logging.IsEnabled)
{
Logging.Enter(this, amqpUnit, nameof(RemoveAmqpUnit));
}

lock (_unitsLock)
{
Expand All @@ -269,10 +237,9 @@ public void RemoveAmqpUnit(AmqpUnit amqpUnit)
Shutdown();
}
}

if (Logging.IsEnabled)
{
Logging.Exit(this, amqpUnit, nameof(RemoveAmqpUnit));
}
}

internal IDeviceIdentity GetDeviceIdentityOfAuthenticationProvider()
Expand Down
9 changes: 6 additions & 3 deletions iothub/device/src/Transport/AmqpIot/AmqpIotConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ internal AmqpIotConnector(AmqpTransportSettings amqpTransportSettings, string ho

public async Task<AmqpIotConnection> OpenConnectionAsync(CancellationToken cancellationToken)
{
Logging.Enter(this, nameof(OpenConnectionAsync));
if (Logging.IsEnabled)
Logging.Enter(this, nameof(OpenConnectionAsync));

var amqpTransportProvider = new AmqpTransportProvider();
amqpTransportProvider.Versions.Add(s_amqpVersion_1_0_0);
Expand All @@ -63,7 +64,8 @@ public async Task<AmqpIotConnection> OpenConnectionAsync(CancellationToken cance
amqpConnection.Closed += amqpIotConnection.AmqpConnectionClosed;
await amqpConnection.OpenAsync(cancellationToken).ConfigureAwait(false);

Logging.Exit(this, $"{nameof(OpenConnectionAsync)}");
if (Logging.IsEnabled)
Logging.Exit(this, $"{nameof(OpenConnectionAsync)}");

return amqpIotConnection;
}
Expand All @@ -75,7 +77,8 @@ public async Task<AmqpIotConnection> OpenConnectionAsync(CancellationToken cance
}
finally
{
Logging.Exit(this, nameof(OpenConnectionAsync));
if (Logging.IsEnabled)
Logging.Exit(this, nameof(OpenConnectionAsync));
}
}

Expand Down
34 changes: 25 additions & 9 deletions iothub/device/src/Transport/AmqpIot/AmqpIotTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ public void Dispose()

internal async Task<TransportBase> InitializeAsync(CancellationToken cancellationToken)
{
Logging.Enter(this, nameof(InitializeAsync));
if (Logging.IsEnabled)
Logging.Enter(this, nameof(InitializeAsync));

TransportBase transport;

Expand All @@ -96,7 +97,8 @@ internal async Task<TransportBase> InitializeAsync(CancellationToken cancellatio
default:
throw new InvalidOperationException("AmqpTransportSettings must specify WebSocketOnly or TcpOnly");
}
Logging.Exit(this, nameof(InitializeAsync));
if (Logging.IsEnabled)
Logging.Exit(this, nameof(InitializeAsync));

return transport;
}
Expand All @@ -107,7 +109,8 @@ private async Task<TransportBase> CreateClientWebSocketTransportAsync(Cancellati
{
cancellationToken.ThrowIfCancellationRequested();

Logging.Enter(this, nameof(CreateClientWebSocketTransportAsync));
if (Logging.IsEnabled)
Logging.Enter(this, nameof(CreateClientWebSocketTransportAsync));

string additionalQueryParams = "";
var websocketUri = new Uri($"{WebSocketConstants.Scheme}{_hostName}:{WebSocketConstants.SecurePort}{WebSocketConstants.UriSuffix}{additionalQueryParams}");
Expand Down Expand Up @@ -139,7 +142,8 @@ private async Task<TransportBase> CreateClientWebSocketTransportAsync(Cancellati
}
finally
{
Logging.Exit(this, $"{nameof(CreateClientWebSocketTransportAsync)}");
if (Logging.IsEnabled)
Logging.Exit(this, $"{nameof(CreateClientWebSocketTransportAsync)}");
}
}

Expand All @@ -161,7 +165,8 @@ private async Task<ClientWebSocket> CreateClientWebSocketAsync(Uri websocketUri,
{
try
{
Logging.Enter(this, nameof(CreateClientWebSocketAsync));
if (Logging.IsEnabled)
Logging.Enter(this, nameof(CreateClientWebSocketAsync));

var websocket = new ClientWebSocket();

Expand All @@ -177,13 +182,22 @@ private async Task<ClientWebSocket> CreateClientWebSocketAsync(Uri websocketUri,
{
// Configure proxy server
websocket.Options.Proxy = webProxy;
Logging.Info(this, $"{nameof(CreateClientWebSocketAsync)} Setting ClientWebSocket.Options.Proxy");
if (Logging.IsEnabled)
Logging.Info(this, $"{nameof(CreateClientWebSocketAsync)} Set ClientWebSocket.Options.Proxy to {webProxy}");
}
}
catch (PlatformNotSupportedException)
{
// .NET Core 2.0 doesn't support proxy. Ignore this setting.
Logging.Error(this, $"{nameof(CreateClientWebSocketAsync)} PlatformNotSupportedException thrown as .NET Core 2.0 doesn't support proxy");
if (Logging.IsEnabled)
Logging.Error(this, $"{nameof(CreateClientWebSocketAsync)} PlatformNotSupportedException thrown as .NET Core 2.0 doesn't support proxy");
}

if (_amqpTransportSettings.WebSocketKeepAlive.HasValue)
{
websocket.Options.KeepAliveInterval = _amqpTransportSettings.WebSocketKeepAlive.Value;
if (Logging.IsEnabled)
Logging.Info(this, $"{nameof(CreateClientWebSocketAsync)} Set websocket keep-alive to {_amqpTransportSettings.WebSocketKeepAlive}");
}

if (_amqpTransportSettings.ClientCertificate != null)
Expand All @@ -196,7 +210,8 @@ private async Task<ClientWebSocket> CreateClientWebSocketAsync(Uri websocketUri,
if (_amqpTransportSettings.RemoteCertificateValidationCallback != null)
{
websocket.Options.RemoteCertificateValidationCallback = _amqpTransportSettings.RemoteCertificateValidationCallback;
Logging.Info(this, $"{nameof(CreateClientWebSocketAsync)} Setting RemoteCertificateValidationCallback");
if (Logging.IsEnabled)
Logging.Info(this, $"{nameof(CreateClientWebSocketAsync)} Setting RemoteCertificateValidationCallback");
}
#endif
await websocket.ConnectAsync(websocketUri, cancellationToken).ConfigureAwait(false);
Expand All @@ -205,7 +220,8 @@ private async Task<ClientWebSocket> CreateClientWebSocketAsync(Uri websocketUri,
}
finally
{
Logging.Exit(this, nameof(CreateClientWebSocketAsync));
if (Logging.IsEnabled)
Logging.Exit(this, nameof(CreateClientWebSocketAsync));
}
}

Expand Down
4 changes: 2 additions & 2 deletions iothub/device/src/Transport/Mqtt/ClientWebSocketChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,8 @@ private async Task<int> DoReadBytesAsync(IByteBuffer byteBuffer)
try
{
WebSocketReceiveResult receiveResult = await _webSocket
.ReceiveAsync(new ArraySegment<byte>(byteBuffer.Array, byteBuffer.ArrayOffset + byteBuffer.WriterIndex, byteBuffer.WritableBytes), CancellationToken.None)
.ConfigureAwait(false);
.ReceiveAsync(new ArraySegment<byte>(byteBuffer.Array, byteBuffer.ArrayOffset + byteBuffer.WriterIndex, byteBuffer.WritableBytes), CancellationToken.None)
.ConfigureAwait(false);
if (receiveResult.MessageType == WebSocketMessageType.Text)
{
throw new ProtocolViolationException("Mqtt over WS message cannot be in text");
Expand Down
4 changes: 3 additions & 1 deletion iothub/device/src/Transport/Mqtt/MqttIotHubAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,9 @@ private async void ScheduleCheckConnectTimeoutAsync(IChannelHandlerContext conte

try
{
await context.Channel.EventLoop.ScheduleAsync(s_checkConnAckTimeoutCallback, context, _mqttTransportSettings.ConnectArrivalTimeout).ConfigureAwait(true);
await context.Channel.EventLoop
.ScheduleAsync(s_checkConnAckTimeoutCallback, context, _mqttTransportSettings.ConnectArrivalTimeout)
.ConfigureAwait(true);
}
catch (Exception ex) when (!ex.IsFatal())
{
Expand Down
Loading

0 comments on commit e98b4db

Please sign in to comment.