From 677e16d96bde53c88459e61ec72b67cd3ef29a3a Mon Sep 17 00:00:00 2001 From: Varun Puranik Date: Mon, 26 Nov 2018 18:46:13 -0800 Subject: [PATCH] EdgeHub: Resync service identity if client request cannot be authenticated (#556) * Add logic to reget service identity if auth fails * Fix creds * Fix code and add tests * Cleanup and add logs * Don't resync when reauthenticating * Adding comment --- .../DeviceScopeTokenAuthenticator.cs | 168 ++++------------- .../DeviceScopeAuthenticator.cs | 173 ++++++++++++++++++ .../HubCoreEventIds.cs | 2 + .../modules/CommonModule.cs | 4 +- .../DeviceScopeTokenAuthenticatorTest.cs | 102 +++++++++-- 5 files changed, 295 insertions(+), 154 deletions(-) create mode 100644 edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/DeviceScopeAuthenticator.cs diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs index c7f642adaee..06eff166409 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.CloudProxy/authenticators/DeviceScopeTokenAuthenticator.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. + namespace Microsoft.Azure.Devices.Edge.Hub.CloudProxy.Authenticators { using System; using System.Net; - using System.Threading.Tasks; using Microsoft.Azure.Devices.Common.Data; using Microsoft.Azure.Devices.Common.Security; using Microsoft.Azure.Devices.Edge.Hub.Core; @@ -13,102 +13,31 @@ namespace Microsoft.Azure.Devices.Edge.Hub.CloudProxy.Authenticators using Microsoft.Azure.Devices.Edge.Util; using Microsoft.Extensions.Logging; - public class DeviceScopeTokenAuthenticator : IAuthenticator + public class DeviceScopeTokenAuthenticator : DeviceScopeAuthenticator { - readonly IDeviceScopeIdentitiesCache deviceScopeIdentitiesCache; readonly string iothubHostName; readonly string edgeHubHostName; - readonly IAuthenticator underlyingAuthenticator; public DeviceScopeTokenAuthenticator( IDeviceScopeIdentitiesCache deviceScopeIdentitiesCache, string iothubHostName, string edgeHubHostName, - IAuthenticator underlyingAuthenticator) + IAuthenticator underlyingAuthenticator, + bool allowDeviceAuthForModule, + bool syncServiceIdentityOnFailure) + : + base(deviceScopeIdentitiesCache, underlyingAuthenticator, allowDeviceAuthForModule, syncServiceIdentityOnFailure) { - this.underlyingAuthenticator = Preconditions.CheckNotNull(underlyingAuthenticator, nameof(underlyingAuthenticator)); - this.deviceScopeIdentitiesCache = Preconditions.CheckNotNull(deviceScopeIdentitiesCache, nameof(deviceScopeIdentitiesCache)); this.iothubHostName = Preconditions.CheckNonWhiteSpace(iothubHostName, nameof(iothubHostName)); - this.edgeHubHostName = Preconditions.CheckNotNull(edgeHubHostName, nameof(edgeHubHostName)); + this.edgeHubHostName = Preconditions.CheckNonWhiteSpace(edgeHubHostName, nameof(edgeHubHostName)); } - public async Task AuthenticateAsync(IClientCredentials clientCredentials) - { - if (!(clientCredentials is ITokenCredentials tokenCredentials)) - { - return false; - } - - Option serviceIdentity = await this.deviceScopeIdentitiesCache.GetServiceIdentity(clientCredentials.Identity.Id, true); - if (serviceIdentity.HasValue) - { - try - { - bool isAuthenticated = await serviceIdentity - .Map(s => this.AuthenticateInternalAsync(tokenCredentials, s)) - .GetOrElse(Task.FromResult(false)); - Events.AuthenticatedInScope(clientCredentials.Identity, isAuthenticated); - return isAuthenticated; - } - catch (Exception e) - { - Events.ErrorAuthenticating(e, clientCredentials); - return await this.underlyingAuthenticator.AuthenticateAsync(clientCredentials); - } - } - else - { - Events.ServiceIdentityNotFound(clientCredentials.Identity); - return await this.underlyingAuthenticator.AuthenticateAsync(clientCredentials); - } - } + protected override bool AreInputCredentialsValid(ITokenCredentials credentials) => this.TryGetSharedAccessSignature(credentials.Token, credentials.Identity, out SharedAccessSignature _); - public async Task ReauthenticateAsync(IClientCredentials clientCredentials) - { - if (!(clientCredentials is ITokenCredentials tokenCredentials)) - { - return false; - } - - Option serviceIdentity = await this.deviceScopeIdentitiesCache.GetServiceIdentity(clientCredentials.Identity.Id); - if (serviceIdentity.HasValue) - { - try - { - bool isAuthenticated = await serviceIdentity.Map(s => this.AuthenticateInternalAsync(tokenCredentials, s)).GetOrElse(Task.FromResult(false)); - Events.ReauthenticatedInScope(clientCredentials.Identity, isAuthenticated); - return isAuthenticated; - } - catch (Exception e) - { - Events.ErrorAuthenticating(e, clientCredentials); - return await this.underlyingAuthenticator.ReauthenticateAsync(clientCredentials); - } - } - else - { - Events.ServiceIdentityNotFound(clientCredentials.Identity); - return await this.underlyingAuthenticator.ReauthenticateAsync(clientCredentials); - } - } - - async Task AuthenticateInternalAsync(ITokenCredentials tokenCredentials, ServiceIdentity serviceIdentity) - { - if (!this.TryGetSharedAccessSignature(tokenCredentials.Token, tokenCredentials.Identity, out SharedAccessSignature sharedAccessSignature)) - { - return false; - } - - bool result = this.ValidateCredentials(sharedAccessSignature, serviceIdentity, tokenCredentials.Identity); - if (!result && tokenCredentials.Identity is IModuleIdentity moduleIdentity && serviceIdentity.IsModule) - { - // Module can use the Device key to authenticate - Option deviceServiceIdentity = await this.deviceScopeIdentitiesCache.GetServiceIdentity(moduleIdentity.DeviceId); - result = await deviceServiceIdentity.Map(d => this.AuthenticateInternalAsync(tokenCredentials, d)) - .GetOrElse(Task.FromResult(false)); - } - return result; - } + protected override bool ValidateWithServiceIdentity(ServiceIdentity serviceIdentity, ITokenCredentials credentials) => + this.TryGetSharedAccessSignature(credentials.Token, credentials.Identity, out SharedAccessSignature sharedAccessSignature) + ? this.ValidateCredentials(sharedAccessSignature, serviceIdentity, credentials.Identity) + : false; bool TryGetSharedAccessSignature(string token, IIdentity identity, out SharedAccessSignature sharedAccessSignature) { @@ -156,25 +85,25 @@ bool ValidateTokenWithSecurityIdentity(SharedAccessSignature sharedAccessSignatu } return serviceIdentity.Authentication.SymmetricKey.Map( - s => - { - var rule = new SharedAccessSignatureAuthorizationRule + s => { - PrimaryKey = s.PrimaryKey, - SecondaryKey = s.SecondaryKey - }; - - try - { - sharedAccessSignature.Authenticate(rule); - return true; - } - catch (UnauthorizedAccessException e) - { - Events.KeysMismatch(serviceIdentity.Id, e); - return false; - } - }) + var rule = new SharedAccessSignatureAuthorizationRule + { + PrimaryKey = s.PrimaryKey, + SecondaryKey = s.SecondaryKey + }; + + try + { + sharedAccessSignature.Authenticate(rule); + return true; + } + catch (UnauthorizedAccessException e) + { + Events.KeysMismatch(serviceIdentity.Id, e); + return false; + } + }) .GetOrElse(() => throw new InvalidOperationException($"Unable to validate token because the service identity has empty symmetric keys")); } @@ -246,23 +175,14 @@ static class Events enum EventIds { - ErrorReauthenticating = IdStart, - InvalidHostName, + InvalidHostName = IdStart, InvalidAudience, IdMismatch, KeysMismatch, InvalidServiceIdentityType, - ErrorAuthenticating, ServiceIdentityNotEnabled, TokenExpired, - ErrorParsingToken, - ServiceIdentityNotFound, - AuthenticatedInScope - } - - public static void ErrorReauthenticating(Exception exception, ServiceIdentity serviceIdentity) - { - Log.LogWarning((int)EventIds.ErrorReauthenticating, exception, $"Error re-authenticating {serviceIdentity.Id} after the service identity was updated."); + ErrorParsingToken } public static void InvalidHostName(string id, string hostName, string iotHubHostName, string edgeHubHostName) @@ -290,11 +210,6 @@ public static void InvalidServiceIdentityType(ServiceIdentity serviceIdentity) Log.LogWarning((int)EventIds.InvalidServiceIdentityType, $"Error authenticating token for {serviceIdentity.Id} because the service identity authentication type is unexpected - {serviceIdentity.Authentication.Type}"); } - public static void ErrorAuthenticating(Exception exception, IClientCredentials credentials) - { - Log.LogWarning((int)EventIds.ErrorAuthenticating, exception, $"Error authenticating credentials for {credentials.Identity.Id}"); - } - public static void ServiceIdentityNotEnabled(ServiceIdentity serviceIdentity) { Log.LogWarning((int)EventIds.ServiceIdentityNotEnabled, $"Error authenticating token for {serviceIdentity.Id} because the service identity is not enabled"); @@ -309,23 +224,6 @@ public static void ErrorParsingToken(IIdentity identity, Exception exception) { Log.LogWarning((int)EventIds.ErrorParsingToken, exception, $"Error authenticating token for {identity.Id} because the token could not be parsed"); } - - public static void ServiceIdentityNotFound(IIdentity identity) - { - Log.LogDebug((int)EventIds.ServiceIdentityNotFound, $"Service identity for {identity.Id} not found. Using underlying authenticator to authenticate"); - } - - public static void AuthenticatedInScope(IIdentity identity, bool isAuthenticated) - { - string authenticated = isAuthenticated ? "authenticated" : "not authenticated"; - Log.LogInformation((int)EventIds.AuthenticatedInScope, $"Client {identity.Id} in device scope {authenticated} locally."); - } - - public static void ReauthenticatedInScope(IIdentity identity, bool isAuthenticated) - { - string authenticated = isAuthenticated ? "reauthenticated" : "not reauthenticated"; - Log.LogDebug((int)EventIds.AuthenticatedInScope, $"Client {identity.Id} in device scope {authenticated} locally."); - } } } } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/DeviceScopeAuthenticator.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/DeviceScopeAuthenticator.cs new file mode 100644 index 00000000000..24cb0852c30 --- /dev/null +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/DeviceScopeAuthenticator.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.Azure.Devices.Edge.Hub.Core +{ + using System; + using System.Threading.Tasks; + using Microsoft.Azure.Devices.Edge.Hub.Core.Device; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity; + using Microsoft.Azure.Devices.Edge.Hub.Core.Identity.Service; + using Microsoft.Azure.Devices.Edge.Util; + using Microsoft.Extensions.Logging; + + public abstract class DeviceScopeAuthenticator : IAuthenticator + where T : IClientCredentials + { + readonly IDeviceScopeIdentitiesCache deviceScopeIdentitiesCache; + readonly IAuthenticator underlyingAuthenticator; + readonly bool allowDeviceAuthForModule; + readonly bool syncServiceIdentityOnFailure; + + protected DeviceScopeAuthenticator( + IDeviceScopeIdentitiesCache deviceScopeIdentitiesCache, + IAuthenticator underlyingAuthenticator, + bool allowDeviceAuthForModule, + bool syncServiceIdentityOnFailure) + { + this.underlyingAuthenticator = Preconditions.CheckNotNull(underlyingAuthenticator, nameof(underlyingAuthenticator)); + this.deviceScopeIdentitiesCache = Preconditions.CheckNotNull(deviceScopeIdentitiesCache, nameof(deviceScopeIdentitiesCache)); + this.allowDeviceAuthForModule = allowDeviceAuthForModule; + this.syncServiceIdentityOnFailure = syncServiceIdentityOnFailure; + } + + public async Task AuthenticateAsync(IClientCredentials clientCredentials) + { + if (!(clientCredentials is T tCredentials)) + { + return false; + } + + (bool isAuthenticated, bool shouldFallback) = await this.AuthenticateInternalAsync(tCredentials, false); + Events.AuthenticatedInScope(clientCredentials.Identity, isAuthenticated); + if (!isAuthenticated && shouldFallback) + { + isAuthenticated = await this.underlyingAuthenticator.AuthenticateAsync(clientCredentials); + } + + return isAuthenticated; + } + + public async Task ReauthenticateAsync(IClientCredentials clientCredentials) + { + if (!(clientCredentials is T tCredentials)) + { + return false; + } + + (bool isAuthenticated, bool shouldFallback) = await this.AuthenticateInternalAsync(tCredentials, true); + Events.ReauthenticatedInScope(clientCredentials.Identity, isAuthenticated); + if (!isAuthenticated && shouldFallback) + { + Events.ServiceIdentityNotFound(tCredentials.Identity); + isAuthenticated = await this.underlyingAuthenticator.ReauthenticateAsync(clientCredentials); + } + + return isAuthenticated; + } + + protected abstract bool AreInputCredentialsValid(T credentials); + + protected abstract bool ValidateWithServiceIdentity(ServiceIdentity serviceIdentity, T credentials); + + async Task<(bool isAuthenticated, bool shouldFallback)> AuthenticateInternalAsync(T tCredentials, bool reauthenticating) + { + try + { + if (!this.AreInputCredentialsValid(tCredentials)) + { + Events.InputCredentialsNotValid(tCredentials.Identity); + return (false, false); + } + + bool syncServiceIdentity = this.syncServiceIdentityOnFailure && !reauthenticating; + (bool isAuthenticated, bool valueFound) = await this.AuthenticateWithServiceIdentity(tCredentials, tCredentials.Identity.Id, syncServiceIdentity); + if (!isAuthenticated && this.allowDeviceAuthForModule && tCredentials.Identity is IModuleIdentity moduleIdentity) + { + // Module can use the Device key to authenticate + Events.AuthenticatingWithDeviceIdentity(moduleIdentity); + (isAuthenticated, valueFound) = await this.AuthenticateWithServiceIdentity(tCredentials, moduleIdentity.DeviceId, syncServiceIdentity); + } + + // In the return value, the first flag indicates if the authentication succeeded. + // The second flag indicates whether the authenticator should fall back to the underlying authenticator. This is done if + // the ServiceIdentity was not found (which means the device/module is not in scope). + return (isAuthenticated, !valueFound); + } + catch (Exception e) + { + Events.ErrorAuthenticating(e, tCredentials, reauthenticating); + return (false, true); + } + } + + async Task<(bool isAuthenticated, bool serviceIdentityFound)> AuthenticateWithServiceIdentity(T credentials, string serviceIdentityId, bool syncServiceIdentity) + { + Option serviceIdentity = await this.deviceScopeIdentitiesCache.GetServiceIdentity(serviceIdentityId); + (bool isAuthenticated, bool serviceIdentityFound) = serviceIdentity.Map(s => (this.ValidateWithServiceIdentity(s, credentials), true)).GetOrElse((false, false)); + + if (!isAuthenticated && (!serviceIdentityFound || syncServiceIdentity)) + { + Events.ResyncingServiceIdentity(credentials.Identity, serviceIdentityId); + await this.deviceScopeIdentitiesCache.RefreshServiceIdentity(serviceIdentityId); + serviceIdentity = await this.deviceScopeIdentitiesCache.GetServiceIdentity(serviceIdentityId); + (isAuthenticated, serviceIdentityFound) = serviceIdentity.Map(s => (this.ValidateWithServiceIdentity(s, credentials), true)).GetOrElse((false, false)); + } + + return (isAuthenticated, serviceIdentityFound); + } + + static class Events + { + static readonly ILogger Log = Logger.Factory.CreateLogger>(); + const int IdStart = HubCoreEventIds.DeviceScopeAuthenticator; + + enum EventIds + { + ErrorAuthenticating = IdStart, + ServiceIdentityNotFound, + AuthenticatedInScope, + InputCredentialsNotValid, + ResyncingServiceIdentity, + AuthenticatingWithDeviceIdentity + } + + public static void ErrorAuthenticating(Exception exception, IClientCredentials credentials, bool reauthenticating) + { + string operation = reauthenticating ? "reauthenticating" : "authenticating"; + Log.LogWarning((int)EventIds.ErrorAuthenticating, exception, $"Error {operation} credentials for {credentials.Identity.Id}"); + } + + public static void ServiceIdentityNotFound(IIdentity identity) + { + Log.LogDebug((int)EventIds.ServiceIdentityNotFound, $"Service identity for {identity.Id} not found. Using underlying authenticator to authenticate"); + } + + public static void AuthenticatedInScope(IIdentity identity, bool isAuthenticated) + { + string authenticated = isAuthenticated ? "authenticated" : "not authenticated"; + Log.LogInformation((int)EventIds.AuthenticatedInScope, $"Client {identity.Id} in device scope {authenticated} locally."); + } + + public static void ReauthenticatedInScope(IIdentity identity, bool isAuthenticated) + { + string authenticated = isAuthenticated ? "reauthenticated" : "not reauthenticated"; + Log.LogDebug((int)EventIds.AuthenticatedInScope, $"Client {identity.Id} in device scope {authenticated} locally."); + } + + public static void InputCredentialsNotValid(IIdentity identity) + { + Log.LogInformation((int)EventIds.InputCredentialsNotValid, $"Credentials for client {identity.Id} are not valid."); + } + + public static void ResyncingServiceIdentity(IIdentity identity, string serviceIdentityId) + { + Log.LogInformation((int)EventIds.ResyncingServiceIdentity, $"Unable to authenticate client {identity.Id} with cached service identity {serviceIdentityId}. Resyncing service identity..."); + } + + public static void AuthenticatingWithDeviceIdentity(IModuleIdentity moduleIdentity) + { + Log.LogInformation((int)EventIds.AuthenticatingWithDeviceIdentity, $"Unable to authenticate client {moduleIdentity.Id} with module credentials. Attempting to authenticate using device {moduleIdentity.DeviceId} credentials."); + } + } + } +} diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/HubCoreEventIds.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/HubCoreEventIds.cs index 61c812a3b0a..dd85e09affc 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/HubCoreEventIds.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Core/HubCoreEventIds.cs @@ -18,5 +18,7 @@ public static class HubCoreEventIds public const int InvokeMethodHandler = EventIdStart + 1100; public const int DeviceScopeIdentitiesCache = EventIdStart + 1200; public const int PeriodicConnectionAuthenticator = EventIdStart + 1300; + public const int DeviceScopeAuthenticator = EventIdStart + 1400; + } } diff --git a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/CommonModule.cs b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/CommonModule.cs index 8b05202268b..fb5d6ea3ccd 100644 --- a/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/CommonModule.cs +++ b/edge-hub/src/Microsoft.Azure.Devices.Edge.Hub.Service/modules/CommonModule.cs @@ -236,14 +236,14 @@ protected override void Load(ContainerBuilder builder) case AuthenticationMode.Scope: deviceScopeIdentitiesCache = await c.Resolve>(); - tokenAuthenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, this.iothubHostName, this.edgeDeviceHostName, new NullAuthenticator()); + tokenAuthenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, this.iothubHostName, this.edgeDeviceHostName, new NullAuthenticator(), true, true); break; default: var deviceScopeIdentitiesCacheTask = c.Resolve>(); IAuthenticator cloudTokenAuthenticator = await this.GetCloudTokenAuthenticator(c); deviceScopeIdentitiesCache = await deviceScopeIdentitiesCacheTask; - tokenAuthenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, this.iothubHostName, this.edgeDeviceHostName, cloudTokenAuthenticator); + tokenAuthenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, this.iothubHostName, this.edgeDeviceHostName, cloudTokenAuthenticator, true, true); break; } diff --git a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.CloudProxy.Test/DeviceScopeTokenAuthenticatorTest.cs b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.CloudProxy.Test/DeviceScopeTokenAuthenticatorTest.cs index 46892c931fc..17cc5f7718d 100644 --- a/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.CloudProxy.Test/DeviceScopeTokenAuthenticatorTest.cs +++ b/edge-hub/test/Microsoft.Azure.Devices.Edge.Hub.CloudProxy.Test/DeviceScopeTokenAuthenticatorTest.cs @@ -33,10 +33,42 @@ public async Task AuthenticateTest_Device() var deviceScopeIdentitiesCache = new Mock(); string key = GetKey(); var serviceIdentity = new ServiceIdentity(deviceId, "1234", new string[0], new ServiceAuthentication(new SymmetricKeyAuthentication(key, GetKey())), ServiceIdentityStatus.Enabled); - deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == deviceId), true)) + deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == deviceId), false)) .ReturnsAsync(Option.Some(serviceIdentity)); - IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator); + IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); + + var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); + string token = GetDeviceToken(iothubHostName, deviceId, key); + var tokenCredentials = Mock.Of(t => t.Identity == identity && t.Token == token); + + // Act + bool isAuthenticated = await authenticator.AuthenticateAsync(tokenCredentials); + + // Assert + Assert.True(isAuthenticated); + Mock.Get(underlyingAuthenticator).VerifyAll(); + } + + [Fact] + public async Task AuthenticateTest_DeviceUpdateServiceIdentity() + { + // Arrange + string iothubHostName = "testiothub.azure-devices.net"; + string edgehubHostName = "edgehub1"; + string deviceId = "d1"; + var underlyingAuthenticator = Mock.Of(); + var deviceScopeIdentitiesCache = new Mock(); + string key = GetKey(); + var serviceIdentity1 = new ServiceIdentity(deviceId, "1234", new string[0], new ServiceAuthentication(new SymmetricKeyAuthentication(GetKey(), GetKey())), ServiceIdentityStatus.Enabled); + var serviceIdentity2 = new ServiceIdentity(deviceId, "1234", new string[0], new ServiceAuthentication(new SymmetricKeyAuthentication(key, GetKey())), ServiceIdentityStatus.Enabled); + deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == deviceId), false)) + .ReturnsAsync(Option.Some(serviceIdentity1)); + deviceScopeIdentitiesCache.Setup(d => d.RefreshServiceIdentity(deviceId)) + .Callback(id => deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(deviceId, false)).ReturnsAsync(Option.Some(serviceIdentity2))) + .Returns(Task.CompletedTask); + + IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); string token = GetDeviceToken(iothubHostName, deviceId, key); @@ -48,6 +80,42 @@ public async Task AuthenticateTest_Device() // Assert Assert.True(isAuthenticated); Mock.Get(underlyingAuthenticator).VerifyAll(); + deviceScopeIdentitiesCache.Verify(d => d.GetServiceIdentity(deviceId, false), Times.Exactly(2)); + deviceScopeIdentitiesCache.Verify(d => d.RefreshServiceIdentity(deviceId), Times.Once); + } + + [Fact] + public async Task ReauthenticateTest_DeviceUpdateServiceIdentity() + { + // Arrange + string iothubHostName = "testiothub.azure-devices.net"; + string edgehubHostName = "edgehub1"; + string deviceId = "d1"; + var underlyingAuthenticator = Mock.Of(); + var deviceScopeIdentitiesCache = new Mock(); + string key = GetKey(); + var serviceIdentity1 = new ServiceIdentity(deviceId, "1234", new string[0], new ServiceAuthentication(new SymmetricKeyAuthentication(GetKey(), GetKey())), ServiceIdentityStatus.Enabled); + var serviceIdentity2 = new ServiceIdentity(deviceId, "1234", new string[0], new ServiceAuthentication(new SymmetricKeyAuthentication(key, GetKey())), ServiceIdentityStatus.Enabled); + deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(deviceId, false)) + .ReturnsAsync(Option.Some(serviceIdentity1)); + deviceScopeIdentitiesCache.Setup(d => d.RefreshServiceIdentity(deviceId)) + .Callback(id => deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(deviceId, false)).ReturnsAsync(Option.Some(serviceIdentity2))) + .Returns(Task.CompletedTask); + + IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); + + var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); + string token = GetDeviceToken(iothubHostName, deviceId, key); + var tokenCredentials = Mock.Of(t => t.Identity == identity && t.Token == token); + + // Act + bool isAuthenticated = await authenticator.ReauthenticateAsync(tokenCredentials); + + // Assert + Assert.False(isAuthenticated); + Mock.Get(underlyingAuthenticator).VerifyAll(); + deviceScopeIdentitiesCache.Verify(d => d.GetServiceIdentity(deviceId, false), Times.Once); + deviceScopeIdentitiesCache.Verify(d => d.RefreshServiceIdentity(deviceId), Times.Never); } [Fact] @@ -62,10 +130,10 @@ public async Task AuthenticateTest_Module() var deviceScopeIdentitiesCache = new Mock(); string key = GetKey(); var serviceIdentity = new ServiceIdentity(deviceId, moduleId, "1234", new string[0], new ServiceAuthentication(new SymmetricKeyAuthentication(key, GetKey())), ServiceIdentityStatus.Enabled); - deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == $"{deviceId}/{moduleId}"), true)) + deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == $"{deviceId}/{moduleId}"), false)) .ReturnsAsync(Option.Some(serviceIdentity)); - IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator); + IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.ModuleId == moduleId && d.Id == $"{deviceId}/{moduleId}"); string token = GetDeviceToken(iothubHostName, deviceId, moduleId, key); @@ -97,7 +165,7 @@ public async Task AuthenticateTest_ModuleWithDeviceToken() deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == deviceId), false)) .ReturnsAsync(Option.Some(deviceServiceIdentity)); - IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator); + IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.ModuleId == moduleId && d.Id == $"{deviceId}/{moduleId}"); string token = GetDeviceToken(iothubHostName, deviceId, key); @@ -129,7 +197,7 @@ public async Task AuthenticateTest_ModuleWithDeviceKey() deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == deviceId), false)) .ReturnsAsync(Option.Some(deviceServiceIdentity)); - IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator); + IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.ModuleId == moduleId && d.Id == $"{deviceId}/{moduleId}"); string token = GetDeviceToken(iothubHostName, deviceId, moduleId, key); @@ -157,7 +225,7 @@ public async Task AuthenticateTest_Device_ServiceIdentityNotEnabled() deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == deviceId), true)) .ReturnsAsync(Option.Some(serviceIdentity)); - IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator); + IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); string token = GetDeviceToken(iothubHostName, deviceId, key); @@ -185,7 +253,7 @@ public async Task AuthenticateTest_Device_WrongToken() deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == deviceId), true)) .ReturnsAsync(Option.Some(serviceIdentity)); - IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator); + IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); string token = GetDeviceToken(iothubHostName, deviceId, key); @@ -213,7 +281,7 @@ public async Task AuthenticateTest_Device_TokenExpired() deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == deviceId), true)) .ReturnsAsync(Option.Some(serviceIdentity)); - IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator); + IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); string token = GetDeviceToken(iothubHostName, deviceId, key, TimeSpan.FromHours(-1)); @@ -238,7 +306,7 @@ public void ValidateAudienceTest() var deviceScopeIdentitiesCache = Mock.Of(); string key = GetKey(); - var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator); + var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); string token = GetDeviceToken(iothubHostName, deviceId, key); @@ -264,7 +332,7 @@ public void ValidateAudienceWithEdgeHubHostNameTest() var deviceScopeIdentitiesCache = Mock.Of(); string key = GetKey(); - var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator); + var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); string token = GetDeviceToken(edgehubHostName, deviceId, key); @@ -291,7 +359,7 @@ public void InvalidAudienceTest_DeviceId() var deviceScopeIdentitiesCache = Mock.Of(); string key = GetKey(); - var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator); + var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); string token = GetDeviceToken(edgehubHostName, "d2", key); @@ -318,7 +386,7 @@ public void InvalidAudienceTest_Hostname() var deviceScopeIdentitiesCache = Mock.Of(); string key = GetKey(); - var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator); + var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); string token = GetDeviceToken("edgehub2", deviceId, key); @@ -344,7 +412,7 @@ public void InvalidAudienceTest_Device_Format() var underlyingAuthenticator = Mock.Of(); var deviceScopeIdentitiesCache = Mock.Of(); - var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator); + var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); string audience = $"{iothubHostName}/devices/{deviceId}/foo"; @@ -369,7 +437,7 @@ public void InvalidAudienceTest_Module_Format() var underlyingAuthenticator = Mock.Of(); var deviceScopeIdentitiesCache = Mock.Of(); - var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator); + var authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.ModuleId == moduleId && d.Id == $"{deviceId}/{moduleId}"); string audience = $"{iothubHostName}/devices/{deviceId}/modules/{moduleId}/m1"; @@ -392,10 +460,10 @@ public async Task ValidateUnderlyingAuthenticatorErrorTest() var underlyingAuthenticator = Mock.Of(); Mock.Get(underlyingAuthenticator).Setup(u => u.AuthenticateAsync(It.IsAny())).ThrowsAsync(new TimeoutException()); var deviceScopeIdentitiesCache = new Mock(); - deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == deviceId), true)) + deviceScopeIdentitiesCache.Setup(d => d.GetServiceIdentity(It.Is(i => i == deviceId), false)) .ReturnsAsync(Option.None()); - IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator); + IAuthenticator authenticator = new DeviceScopeTokenAuthenticator(deviceScopeIdentitiesCache.Object, iothubHostName, edgehubHostName, underlyingAuthenticator, true, true); var identity = Mock.Of(d => d.DeviceId == deviceId && d.Id == deviceId); string token = GetDeviceToken(iothubHostName, deviceId, GetKey());