diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AppConfig.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AppConfig.cs new file mode 100644 index 000000000..c2a924598 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AppConfig.cs @@ -0,0 +1,55 @@ +using Microsoft.Identity.Client; +using System.Security.Cryptography.X509Certificates; + +namespace Microsoft.Teams.AI.Tests.Application.Authentication +{ + internal class AppConfig : IAppConfig + { +#pragma warning disable CS8618 // This class is for test purpose only + public AppConfig(string clientId, string tenantId) +#pragma warning restore CS8618 + { + ClientId = clientId; + TenantId = tenantId; + } + + public string ClientId { get; } + + public bool EnablePiiLogging { get; } + + public IMsalHttpClientFactory HttpClientFactory { get; } + + public LogLevel LogLevel { get; } + + public bool IsDefaultPlatformLoggingEnabled { get; } + + public string RedirectUri { get; } + + public string TenantId { get; } + + public LogCallback LoggingCallback { get; } + + public IDictionary ExtraQueryParameters { get; } + + public bool IsBrokerEnabled { get; } + + public string ClientName { get; } + + public string ClientVersion { get; } + + [Obsolete] + public ITelemetryConfig TelemetryConfig { get; } + + public bool ExperimentalFeaturesEnabled { get; } + + public IEnumerable ClientCapabilities { get; } + + public bool LegacyCacheCompatibilityEnabled { get; } + + public string ClientSecret { get; } + + public X509Certificate2 ClientCredentialCertificate { get; } + + public Func ParentActivityOrWindowFunc { get; } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AuthenticationManagerTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AuthenticationManagerTests.cs index f8ba1ab2e..e983b3084 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AuthenticationManagerTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AuthenticationManagerTests.cs @@ -110,7 +110,6 @@ public async void Test_SignOut_DefaultHandler() public async void Test_SignOut_SpecificHandler() { // arrange - var graphToken = "graph token"; var app = new TestApplication(new TestApplicationOptions()); var options = new AuthenticationOptions(); options._authenticationSettings = new Dictionary() diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoBotAuthenticationTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoBotAuthenticationTests.cs new file mode 100644 index 000000000..abae878b2 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoBotAuthenticationTests.cs @@ -0,0 +1,204 @@ +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Dialogs; +using Microsoft.Bot.Schema; +using Microsoft.Identity.Client; +using Microsoft.Teams.AI.State; +using Microsoft.Teams.AI.Tests.TestUtils; +using Moq; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Teams.AI.Tests.Application.Authentication.Bot +{ + public class TeamsSsoBotAuthenticationTests + { + internal class MockTeamsSsoBotAuthentication : TeamsSsoBotAuthentication + where TState : TurnState, new() + { + public MockTeamsSsoBotAuthentication(Application app, string name, TeamsSsoSettings settings, TeamsSsoPrompt? mockPrompt = null) : base(app, name, settings, null) + { + if (mockPrompt != null) + { + _prompt = mockPrompt; + } + } + + public async Task TokenExchangeRouteSelectorPublic(ITurnContext context, CancellationToken cancellationToken) + { + return await base.TokenExchangeRouteSelector(context, cancellationToken); + } + } + + + [Fact] + public async void Test_RunDialog_BeginNew() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var mockedPrompt = CreateTeamsSsoPromptMock(settings); + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "TokenName", settings, mockedPrompt.Object); + var messageContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(messageContext); + + // act + var result = await botAuthentication.RunDialog(messageContext, turnState, "dialogStateProperty"); + + // assert + Assert.Equal(DialogTurnStatus.Waiting, result.Status); + } + + [Fact] + public async void Test_RunDialog_ContinueExisting() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var mockedPrompt = CreateTeamsSsoPromptMock(settings); + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "TokenName", settings, mockedPrompt.Object); + var messageContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(messageContext); + await botAuthentication.RunDialog(messageContext, turnState, "dialogStateProperty"); // Begin new dialog first + + // act + var tokenExchangeContext = MockTokenExchangeContext(); + var result = await botAuthentication.RunDialog(tokenExchangeContext, turnState, "dialogStateProperty"); + + // assert + Assert.Equal(DialogTurnStatus.Complete, result.Status); + } + + + [Fact] + public async void Test_ContinueDialog() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var mockedPrompt = CreateTeamsSsoPromptMock(settings); + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "TokenName", settings, mockedPrompt.Object); + var messageContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(messageContext); + await botAuthentication.RunDialog(messageContext, turnState, "dialogStateProperty"); // Begin new dialog first + + // act + var tokenExchangeContext = MockTokenExchangeContext(); + var result = await botAuthentication.ContinueDialog(tokenExchangeContext, turnState, "dialogStateProperty"); + + // assert + Assert.Equal(DialogTurnStatus.Complete, result.Status); + } + + [Fact] + public async void Test_TokenExchangeRouteSelector_NameMatched() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var turnContext = MockTokenExchangeContext("test"); + + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "test", settings); + + // act + var result = await botAuthentication.TokenExchangeRouteSelectorPublic(turnContext, CancellationToken.None); + + // assert + Assert.True(result); + } + + [Fact] + public async void Test_TokenExchangeRouteSelector_NameNotMatch() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var turnContext = MockTokenExchangeContext("AnotherTokenName"); + + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "test", settings); + + // act + var result = await botAuthentication.TokenExchangeRouteSelectorPublic(turnContext, CancellationToken.None); + + // assert + Assert.False(result); + } + + [Fact] + public async void Test_Dedupe() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var mockedPrompt = CreateTeamsSsoPromptMock(settings); + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "TokenName", settings, mockedPrompt.Object); + + // act + var messageContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(messageContext); + await botAuthentication.RunDialog(messageContext, turnState, "dialogStateProperty"); + var tokenExchangeContext = MockTokenExchangeContext(); + var tokenExchangeResult = await botAuthentication.ContinueDialog(tokenExchangeContext, turnState, "dialogStateProperty"); + + // assert + Assert.NotNull(tokenExchangeResult.Result); + Assert.Equal("test token", ((TokenResponse)tokenExchangeResult.Result).Token); + + // act - simulate processing duplicate request + await botAuthentication.RunDialog(messageContext, turnState, "dialogStateProperty"); + tokenExchangeResult = await botAuthentication.ContinueDialog(tokenExchangeContext, turnState, "dialogStateProperty"); + + // assert + Assert.Equal(DialogTurnStatus.Waiting, tokenExchangeResult.Status); + } + + private static Mock CreateTeamsSsoPromptMock(TeamsSsoSettings settings) + { + var mockedPrompt = new Mock("TeamsSsoPrompt", "TokenName", settings); + mockedPrompt + .Setup(mock => mock.BeginDialogAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(new DialogTurnResult(DialogTurnStatus.Waiting)); + mockedPrompt + .Setup(mock => mock.ContinueDialogAsync(It.IsAny(), It.IsAny())) + .Returns(async (DialogContext dc, CancellationToken cancellationToken) => + { + return await dc.EndDialogAsync(new TokenResponse(token: "test token")); + }); + return mockedPrompt; + } + + private static TurnContext MockTurnContext(string type = ActivityTypes.Message, string? name = null) + { + return new TurnContext(new SimpleAdapter(), new Activity() + { + Type = type, + Recipient = new() { Id = "recipientId" }, + Conversation = new() { Id = "conversationId" }, + From = new() { Id = "fromId" }, + ChannelId = "channelId", + Name = name + }); + } + + private static TurnContext MockTokenExchangeContext(string settingName = "test") + { + JObject activityValue = new(); + activityValue["id"] = $"{Guid.NewGuid()}-{settingName}"; + + return new TurnContext(new SimpleAdapter(), new Activity() + { + Type = ActivityTypes.Invoke, + Name = SignInConstants.TokenExchangeOperationName, + Recipient = new() { Id = "recipientId" }, + Conversation = new() { Id = "conversationId" }, + From = new() { Id = "fromId" }, + ChannelId = "channelId", + Value = activityValue + }); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoPromptTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoPromptTests.cs new file mode 100644 index 000000000..065194b22 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoPromptTests.cs @@ -0,0 +1,272 @@ +using Microsoft.Bot.Builder.Dialogs; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Schema; +using Microsoft.Identity.Client; +using Moq; +using Microsoft.Bot.Builder.Adapters; +using Microsoft.Bot.Connector; +using System.Text.Json; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Teams.AI.Tests.Application.Authentication.Bot +{ + public class TeamsSsoPromptTests + { + private const string TokenExchangeSuccess = "TokenExchangeSuccess"; + private const string TokenExchangeFail = "TokenExchangeFail"; + private const string DialogId = "DialogId"; + private const string PromptName = "PromptName"; + private const string ClientId = "ClientId"; + private const string TenantId = "TenantId"; + private const string UserReadScope = "User.Read"; + private const string AuthStartPage = "https://localhost/auth-start.html"; + private const string AccessToken = "test token"; + + private class TeamsSsoPromptMock : TeamsSsoPrompt + { + public TeamsSsoPromptMock(string dialogId, string name, TeamsSsoSettings settings, IConfidentialClientApplicationAdapter msalAdapterMock) : base(dialogId, name, settings) + { + _msalAdapter = msalAdapterMock; + } + } + + [Fact] + public async Task BeginDialogAsync_SendOAuthCard() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var testFlow = InitTestFlow(msalAdapterMock.Object); + + // Act and Assert + await testFlow + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Text = "hello", + Conversation = new ConversationAccount() { Id = "testUserId" } + }) + .AssertReply(activity => + { + Assert.Equal(1, ((Activity)activity).Attachments.Count); + Assert.Equal(OAuthCard.ContentType, ((Activity)activity).Attachments[0].ContentType); + OAuthCard? card = ((Activity)activity).Attachments[0].Content as OAuthCard; + Assert.NotNull(card); + Assert.Equal(1, card.Buttons.Count); + Assert.Equal(ActionTypes.Signin, card!.Buttons[0].Type); + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", card!.Buttons[0].Value); + }) + .StartTestAsync(); + } + + [Fact] + public async Task ContinueDialogAsync_TokenExchangeSuccess() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var authenticationResult = MockAuthenticationResult(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).ReturnsAsync(authenticationResult); + + var testFlow = InitTestFlow(msalAdapterMock.Object); + + // Act and Assert + await testFlow + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Text = "hello", + Conversation = new ConversationAccount() { Id = "testUserId" } + }) + .AssertReply(activity => + { + Assert.Equal(1, ((Activity)activity).Attachments.Count); + Assert.Equal(OAuthCard.ContentType, ((Activity)activity).Attachments[0].ContentType); + OAuthCard? card = ((Activity)activity).Attachments[0].Content as OAuthCard; + Assert.NotNull(card); + Assert.Equal(1, card.Buttons.Count); + Assert.Equal(ActionTypes.Signin, card!.Buttons[0].Type); + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", card!.Buttons[0].Value); + }) + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Type = ActivityTypes.Invoke, + Name = SignInConstants.TokenExchangeOperationName, + Value = JObject.FromObject(new TokenExchangeInvokeRequest() + { + Id = "fake_id", + Token = "fake_token" + }) + }) + .AssertReply(a => + { + Assert.Equal(ActivityTypesEx.InvokeResponse, a.Type); + var response = ((Activity)a).Value as InvokeResponse; + Assert.NotNull(response); + Assert.Equal(200, response!.Status); + }) + .AssertReply(TokenExchangeSuccess) + .AssertReply(activity => + { + var response = JsonSerializer.Deserialize(((Activity)activity).Text); + Assert.Equal(authenticationResult.AccessToken, response!.Token); + Assert.Equal(authenticationResult.ExpiresOn.ToString("O"), response!.Expiration); + }) + .StartTestAsync(); + } + + [Fact] + public async Task ContinueDialogAsync_TokenExchangeFail() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).Throws(new MsalUiRequiredException("error code", "error message")); + + var testFlow = InitTestFlow(msalAdapterMock.Object); + + // Act and Assert + await testFlow + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Text = "hello", + Conversation = new ConversationAccount() { Id = "testUserId" } + }) + .AssertReply(activity => + { + Assert.Equal(1, ((Activity)activity).Attachments.Count); + Assert.Equal(OAuthCard.ContentType, ((Activity)activity).Attachments[0].ContentType); + OAuthCard? card = ((Activity)activity).Attachments[0].Content as OAuthCard; + Assert.NotNull(card); + Assert.Equal(1, card.Buttons.Count); + Assert.Equal(ActionTypes.Signin, card!.Buttons[0].Type); + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", card!.Buttons[0].Value); + }) + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Type = ActivityTypes.Invoke, + Name = SignInConstants.TokenExchangeOperationName, + Value = JObject.FromObject(new TokenExchangeInvokeRequest() + { + Id = "fake_id", + Token = "fake_token" + }) + }) + .AssertReply(a => + { + Assert.Equal(ActivityTypesEx.InvokeResponse, a.Type); + var response = ((Activity)a).Value as InvokeResponse; + Assert.NotNull(response); + Assert.Equal(412, response!.Status); + }) + .StartTestAsync(); + } + + [Fact] + public async Task ContinueDialogAsync_SignInVerify() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var testFlow = InitTestFlow(msalAdapterMock.Object); + + // Act and Assert + await testFlow + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Text = "hello", + Conversation = new ConversationAccount() { Id = "testUserId" } + }) + .AssertReply(activity => + { + Assert.Equal(1, ((Activity)activity).Attachments.Count); + Assert.Equal(OAuthCard.ContentType, ((Activity)activity).Attachments[0].ContentType); + OAuthCard? card = ((Activity)activity).Attachments[0].Content as OAuthCard; + Assert.NotNull(card); + Assert.Equal(1, card.Buttons.Count); + Assert.Equal(ActionTypes.Signin, card!.Buttons[0].Type); + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", card!.Buttons[0].Value); + }) + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Type = ActivityTypes.Invoke, + Name = SignInConstants.VerifyStateOperationName + }) + .AssertReply(activity => + { + Assert.Equal(1, ((Activity)activity).Attachments.Count); + Assert.Equal(OAuthCard.ContentType, ((Activity)activity).Attachments[0].ContentType); + OAuthCard? card = ((Activity)activity).Attachments[0].Content as OAuthCard; + Assert.NotNull(card); + Assert.Equal(1, card.Buttons.Count); + Assert.Equal(ActionTypes.Signin, card!.Buttons[0].Type); + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", card!.Buttons[0].Value); + }) + .AssertReply(a => + { + Assert.Equal(ActivityTypesEx.InvokeResponse, a.Type); + var response = ((Activity)a).Value as InvokeResponse; + Assert.NotNull(response); + Assert.Equal(200, response!.Status); + }) + .StartTestAsync(); + } + + private static AuthenticationResult MockAuthenticationResult(string token = AccessToken, string scope = UserReadScope) + { + return new AuthenticationResult(token, false, "", DateTimeOffset.Now, DateTimeOffset.Now, "", null, "", new string[] { scope }, Guid.NewGuid()); + } + + private static Mock MockMsalAdapter() + { + var msalAdapterMock = new Mock(); + msalAdapterMock.Setup(m => m.AppConfig).Returns(new AppConfig(ClientId, TenantId)); + return msalAdapterMock; + } + + private static TeamsSsoPrompt CreateTeamsSsoPrompt(IConfidentialClientApplicationAdapter msalAdapterMock) + { + var settings = new TeamsSsoSettings(new string[] { UserReadScope }, AuthStartPage, It.IsAny()); + var teamsSsoPrompt = new TeamsSsoPromptMock(DialogId, PromptName, settings, msalAdapterMock); + return teamsSsoPrompt; + } + + private static TestFlow InitTestFlow(IConfidentialClientApplicationAdapter msalAdapterMock) + { + var teamsSsoPrompt = CreateTeamsSsoPrompt(msalAdapterMock); + var conversationState = new ConversationState(new MemoryStorage()); + var dialogState = conversationState.CreateProperty("dialogState"); + var dialogs = new DialogSet(dialogState); + dialogs.Add(teamsSsoPrompt); + + var adapter = new TestAdapter() + .Use(new AutoSaveStateMiddleware(conversationState)); + + BotCallbackHandler botCallbackHandler = async (turnContext, cancellationToken) => + { + var dc = await dialogs.CreateContextAsync(turnContext, cancellationToken); + + var results = await dc.ContinueDialogAsync(cancellationToken); + if (results.Status == DialogTurnStatus.Empty) + { + await dc.PromptAsync(DialogId, new PromptOptions(), cancellationToken); + } + else if (results.Status == DialogTurnStatus.Complete) + { + if (results.Result is TokenResponse) + { + await turnContext.SendActivityAsync(MessageFactory.Text(TokenExchangeSuccess), cancellationToken); + await turnContext.SendActivityAsync(MessageFactory.Text(JsonSerializer.Serialize(results.Result)), cancellationToken); + } + else + { + await turnContext.SendActivityAsync(MessageFactory.Text(TokenExchangeFail), cancellationToken); + } + } + }; + + return new TestFlow(adapter, botCallbackHandler); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthenticationTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthenticationTests.cs new file mode 100644 index 000000000..f8022eaa7 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthenticationTests.cs @@ -0,0 +1,193 @@ +using Microsoft.Bot.Builder; +using Microsoft.Identity.Client; +using Moq; +using Microsoft.Bot.Schema; +using Microsoft.Teams.AI.Tests.TestUtils; +using Newtonsoft.Json.Linq; +using Microsoft.Teams.AI.Exceptions; + +namespace Microsoft.Teams.AI.Tests.Application.Authentication.MessageExtensions +{ + public class TeamsSsoMessageExtensionsAuthenticationTests + { + private const string ClientId = "ClientId"; + private const string TenantId = "TenantId"; + private const string UserReadScope = "User.Read"; + private const string AuthStartPage = "https://localhost/auth-start.html"; + private const string AccessToken = "test token"; + + private class TeamsSsoMessageExtensionsAuthenticationMock : TeamsSsoMessageExtensionsAuthentication + { + public TeamsSsoMessageExtensionsAuthenticationMock(TeamsSsoSettings settings, IConfidentialClientApplicationAdapter msalAdapterMock) : base(settings) + { + _msalAdapter = msalAdapterMock; + } + } + + [Fact] + public async Task GetSignInLink() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + var turnContext = MockTurnContext(); + + // Act + var signInLink = await messageExtensionAuth.GetSignInLink(turnContext); + + // Assert + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", signInLink); + } + + [Fact] + public async Task HandleUserSignIn() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + var turnContext = MockTurnContext(); + + // Act + var tokenResponse = await messageExtensionAuth.HandleUserSignIn(turnContext, "123456"); + + // Assert + Assert.Null(tokenResponse.Token); + Assert.Null(tokenResponse.Expiration); + } + + [Fact] + public void IsValidActivity_Valid() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + var turnContext = MockTurnContext(); + + // Act + var isValidActivity = messageExtensionAuth.IsValidActivity(turnContext); + + // Assert + Assert.True(isValidActivity); + } + + [Fact] + public void IsValidActivity_InValid() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + + // Act and Assert + Assert.False(messageExtensionAuth.IsValidActivity(MockTurnContext(MessageExtensionsInvokeNames.QUERY_LINK_INVOKE_NAME))); + Assert.False(messageExtensionAuth.IsValidActivity(MockTurnContext(MessageExtensionsInvokeNames.FETCH_TASK_INVOKE_NAME))); + Assert.False(messageExtensionAuth.IsValidActivity(MockTurnContext(MessageExtensionsInvokeNames.ANONYMOUS_QUERY_LINK_INVOKE_NAME))); + } + + [Fact] + public async Task HandleSsoTokenExchange_NoTokenInRequest() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var authenticationResult = MockAuthenticationResult(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).ReturnsAsync(authenticationResult); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + var turnContext = MockTurnContext(); + + // Act + var result = await messageExtensionAuth.HandleSsoTokenExchange(turnContext); + + // Assert + Assert.Null(result.Token); + Assert.Null(result.Expiration); + } + + [Fact] + public async Task HandleSsoTokenExchange_TokenExchangeSuccess() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var authenticationResult = MockAuthenticationResult(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).ReturnsAsync(authenticationResult); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + JObject activityValue = new(); + activityValue["authentication"] = new JObject(); + activityValue["authentication"]!["token"] = "sso token"; + var turnContext = MockTurnContext(activityValue: activityValue); + + // Act + var result = await messageExtensionAuth.HandleSsoTokenExchange(turnContext); + + // Assert + Assert.Equal(authenticationResult.AccessToken, result.Token); + Assert.Equal(authenticationResult.ExpiresOn.ToString("O"), result.Expiration); + } + + [Fact] + public async Task HandleSsoTokenExchange_TokenExchangeFail() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).Throws(new MsalUiRequiredException("error code", "error message")); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + JObject activityValue = new(); + activityValue["authentication"] = new JObject(); + activityValue["authentication"]!["token"] = "sso token"; + var turnContext = MockTurnContext(activityValue: activityValue); + + // Act + var result = await messageExtensionAuth.HandleSsoTokenExchange(turnContext); + + // Assert + Assert.Null(result.Token); + Assert.Null(result.Expiration); + } + + [Fact] + public async Task HandleSsoTokenExchange_UnexpectedException() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).Throws(new MsalServiceException("error code", "error message")); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + JObject activityValue = new(); + activityValue["authentication"] = new JObject(); + activityValue["authentication"]!["token"] = "sso token"; + var turnContext = MockTurnContext(activityValue: activityValue); + + // Act and Assert + await Assert.ThrowsAsync(async () => { await messageExtensionAuth.HandleSsoTokenExchange(turnContext); }); + } + + private static Mock MockMsalAdapter() + { + var msalAdapterMock = new Mock(); + msalAdapterMock.Setup(m => m.AppConfig).Returns(new AppConfig(ClientId, TenantId)); + return msalAdapterMock; + } + + private static TeamsSsoMessageExtensionsAuthentication CreateTestClass(IConfidentialClientApplicationAdapter msalAdapterMock) + { + var settings = new TeamsSsoSettings(new string[] { UserReadScope }, AuthStartPage, It.IsAny()); + return new TeamsSsoMessageExtensionsAuthenticationMock(settings, msalAdapterMock); + } + + private static AuthenticationResult MockAuthenticationResult(string token = AccessToken, string scope = UserReadScope) + { + return new AuthenticationResult(token, false, "", DateTimeOffset.Now, DateTimeOffset.Now, "", null, "", new string[] { scope }, Guid.NewGuid()); + } + + private static TurnContext MockTurnContext(string? name = null, JObject? activityValue = null) + { + return new TurnContext(new SimpleAdapter(), new Activity() + { + Type = ActivityTypes.Invoke, + Recipient = new() { Id = "recipientId" }, + Conversation = new() { Id = "conversationId", TenantId = "tenantId" }, + From = new() { Id = "fromId", AadObjectId = "aadObjectId" }, + ChannelId = "channelId", + Name = name ?? MessageExtensionsInvokeNames.QUERY_INVOKE_NAME, + Value = activityValue ?? new JObject() + }); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/TeamsSsoAuthenticationTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/TeamsSsoAuthenticationTests.cs new file mode 100644 index 000000000..91054fb07 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/TeamsSsoAuthenticationTests.cs @@ -0,0 +1,172 @@ +using Microsoft.Bot.Builder; +using Microsoft.Bot.Schema; +using Microsoft.Identity.Client; +using Microsoft.Teams.AI.State; +using Microsoft.Teams.AI.Exceptions; +using Microsoft.Teams.AI.Tests.TestUtils; +using Moq; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Teams.AI.Tests.Application.Authentication +{ + public class TeamsSsoAuthenticationTests + { + private const string ClientId = "ClientId"; + private const string TenantId = "TenantId"; + private const string UserReadScope = "User.Read"; + private const string AuthStartPage = "https://localhost/auth-start.html"; + private const string AccessToken = "test token"; + + private class TeamsSsoAuthenticationMock : TeamsSsoAuthentication + where TState : TurnState, new() + { + public TeamsSsoAuthenticationMock(Application app, string name, TeamsSsoSettings settings, IConfidentialClientApplicationAdapter msalAdapterMock) : base(app, name, settings, null) + { + _msalAdapter = msalAdapterMock; + } + + public Func? GetSignInSuccessHandler() + { + return _botAuth?._userSignInSuccessHandler; + } + + public Func? GetSignInFailureHandler() + { + return _botAuth?._userSignInFailureHandler; + } + } + + [Fact] + public async Task SignInUserAsync_GetTokenFromCache() + { + // Arrange + var authenticationResult = MockAuthenticationResult(); + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.AcquireTokenInLongRunningProcess(It.IsAny>(), It.IsAny())).ReturnsAsync(authenticationResult); + var turnContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(turnContext); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + var result = await teamsSsoAuthentication.SignInUserAsync(turnContext, turnState); + + // Assert + Assert.Equal(authenticationResult.AccessToken, result); + } + + [Fact] + public async Task SignOutUserAsync() + { + // Arrange + var authenticationResult = MockAuthenticationResult(); + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.StopLongRunningProcessInWebApiAsync(It.IsAny(), It.IsAny())).ReturnsAsync(true); + var turnContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(turnContext); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + await teamsSsoAuthentication.SignOutUserAsync(turnContext, turnState); + + // Assert + msalAdapterMock.Verify(m => m.StopLongRunningProcessInWebApiAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public void OnUserSignInSuccess() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + teamsSsoAuthentication.OnUserSignInSuccess((turnContext, turnState) => { return Task.CompletedTask; }); + + // Assert + Assert.NotNull(teamsSsoAuthentication.GetSignInSuccessHandler()); + } + + [Fact] + public void OnUserSignInFailure() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + teamsSsoAuthentication.OnUserSignInFailure((turnContext, turnState, exception) => { return Task.CompletedTask; }); + + // Assert + Assert.NotNull(teamsSsoAuthentication.GetSignInFailureHandler()); + } + + [Fact] + public async Task IsUserSignedInAsync_UserSignedIn() + { + // Arrange + var authenticationResult = MockAuthenticationResult(); + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.AcquireTokenInLongRunningProcess(It.IsAny>(), It.IsAny())).ReturnsAsync(authenticationResult); + var turnContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(turnContext); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + var result = await teamsSsoAuthentication.IsUserSignedInAsync(turnContext); + + // Assert + Assert.Equal(authenticationResult.AccessToken, result); + } + + [Fact] + public async Task IsUserSignedInAsync_UserNotSignedIn() + { + // Arrange + var authenticationResult = MockAuthenticationResult(); + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.AcquireTokenInLongRunningProcess(It.IsAny>(), It.IsAny())).Throws(new MsalClientException("error code", "error message")); + var turnContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(turnContext); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + var result = await teamsSsoAuthentication.IsUserSignedInAsync(turnContext); + + // Assert + Assert.Null(result); + } + + private static AuthenticationResult MockAuthenticationResult(string token = AccessToken, string scope = UserReadScope) + { + return new AuthenticationResult(token, false, "", DateTimeOffset.Now, DateTimeOffset.Now, "", null, "", new string[] { scope }, Guid.NewGuid()); + } + + private static Mock MockMsalAdapter() + { + var msalAdapterMock = new Mock(); + msalAdapterMock.Setup(m => m.AppConfig).Returns(new AppConfig(ClientId, TenantId)); + return msalAdapterMock; + } + + private static TeamsSsoAuthenticationMock CreateTestClass(IConfidentialClientApplicationAdapter msalAdapterMock) + { + var app = new Application(new ApplicationOptions()); + var settings = new TeamsSsoSettings(new string[] { UserReadScope }, AuthStartPage, It.IsAny()); + return new TeamsSsoAuthenticationMock(app, "test", settings, msalAdapterMock); + } + + private static TurnContext MockTurnContext(string? name = null, JObject? activityValue = null) + { + return new TurnContext(new SimpleAdapter(), new Activity() + { + Type = ActivityTypes.Invoke, + Recipient = new() { Id = "recipientId" }, + Conversation = new() { Id = "conversationId", TenantId = "tenantId" }, + From = new() { Id = "fromId", AadObjectId = "aadObjectId" }, + ChannelId = "channelId", + Name = name ?? MessageExtensionsInvokeNames.QUERY_INVOKE_NAME, + Value = activityValue ?? new JObject() + }); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/BotAuthenticationBase.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/BotAuthenticationBase.cs index 719f37d29..17e752ff3 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/BotAuthenticationBase.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/BotAuthenticationBase.cs @@ -25,12 +25,12 @@ internal abstract class BotAuthenticationBase /// /// Callback when user sign in success /// - protected Func? _userSignInSuccessHandler; + internal Func? _userSignInSuccessHandler; /// /// Callback when user sign in fail /// - protected Func? _userSignInFailureHandler; + internal Func? _userSignInFailureHandler; /// /// Initializes the class diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoBotAuthentication.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoBotAuthentication.cs index cb61a95e1..7a1c27645 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoBotAuthentication.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoBotAuthentication.cs @@ -16,7 +16,7 @@ internal class TeamsSsoBotAuthentication : BotAuthenticationBase { private const string SSO_DIALOG_ID = "_TeamsSsoDialog"; private Regex _tokenExchangeIdRegex; - private TeamsSsoPrompt _prompt; + protected TeamsSsoPrompt _prompt; /// /// Initializes the class @@ -90,6 +90,7 @@ private async Task CreateSsoDialogContext(ITurnContext context, T TurnStateProperty accessor = new(state, "conversation", dialogStateProperty); DialogSet dialogSet = new(accessor); WaterfallDialog ssoDialog = new(SSO_DIALOG_ID); + dialogSet.Add(this._prompt); dialogSet.Add(new WaterfallDialog(SSO_DIALOG_ID, new WaterfallStep[] { @@ -99,7 +100,7 @@ private async Task CreateSsoDialogContext(ITurnContext context, T }, async (step, cancellationToken) => { - TokenResponse? tokenResponse = step.Result as TokenResponse; + TokenResponse? tokenResponse = step.Result as TokenResponse; if (tokenResponse != null && await ShouldDedup(context)) { state.Temp.DuplicateTokenExchange = true; diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoPrompt.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoPrompt.cs index a30dc1c11..3b10f3f8c 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoPrompt.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoPrompt.cs @@ -10,6 +10,8 @@ namespace Microsoft.Teams.AI { internal class TeamsSsoPrompt : Dialog { + protected IConfidentialClientApplicationAdapter _msalAdapter; + private const string _expiresKey = "expires"; private string _name; private TeamsSsoSettings _settings; @@ -17,8 +19,9 @@ internal class TeamsSsoPrompt : Dialog public TeamsSsoPrompt(string dialogId, string name, TeamsSsoSettings settings) : base(dialogId) { - this._name = name; - this._settings = settings; + _name = name; + _settings = settings; + _msalAdapter = new ConfidentialClientApplicationAdapter(settings.MSAL); } public override async Task BeginDialogAsync(DialogContext dc, object options, CancellationToken cancellationToken) @@ -28,19 +31,7 @@ public override async Task BeginDialogAsync(DialogContext dc, IDictionary state = dc.ActiveDialog.State; state[_expiresKey] = DateTime.Now.AddMilliseconds(timeout); - AuthenticationResult? token = await this.TryGetUserToken(dc.Context); - if (token != null) - { - TokenResponse tokenResponse = new() - { - ConnectionName = "", // No connection name is available in this implementation - Token = token.AccessToken, - Expiration = token.ExpiresOn.ToString("o") - }; - return await dc.EndDialogAsync(tokenResponse); - } - - // Cannot get token from cache, send OAuth card to get SSO token + // Send OAuth card to get SSO token await this.SendOAuthCardToObtainTokenAsync(dc.Context, cancellationToken); return EndOfTurn; } @@ -106,11 +97,7 @@ private async Task> RecognizeTokenAsync(Di try { string homeAccountId = $"{context.Activity.From.AadObjectId}.{context.Activity.Conversation.TenantId}"; - AuthenticationResult exchangedToken = await ((ILongRunningWebApi)_settings.MSAL).InitiateLongRunningProcessInWebApi( - _settings.Scopes, - ssoToken, - ref homeAccountId - ).ExecuteAsync(); + AuthenticationResult exchangedToken = await _msalAdapter.InitiateLongRunningProcessInWebApi(_settings.Scopes, ssoToken!, ref homeAccountId); tokenResponse = new TokenResponse { @@ -185,7 +172,7 @@ private async Task SendOAuthCardToObtainTokenAsync(ITurnContext context, Cancell private SignInResource GetSignInResource() { - string signInLink = $"{_settings.SignInLink}?scope={Uri.EscapeDataString(string.Join(" ", _settings.Scopes))}&clientId={_settings.MSAL.AppConfig.ClientId}&tenantId={_settings.MSAL.AppConfig.TenantId}"; + string signInLink = $"{_settings.SignInLink}?scope={Uri.EscapeDataString(string.Join(" ", _settings.Scopes))}&clientId={_msalAdapter.AppConfig.ClientId}&tenantId={_msalAdapter.AppConfig.TenantId}"; SignInResource signInResource = new() { @@ -199,18 +186,6 @@ private SignInResource GetSignInResource() return signInResource; } - private async Task TryGetUserToken(ITurnContext context) - { - string homeAccountId = $"{context.Activity.From.AadObjectId}.{context.Activity.Conversation.TenantId}"; - IAccount account = await this._settings.MSAL.GetAccountAsync(homeAccountId); - if (account != null) - { - AuthenticationResult result = await this._settings.MSAL.AcquireTokenSilent(this._settings.Scopes, account).ExecuteAsync(); - return result; - } - return null; // Return empty indication no token found in cache - } - private bool IsTeamsVerificationInvoke(ITurnContext context) { return (context.Activity.Type == ActivityTypes.Invoke) && (context.Activity.Name == SignInConstants.VerifyStateOperationName); diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/ConfidentialClientApplicationAdapter.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/ConfidentialClientApplicationAdapter.cs new file mode 100644 index 000000000..58b1796cd --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/ConfidentialClientApplicationAdapter.cs @@ -0,0 +1,50 @@ +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.Extensibility; + +namespace Microsoft.Teams.AI +{ + internal class ConfidentialClientApplicationAdapter : IConfidentialClientApplicationAdapter + { + private readonly IConfidentialClientApplication _msal; + + public ConfidentialClientApplicationAdapter(IConfidentialClientApplication msal) + { + _msal = msal; + } + + public IAppConfig AppConfig + { + get + { + return _msal.AppConfig; + } + } + + public Task InitiateLongRunningProcessInWebApi(IEnumerable scopes, string userToken, ref string longRunningProcessSessionKey) + { + return ((ILongRunningWebApi)_msal).InitiateLongRunningProcessInWebApi( + scopes, + userToken, + ref longRunningProcessSessionKey + ).ExecuteAsync(); + } + + public async Task StopLongRunningProcessInWebApiAsync(string longRunningProcessSessionKey, CancellationToken cancellationToken = default) + { + ILongRunningWebApi? oboCca = _msal as ILongRunningWebApi; + if (oboCca != null) + { + return await oboCca.StopLongRunningProcessInWebApiAsync(longRunningProcessSessionKey, cancellationToken); + } + return false; + } + + public async Task AcquireTokenInLongRunningProcess(IEnumerable scopes, string longRunningProcessSessionKey) + { + return await ((ILongRunningWebApi)_msal).AcquireTokenInLongRunningProcess( + scopes, + longRunningProcessSessionKey + ).ExecuteAsync(); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/IConfidentialClientApplicationAdapter.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/IConfidentialClientApplicationAdapter.cs new file mode 100644 index 000000000..01bd59166 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/IConfidentialClientApplicationAdapter.cs @@ -0,0 +1,15 @@ +using Microsoft.Identity.Client; + +namespace Microsoft.Teams.AI +{ + internal interface IConfidentialClientApplicationAdapter + { + IAppConfig AppConfig { get; } + + Task InitiateLongRunningProcessInWebApi(IEnumerable scopes, string userToken, ref string longRunningProcessSessionKey); + + Task StopLongRunningProcessInWebApiAsync(string longRunningProcessSessionKey, CancellationToken cancellationToken = default); + + Task AcquireTokenInLongRunningProcess(IEnumerable scopes, string longRunningProcessSessionKey); + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthentication.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthentication.cs index 7bbed9a5d..154914353 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthentication.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthentication.cs @@ -11,11 +11,14 @@ namespace Microsoft.Teams.AI /// internal class TeamsSsoMessageExtensionsAuthentication : MessageExtensionsAuthenticationBase { + protected IConfidentialClientApplicationAdapter _msalAdapter; + private TeamsSsoSettings _settings; public TeamsSsoMessageExtensionsAuthentication(TeamsSsoSettings settings) { _settings = settings; + _msalAdapter = new ConfidentialClientApplicationAdapter(settings.MSAL); } @@ -26,7 +29,7 @@ public TeamsSsoMessageExtensionsAuthentication(TeamsSsoSettings settings) /// The sign in link public override Task GetSignInLink(ITurnContext context) { - string signInLink = $"{_settings.SignInLink}?scope={Uri.EscapeDataString(string.Join(" ", _settings.Scopes))}&clientId={_settings.MSAL.AppConfig.ClientId}&tenantId={_settings.MSAL.AppConfig.TenantId}"; + string signInLink = $"{_settings.SignInLink}?scope={Uri.EscapeDataString(string.Join(" ", _settings.Scopes))}&clientId={_msalAdapter.AppConfig.ClientId}&tenantId={_msalAdapter.AppConfig.TenantId}"; return Task.FromResult(signInLink); } @@ -58,11 +61,7 @@ public override async Task HandleSsoTokenExchange(ITurnContext co try { string homeAccountId = $"{context.Activity.From.AadObjectId}.{context.Activity.Conversation.TenantId}"; - AuthenticationResult result = await ((ILongRunningWebApi)_settings.MSAL).InitiateLongRunningProcessInWebApi( - _settings.Scopes, - token.ToString(), - ref homeAccountId - ).ExecuteAsync(); + AuthenticationResult result = await _msalAdapter.InitiateLongRunningProcessInWebApi(_settings.Scopes, token.ToString(), ref homeAccountId); return new TokenResponse() { diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/TeamsSsoAuthentication.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/TeamsSsoAuthentication.cs index 8ee11bbe8..8a4edfce0 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/TeamsSsoAuthentication.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/TeamsSsoAuthentication.cs @@ -1,6 +1,5 @@ using Microsoft.Bot.Builder; using Microsoft.Identity.Client; -using Microsoft.Identity.Client.Extensibility; using Microsoft.Teams.AI.Exceptions; using Microsoft.Teams.AI.State; @@ -12,7 +11,9 @@ namespace Microsoft.Teams.AI public class TeamsSsoAuthentication : IAuthentication where TState : TurnState, new() { - private TeamsSsoBotAuthentication? _botAuth; + internal IConfidentialClientApplicationAdapter _msalAdapter; + + internal TeamsSsoBotAuthentication? _botAuth; private TeamsSsoMessageExtensionsAuthentication? _messageExtensionsAuth; private TeamsSsoSettings _settings; @@ -28,6 +29,7 @@ public TeamsSsoAuthentication(Application app, string name, TeamsSsoSett _settings = settings; _botAuth = new TeamsSsoBotAuthentication(app, name, _settings, storage); _messageExtensionsAuth = new TeamsSsoMessageExtensionsAuthentication(_settings); + _msalAdapter = new ConfidentialClientApplicationAdapter(settings.MSAL); } /// @@ -68,11 +70,7 @@ public async Task SignOutUserAsync(ITurnContext context, TState state, Cancellat { string homeAccountId = $"{context.Activity.From.AadObjectId}.{context.Activity.Conversation.TenantId}"; - ILongRunningWebApi? oboCca = _settings.MSAL as ILongRunningWebApi; - if (oboCca != null) - { - await oboCca.StopLongRunningProcessInWebApiAsync(homeAccountId, cancellationToken); - } + await _msalAdapter.StopLongRunningProcessInWebApiAsync(homeAccountId, cancellationToken); } /// @@ -120,10 +118,7 @@ private async Task _TryGetUserToken(ITurnContext context) string homeAccountId = $"{context.Activity.From.AadObjectId}.{context.Activity.Conversation.TenantId}"; try { - AuthenticationResult result = await ((ILongRunningWebApi)_settings.MSAL).AcquireTokenInLongRunningProcess( - _settings.Scopes, - homeAccountId - ).ExecuteAsync(); + AuthenticationResult result = await _msalAdapter.AcquireTokenInLongRunningProcess(_settings.Scopes, homeAccountId); return result.AccessToken; } catch (MsalClientException)