diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/ApplicationBuilderTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/ApplicationBuilderTests.cs index 5b559c451..85d4fc2fb 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/ApplicationBuilderTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/ApplicationBuilderTests.cs @@ -55,7 +55,7 @@ public void Test_ApplicationBuilder_CustomSetup() Moderator = new TestModerator() }; AuthenticationOptions authOptions = new(); - authOptions.AddAuthentication("graph", new OAuthSettings()); + authOptions.AddAuthentication("graph", new OAuthSettings() { ConnectionName = "graph-connection" }); // Act var app = new ApplicationBuilder() diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/ApplicationTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/ApplicationTests.cs index 46e43dc15..d317f54b5 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/ApplicationTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/ApplicationTests.cs @@ -65,7 +65,7 @@ public void Test_Application_CustomSetup() { AutoSignIn = (context, cancellationToken) => Task.FromResult(false) }; - authenticationOptions.AddAuthentication("graph", new OAuthSettings()); + authenticationOptions.AddAuthentication("graph", new OAuthSettings() { ConnectionName = "graph-connection" }); ApplicationOptions applicationOptions = new() { RemoveRecipientMention = removeRecipientMention, diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/OAuthBotAuthenticationTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/OAuthBotAuthenticationTests.cs index 51ab1db18..65bbce9ac 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/OAuthBotAuthenticationTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/OAuthBotAuthenticationTests.cs @@ -122,7 +122,7 @@ public async void Test_VerifyStateRouteSelector_ReturnsTrue() ((JObject)turnContext.Activity.Value).Add("settingName", SETTING_NAME); var app = new TestApplication(new() { Adapter = testAdapter }); - var botAuth = new TestOAuthBotAuthentication(app, new(), SETTING_NAME); + var botAuth = new TestOAuthBotAuthentication(app, new() { ConnectionName = "connectionName" }, SETTING_NAME); // Act var result = await botAuth.VerifyStateRouteSelector(turnContext, default); @@ -143,7 +143,7 @@ public async void Test_VerifyStateRouteSelector_IncorrectActivity_ReturnsFalse() ((JObject)turnContext.Activity.Value).Add("settingName", SETTING_NAME); var app = new TestApplication(new() { Adapter = testAdapter }); - var botAuth = new TestOAuthBotAuthentication(app, new(), SETTING_NAME); + var botAuth = new TestOAuthBotAuthentication(app, new() { ConnectionName = "connectionName" }, SETTING_NAME); // Act var result = await botAuth.VerifyStateRouteSelector(turnContext, default); @@ -164,7 +164,7 @@ public async void Test_VerifyStateRouteSelector_IncorrectInvokeName_ReturnsFalse ((JObject)turnContext.Activity.Value).Add("settingName", SETTING_NAME); var app = new TestApplication(new() { Adapter = testAdapter }); - var botAuth = new TestOAuthBotAuthentication(app, new(), SETTING_NAME); + var botAuth = new TestOAuthBotAuthentication(app, new() { ConnectionName = "connectionName" }, SETTING_NAME); // Act var result = await botAuth.VerifyStateRouteSelector(turnContext, default); @@ -185,7 +185,7 @@ public async void Test_VerifyStateRouteSelector_IncorrectSettingName_ReturnsFals ((JObject)turnContext.Activity.Value).Add("settingName", "NOT SETTING_NAME"); var app = new TestApplication(new() { Adapter = testAdapter }); - var botAuth = new TestOAuthBotAuthentication(app, new(), SETTING_NAME); + var botAuth = new TestOAuthBotAuthentication(app, new() { ConnectionName = "connectionName" }, SETTING_NAME); // Act var result = await botAuth.VerifyStateRouteSelector(turnContext, default); @@ -206,7 +206,7 @@ public async void Test_TokenExchangeRouteSelector_ReturnsTrue() ((JObject)turnContext.Activity.Value).Add("settingName", SETTING_NAME); var app = new TestApplication(new() { Adapter = testAdapter }); - var botAuth = new TestOAuthBotAuthentication(app, new(), SETTING_NAME); + var botAuth = new TestOAuthBotAuthentication(app, new() { ConnectionName = "connectionName" }, SETTING_NAME); // Act var result = await botAuth.TokenExchangeRouteSelector(turnContext, default); @@ -227,7 +227,7 @@ public async void Test_TokenExchangeRouteSelector_IncorrectActivity_ReturnsFalse ((JObject)turnContext.Activity.Value).Add("settingName", SETTING_NAME); var app = new TestApplication(new() { Adapter = testAdapter }); - var botAuth = new TestOAuthBotAuthentication(app, new(), SETTING_NAME); + var botAuth = new TestOAuthBotAuthentication(app, new() { ConnectionName = "connectionName" }, SETTING_NAME); // Act var result = await botAuth.TokenExchangeRouteSelector(turnContext, default); @@ -248,7 +248,7 @@ public async void Test_TokenExchangeRouteSelector_IncorrectInvokeName_ReturnsFal ((JObject)turnContext.Activity.Value).Add("settingName", SETTING_NAME); var app = new TestApplication(new() { Adapter = testAdapter }); - var botAuth = new TestOAuthBotAuthentication(app, new(), SETTING_NAME); + var botAuth = new TestOAuthBotAuthentication(app, new() { ConnectionName = "connectionName" }, SETTING_NAME); // Act var result = await botAuth.TokenExchangeRouteSelector(turnContext, default); @@ -269,7 +269,7 @@ public async void Test_TokenExchangeRouteSelector_IncorrectSettingName_ReturnsFa ((JObject)turnContext.Activity.Value).Add("settingName", "NOT SETTING_NAME"); var app = new TestApplication(new() { Adapter = testAdapter }); - var botAuth = new TestOAuthBotAuthentication(app, new(), SETTING_NAME); + var botAuth = new TestOAuthBotAuthentication(app, new() { ConnectionName = "connectionName" }, SETTING_NAME); // Act var result = await botAuth.TokenExchangeRouteSelector(turnContext, default); diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/FilteredTeamsSSOTokenExchangeMiddleware.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/FilteredTeamsSSOTokenExchangeMiddleware.cs index f966154be..c4b94120d 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/FilteredTeamsSSOTokenExchangeMiddleware.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/FilteredTeamsSSOTokenExchangeMiddleware.cs @@ -1,37 +1,49 @@ using Microsoft.Bot.Builder; using Microsoft.Bot.Builder.Teams; +using Microsoft.Bot.Connector; +using Microsoft.Bot.Schema; using Microsoft.Teams.AI.Exceptions; using Newtonsoft.Json.Linq; namespace Microsoft.Teams.AI.Application.Authentication.Bot { - internal class FilteredTeamsSSOTokenExchangeMiddleware : TeamsSSOTokenExchangeMiddleware + internal class FilteredTeamsSSOTokenExchangeMiddleware : IMiddleware { private string _oauthConnectionName; + private TeamsSSOTokenExchangeMiddleware tokenExchangeMiddleware; - public FilteredTeamsSSOTokenExchangeMiddleware(IStorage storage, string oauthConnectionName) : base(storage, oauthConnectionName) + public FilteredTeamsSSOTokenExchangeMiddleware(IStorage storage, string oauthConnectionName) { + this.tokenExchangeMiddleware = new TeamsSSOTokenExchangeMiddleware(storage, oauthConnectionName); this._oauthConnectionName = oauthConnectionName; } - public new async Task OnTurnAsync(ITurnContext turnContext, NextDelegate next, CancellationToken cancellationToken = default) + public async Task OnTurnAsync(ITurnContext turnContext, NextDelegate next, CancellationToken cancellationToken = default) + { + if (string.Equals(Channels.Msteams, turnContext.Activity.ChannelId, StringComparison.OrdinalIgnoreCase) + && string.Equals(SignInConstants.TokenExchangeOperationName, turnContext.Activity.Name, StringComparison.OrdinalIgnoreCase)) + { + string? connectionName = _GetConnectionName(turnContext); + + // If connection name matches then continue to the Teams SSO Token Exchange Middleware. + if (connectionName == this._oauthConnectionName) + { + await tokenExchangeMiddleware.OnTurnAsync(turnContext, next, cancellationToken).ConfigureAwait(false); + return; + } + } + + await next(cancellationToken).ConfigureAwait(false); + } + + private string? _GetConnectionName(ITurnContext turnContext) { JObject? obj = turnContext.Activity.Value as JObject; if (obj == null) { throw new TeamsAIException("Excepted `turnContext.Activity.Value` to have `connectionName` property"); }; - string? connectionName = obj.Value("connectionName"); - - // If connection name matches then continue to the Teams SSO Token Exchange Middleware. - if (connectionName == this._oauthConnectionName) - { - await base.OnTurnAsync(turnContext, next, cancellationToken); - } - else - { - await next(cancellationToken); - } + return obj.Value("connectionName"); } } } diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/OAuthBotAuthentication.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/OAuthBotAuthentication.cs index 92727d825..bf717160d 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/OAuthBotAuthentication.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/OAuthBotAuthentication.cs @@ -30,7 +30,10 @@ public OAuthBotAuthentication(Application app, OAuthSettings oauthSettin this._oauthPrompt = new OAuthPrompt("OAuthPrompt", this._oauthSettings); // Handles deduplication of token exchange event when using SSO with Bot Authentication - app.Adapter.Use(new FilteredTeamsSSOTokenExchangeMiddleware(storage ?? new MemoryStorage(), settingName)); + if (!IsTokenExchangeMiddlewareRegistered(app)) + { + app.Adapter.Use(new FilteredTeamsSSOTokenExchangeMiddleware(storage ?? new MemoryStorage(), oauthSettings.ConnectionName)); + } } /// @@ -115,9 +118,14 @@ public async Task CreateOAuthCard(ITurnContext context, Cancellation }; } - protected async virtual Task GetSignInResourceAsync(ITurnContext context, string connectionName, CancellationToken cancellationToken = default) + protected virtual async Task GetSignInResourceAsync(ITurnContext context, string connectionName, CancellationToken cancellationToken = default) { return await UserTokenClientWrapper.GetSignInResourceAsync(context, this._oauthSettings.ConnectionName, cancellationToken); } + + private bool IsTokenExchangeMiddlewareRegistered(Application app) + { + return app.Adapter.MiddlewareSet.Where(middleWare => middleWare as FilteredTeamsSSOTokenExchangeMiddleware is not null).Count() > 0; + } } }