From 718ff219ed7217825afbd6c86eb8593790b60d7b Mon Sep 17 00:00:00 2001 From: Thomas Avery <43214426+Thomas-Avery@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:09:41 -0600 Subject: [PATCH 01/28] [PM-13706] Add repository + stored procedures for private key regeneration (#4898) * Add stored procedure * Add repository --- .../Models/Data/UserAsymmetricKeys.cs | 9 +++++ .../IUserAsymmetricKeysRepository.cs | 9 +++++ .../DapperServiceCollectionExtensions.cs | 3 ++ .../UserAsymmetricKeysRepository.cs | 36 +++++++++++++++++++ ...ityFrameworkServiceCollectionExtensions.cs | 3 ++ .../UserAsymmetricKeysRepository.cs | 34 ++++++++++++++++++ .../UserAsymmetricKeys_Regenerate.sql | 16 +++++++++ ...-21_00_AddUserAsymmetricKeysRegenerate.sql | 16 +++++++++ 8 files changed, 126 insertions(+) create mode 100644 src/Core/KeyManagement/Models/Data/UserAsymmetricKeys.cs create mode 100644 src/Core/KeyManagement/Repositories/IUserAsymmetricKeysRepository.cs create mode 100644 src/Infrastructure.Dapper/KeyManagement/Repositories/UserAsymmetricKeysRepository.cs create mode 100644 src/Infrastructure.EntityFramework/KeyManagement/Repositories/UserAsymmetricKeysRepository.cs create mode 100644 src/Sql/KeyManagement/dbo/Stored Procedures/UserAsymmetricKeys_Regenerate.sql create mode 100644 util/Migrator/DbScripts/2024-11-21_00_AddUserAsymmetricKeysRegenerate.sql diff --git a/src/Core/KeyManagement/Models/Data/UserAsymmetricKeys.cs b/src/Core/KeyManagement/Models/Data/UserAsymmetricKeys.cs new file mode 100644 index 000000000000..3c1362909260 --- /dev/null +++ b/src/Core/KeyManagement/Models/Data/UserAsymmetricKeys.cs @@ -0,0 +1,9 @@ +#nullable enable +namespace Bit.Core.KeyManagement.Models.Data; + +public class UserAsymmetricKeys +{ + public Guid UserId { get; set; } + public required string PublicKey { get; set; } + public required string UserKeyEncryptedPrivateKey { get; set; } +} diff --git a/src/Core/KeyManagement/Repositories/IUserAsymmetricKeysRepository.cs b/src/Core/KeyManagement/Repositories/IUserAsymmetricKeysRepository.cs new file mode 100644 index 000000000000..fee9aee3bbb3 --- /dev/null +++ b/src/Core/KeyManagement/Repositories/IUserAsymmetricKeysRepository.cs @@ -0,0 +1,9 @@ +#nullable enable +using Bit.Core.KeyManagement.Models.Data; + +namespace Bit.Core.KeyManagement.Repositories; + +public interface IUserAsymmetricKeysRepository +{ + Task RegenerateUserAsymmetricKeysAsync(UserAsymmetricKeys userAsymmetricKeys); +} diff --git a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs index 550c572cf509..c873f84aa068 100644 --- a/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs +++ b/src/Infrastructure.Dapper/DapperServiceCollectionExtensions.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Repositories; using Bit.Core.Auth.Repositories; using Bit.Core.Billing.Repositories; +using Bit.Core.KeyManagement.Repositories; using Bit.Core.NotificationCenter.Repositories; using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; @@ -9,6 +10,7 @@ using Bit.Infrastructure.Dapper.AdminConsole.Repositories; using Bit.Infrastructure.Dapper.Auth.Repositories; using Bit.Infrastructure.Dapper.Billing.Repositories; +using Bit.Infrastructure.Dapper.KeyManagement.Repositories; using Bit.Infrastructure.Dapper.NotificationCenter.Repositories; using Bit.Infrastructure.Dapper.Repositories; using Bit.Infrastructure.Dapper.SecretsManager.Repositories; @@ -60,6 +62,7 @@ public static void AddDapperRepositories(this IServiceCollection services, bool .AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); if (selfHosted) { diff --git a/src/Infrastructure.Dapper/KeyManagement/Repositories/UserAsymmetricKeysRepository.cs b/src/Infrastructure.Dapper/KeyManagement/Repositories/UserAsymmetricKeysRepository.cs new file mode 100644 index 000000000000..f176327f4ff8 --- /dev/null +++ b/src/Infrastructure.Dapper/KeyManagement/Repositories/UserAsymmetricKeysRepository.cs @@ -0,0 +1,36 @@ +#nullable enable +using System.Data; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Repositories; +using Bit.Core.Settings; +using Bit.Infrastructure.Dapper.Repositories; +using Dapper; +using Microsoft.Data.SqlClient; + +namespace Bit.Infrastructure.Dapper.KeyManagement.Repositories; + +public class UserAsymmetricKeysRepository : BaseRepository, IUserAsymmetricKeysRepository +{ + public UserAsymmetricKeysRepository(GlobalSettings globalSettings) + : this(globalSettings.SqlServer.ConnectionString, globalSettings.SqlServer.ReadOnlyConnectionString) + { + } + + public UserAsymmetricKeysRepository(string connectionString, string readOnlyConnectionString) : base( + connectionString, readOnlyConnectionString) + { + } + + public async Task RegenerateUserAsymmetricKeysAsync(UserAsymmetricKeys userAsymmetricKeys) + { + await using var connection = new SqlConnection(ConnectionString); + + await connection.ExecuteAsync("[dbo].[UserAsymmetricKeys_Regenerate]", + new + { + userAsymmetricKeys.UserId, + userAsymmetricKeys.PublicKey, + PrivateKey = userAsymmetricKeys.UserKeyEncryptedPrivateKey + }, commandType: CommandType.StoredProcedure); + } +} diff --git a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs index b8c84f649f21..b2eefe4523f8 100644 --- a/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs +++ b/src/Infrastructure.EntityFramework/EntityFrameworkServiceCollectionExtensions.cs @@ -2,6 +2,7 @@ using Bit.Core.Auth.Repositories; using Bit.Core.Billing.Repositories; using Bit.Core.Enums; +using Bit.Core.KeyManagement.Repositories; using Bit.Core.NotificationCenter.Repositories; using Bit.Core.Repositories; using Bit.Core.SecretsManager.Repositories; @@ -10,6 +11,7 @@ using Bit.Infrastructure.EntityFramework.AdminConsole.Repositories; using Bit.Infrastructure.EntityFramework.Auth.Repositories; using Bit.Infrastructure.EntityFramework.Billing.Repositories; +using Bit.Infrastructure.EntityFramework.KeyManagement.Repositories; using Bit.Infrastructure.EntityFramework.NotificationCenter.Repositories; using Bit.Infrastructure.EntityFramework.Repositories; using Bit.Infrastructure.EntityFramework.SecretsManager.Repositories; @@ -97,6 +99,7 @@ public static void AddPasswordManagerEFRepositories(this IServiceCollection serv .AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); if (selfHosted) { diff --git a/src/Infrastructure.EntityFramework/KeyManagement/Repositories/UserAsymmetricKeysRepository.cs b/src/Infrastructure.EntityFramework/KeyManagement/Repositories/UserAsymmetricKeysRepository.cs new file mode 100644 index 000000000000..c680424f56a2 --- /dev/null +++ b/src/Infrastructure.EntityFramework/KeyManagement/Repositories/UserAsymmetricKeysRepository.cs @@ -0,0 +1,34 @@ +#nullable enable +using AutoMapper; +using Bit.Core.KeyManagement.Models.Data; +using Bit.Core.KeyManagement.Repositories; +using Bit.Infrastructure.EntityFramework.Repositories; +using Microsoft.Extensions.DependencyInjection; + +namespace Bit.Infrastructure.EntityFramework.KeyManagement.Repositories; + +public class UserAsymmetricKeysRepository : BaseEntityFrameworkRepository, IUserAsymmetricKeysRepository +{ + public UserAsymmetricKeysRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper) : base( + serviceScopeFactory, + mapper) + { + } + + public async Task RegenerateUserAsymmetricKeysAsync(UserAsymmetricKeys userAsymmetricKeys) + { + await using var scope = ServiceScopeFactory.CreateAsyncScope(); + var dbContext = GetDatabaseContext(scope); + + var entity = await dbContext.Users.FindAsync(userAsymmetricKeys.UserId); + if (entity != null) + { + var utcNow = DateTime.UtcNow; + entity.PublicKey = userAsymmetricKeys.PublicKey; + entity.PrivateKey = userAsymmetricKeys.UserKeyEncryptedPrivateKey; + entity.RevisionDate = utcNow; + entity.AccountRevisionDate = utcNow; + await dbContext.SaveChangesAsync(); + } + } +} diff --git a/src/Sql/KeyManagement/dbo/Stored Procedures/UserAsymmetricKeys_Regenerate.sql b/src/Sql/KeyManagement/dbo/Stored Procedures/UserAsymmetricKeys_Regenerate.sql new file mode 100644 index 000000000000..26d0c40183bb --- /dev/null +++ b/src/Sql/KeyManagement/dbo/Stored Procedures/UserAsymmetricKeys_Regenerate.sql @@ -0,0 +1,16 @@ +CREATE PROCEDURE [dbo].[UserAsymmetricKeys_Regenerate] + @UserId UNIQUEIDENTIFIER, + @PublicKey VARCHAR(MAX), + @PrivateKey VARCHAR(MAX) +AS +BEGIN + SET NOCOUNT ON + DECLARE @UtcNow DATETIME2(7) = GETUTCDATE(); + + UPDATE [dbo].[User] + SET [PublicKey] = @PublicKey, + [PrivateKey] = @PrivateKey, + [RevisionDate] = @UtcNow, + [AccountRevisionDate] = @UtcNow + WHERE [Id] = @UserId +END diff --git a/util/Migrator/DbScripts/2024-11-21_00_AddUserAsymmetricKeysRegenerate.sql b/util/Migrator/DbScripts/2024-11-21_00_AddUserAsymmetricKeysRegenerate.sql new file mode 100644 index 000000000000..e1f5431145eb --- /dev/null +++ b/util/Migrator/DbScripts/2024-11-21_00_AddUserAsymmetricKeysRegenerate.sql @@ -0,0 +1,16 @@ +CREATE OR ALTER PROCEDURE [dbo].[UserAsymmetricKeys_Regenerate] + @UserId UNIQUEIDENTIFIER, + @PublicKey VARCHAR(MAX), + @PrivateKey VARCHAR(MAX) +AS +BEGIN + SET NOCOUNT ON + DECLARE @UtcNow DATETIME2(7) = GETUTCDATE(); + + UPDATE [dbo].[User] + SET [PublicKey] = @PublicKey, + [PrivateKey] = @PrivateKey, + [RevisionDate] = @UtcNow, + [AccountRevisionDate] = @UtcNow + WHERE [Id] = @UserId +END From 5dbda8c831ea9562ff04724bde24b4ec39662626 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:21:47 +0100 Subject: [PATCH 02/28] [deps] Tools: Update aws-sdk-net monorepo (#5056) Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- src/Core/Core.csproj | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index 913c4b387799..b8e2b8409a2f 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -21,8 +21,8 @@ - - + + From f4dd794cba71fd9d99d451a48d8b14a075731529 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 13:01:45 +0000 Subject: [PATCH 03/28] [deps] Platform: Update Quartz to 3.13.1 (#4655) Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- src/Core/Core.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index b8e2b8409a2f..2f40fb7b23b0 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -44,7 +44,7 @@ - + From dac8f66a598d2369af81e6d55dc79489e58d3365 Mon Sep 17 00:00:00 2001 From: Justin Baur <19896123+justindbaur@users.noreply.github.com> Date: Fri, 22 Nov 2024 16:05:15 -0500 Subject: [PATCH 04/28] Resolve AC Warnings (#4644) * Resolve AC Warnings * Remove Unneeded Changes * Add Back RequiredAttribute * Format --- src/Admin/AdminConsole/Models/OrganizationEditModel.cs | 2 +- src/Api/AdminConsole/Public/Models/MemberBaseModel.cs | 9 ++++++--- .../Public/Models/Response/MemberResponseModel.cs | 3 +++ .../Models/Mail/Provider/ProviderInitiateDeleteModel.cs | 2 -- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs index 4ba22130f718..48340df708c5 100644 --- a/src/Admin/AdminConsole/Models/OrganizationEditModel.cs +++ b/src/Admin/AdminConsole/Models/OrganizationEditModel.cs @@ -143,7 +143,7 @@ public OrganizationEditModel( [Display(Name = "SCIM")] public bool UseScim { get; set; } [Display(Name = "Secrets Manager")] - public bool UseSecretsManager { get; set; } + public new bool UseSecretsManager { get; set; } [Display(Name = "Self Host")] public bool SelfHost { get; set; } [Display(Name = "Users Get Premium")] diff --git a/src/Api/AdminConsole/Public/Models/MemberBaseModel.cs b/src/Api/AdminConsole/Public/Models/MemberBaseModel.cs index c56117ae71d7..dc3f91d49f5b 100644 --- a/src/Api/AdminConsole/Public/Models/MemberBaseModel.cs +++ b/src/Api/AdminConsole/Public/Models/MemberBaseModel.cs @@ -1,8 +1,11 @@ using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Data.Organizations.OrganizationUsers; +#nullable enable + namespace Bit.Api.AdminConsole.Public.Models; public abstract class MemberBaseModel @@ -25,6 +28,7 @@ public MemberBaseModel(OrganizationUser user) } } + [SetsRequiredMembers] public MemberBaseModel(OrganizationUserUserDetails user) { if (user == null) @@ -46,14 +50,13 @@ public MemberBaseModel(OrganizationUserUserDetails user) /// [Required] [EnumDataType(typeof(OrganizationUserType))] - public OrganizationUserType? Type { get; set; } + public required OrganizationUserType? Type { get; set; } /// /// External identifier for reference or linking this member to another system, such as a user directory. /// /// external_id_123456 [StringLength(300)] - public string ExternalId { get; set; } - + public string? ExternalId { get; set; } /// /// The member's custom permissions if the member has a Custom role. If not supplied, all custom permissions will /// default to false. diff --git a/src/Api/AdminConsole/Public/Models/Response/MemberResponseModel.cs b/src/Api/AdminConsole/Public/Models/Response/MemberResponseModel.cs index ab6ecbca4497..499c27cfc9be 100644 --- a/src/Api/AdminConsole/Public/Models/Response/MemberResponseModel.cs +++ b/src/Api/AdminConsole/Public/Models/Response/MemberResponseModel.cs @@ -1,4 +1,5 @@ using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; using System.Text.Json.Serialization; using Bit.Api.Models.Public.Response; using Bit.Core.Entities; @@ -16,6 +17,7 @@ public class MemberResponseModel : MemberBaseModel, IResponseModel [JsonConstructor] public MemberResponseModel() { } + [SetsRequiredMembers] public MemberResponseModel(OrganizationUser user, IEnumerable collections) : base(user) { if (user == null) @@ -31,6 +33,7 @@ public MemberResponseModel(OrganizationUser user, IEnumerable collections) : base(user) { diff --git a/src/Core/Models/Mail/Provider/ProviderInitiateDeleteModel.cs b/src/Core/Models/Mail/Provider/ProviderInitiateDeleteModel.cs index 196decb5eef9..a5071527fe7e 100644 --- a/src/Core/Models/Mail/Provider/ProviderInitiateDeleteModel.cs +++ b/src/Core/Models/Mail/Provider/ProviderInitiateDeleteModel.cs @@ -8,10 +8,8 @@ public class ProviderInitiateDeleteModel : BaseMailModel Token, ProviderNameUrlEncoded); - public string WebVaultUrl { get; set; } public string Token { get; set; } public Guid ProviderId { get; set; } - public string SiteName { get; set; } public string ProviderName { get; set; } public string ProviderNameUrlEncoded { get; set; } public string ProviderBillingEmail { get; set; } From c4ab5f31f5979319d4b7a9d3189117edaca3304c Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:12:04 +0100 Subject: [PATCH 05/28] [deps] Tools: Update aws-sdk-net monorepo (#5065) Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- src/Core/Core.csproj | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index 2f40fb7b23b0..d16624898dee 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -21,8 +21,8 @@ - - + + From 07592e22b9f4102e118d817c6384a7c4da604195 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 16:17:59 +0100 Subject: [PATCH 06/28] [deps]: Update Microsoft.NET.Test.Sdk to 17.12.0 (#5067) Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Daniel James Smith <2670567+djsmith85@users.noreply.github.com> --- .../Infrastructure.Dapper.Test.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Infrastructure.Dapper.Test/Infrastructure.Dapper.Test.csproj b/test/Infrastructure.Dapper.Test/Infrastructure.Dapper.Test.csproj index dfc8951cc38b..82d63bd3c19d 100644 --- a/test/Infrastructure.Dapper.Test/Infrastructure.Dapper.Test.csproj +++ b/test/Infrastructure.Dapper.Test/Infrastructure.Dapper.Test.csproj @@ -10,7 +10,7 @@ - + runtime; build; native; contentfiles; analyzers; buildtransitive From fd7ff2ac63f40cbee8457af1aa44e6bc9a1c0070 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:30:02 -0500 Subject: [PATCH 07/28] [deps] Billing: Update FluentAssertions to 6.12.2 (#5015) Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> --- test/Billing.Test/Billing.Test.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Billing.Test/Billing.Test.csproj b/test/Billing.Test/Billing.Test.csproj index 3bbda52ded40..c6a7ef48e0ab 100644 --- a/test/Billing.Test/Billing.Test.csproj +++ b/test/Billing.Test/Billing.Test.csproj @@ -6,7 +6,7 @@ - + From b974899127ee0154b20fb8ad8d254b17194ee523 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Mon, 25 Nov 2024 14:30:32 -0500 Subject: [PATCH 08/28] [deps] Billing: Update Braintree to 5.28.0 (#5019) Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> Co-authored-by: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> --- src/Core/Core.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index d16624898dee..5970629caa00 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -54,7 +54,7 @@ - + From 8f703a29ac4dd46d8f5eacee09b9fbb73ede9031 Mon Sep 17 00:00:00 2001 From: "renovate[bot]" <29139614+renovate[bot]@users.noreply.github.com> Date: Tue, 26 Nov 2024 13:20:42 -0500 Subject: [PATCH 09/28] [deps] DbOps: Update Microsoft.Azure.Cosmos to 3.46.0 (#5066) Co-authored-by: renovate[bot] <29139614+renovate[bot]@users.noreply.github.com> --- src/Core/Core.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Core/Core.csproj b/src/Core/Core.csproj index 5970629caa00..fd4d8cc7e124 100644 --- a/src/Core/Core.csproj +++ b/src/Core/Core.csproj @@ -36,7 +36,7 @@ - + From 1b75e35c319f588d9e8a5fe7fd9ccd6591813cf1 Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Tue, 26 Nov 2024 16:37:12 -0600 Subject: [PATCH 10/28] [PM-10319] - Revoke Non Complaint Users for 2FA and Single Org Policy Enablement (#5037) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Revoking users when enabling single org and 2fa policies. - Updated emails sent when users are revoked via 2FA or Single Organization policy enablement Co-authored-by: Matt Bishop Co-authored-by: Rui Tomé <108268980+r-tome@users.noreply.github.com> --- .../AdminConsole/Enums/EventSystemUser.cs | 1 + .../AdminConsole/Models/Data/IActingUser.cs | 10 + .../AdminConsole/Models/Data/StandardUser.cs | 16 ++ .../AdminConsole/Models/Data/SystemUser.cs | 16 ++ .../VerifyOrganizationDomainCommand.cs | 93 +++++---- ...vokeNonCompliantOrganizationUserCommand.cs | 9 + .../Requests/RevokeOrganizationUserRequest.cs | 13 ++ ...vokeNonCompliantOrganizationUserCommand.cs | 112 +++++++++++ .../Policies/Models/PolicyUpdate.cs | 2 + .../SingleOrgPolicyValidator.cs | 70 ++++++- .../TwoFactorAuthenticationPolicyValidator.cs | 69 ++++++- .../IOrganizationUserRepository.cs | 2 + .../AdminConsole/Services/IPolicyService.cs | 2 +- .../Services/Implementations/PolicyService.cs | 15 +- ...tionUserRevokedForSingleOrgPolicy.html.hbs | 14 ++ ...tionUserRevokedForSingleOrgPolicy.text.hbs | 5 + ...tionUserRevokedForTwoFactorPolicy.html.hbs | 15 ++ ...tionUserRevokedForTwoFactorPolicy.text.hbs | 7 + src/Core/Models/Commands/CommandResult.cs | 12 ++ ...nUserRevokedForPolicySingleOrgViewModel.cs | 6 + ...nUserRevokedForPolicyTwoFactorViewModel.cs | 6 + ...OrganizationServiceCollectionExtensions.cs | 1 + src/Core/Services/IMailService.cs | 2 + .../Implementations/HandlebarsMailService.cs | 31 ++- .../NoopImplementations/NoopMailService.cs | 6 + .../OrganizationUserRepository.cs | 10 + .../OrganizationUserRepository.cs | 12 ++ ...OrganizationUser_SetStatusForUsersById.sql | 29 +++ .../AutoFixture/PolicyUpdateFixtures.cs | 4 +- .../VerifyOrganizationDomainCommandTests.cs | 38 +++- ...onCompliantOrganizationUserCommandTests.cs | 185 ++++++++++++++++++ .../SingleOrgPolicyValidatorTests.cs | 151 ++++++++++++++ ...actorAuthenticationPolicyValidatorTests.cs | 150 +++++++++++++- .../Policies/SavePolicyCommandTests.cs | 4 +- .../Services/EventServiceTests.cs | 1 - .../2024-11-26-00_OrgUserSetStatusBulk.sql | 28 +++ 36 files changed, 1074 insertions(+), 73 deletions(-) create mode 100644 src/Core/AdminConsole/Models/Data/IActingUser.cs create mode 100644 src/Core/AdminConsole/Models/Data/StandardUser.cs create mode 100644 src/Core/AdminConsole/Models/Data/SystemUser.cs create mode 100644 src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRevokeNonCompliantOrganizationUserCommand.cs create mode 100644 src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Requests/RevokeOrganizationUserRequest.cs create mode 100644 src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeNonCompliantOrganizationUserCommand.cs create mode 100644 src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForSingleOrgPolicy.html.hbs create mode 100644 src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForSingleOrgPolicy.text.hbs create mode 100644 src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForTwoFactorPolicy.html.hbs create mode 100644 src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForTwoFactorPolicy.text.hbs create mode 100644 src/Core/Models/Commands/CommandResult.cs create mode 100644 src/Core/Models/Mail/OrganizationUserRevokedForPolicySingleOrgViewModel.cs create mode 100644 src/Core/Models/Mail/OrganizationUserRevokedForPolicyTwoFactorViewModel.cs create mode 100644 src/Sql/dbo/Stored Procedures/OrganizationUser_SetStatusForUsersById.sql create mode 100644 test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeNonCompliantOrganizationUserCommandTests.cs create mode 100644 util/Migrator/DbScripts/2024-11-26-00_OrgUserSetStatusBulk.sql diff --git a/src/Core/AdminConsole/Enums/EventSystemUser.cs b/src/Core/AdminConsole/Enums/EventSystemUser.cs index df9be9a350b7..c3e13705dd9f 100644 --- a/src/Core/AdminConsole/Enums/EventSystemUser.cs +++ b/src/Core/AdminConsole/Enums/EventSystemUser.cs @@ -2,6 +2,7 @@ public enum EventSystemUser : byte { + Unknown = 0, SCIM = 1, DomainVerification = 2, PublicApi = 3, diff --git a/src/Core/AdminConsole/Models/Data/IActingUser.cs b/src/Core/AdminConsole/Models/Data/IActingUser.cs new file mode 100644 index 000000000000..f97235f34cc4 --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/IActingUser.cs @@ -0,0 +1,10 @@ +using Bit.Core.Enums; + +namespace Bit.Core.AdminConsole.Models.Data; + +public interface IActingUser +{ + Guid? UserId { get; } + bool IsOrganizationOwnerOrProvider { get; } + EventSystemUser? SystemUserType { get; } +} diff --git a/src/Core/AdminConsole/Models/Data/StandardUser.cs b/src/Core/AdminConsole/Models/Data/StandardUser.cs new file mode 100644 index 000000000000..f21a41db7c0e --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/StandardUser.cs @@ -0,0 +1,16 @@ +using Bit.Core.Enums; + +namespace Bit.Core.AdminConsole.Models.Data; + +public class StandardUser : IActingUser +{ + public StandardUser(Guid userId, bool isOrganizationOwner) + { + UserId = userId; + IsOrganizationOwnerOrProvider = isOrganizationOwner; + } + + public Guid? UserId { get; } + public bool IsOrganizationOwnerOrProvider { get; } + public EventSystemUser? SystemUserType => throw new Exception($"{nameof(StandardUser)} does not have a {nameof(SystemUserType)}"); +} diff --git a/src/Core/AdminConsole/Models/Data/SystemUser.cs b/src/Core/AdminConsole/Models/Data/SystemUser.cs new file mode 100644 index 000000000000..c4859f928fa5 --- /dev/null +++ b/src/Core/AdminConsole/Models/Data/SystemUser.cs @@ -0,0 +1,16 @@ +using Bit.Core.Enums; + +namespace Bit.Core.AdminConsole.Models.Data; + +public class SystemUser : IActingUser +{ + public SystemUser(EventSystemUser systemUser) + { + SystemUserType = systemUser; + } + + public Guid? UserId => throw new Exception($"{nameof(SystemUserType)} does not have a {nameof(UserId)}."); + + public bool IsOrganizationOwnerOrProvider => false; + public EventSystemUser? SystemUserType { get; } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs index 870fa72aa709..dd6add669f2d 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs @@ -1,7 +1,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.Services; +using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -12,124 +14,121 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains; -public class VerifyOrganizationDomainCommand : IVerifyOrganizationDomainCommand +public class VerifyOrganizationDomainCommand( + IOrganizationDomainRepository organizationDomainRepository, + IDnsResolverService dnsResolverService, + IEventService eventService, + IGlobalSettings globalSettings, + IPolicyService policyService, + IFeatureService featureService, + ICurrentContext currentContext, + ILogger logger) + : IVerifyOrganizationDomainCommand { - private readonly IOrganizationDomainRepository _organizationDomainRepository; - private readonly IDnsResolverService _dnsResolverService; - private readonly IEventService _eventService; - private readonly IGlobalSettings _globalSettings; - private readonly IPolicyService _policyService; - private readonly IFeatureService _featureService; - private readonly ILogger _logger; - - public VerifyOrganizationDomainCommand( - IOrganizationDomainRepository organizationDomainRepository, - IDnsResolverService dnsResolverService, - IEventService eventService, - IGlobalSettings globalSettings, - IPolicyService policyService, - IFeatureService featureService, - ILogger logger) - { - _organizationDomainRepository = organizationDomainRepository; - _dnsResolverService = dnsResolverService; - _eventService = eventService; - _globalSettings = globalSettings; - _policyService = policyService; - _featureService = featureService; - _logger = logger; - } public async Task UserVerifyOrganizationDomainAsync(OrganizationDomain organizationDomain) { - var domainVerificationResult = await VerifyOrganizationDomainAsync(organizationDomain); + if (currentContext.UserId is null) + { + throw new InvalidOperationException( + $"{nameof(UserVerifyOrganizationDomainAsync)} can only be called by a user. " + + $"Please call {nameof(SystemVerifyOrganizationDomainAsync)} for system users."); + } - await _eventService.LogOrganizationDomainEventAsync(domainVerificationResult, + var actingUser = new StandardUser(currentContext.UserId.Value, await currentContext.OrganizationOwner(organizationDomain.OrganizationId)); + + var domainVerificationResult = await VerifyOrganizationDomainAsync(organizationDomain, actingUser); + + await eventService.LogOrganizationDomainEventAsync(domainVerificationResult, domainVerificationResult.VerifiedDate != null ? EventType.OrganizationDomain_Verified : EventType.OrganizationDomain_NotVerified); - await _organizationDomainRepository.ReplaceAsync(domainVerificationResult); + await organizationDomainRepository.ReplaceAsync(domainVerificationResult); return domainVerificationResult; } public async Task SystemVerifyOrganizationDomainAsync(OrganizationDomain organizationDomain) { + var actingUser = new SystemUser(EventSystemUser.DomainVerification); + organizationDomain.SetJobRunCount(); - var domainVerificationResult = await VerifyOrganizationDomainAsync(organizationDomain); + var domainVerificationResult = await VerifyOrganizationDomainAsync(organizationDomain, actingUser); if (domainVerificationResult.VerifiedDate is not null) { - _logger.LogInformation(Constants.BypassFiltersEventId, "Successfully validated domain"); + logger.LogInformation(Constants.BypassFiltersEventId, "Successfully validated domain"); - await _eventService.LogOrganizationDomainEventAsync(domainVerificationResult, + await eventService.LogOrganizationDomainEventAsync(domainVerificationResult, EventType.OrganizationDomain_Verified, EventSystemUser.DomainVerification); } else { - domainVerificationResult.SetNextRunDate(_globalSettings.DomainVerification.VerificationInterval); + domainVerificationResult.SetNextRunDate(globalSettings.DomainVerification.VerificationInterval); - await _eventService.LogOrganizationDomainEventAsync(domainVerificationResult, + await eventService.LogOrganizationDomainEventAsync(domainVerificationResult, EventType.OrganizationDomain_NotVerified, EventSystemUser.DomainVerification); - _logger.LogInformation(Constants.BypassFiltersEventId, + logger.LogInformation(Constants.BypassFiltersEventId, "Verification for organization {OrgId} with domain {Domain} failed", domainVerificationResult.OrganizationId, domainVerificationResult.DomainName); } - await _organizationDomainRepository.ReplaceAsync(domainVerificationResult); + await organizationDomainRepository.ReplaceAsync(domainVerificationResult); return domainVerificationResult; } - private async Task VerifyOrganizationDomainAsync(OrganizationDomain domain) + private async Task VerifyOrganizationDomainAsync(OrganizationDomain domain, IActingUser actingUser) { domain.SetLastCheckedDate(); if (domain.VerifiedDate is not null) { - await _organizationDomainRepository.ReplaceAsync(domain); + await organizationDomainRepository.ReplaceAsync(domain); throw new ConflictException("Domain has already been verified."); } var claimedDomain = - await _organizationDomainRepository.GetClaimedDomainsByDomainNameAsync(domain.DomainName); + await organizationDomainRepository.GetClaimedDomainsByDomainNameAsync(domain.DomainName); if (claimedDomain.Count > 0) { - await _organizationDomainRepository.ReplaceAsync(domain); + await organizationDomainRepository.ReplaceAsync(domain); throw new ConflictException("The domain is not available to be claimed."); } try { - if (await _dnsResolverService.ResolveAsync(domain.DomainName, domain.Txt)) + if (await dnsResolverService.ResolveAsync(domain.DomainName, domain.Txt)) { domain.SetVerifiedDate(); - await EnableSingleOrganizationPolicyAsync(domain.OrganizationId); + await EnableSingleOrganizationPolicyAsync(domain.OrganizationId, actingUser); } } catch (Exception e) { - _logger.LogError("Error verifying Organization domain: {domain}. {errorMessage}", + logger.LogError("Error verifying Organization domain: {domain}. {errorMessage}", domain.DomainName, e.Message); } return domain; } - private async Task EnableSingleOrganizationPolicyAsync(Guid organizationId) + private async Task EnableSingleOrganizationPolicyAsync(Guid organizationId, IActingUser actingUser) { - if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning)) + if (featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning)) { - await _policyService.SaveAsync( - new Policy { OrganizationId = organizationId, Type = PolicyType.SingleOrg, Enabled = true }, null); + await policyService.SaveAsync( + new Policy { OrganizationId = organizationId, Type = PolicyType.SingleOrg, Enabled = true }, + savingUserId: actingUser is StandardUser standardUser ? standardUser.UserId : null, + eventSystemUser: actingUser is SystemUser systemUser ? systemUser.SystemUserType : null); } } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRevokeNonCompliantOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRevokeNonCompliantOrganizationUserCommand.cs new file mode 100644 index 000000000000..c9768a890576 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRevokeNonCompliantOrganizationUserCommand.cs @@ -0,0 +1,9 @@ +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; +using Bit.Core.Models.Commands; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; + +public interface IRevokeNonCompliantOrganizationUserCommand +{ + Task RevokeNonCompliantOrganizationUsersAsync(RevokeOrganizationUsersRequest request); +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Requests/RevokeOrganizationUserRequest.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Requests/RevokeOrganizationUserRequest.cs new file mode 100644 index 000000000000..88f1dc8aa1f0 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Requests/RevokeOrganizationUserRequest.cs @@ -0,0 +1,13 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; + +public record RevokeOrganizationUsersRequest( + Guid OrganizationId, + IEnumerable OrganizationUsers, + IActingUser ActionPerformedBy) +{ + public RevokeOrganizationUsersRequest(Guid organizationId, OrganizationUserUserDetails organizationUser, IActingUser actionPerformedBy) + : this(organizationId, [organizationUser], actionPerformedBy) { } +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeNonCompliantOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeNonCompliantOrganizationUserCommand.cs new file mode 100644 index 000000000000..971ed02b29e9 --- /dev/null +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeNonCompliantOrganizationUserCommand.cs @@ -0,0 +1,112 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; +using Bit.Core.Enums; +using Bit.Core.Models.Commands; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Repositories; +using Bit.Core.Services; + +namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; + +public class RevokeNonCompliantOrganizationUserCommand(IOrganizationUserRepository organizationUserRepository, + IEventService eventService, + IHasConfirmedOwnersExceptQuery confirmedOwnersExceptQuery, + TimeProvider timeProvider) : IRevokeNonCompliantOrganizationUserCommand +{ + public const string ErrorCannotRevokeSelf = "You cannot revoke yourself."; + public const string ErrorOnlyOwnersCanRevokeOtherOwners = "Only owners can revoke other owners."; + public const string ErrorUserAlreadyRevoked = "User is already revoked."; + public const string ErrorOrgMustHaveAtLeastOneOwner = "Organization must have at least one confirmed owner."; + public const string ErrorInvalidUsers = "Invalid users."; + public const string ErrorRequestedByWasNotValid = "Action was performed by an unexpected type."; + + public async Task RevokeNonCompliantOrganizationUsersAsync(RevokeOrganizationUsersRequest request) + { + var validationResult = await ValidateAsync(request); + + if (validationResult.HasErrors) + { + return validationResult; + } + + await organizationUserRepository.RevokeManyByIdAsync(request.OrganizationUsers.Select(x => x.Id)); + + var now = timeProvider.GetUtcNow(); + + switch (request.ActionPerformedBy) + { + case StandardUser: + await eventService.LogOrganizationUserEventsAsync( + request.OrganizationUsers.Select(x => GetRevokedUserEventTuple(x, now))); + break; + case SystemUser { SystemUserType: not null } loggableSystem: + await eventService.LogOrganizationUserEventsAsync( + request.OrganizationUsers.Select(x => + GetRevokedUserEventBySystemUserTuple(x, loggableSystem.SystemUserType.Value, now))); + break; + } + + return validationResult; + } + + private static (OrganizationUserUserDetails organizationUser, EventType eventType, DateTime? time) GetRevokedUserEventTuple( + OrganizationUserUserDetails organizationUser, DateTimeOffset dateTimeOffset) => + new(organizationUser, EventType.OrganizationUser_Revoked, dateTimeOffset.UtcDateTime); + + private static (OrganizationUserUserDetails organizationUser, EventType eventType, EventSystemUser eventSystemUser, DateTime? time) GetRevokedUserEventBySystemUserTuple( + OrganizationUserUserDetails organizationUser, EventSystemUser systemUser, DateTimeOffset dateTimeOffset) => new(organizationUser, + EventType.OrganizationUser_Revoked, systemUser, dateTimeOffset.UtcDateTime); + + private async Task ValidateAsync(RevokeOrganizationUsersRequest request) + { + if (!PerformedByIsAnExpectedType(request.ActionPerformedBy)) + { + return new CommandResult(ErrorRequestedByWasNotValid); + } + + if (request.ActionPerformedBy is StandardUser user + && request.OrganizationUsers.Any(x => x.UserId == user.UserId)) + { + return new CommandResult(ErrorCannotRevokeSelf); + } + + if (request.OrganizationUsers.Any(x => x.OrganizationId != request.OrganizationId)) + { + return new CommandResult(ErrorInvalidUsers); + } + + if (!await confirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync( + request.OrganizationId, + request.OrganizationUsers.Select(x => x.Id))) + { + return new CommandResult(ErrorOrgMustHaveAtLeastOneOwner); + } + + return request.OrganizationUsers.Aggregate(new CommandResult(), (result, userToRevoke) => + { + if (IsAlreadyRevoked(userToRevoke)) + { + result.ErrorMessages.Add($"{ErrorUserAlreadyRevoked} Id: {userToRevoke.Id}"); + return result; + } + + if (NonOwnersCannotRevokeOwners(userToRevoke, request.ActionPerformedBy)) + { + result.ErrorMessages.Add($"{ErrorOnlyOwnersCanRevokeOtherOwners}"); + return result; + } + + return result; + }); + } + + private static bool PerformedByIsAnExpectedType(IActingUser entity) => entity is SystemUser or StandardUser; + + private static bool IsAlreadyRevoked(OrganizationUserUserDetails organizationUser) => + organizationUser is { Status: OrganizationUserStatusType.Revoked }; + + private static bool NonOwnersCannotRevokeOwners(OrganizationUserUserDetails organizationUser, + IActingUser actingUser) => + actingUser is StandardUser { IsOrganizationOwnerOrProvider: false } && organizationUser.Type == OrganizationUserType.Owner; +} diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs index 117a7ec73396..d1a52f008011 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Models/PolicyUpdate.cs @@ -1,6 +1,7 @@ #nullable enable using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.Utilities; @@ -15,6 +16,7 @@ public record PolicyUpdate public PolicyType Type { get; set; } public string? Data { get; set; } public bool Enabled { get; set; } + public IActingUser? PerformedBy { get; set; } public T GetDataModel() where T : IPolicyDataModel, new() { diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs index cc6971f9462a..050949ee7f91 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidator.cs @@ -2,8 +2,10 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; @@ -18,6 +20,8 @@ namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; public class SingleOrgPolicyValidator : IPolicyValidator { public PolicyType Type => PolicyType.SingleOrg; + private const string OrganizationNotFoundErrorMessage = "Organization not found."; + private const string ClaimedDomainSingleOrganizationRequiredErrorMessage = "The Single organization policy is required for organizations that have enabled domain verification."; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IMailService _mailService; @@ -27,6 +31,7 @@ public class SingleOrgPolicyValidator : IPolicyValidator private readonly IFeatureService _featureService; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; private readonly IOrganizationHasVerifiedDomainsQuery _organizationHasVerifiedDomainsQuery; + private readonly IRevokeNonCompliantOrganizationUserCommand _revokeNonCompliantOrganizationUserCommand; public SingleOrgPolicyValidator( IOrganizationUserRepository organizationUserRepository, @@ -36,7 +41,8 @@ public SingleOrgPolicyValidator( ICurrentContext currentContext, IFeatureService featureService, IRemoveOrganizationUserCommand removeOrganizationUserCommand, - IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery) + IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, + IRevokeNonCompliantOrganizationUserCommand revokeNonCompliantOrganizationUserCommand) { _organizationUserRepository = organizationUserRepository; _mailService = mailService; @@ -46,6 +52,7 @@ public SingleOrgPolicyValidator( _featureService = featureService; _removeOrganizationUserCommand = removeOrganizationUserCommand; _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; + _revokeNonCompliantOrganizationUserCommand = revokeNonCompliantOrganizationUserCommand; } public IEnumerable RequiredPolicies => []; @@ -54,8 +61,52 @@ public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? curr { if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true }) { - await RemoveNonCompliantUsersAsync(policyUpdate.OrganizationId); + if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning)) + { + var currentUser = _currentContext.UserId ?? Guid.Empty; + var isOwnerOrProvider = await _currentContext.OrganizationOwner(policyUpdate.OrganizationId); + await RevokeNonCompliantUsersAsync(policyUpdate.OrganizationId, policyUpdate.PerformedBy ?? new StandardUser(currentUser, isOwnerOrProvider)); + } + else + { + await RemoveNonCompliantUsersAsync(policyUpdate.OrganizationId); + } + } + } + + private async Task RevokeNonCompliantUsersAsync(Guid organizationId, IActingUser performedBy) + { + var organization = await _organizationRepository.GetByIdAsync(organizationId); + + if (organization is null) + { + throw new NotFoundException(OrganizationNotFoundErrorMessage); } + + var currentActiveRevocableOrganizationUsers = + (await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId)) + .Where(ou => ou.Status != OrganizationUserStatusType.Invited && + ou.Status != OrganizationUserStatusType.Revoked && + ou.Type != OrganizationUserType.Owner && + ou.Type != OrganizationUserType.Admin && + !(performedBy is StandardUser stdUser && stdUser.UserId == ou.UserId)) + .ToList(); + + if (currentActiveRevocableOrganizationUsers.Count == 0) + { + return; + } + + var commandResult = await _revokeNonCompliantOrganizationUserCommand.RevokeNonCompliantOrganizationUsersAsync( + new RevokeOrganizationUsersRequest(organizationId, currentActiveRevocableOrganizationUsers, performedBy)); + + if (commandResult.HasErrors) + { + throw new BadRequestException(string.Join(", ", commandResult.ErrorMessages)); + } + + await Task.WhenAll(currentActiveRevocableOrganizationUsers.Select(x => + _mailService.SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(organization.DisplayName(), x.Email))); } private async Task RemoveNonCompliantUsersAsync(Guid organizationId) @@ -67,7 +118,7 @@ private async Task RemoveNonCompliantUsersAsync(Guid organizationId) var org = await _organizationRepository.GetByIdAsync(organizationId); if (org == null) { - throw new NotFoundException("Organization not found."); + throw new NotFoundException(OrganizationNotFoundErrorMessage); } var removableOrgUsers = orgUsers.Where(ou => @@ -76,18 +127,17 @@ private async Task RemoveNonCompliantUsersAsync(Guid organizationId) ou.Type != OrganizationUserType.Owner && ou.Type != OrganizationUserType.Admin && ou.UserId != savingUserId - ).ToList(); + ).ToList(); var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync( - removableOrgUsers.Select(ou => ou.UserId!.Value)); + removableOrgUsers.Select(ou => ou.UserId!.Value)); foreach (var orgUser in removableOrgUsers) { if (userOrgs.Any(ou => ou.UserId == orgUser.UserId - && ou.OrganizationId != org.Id - && ou.Status != OrganizationUserStatusType.Invited)) + && ou.OrganizationId != org.Id + && ou.Status != OrganizationUserStatusType.Invited)) { - await _removeOrganizationUserCommand.RemoveUserAsync(organizationId, orgUser.Id, - savingUserId); + await _removeOrganizationUserCommand.RemoveUserAsync(organizationId, orgUser.Id, savingUserId); await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync( org.DisplayName(), orgUser.Email); @@ -111,7 +161,7 @@ public async Task ValidateAsync(PolicyUpdate policyUpdate, Policy? curre if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning) && await _organizationHasVerifiedDomainsQuery.HasVerifiedDomainsAsync(policyUpdate.OrganizationId)) { - return "The Single organization policy is required for organizations that have enabled domain verification."; + return ClaimedDomainSingleOrganizationRequiredErrorMessage; } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs index ef896bbb9b57..c2dd8cff91c8 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidator.cs @@ -2,12 +2,15 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; using Bit.Core.Services; @@ -21,6 +24,10 @@ public class TwoFactorAuthenticationPolicyValidator : IPolicyValidator private readonly ICurrentContext _currentContext; private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; + private readonly IFeatureService _featureService; + private readonly IRevokeNonCompliantOrganizationUserCommand _revokeNonCompliantOrganizationUserCommand; + + public const string NonCompliantMembersWillLoseAccessMessage = "Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page."; public PolicyType Type => PolicyType.TwoFactorAuthentication; public IEnumerable RequiredPolicies => []; @@ -31,7 +38,9 @@ public TwoFactorAuthenticationPolicyValidator( IOrganizationRepository organizationRepository, ICurrentContext currentContext, ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, - IRemoveOrganizationUserCommand removeOrganizationUserCommand) + IRemoveOrganizationUserCommand removeOrganizationUserCommand, + IFeatureService featureService, + IRevokeNonCompliantOrganizationUserCommand revokeNonCompliantOrganizationUserCommand) { _organizationUserRepository = organizationUserRepository; _mailService = mailService; @@ -39,14 +48,63 @@ public TwoFactorAuthenticationPolicyValidator( _currentContext = currentContext; _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; _removeOrganizationUserCommand = removeOrganizationUserCommand; + _featureService = featureService; + _revokeNonCompliantOrganizationUserCommand = revokeNonCompliantOrganizationUserCommand; } public async Task OnSaveSideEffectsAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) { if (currentPolicy is not { Enabled: true } && policyUpdate is { Enabled: true }) { - await RemoveNonCompliantUsersAsync(policyUpdate.OrganizationId); + if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning)) + { + var currentUser = _currentContext.UserId ?? Guid.Empty; + var isOwnerOrProvider = await _currentContext.OrganizationOwner(policyUpdate.OrganizationId); + await RevokeNonCompliantUsersAsync(policyUpdate.OrganizationId, policyUpdate.PerformedBy ?? new StandardUser(currentUser, isOwnerOrProvider)); + } + else + { + await RemoveNonCompliantUsersAsync(policyUpdate.OrganizationId); + } + } + } + + private async Task RevokeNonCompliantUsersAsync(Guid organizationId, IActingUser performedBy) + { + var organization = await _organizationRepository.GetByIdAsync(organizationId); + + var currentActiveRevocableOrganizationUsers = + (await _organizationUserRepository.GetManyDetailsByOrganizationAsync(organizationId)) + .Where(ou => ou.Status != OrganizationUserStatusType.Invited && + ou.Status != OrganizationUserStatusType.Revoked && + ou.Type != OrganizationUserType.Owner && + ou.Type != OrganizationUserType.Admin && + !(performedBy is StandardUser stdUser && stdUser.UserId == ou.UserId)) + .ToList(); + + if (currentActiveRevocableOrganizationUsers.Count == 0) + { + return; + } + + var organizationUsersTwoFactorEnabled = + await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(currentActiveRevocableOrganizationUsers); + + if (NonCompliantMembersWillLoseAccess(currentActiveRevocableOrganizationUsers, organizationUsersTwoFactorEnabled)) + { + throw new BadRequestException(NonCompliantMembersWillLoseAccessMessage); } + + var commandResult = await _revokeNonCompliantOrganizationUserCommand.RevokeNonCompliantOrganizationUsersAsync( + new RevokeOrganizationUsersRequest(organizationId, currentActiveRevocableOrganizationUsers, performedBy)); + + if (commandResult.HasErrors) + { + throw new BadRequestException(string.Join(", ", commandResult.ErrorMessages)); + } + + await Task.WhenAll(currentActiveRevocableOrganizationUsers.Select(x => + _mailService.SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(organization.DisplayName(), x.Email))); } private async Task RemoveNonCompliantUsersAsync(Guid organizationId) @@ -83,5 +141,12 @@ await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( } } + private static bool NonCompliantMembersWillLoseAccess( + IEnumerable orgUserDetails, + IEnumerable<(OrganizationUserUserDetails user, bool isTwoFactorEnabled)> organizationUsersTwoFactorEnabled) => + orgUserDetails.Any(x => + !x.HasMasterPassword && !organizationUsersTwoFactorEnabled.FirstOrDefault(u => u.user.Id == x.Id) + .isTwoFactorEnabled); + public Task ValidateAsync(PolicyUpdate policyUpdate, Policy? currentPolicy) => Task.FromResult(""); } diff --git a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs index cb540c212b5b..516b4614af49 100644 --- a/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs +++ b/src/Core/AdminConsole/Repositories/IOrganizationUserRepository.cs @@ -58,4 +58,6 @@ UpdateEncryptedDataForKeyRotation UpdateForKeyRotation(Guid userId, /// Returns a list of OrganizationUsers with email domains that match one of the Organization's claimed domains. /// Task> GetManyByOrganizationWithClaimedDomainsAsync(Guid organizationId); + + Task RevokeManyByIdAsync(IEnumerable organizationUserIds); } diff --git a/src/Core/AdminConsole/Services/IPolicyService.cs b/src/Core/AdminConsole/Services/IPolicyService.cs index 16ff2f4fa11c..715c3a34d934 100644 --- a/src/Core/AdminConsole/Services/IPolicyService.cs +++ b/src/Core/AdminConsole/Services/IPolicyService.cs @@ -9,7 +9,7 @@ namespace Bit.Core.AdminConsole.Services; public interface IPolicyService { - Task SaveAsync(Policy policy, Guid? savingUserId); + Task SaveAsync(Policy policy, Guid? savingUserId, EventSystemUser? eventSystemUser = null); /// /// Get the combined master password policy options for the specified user. diff --git a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs index 42655040a3c1..69ec27fd8a2a 100644 --- a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs +++ b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs @@ -1,5 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; @@ -9,6 +10,7 @@ using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; +using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -34,6 +36,7 @@ public class PolicyService : IPolicyService private readonly ISavePolicyCommand _savePolicyCommand; private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; private readonly IOrganizationHasVerifiedDomainsQuery _organizationHasVerifiedDomainsQuery; + private readonly ICurrentContext _currentContext; public PolicyService( IApplicationCacheService applicationCacheService, @@ -48,7 +51,8 @@ public PolicyService( IFeatureService featureService, ISavePolicyCommand savePolicyCommand, IRemoveOrganizationUserCommand removeOrganizationUserCommand, - IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery) + IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, + ICurrentContext currentContext) { _applicationCacheService = applicationCacheService; _eventService = eventService; @@ -63,19 +67,24 @@ public PolicyService( _savePolicyCommand = savePolicyCommand; _removeOrganizationUserCommand = removeOrganizationUserCommand; _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; + _currentContext = currentContext; } - public async Task SaveAsync(Policy policy, Guid? savingUserId) + public async Task SaveAsync(Policy policy, Guid? savingUserId, EventSystemUser? eventSystemUser = null) { if (_featureService.IsEnabled(FeatureFlagKeys.Pm13322AddPolicyDefinitions)) { // Transitional mapping - this will be moved to callers once the feature flag is removed + // TODO make sure to populate with SystemUser if not an actual user var policyUpdate = new PolicyUpdate { OrganizationId = policy.OrganizationId, Type = policy.Type, Enabled = policy.Enabled, - Data = policy.Data + Data = policy.Data, + PerformedBy = savingUserId.HasValue + ? new StandardUser(savingUserId.Value, await _currentContext.OrganizationOwner(policy.OrganizationId)) + : new SystemUser(eventSystemUser ?? EventSystemUser.Unknown) }; await _savePolicyCommand.SaveAsync(policyUpdate); diff --git a/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForSingleOrgPolicy.html.hbs b/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForSingleOrgPolicy.html.hbs new file mode 100644 index 000000000000..d04abe86c935 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForSingleOrgPolicy.html.hbs @@ -0,0 +1,14 @@ +{{#>FullHtmlLayout}} + + + + + + + +
+ Your user account has been revoked from the {{OrganizationName}} organization because your account is part of multiple organizations. Before you can re-join {{OrganizationName}}, you must first leave all other organizations. +
+ To leave an organization, first log into the web app, select the three dot menu next to the organization name, and select Leave. +
+{{/FullHtmlLayout}} diff --git a/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForSingleOrgPolicy.text.hbs b/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForSingleOrgPolicy.text.hbs new file mode 100644 index 000000000000..f933e8cf625e --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForSingleOrgPolicy.text.hbs @@ -0,0 +1,5 @@ +{{#>BasicTextLayout}} +Your user account has been revoked from the {{OrganizationName}} organization because your account is part of multiple organizations. Before you can rejoin {{OrganizationName}}, you must first leave all other organizations. + +To leave an organization, first log in the web app (https://vault.bitwarden.com/#/login), select the three dot menu next to the organization name, and select Leave. +{{/BasicTextLayout}} diff --git a/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForTwoFactorPolicy.html.hbs b/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForTwoFactorPolicy.html.hbs new file mode 100644 index 000000000000..cf38632a9e1f --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForTwoFactorPolicy.html.hbs @@ -0,0 +1,15 @@ +{{#>FullHtmlLayout}} + + + + + + + +
+ Your user account has been revoked from the {{OrganizationName}} organization because you do not have two-step login configured. Before you can re-join {{OrganizationName}}, you need to set up two-step login on your user account. +
+ Learn how to enable two-step login on your user account at + https://help.bitwarden.com/article/setup-two-step-login/ +
+{{/FullHtmlLayout}} diff --git a/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForTwoFactorPolicy.text.hbs b/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForTwoFactorPolicy.text.hbs new file mode 100644 index 000000000000..f197f37f00e5 --- /dev/null +++ b/src/Core/MailTemplates/Handlebars/AdminConsole/OrganizationUserRevokedForTwoFactorPolicy.text.hbs @@ -0,0 +1,7 @@ +{{#>BasicTextLayout}} + Your user account has been removed from the {{OrganizationName}} organization because you do not have two-step login + configured. Before you can re-join this organization you need to set up two-step login on your user account. + + Learn how to enable two-step login on your user account at + https://help.bitwarden.com/article/setup-two-step-login/ +{{/BasicTextLayout}} diff --git a/src/Core/Models/Commands/CommandResult.cs b/src/Core/Models/Commands/CommandResult.cs new file mode 100644 index 000000000000..4ac2d6249918 --- /dev/null +++ b/src/Core/Models/Commands/CommandResult.cs @@ -0,0 +1,12 @@ +namespace Bit.Core.Models.Commands; + +public class CommandResult(IEnumerable errors) +{ + public CommandResult(string error) : this([error]) { } + + public bool Success => ErrorMessages.Count == 0; + public bool HasErrors => ErrorMessages.Count > 0; + public List ErrorMessages { get; } = errors.ToList(); + + public CommandResult() : this([]) { } +} diff --git a/src/Core/Models/Mail/OrganizationUserRevokedForPolicySingleOrgViewModel.cs b/src/Core/Models/Mail/OrganizationUserRevokedForPolicySingleOrgViewModel.cs new file mode 100644 index 000000000000..27c784bd15dd --- /dev/null +++ b/src/Core/Models/Mail/OrganizationUserRevokedForPolicySingleOrgViewModel.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.Models.Mail; + +public class OrganizationUserRevokedForPolicySingleOrgViewModel : BaseMailModel +{ + public string OrganizationName { get; set; } +} diff --git a/src/Core/Models/Mail/OrganizationUserRevokedForPolicyTwoFactorViewModel.cs b/src/Core/Models/Mail/OrganizationUserRevokedForPolicyTwoFactorViewModel.cs new file mode 100644 index 000000000000..9286ee74b3d2 --- /dev/null +++ b/src/Core/Models/Mail/OrganizationUserRevokedForPolicyTwoFactorViewModel.cs @@ -0,0 +1,6 @@ +namespace Bit.Core.Models.Mail; + +public class OrganizationUserRevokedForPolicyTwoFactorViewModel : BaseMailModel +{ + public string OrganizationName { get; set; } +} diff --git a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs index d11da2119aa5..96fdefcfad44 100644 --- a/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs +++ b/src/Core/OrganizationFeatures/OrganizationServiceCollectionExtensions.cs @@ -91,6 +91,7 @@ private static void AddOrganizationSponsorshipCommands(this IServiceCollection s private static void AddOrganizationUserCommands(this IServiceCollection services) { services.AddScoped(); + services.AddScoped(); services.AddScoped(); services.AddScoped(); services.AddScoped(); diff --git a/src/Core/Services/IMailService.cs b/src/Core/Services/IMailService.cs index 5514cd507d32..bc8d1440f155 100644 --- a/src/Core/Services/IMailService.cs +++ b/src/Core/Services/IMailService.cs @@ -35,6 +35,8 @@ Task SendTrialInitiationSignupEmailAsync( Task SendOrganizationAcceptedEmailAsync(Organization organization, string userIdentifier, IEnumerable adminEmails, bool hasAccessSecretsManager = false); Task SendOrganizationConfirmedEmailAsync(string organizationName, string email, bool hasAccessSecretsManager = false); Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email); + Task SendOrganizationUserRevokedForTwoFactoryPolicyEmailAsync(string organizationName, string email); + Task SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(string organizationName, string email); Task SendPasswordlessSignInAsync(string returnUrl, string token, string email); Task SendInvoiceUpcoming( string email, diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index e1943b0e3c78..acc729e53ce0 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -25,8 +25,7 @@ public class HandlebarsMailService : IMailService private readonly GlobalSettings _globalSettings; private readonly IMailDeliveryService _mailDeliveryService; private readonly IMailEnqueuingService _mailEnqueuingService; - private readonly Dictionary> _templateCache = - new Dictionary>(); + private readonly Dictionary> _templateCache = new(); private bool _registeredHelpersAndPartials = false; @@ -295,6 +294,20 @@ public async Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string o await _mailDeliveryService.SendEmailAsync(message); } + public async Task SendOrganizationUserRevokedForTwoFactoryPolicyEmailAsync(string organizationName, string email) + { + var message = CreateDefaultMessage($"You have been revoked from {organizationName}", email); + var model = new OrganizationUserRevokedForPolicyTwoFactorViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "AdminConsole.OrganizationUserRevokedForTwoFactorPolicy", model); + message.Category = "OrganizationUserRevokedForTwoFactorPolicy"; + await _mailDeliveryService.SendEmailAsync(message); + } + public async Task SendWelcomeEmailAsync(User user) { var message = CreateDefaultMessage("Welcome to Bitwarden!", user.Email); @@ -496,6 +509,20 @@ public async Task SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(string await _mailDeliveryService.SendEmailAsync(message); } + public async Task SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(string organizationName, string email) + { + var message = CreateDefaultMessage($"You have been revoked from {organizationName}", email); + var model = new OrganizationUserRevokedForPolicySingleOrgViewModel + { + OrganizationName = CoreHelpers.SanitizeForEmail(organizationName, false), + WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, + SiteName = _globalSettings.SiteName + }; + await AddMessageContentAsync(message, "AdminConsole.OrganizationUserRevokedForSingleOrgPolicy", model); + message.Category = "OrganizationUserRevokedForSingleOrgPolicy"; + await _mailDeliveryService.SendEmailAsync(message); + } + public async Task SendEnqueuedMailMessageAsync(IMailQueueMessage queueMessage) { var message = CreateDefaultMessage(queueMessage.Subject, queueMessage.ToEmails); diff --git a/src/Core/Services/NoopImplementations/NoopMailService.cs b/src/Core/Services/NoopImplementations/NoopMailService.cs index a56858fb96f5..399874eee73d 100644 --- a/src/Core/Services/NoopImplementations/NoopMailService.cs +++ b/src/Core/Services/NoopImplementations/NoopMailService.cs @@ -79,6 +79,12 @@ public Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organiz return Task.FromResult(0); } + public Task SendOrganizationUserRevokedForTwoFactoryPolicyEmailAsync(string organizationName, string email) => + Task.CompletedTask; + + public Task SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(string organizationName, string email) => + Task.CompletedTask; + public Task SendTwoFactorEmailAsync(string email, string token) { return Task.FromResult(0); diff --git a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs index d5bdd3b6a270..42f79852f3c8 100644 --- a/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.Dapper/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -557,4 +557,14 @@ public async Task> GetManyByOrganizationWithClaime return results.ToList(); } } + + public async Task RevokeManyByIdAsync(IEnumerable organizationUserIds) + { + await using var connection = new SqlConnection(ConnectionString); + + await connection.ExecuteAsync( + "[dbo].[OrganizationUser_SetStatusForUsersById]", + new { OrganizationUserIds = JsonSerializer.Serialize(organizationUserIds), Status = OrganizationUserStatusType.Revoked }, + commandType: CommandType.StoredProcedure); + } } diff --git a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs index a64c19704de2..007ff1a7ff6e 100644 --- a/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs +++ b/src/Infrastructure.EntityFramework/AdminConsole/Repositories/OrganizationUserRepository.cs @@ -721,4 +721,16 @@ public UpdateEncryptedDataForKeyRotation UpdateForKeyRotation( return data; } } + + public async Task RevokeManyByIdAsync(IEnumerable organizationUserIds) + { + using var scope = ServiceScopeFactory.CreateScope(); + + var dbContext = GetDatabaseContext(scope); + + await dbContext.OrganizationUsers.Where(x => organizationUserIds.Contains(x.Id)) + .ExecuteUpdateAsync(s => s.SetProperty(x => x.Status, OrganizationUserStatusType.Revoked)); + + await dbContext.UserBumpAccountRevisionDateByOrganizationUserIdsAsync(organizationUserIds); + } } diff --git a/src/Sql/dbo/Stored Procedures/OrganizationUser_SetStatusForUsersById.sql b/src/Sql/dbo/Stored Procedures/OrganizationUser_SetStatusForUsersById.sql new file mode 100644 index 000000000000..95ed5a3155d0 --- /dev/null +++ b/src/Sql/dbo/Stored Procedures/OrganizationUser_SetStatusForUsersById.sql @@ -0,0 +1,29 @@ +CREATE PROCEDURE [dbo].[OrganizationUser_SetStatusForUsersById] + @OrganizationUserIds AS NVARCHAR(MAX), + @Status SMALLINT +AS +BEGIN + SET NOCOUNT ON + + -- Declare a table variable to hold the parsed JSON data + DECLARE @ParsedIds TABLE (Id UNIQUEIDENTIFIER); + + -- Parse the JSON input into the table variable + INSERT INTO @ParsedIds (Id) + SELECT value + FROM OPENJSON(@OrganizationUserIds); + + -- Check if the input table is empty + IF (SELECT COUNT(1) FROM @ParsedIds) < 1 + BEGIN + RETURN(-1); + END + + UPDATE + [dbo].[OrganizationUser] + SET [Status] = @Status + WHERE [Id] IN (SELECT Id from @ParsedIds) + + EXEC [dbo].[User_BumpAccountRevisionDateByOrganizationUserIds] @OrganizationUserIds +END + diff --git a/test/Core.Test/AdminConsole/AutoFixture/PolicyUpdateFixtures.cs b/test/Core.Test/AdminConsole/AutoFixture/PolicyUpdateFixtures.cs index dff9b571782c..794f6fddf351 100644 --- a/test/Core.Test/AdminConsole/AutoFixture/PolicyUpdateFixtures.cs +++ b/test/Core.Test/AdminConsole/AutoFixture/PolicyUpdateFixtures.cs @@ -2,6 +2,7 @@ using AutoFixture; using AutoFixture.Xunit2; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; namespace Bit.Core.Test.AdminConsole.AutoFixture; @@ -12,7 +13,8 @@ public void Customize(IFixture fixture) { fixture.Customize(composer => composer .With(o => o.Type, type) - .With(o => o.Enabled, enabled)); + .With(o => o.Enabled, enabled) + .With(o => o.PerformedBy, new StandardUser(Guid.NewGuid(), false))); } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs index 2fcaf8134c80..8dbe5331317c 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs @@ -2,6 +2,7 @@ using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains; using Bit.Core.AdminConsole.Services; +using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -28,7 +29,12 @@ public async Task UserVerifyOrganizationDomainAsync_ShouldThrowConflict_WhenDoma DomainName = "Test Domain", Txt = "btw+test18383838383" }; + + sutProvider.GetDependency() + .UserId.Returns(Guid.NewGuid()); + expected.SetVerifiedDate(); + sutProvider.GetDependency() .GetByIdAsync(id) .Returns(expected); @@ -53,6 +59,10 @@ public async Task UserVerifyOrganizationDomainAsync_ShouldThrowConflict_WhenDoma sutProvider.GetDependency() .GetByIdAsync(id) .Returns(expected); + + sutProvider.GetDependency() + .UserId.Returns(Guid.NewGuid()); + sutProvider.GetDependency() .GetClaimedDomainsByDomainNameAsync(expected.DomainName) .Returns(new List { expected }); @@ -77,9 +87,14 @@ public async Task UserVerifyOrganizationDomainAsync_ShouldVerifyDomainUpdateAndL sutProvider.GetDependency() .GetByIdAsync(id) .Returns(expected); + + sutProvider.GetDependency() + .UserId.Returns(Guid.NewGuid()); + sutProvider.GetDependency() .GetClaimedDomainsByDomainNameAsync(expected.DomainName) .Returns(new List()); + sutProvider.GetDependency() .ResolveAsync(expected.DomainName, Arg.Any()) .Returns(true); @@ -107,9 +122,14 @@ public async Task UserVerifyOrganizationDomainAsync_ShouldNotSetVerifiedDate_Whe sutProvider.GetDependency() .GetByIdAsync(id) .Returns(expected); + + sutProvider.GetDependency() + .UserId.Returns(Guid.NewGuid()); + sutProvider.GetDependency() .GetClaimedDomainsByDomainNameAsync(expected.DomainName) .Returns(new List()); + sutProvider.GetDependency() .ResolveAsync(expected.DomainName, Arg.Any()) .Returns(false); @@ -143,7 +163,7 @@ await sutProvider.GetDependency().ReceivedWithAnyArgs(1) [Theory, BitAutoData] public async Task UserVerifyOrganizationDomainAsync_GivenOrganizationDomainWithAccountDeprovisioningEnabled_WhenDomainIsVerified_ThenSingleOrgPolicyShouldBeEnabled( - OrganizationDomain domain, SutProvider sutProvider) + OrganizationDomain domain, Guid userId, SutProvider sutProvider) { sutProvider.GetDependency() .GetClaimedDomainsByDomainNameAsync(domain.DomainName) @@ -157,11 +177,14 @@ public async Task UserVerifyOrganizationDomainAsync_GivenOrganizationDomainWithA .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) .Returns(true); + sutProvider.GetDependency() + .UserId.Returns(userId); + _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); await sutProvider.GetDependency() .Received(1) - .SaveAsync(Arg.Is(x => x.Type == PolicyType.SingleOrg && x.OrganizationId == domain.OrganizationId && x.Enabled), null); + .SaveAsync(Arg.Is(x => x.Type == PolicyType.SingleOrg && x.OrganizationId == domain.OrganizationId && x.Enabled), userId); } [Theory, BitAutoData] @@ -176,6 +199,9 @@ public async Task UserVerifyOrganizationDomainAsync_GivenOrganizationDomainWithA .ResolveAsync(domain.DomainName, domain.Txt) .Returns(true); + sutProvider.GetDependency() + .UserId.Returns(Guid.NewGuid()); + sutProvider.GetDependency() .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) .Returns(false); @@ -189,7 +215,7 @@ await sutProvider.GetDependency() [Theory, BitAutoData] public async Task UserVerifyOrganizationDomainAsync_GivenOrganizationDomainWithAccountDeprovisioningEnabled_WhenDomainIsNotVerified_ThenSingleOrgPolicyShouldNotBeEnabled( - OrganizationDomain domain, SutProvider sutProvider) + OrganizationDomain domain, SutProvider sutProvider) { sutProvider.GetDependency() .GetClaimedDomainsByDomainNameAsync(domain.DomainName) @@ -199,6 +225,9 @@ public async Task UserVerifyOrganizationDomainAsync_GivenOrganizationDomainWithA .ResolveAsync(domain.DomainName, domain.Txt) .Returns(false); + sutProvider.GetDependency() + .UserId.Returns(Guid.NewGuid()); + sutProvider.GetDependency() .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) .Returns(true); @@ -223,6 +252,9 @@ public async Task UserVerifyOrganizationDomainAsync_GivenOrganizationDomainWithA .ResolveAsync(domain.DomainName, domain.Txt) .Returns(false); + sutProvider.GetDependency() + .UserId.Returns(Guid.NewGuid()); + sutProvider.GetDependency() .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) .Returns(true); diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeNonCompliantOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeNonCompliantOrganizationUserCommandTests.cs new file mode 100644 index 000000000000..3653cd27d7df --- /dev/null +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RevokeNonCompliantOrganizationUserCommandTests.cs @@ -0,0 +1,185 @@ +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; +using Bit.Core.Enums; +using Bit.Core.Models.Data.Organizations.OrganizationUsers; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using NSubstitute; +using Xunit; + +namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers; + +[SutProviderCustomize] +public class RevokeNonCompliantOrganizationUserCommandTests +{ + [Theory, BitAutoData] + public async Task RevokeNonCompliantOrganizationUsersAsync_GivenUnrecognizedUserType_WhenAttemptingToRevoke_ThenErrorShouldBeReturned( + Guid organizationId, SutProvider sutProvider) + { + var command = new RevokeOrganizationUsersRequest(organizationId, [], new InvalidUser()); + + var result = await sutProvider.Sut.RevokeNonCompliantOrganizationUsersAsync(command); + + Assert.True(result.HasErrors); + Assert.Contains(RevokeNonCompliantOrganizationUserCommand.ErrorRequestedByWasNotValid, result.ErrorMessages); + } + + [Theory, BitAutoData] + public async Task RevokeNonCompliantOrganizationUsersAsync_GivenPopulatedRequest_WhenUserAttemptsToRevokeThemselves_ThenErrorShouldBeReturned( + Guid organizationId, OrganizationUserUserDetails revokingUser, + SutProvider sutProvider) + { + var command = new RevokeOrganizationUsersRequest(organizationId, revokingUser, + new StandardUser(revokingUser?.UserId ?? Guid.NewGuid(), true)); + + var result = await sutProvider.Sut.RevokeNonCompliantOrganizationUsersAsync(command); + + Assert.True(result.HasErrors); + Assert.Contains(RevokeNonCompliantOrganizationUserCommand.ErrorCannotRevokeSelf, result.ErrorMessages); + } + + [Theory, BitAutoData] + public async Task RevokeNonCompliantOrganizationUsersAsync_GivenPopulatedRequest_WhenUserAttemptsToRevokeOrgUsersFromAnotherOrg_ThenErrorShouldBeReturned( + Guid organizationId, OrganizationUserUserDetails userFromAnotherOrg, + SutProvider sutProvider) + { + userFromAnotherOrg.OrganizationId = Guid.NewGuid(); + + var command = new RevokeOrganizationUsersRequest(organizationId, userFromAnotherOrg, + new StandardUser(Guid.NewGuid(), true)); + + var result = await sutProvider.Sut.RevokeNonCompliantOrganizationUsersAsync(command); + + Assert.True(result.HasErrors); + Assert.Contains(RevokeNonCompliantOrganizationUserCommand.ErrorInvalidUsers, result.ErrorMessages); + } + + [Theory, BitAutoData] + public async Task RevokeNonCompliantOrganizationUsersAsync_GivenPopulatedRequest_WhenUserAttemptsToRevokeAllOwnersFromOrg_ThenErrorShouldBeReturned( + Guid organizationId, OrganizationUserUserDetails userToRevoke, + SutProvider sutProvider) + { + userToRevoke.OrganizationId = organizationId; + + var command = new RevokeOrganizationUsersRequest(organizationId, userToRevoke, + new StandardUser(Guid.NewGuid(), true)); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(false); + + var result = await sutProvider.Sut.RevokeNonCompliantOrganizationUsersAsync(command); + + Assert.True(result.HasErrors); + Assert.Contains(RevokeNonCompliantOrganizationUserCommand.ErrorOrgMustHaveAtLeastOneOwner, result.ErrorMessages); + } + + [Theory, BitAutoData] + public async Task RevokeNonCompliantOrganizationUsersAsync_GivenPopulatedRequest_WhenUserAttemptsToRevokeOwnerWhenNotAnOwner_ThenErrorShouldBeReturned( + Guid organizationId, OrganizationUserUserDetails userToRevoke, + SutProvider sutProvider) + { + userToRevoke.OrganizationId = organizationId; + userToRevoke.Type = OrganizationUserType.Owner; + + var command = new RevokeOrganizationUsersRequest(organizationId, userToRevoke, + new StandardUser(Guid.NewGuid(), false)); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + var result = await sutProvider.Sut.RevokeNonCompliantOrganizationUsersAsync(command); + + Assert.True(result.HasErrors); + Assert.Contains(RevokeNonCompliantOrganizationUserCommand.ErrorOnlyOwnersCanRevokeOtherOwners, result.ErrorMessages); + } + + [Theory, BitAutoData] + public async Task RevokeNonCompliantOrganizationUsersAsync_GivenPopulatedRequest_WhenUserAttemptsToRevokeUserWhoIsAlreadyRevoked_ThenErrorShouldBeReturned( + Guid organizationId, OrganizationUserUserDetails userToRevoke, + SutProvider sutProvider) + { + userToRevoke.OrganizationId = organizationId; + userToRevoke.Status = OrganizationUserStatusType.Revoked; + + var command = new RevokeOrganizationUsersRequest(organizationId, userToRevoke, + new StandardUser(Guid.NewGuid(), true)); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + var result = await sutProvider.Sut.RevokeNonCompliantOrganizationUsersAsync(command); + + Assert.True(result.HasErrors); + Assert.Contains($"{RevokeNonCompliantOrganizationUserCommand.ErrorUserAlreadyRevoked} Id: {userToRevoke.Id}", result.ErrorMessages); + } + + [Theory, BitAutoData] + public async Task RevokeNonCompliantOrganizationUsersAsync_GivenPopulatedRequest_WhenUserHasMultipleInvalidUsers_ThenErrorShouldBeReturned( + Guid organizationId, IEnumerable usersToRevoke, + SutProvider sutProvider) + { + var revocableUsers = usersToRevoke.ToList(); + revocableUsers.ForEach(user => user.OrganizationId = organizationId); + revocableUsers[0].Type = OrganizationUserType.Owner; + revocableUsers[1].Status = OrganizationUserStatusType.Revoked; + + var command = new RevokeOrganizationUsersRequest(organizationId, revocableUsers, + new StandardUser(Guid.NewGuid(), false)); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + var result = await sutProvider.Sut.RevokeNonCompliantOrganizationUsersAsync(command); + + Assert.True(result.HasErrors); + Assert.True(result.ErrorMessages.Count > 1); + } + + [Theory, BitAutoData] + public async Task RevokeNonCompliantOrganizationUsersAsync_GivenValidPopulatedRequest_WhenUserAttemptsToRevokeAUser_ThenUserShouldBeRevoked( + Guid organizationId, OrganizationUserUserDetails userToRevoke, + SutProvider sutProvider) + { + userToRevoke.OrganizationId = organizationId; + userToRevoke.Type = OrganizationUserType.Admin; + + var command = new RevokeOrganizationUsersRequest(organizationId, userToRevoke, + new StandardUser(Guid.NewGuid(), false)); + + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(organizationId, Arg.Any>()) + .Returns(true); + + var result = await sutProvider.Sut.RevokeNonCompliantOrganizationUsersAsync(command); + + await sutProvider.GetDependency() + .Received(1) + .RevokeManyByIdAsync(Arg.Any>()); + + Assert.True(result.Success); + + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventsAsync( + Arg.Is>( + x => x.Any(y => + y.organizationUser.Id == userToRevoke.Id && y.eventType == EventType.OrganizationUser_Revoked) + )); + } + + public class InvalidUser : IActingUser + { + public Guid? UserId => Guid.Empty; + public bool IsOrganizationOwnerOrProvider => false; + public EventSystemUser? SystemUserType => null; + } +} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs index 76ee5748403f..0731920757a6 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/SingleOrgPolicyValidatorTests.cs @@ -1,6 +1,7 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; using Bit.Core.Auth.Entities; @@ -10,6 +11,7 @@ using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; +using Bit.Core.Models.Commands; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; using Bit.Core.Services; @@ -61,6 +63,79 @@ public async Task ValidateAsync_DisablingPolicy_KeyConnectorNotEnabled_Success( Assert.True(string.IsNullOrEmpty(result)); } + [Theory, BitAutoData] + public async Task OnSaveSideEffectsAsync_RevokesNonCompliantUsers( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy policy, + Guid savingUserId, + Guid nonCompliantUserId, + Organization organization, SutProvider sutProvider) + { + policy.OrganizationId = organization.Id = policyUpdate.OrganizationId; + + var compliantUser1 = new OrganizationUserUserDetails + { + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = new Guid(), + Email = "user1@example.com" + }; + + var compliantUser2 = new OrganizationUserUserDetails + { + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = new Guid(), + Email = "user2@example.com" + }; + + var nonCompliantUser = new OrganizationUserUserDetails + { + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = nonCompliantUserId, + Email = "user3@example.com" + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([compliantUser1, compliantUser2, nonCompliantUser]); + + var otherOrganizationUser = new OrganizationUser + { + OrganizationId = new Guid(), + UserId = nonCompliantUserId, + Status = OrganizationUserStatusType.Confirmed + }; + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Is>(ids => ids.Contains(nonCompliantUserId))) + .Returns([otherOrganizationUser]); + + sutProvider.GetDependency().UserId.Returns(savingUserId); + sutProvider.GetDependency().GetByIdAsync(policyUpdate.OrganizationId).Returns(organization); + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(true); + + sutProvider.GetDependency() + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()) + .Returns(new CommandResult()); + + await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy); + + await sutProvider.GetDependency() + .Received(1) + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()); + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(organization.DisplayName(), + "user3@example.com"); + } + [Theory, BitAutoData] public async Task OnSaveSideEffectsAsync_RemovesNonCompliantUsers( [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, @@ -116,6 +191,13 @@ public async Task OnSaveSideEffectsAsync_RemovesNonCompliantUsers( sutProvider.GetDependency().UserId.Returns(savingUserId); sutProvider.GetDependency().GetByIdAsync(policyUpdate.OrganizationId).Returns(organization); + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(false); + + sutProvider.GetDependency() + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()) + .Returns(new CommandResult()); + await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy); await sutProvider.GetDependency() @@ -126,4 +208,73 @@ await sutProvider.GetDependency() .SendOrganizationUserRemovedForPolicySingleOrgEmailAsync(organization.DisplayName(), "user3@example.com"); } + + [Theory, BitAutoData] + public async Task OnSaveSideEffectsAsync_WhenAccountDeprovisioningIsEnabled_ThenUsersAreRevoked( + [PolicyUpdate(PolicyType.SingleOrg)] PolicyUpdate policyUpdate, + [Policy(PolicyType.SingleOrg, false)] Policy policy, + Guid savingUserId, + Guid nonCompliantUserId, + Organization organization, SutProvider sutProvider) + { + policy.OrganizationId = organization.Id = policyUpdate.OrganizationId; + + var compliantUser1 = new OrganizationUserUserDetails + { + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = new Guid(), + Email = "user1@example.com" + }; + + var compliantUser2 = new OrganizationUserUserDetails + { + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = new Guid(), + Email = "user2@example.com" + }; + + var nonCompliantUser = new OrganizationUserUserDetails + { + OrganizationId = organization.Id, + Type = OrganizationUserType.User, + Status = OrganizationUserStatusType.Confirmed, + UserId = nonCompliantUserId, + Email = "user3@example.com" + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([compliantUser1, compliantUser2, nonCompliantUser]); + + var otherOrganizationUser = new OrganizationUser + { + OrganizationId = new Guid(), + UserId = nonCompliantUserId, + Status = OrganizationUserStatusType.Confirmed + }; + + sutProvider.GetDependency() + .GetManyByManyUsersAsync(Arg.Is>(ids => ids.Contains(nonCompliantUserId))) + .Returns([otherOrganizationUser]); + + sutProvider.GetDependency().UserId.Returns(savingUserId); + sutProvider.GetDependency().GetByIdAsync(policyUpdate.OrganizationId) + .Returns(organization); + + sutProvider.GetDependency().IsEnabled(FeatureFlagKeys.AccountDeprovisioning).Returns(true); + + sutProvider.GetDependency() + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()) + .Returns(new CommandResult()); + + await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy); + + await sutProvider.GetDependency() + .Received() + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs index 4dce13174934..4e5f1816a5a4 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/PolicyValidators/TwoFactorAuthenticationPolicyValidatorTests.cs @@ -1,12 +1,14 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Requests; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyValidators; using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Context; using Bit.Core.Enums; using Bit.Core.Exceptions; +using Bit.Core.Models.Commands; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; using Bit.Core.Services; @@ -176,6 +178,10 @@ public async Task OnSaveSideEffectsAsync_UsersToBeRemovedDontHaveMasterPasswords HasMasterPassword = false }; + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(false); + sutProvider.GetDependency() .GetManyDetailsByOrganizationAsync(policy.OrganizationId) .Returns(new List @@ -201,9 +207,151 @@ public async Task OnSaveSideEffectsAsync_UsersToBeRemovedDontHaveMasterPasswords var badRequestException = await Assert.ThrowsAsync( () => sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy)); - Assert.Contains("Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); + Assert.Equal(TwoFactorAuthenticationPolicyValidator.NonCompliantMembersWillLoseAccessMessage, badRequestException.Message); await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() .RemoveUserAsync(organizationId: default, organizationUserId: default, deletingUserId: default); } + + [Theory, BitAutoData] + public async Task OnSaveSideEffectsAsync_GivenUpdateTo2faPolicy_WhenAccountProvisioningIsDisabled_ThenRevokeUserCommandShouldNotBeCalled( + Organization organization, + [PolicyUpdate(PolicyType.TwoFactorAuthentication)] + PolicyUpdate policyUpdate, + [Policy(PolicyType.TwoFactorAuthentication, false)] + Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = organization.Id = policyUpdate.OrganizationId; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(false); + + var orgUserDetailUserAcceptedWithout2Fa = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Accepted, + Type = OrganizationUserType.User, + Email = "user3@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = true + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns(new List + { + orgUserDetailUserAcceptedWithout2Fa + }); + + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns(new List<(OrganizationUserUserDetails user, bool hasTwoFactor)>() + { + (orgUserDetailUserAcceptedWithout2Fa, false), + }); + + await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy); + + await sutProvider.GetDependency() + .DidNotReceive() + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()); + } + + [Theory, BitAutoData] + public async Task OnSaveSideEffectsAsync_GivenUpdateTo2faPolicy_WhenAccountProvisioningIsEnabledAndUserDoesNotHaveMasterPassword_ThenNonCompliantMembersErrorMessageWillReturn( + Organization organization, + [PolicyUpdate(PolicyType.TwoFactorAuthentication)] PolicyUpdate policyUpdate, + [Policy(PolicyType.TwoFactorAuthentication, false)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = organization.Id = policyUpdate.OrganizationId; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(true); + + var orgUserDetailUserWithout2Fa = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.User, + Email = "user3@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = false + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([orgUserDetailUserWithout2Fa]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns(new List<(OrganizationUserUserDetails user, bool hasTwoFactor)>() + { + (orgUserDetailUserWithout2Fa, false), + }); + + var exception = await Assert.ThrowsAsync(() => sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy)); + + Assert.Equal(TwoFactorAuthenticationPolicyValidator.NonCompliantMembersWillLoseAccessMessage, exception.Message); + } + + [Theory, BitAutoData] + public async Task OnSaveSideEffectsAsync_WhenAccountProvisioningIsEnabledAndUserHasMasterPassword_ThenUserWillBeRevoked( + Organization organization, + [PolicyUpdate(PolicyType.TwoFactorAuthentication)] PolicyUpdate policyUpdate, + [Policy(PolicyType.TwoFactorAuthentication, false)] Policy policy, + SutProvider sutProvider) + { + policy.OrganizationId = organization.Id = policyUpdate.OrganizationId; + sutProvider.GetDependency().GetByIdAsync(organization.Id).Returns(organization); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(true); + + var orgUserDetailUserWithout2Fa = new OrganizationUserUserDetails + { + Id = Guid.NewGuid(), + Status = OrganizationUserStatusType.Confirmed, + Type = OrganizationUserType.User, + Email = "user3@test.com", + Name = "TEST", + UserId = Guid.NewGuid(), + HasMasterPassword = true + }; + + sutProvider.GetDependency() + .GetManyDetailsByOrganizationAsync(policyUpdate.OrganizationId) + .Returns([orgUserDetailUserWithout2Fa]); + + sutProvider.GetDependency() + .TwoFactorIsEnabledAsync(Arg.Any>()) + .Returns(new List<(OrganizationUserUserDetails user, bool hasTwoFactor)>() + { + (orgUserDetailUserWithout2Fa, true), + }); + + sutProvider.GetDependency() + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()) + .Returns(new CommandResult()); + + await sutProvider.Sut.OnSaveSideEffectsAsync(policyUpdate, policy); + + await sutProvider.GetDependency() + .Received(1) + .RevokeNonCompliantOrganizationUsersAsync(Arg.Any()); + + await sutProvider.GetDependency() + .Received(1) + .SendOrganizationUserRevokedForPolicySingleOrgEmailAsync(organization.DisplayName(), + "user3@test.com"); + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs index 342ede9c829d..3ca7004e70ba 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/Policies/SavePolicyCommandTests.cs @@ -100,7 +100,7 @@ public void Constructor_DuplicatePolicyValidators_Throws() } [Theory, BitAutoData] - public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest(PolicyUpdate policyUpdate) + public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest([PolicyUpdate(PolicyType.ActivateAutofill)] PolicyUpdate policyUpdate) { var sutProvider = SutProviderFactory(); sutProvider.GetDependency() @@ -115,7 +115,7 @@ public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest(PolicyUpda } [Theory, BitAutoData] - public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest(PolicyUpdate policyUpdate) + public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest([PolicyUpdate(PolicyType.ActivateAutofill)] PolicyUpdate policyUpdate) { var sutProvider = SutProviderFactory(); sutProvider.GetDependency() diff --git a/test/Core.Test/AdminConsole/Services/EventServiceTests.cs b/test/Core.Test/AdminConsole/Services/EventServiceTests.cs index 18f5371b4943..d064fce2ecd8 100644 --- a/test/Core.Test/AdminConsole/Services/EventServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/EventServiceTests.cs @@ -169,7 +169,6 @@ public async Task LogOrganizationUserEvent_WithEventSystemUser_LogsRequiredInfo( new EventMessage() { IpAddress = ipAddress, - DeviceType = DeviceType.Server, OrganizationId = orgUser.OrganizationId, UserId = orgUser.UserId, OrganizationUserId = orgUser.Id, diff --git a/util/Migrator/DbScripts/2024-11-26-00_OrgUserSetStatusBulk.sql b/util/Migrator/DbScripts/2024-11-26-00_OrgUserSetStatusBulk.sql new file mode 100644 index 000000000000..95151f90ded5 --- /dev/null +++ b/util/Migrator/DbScripts/2024-11-26-00_OrgUserSetStatusBulk.sql @@ -0,0 +1,28 @@ +CREATE PROCEDURE [dbo].[OrganizationUser_SetStatusForUsersById] + @OrganizationUserIds AS NVARCHAR(MAX), + @Status SMALLINT +AS +BEGIN + SET NOCOUNT ON + + -- Declare a table variable to hold the parsed JSON data + DECLARE @ParsedIds TABLE (Id UNIQUEIDENTIFIER); + + -- Parse the JSON input into the table variable + INSERT INTO @ParsedIds (Id) + SELECT value + FROM OPENJSON(@OrganizationUserIds); + + -- Check if the input table is empty + IF (SELECT COUNT(1) FROM @ParsedIds) < 1 + BEGIN + RETURN(-1); + END + + UPDATE + [dbo].[OrganizationUser] + SET [Status] = @Status + WHERE [Id] IN (SELECT Id from @ParsedIds) + + EXEC [dbo].[User_BumpAccountRevisionDateByOrganizationUserIds] @OrganizationUserIds +END From 674bd1e495b343e68174165cebb1f2458efcff0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Wed, 27 Nov 2024 12:26:42 +0000 Subject: [PATCH 11/28] [PM-13026] Refactor remove and bulkremove methods to throw error if user is managed by an organization (#5034) * Enhance RemoveOrganizationUserCommand to block removing managed users when account deprovisioning is enabled * Refactor RemoveUsersAsync method to return just the OrgUserId and update related logic. * Refactor RemoveOrganizationUserCommand to improve variable naming and remove unused logging method * Add support for event system user in RemoveUsersAsync method. Refactor unit tests. * Add xmldoc to IRemoveOrganizationUserCommand methods * Refactor RemoveOrganizationUserCommand to use TimeProvider for event date retrieval and update unit tests accordingly * Refactor RemoveOrganizationUserCommand to use constants for error messages * Refactor unit tests to separate feature flag tests * refactor: Update parameter names for clarity in RemoveOrganizationUserCommand * refactor: Rename validation and repository methods for user removal clarity --- .../OrganizationUsersController.cs | 2 +- .../IRemoveOrganizationUserCommand.cs | 49 +- .../RemoveOrganizationUserCommand.cs | 199 +++-- .../RemoveOrganizationUserCommandTests.cs | 685 +++++++++++++++--- 4 files changed, 757 insertions(+), 178 deletions(-) diff --git a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs index 3193962fa949..12d11fbc1885 100644 --- a/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs +++ b/src/Api/AdminConsole/Controllers/OrganizationUsersController.cs @@ -539,7 +539,7 @@ public async Task> BulkRemo var userId = _userService.GetProperUserId(User); var result = await _removeOrganizationUserCommand.RemoveUsersAsync(orgId, model.Ids, userId.Value); return new ListResponseModel(result.Select(r => - new OrganizationUserBulkResponseModel(r.Item1.Id, r.Item2))); + new OrganizationUserBulkResponseModel(r.OrganizationUserId, r.ErrorMessage))); } [RequireFeature(FeatureFlagKeys.AccountDeprovisioning)] diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRemoveOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRemoveOrganizationUserCommand.cs index 583645a89030..7c1cdf05f8ea 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRemoveOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/Interfaces/IRemoveOrganizationUserCommand.cs @@ -1,14 +1,53 @@ -using Bit.Core.Entities; -using Bit.Core.Enums; +using Bit.Core.Enums; namespace Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; public interface IRemoveOrganizationUserCommand { + /// + /// Removes a user from an organization. + /// + /// The ID of the organization. + /// The ID of the user to remove. + Task RemoveUserAsync(Guid organizationId, Guid userId); + + /// + /// Removes a user from an organization with a specified deleting user. + /// + /// The ID of the organization. + /// The ID of the organization user to remove. + /// The ID of the user performing the removal operation. Task RemoveUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId); + /// + /// Removes a user from an organization using a system user. + /// + /// The ID of the organization. + /// The ID of the organization user to remove. + /// The system user performing the removal operation. Task RemoveUserAsync(Guid organizationId, Guid organizationUserId, EventSystemUser eventSystemUser); - Task RemoveUserAsync(Guid organizationId, Guid userId); - Task>> RemoveUsersAsync(Guid organizationId, - IEnumerable organizationUserIds, Guid? deletingUserId); + + /// + /// Removes multiple users from an organization with a specified deleting user. + /// + /// The ID of the organization. + /// The collection of organization user IDs to remove. + /// The ID of the user performing the removal operation. + /// + /// A list of tuples containing the organization user id and the error message for each user that could not be removed, otherwise empty. + /// + Task> RemoveUsersAsync( + Guid organizationId, IEnumerable organizationUserIds, Guid? deletingUserId); + + /// + /// Removes multiple users from an organization using a system user. + /// + /// The ID of the organization. + /// The collection of organization user IDs to remove. + /// The system user performing the removal operation. + /// + /// A list of tuples containing the organization user id and the error message for each user that could not be removed, otherwise empty. + /// + Task> RemoveUsersAsync( + Guid organizationId, IEnumerable organizationUserIds, EventSystemUser eventSystemUser); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs index e6d56ea878a2..fa027f8e470a 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommand.cs @@ -17,6 +17,16 @@ public class RemoveOrganizationUserCommand : IRemoveOrganizationUserCommand private readonly IPushRegistrationService _pushRegistrationService; private readonly ICurrentContext _currentContext; private readonly IHasConfirmedOwnersExceptQuery _hasConfirmedOwnersExceptQuery; + private readonly IGetOrganizationUsersManagementStatusQuery _getOrganizationUsersManagementStatusQuery; + private readonly IFeatureService _featureService; + private readonly TimeProvider _timeProvider; + + public const string UserNotFoundErrorMessage = "User not found."; + public const string UsersInvalidErrorMessage = "Users invalid."; + public const string RemoveYourselfErrorMessage = "You cannot remove yourself."; + public const string RemoveOwnerByNonOwnerErrorMessage = "Only owners can delete other owners."; + public const string RemoveLastConfirmedOwnerErrorMessage = "Organization must have at least one confirmed owner."; + public const string RemoveClaimedAccountErrorMessage = "Cannot remove member accounts claimed by the organization. To offboard a member, revoke or delete the account."; public RemoveOrganizationUserCommand( IDeviceRepository deviceRepository, @@ -25,7 +35,10 @@ public RemoveOrganizationUserCommand( IPushNotificationService pushNotificationService, IPushRegistrationService pushRegistrationService, ICurrentContext currentContext, - IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery) + IHasConfirmedOwnersExceptQuery hasConfirmedOwnersExceptQuery, + IGetOrganizationUsersManagementStatusQuery getOrganizationUsersManagementStatusQuery, + IFeatureService featureService, + TimeProvider timeProvider) { _deviceRepository = deviceRepository; _organizationUserRepository = organizationUserRepository; @@ -34,123 +47,107 @@ public RemoveOrganizationUserCommand( _pushRegistrationService = pushRegistrationService; _currentContext = currentContext; _hasConfirmedOwnersExceptQuery = hasConfirmedOwnersExceptQuery; + _getOrganizationUsersManagementStatusQuery = getOrganizationUsersManagementStatusQuery; + _featureService = featureService; + _timeProvider = timeProvider; } - public async Task RemoveUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId) + public async Task RemoveUserAsync(Guid organizationId, Guid userId) { - var organizationUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - ValidateDeleteUser(organizationId, organizationUser); + var organizationUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId, userId); + ValidateRemoveUser(organizationId, organizationUser); - await RepositoryDeleteUserAsync(organizationUser, deletingUserId); + await RepositoryRemoveUserAsync(organizationUser, deletingUserId: null, eventSystemUser: null); await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed); } - public async Task RemoveUserAsync(Guid organizationId, Guid organizationUserId, EventSystemUser eventSystemUser) + public async Task RemoveUserAsync(Guid organizationId, Guid organizationUserId, Guid? deletingUserId) { var organizationUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); - ValidateDeleteUser(organizationId, organizationUser); + ValidateRemoveUser(organizationId, organizationUser); - await RepositoryDeleteUserAsync(organizationUser, null); + await RepositoryRemoveUserAsync(organizationUser, deletingUserId, eventSystemUser: null); - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed, eventSystemUser); + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed); } - public async Task RemoveUserAsync(Guid organizationId, Guid userId) + public async Task RemoveUserAsync(Guid organizationId, Guid organizationUserId, EventSystemUser eventSystemUser) { - var organizationUser = await _organizationUserRepository.GetByOrganizationAsync(organizationId, userId); - ValidateDeleteUser(organizationId, organizationUser); + var organizationUser = await _organizationUserRepository.GetByIdAsync(organizationUserId); + ValidateRemoveUser(organizationId, organizationUser); - await RepositoryDeleteUserAsync(organizationUser, null); + await RepositoryRemoveUserAsync(organizationUser, deletingUserId: null, eventSystemUser); - await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed); + await _eventService.LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed, eventSystemUser); } - public async Task>> RemoveUsersAsync(Guid organizationId, - IEnumerable organizationUsersId, - Guid? deletingUserId) + public async Task> RemoveUsersAsync( + Guid organizationId, IEnumerable organizationUserIds, Guid? deletingUserId) { - var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); - var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId) - .ToList(); + var result = await RemoveUsersInternalAsync(organizationId, organizationUserIds, deletingUserId, eventSystemUser: null); - if (!filteredUsers.Any()) + var removedUsers = result.Where(r => string.IsNullOrEmpty(r.ErrorMessage)).Select(r => r.OrganizationUser).ToList(); + if (removedUsers.Any()) { - throw new BadRequestException("Users invalid."); + DateTime? eventDate = _timeProvider.GetUtcNow().UtcDateTime; + await _eventService.LogOrganizationUserEventsAsync( + removedUsers.Select(ou => (ou, EventType.OrganizationUser_Removed, eventDate))); } - if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync(organizationId, organizationUsersId)) - { - throw new BadRequestException("Organization must have at least one confirmed owner."); - } + return result.Select(r => (r.OrganizationUser.Id, r.ErrorMessage)); + } - var deletingUserIsOwner = false; - if (deletingUserId.HasValue) - { - deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); - } + public async Task> RemoveUsersAsync( + Guid organizationId, IEnumerable organizationUserIds, EventSystemUser eventSystemUser) + { + var result = await RemoveUsersInternalAsync(organizationId, organizationUserIds, deletingUserId: null, eventSystemUser); - var result = new List>(); - var deletedUserIds = new List(); - foreach (var orgUser in filteredUsers) + var removedUsers = result.Where(r => string.IsNullOrEmpty(r.ErrorMessage)).Select(r => r.OrganizationUser).ToList(); + if (removedUsers.Any()) { - try - { - if (deletingUserId.HasValue && orgUser.UserId == deletingUserId) - { - throw new BadRequestException("You cannot remove yourself."); - } - - if (orgUser.Type == OrganizationUserType.Owner && deletingUserId.HasValue && !deletingUserIsOwner) - { - throw new BadRequestException("Only owners can delete other owners."); - } - - await _eventService.LogOrganizationUserEventAsync(orgUser, EventType.OrganizationUser_Removed); - - if (orgUser.UserId.HasValue) - { - await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); - } - result.Add(Tuple.Create(orgUser, "")); - deletedUserIds.Add(orgUser.Id); - } - catch (BadRequestException e) - { - result.Add(Tuple.Create(orgUser, e.Message)); - } - - await _organizationUserRepository.DeleteManyAsync(deletedUserIds); + DateTime? eventDate = _timeProvider.GetUtcNow().UtcDateTime; + await _eventService.LogOrganizationUserEventsAsync( + removedUsers.Select(ou => (ou, EventType.OrganizationUser_Removed, eventSystemUser, eventDate))); } - return result; + return result.Select(r => (r.OrganizationUser.Id, r.ErrorMessage)); } - private void ValidateDeleteUser(Guid organizationId, OrganizationUser orgUser) + private void ValidateRemoveUser(Guid organizationId, OrganizationUser orgUser) { if (orgUser == null || orgUser.OrganizationId != organizationId) { - throw new NotFoundException("User not found."); + throw new NotFoundException(UserNotFoundErrorMessage); } } - private async Task RepositoryDeleteUserAsync(OrganizationUser orgUser, Guid? deletingUserId) + private async Task RepositoryRemoveUserAsync(OrganizationUser orgUser, Guid? deletingUserId, EventSystemUser? eventSystemUser) { if (deletingUserId.HasValue && orgUser.UserId == deletingUserId.Value) { - throw new BadRequestException("You cannot remove yourself."); + throw new BadRequestException(RemoveYourselfErrorMessage); } if (orgUser.Type == OrganizationUserType.Owner) { if (deletingUserId.HasValue && !await _currentContext.OrganizationOwner(orgUser.OrganizationId)) { - throw new BadRequestException("Only owners can delete other owners."); + throw new BadRequestException(RemoveOwnerByNonOwnerErrorMessage); } if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync(orgUser.OrganizationId, new[] { orgUser.Id }, includeProvider: true)) { - throw new BadRequestException("Organization must have at least one confirmed owner."); + throw new BadRequestException(RemoveLastConfirmedOwnerErrorMessage); + } + } + + if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning) && deletingUserId.HasValue && eventSystemUser == null) + { + var managementStatus = await _getOrganizationUsersManagementStatusQuery.GetUsersOrganizationManagementStatusAsync(orgUser.OrganizationId, new[] { orgUser.Id }); + if (managementStatus.TryGetValue(orgUser.Id, out var isManaged) && isManaged) + { + throw new BadRequestException(RemoveClaimedAccountErrorMessage); } } @@ -177,4 +174,70 @@ await _pushRegistrationService.DeleteUserRegistrationOrganizationAsync(devices, organizationId.ToString()); await _pushNotificationService.PushSyncOrgKeysAsync(userId); } + + private async Task> RemoveUsersInternalAsync( + Guid organizationId, IEnumerable organizationUsersId, Guid? deletingUserId, EventSystemUser? eventSystemUser) + { + var orgUsers = await _organizationUserRepository.GetManyAsync(organizationUsersId); + var filteredUsers = orgUsers.Where(u => u.OrganizationId == organizationId).ToList(); + + if (!filteredUsers.Any()) + { + throw new BadRequestException(UsersInvalidErrorMessage); + } + + if (!await _hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync(organizationId, organizationUsersId)) + { + throw new BadRequestException(RemoveLastConfirmedOwnerErrorMessage); + } + + var deletingUserIsOwner = false; + if (deletingUserId.HasValue) + { + deletingUserIsOwner = await _currentContext.OrganizationOwner(organizationId); + } + + var managementStatus = _featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning) && deletingUserId.HasValue && eventSystemUser == null + ? await _getOrganizationUsersManagementStatusQuery.GetUsersOrganizationManagementStatusAsync(organizationId, filteredUsers.Select(u => u.Id)) + : filteredUsers.ToDictionary(u => u.Id, u => false); + var result = new List<(OrganizationUser OrganizationUser, string ErrorMessage)>(); + foreach (var orgUser in filteredUsers) + { + try + { + if (deletingUserId.HasValue && orgUser.UserId == deletingUserId) + { + throw new BadRequestException(RemoveYourselfErrorMessage); + } + + if (orgUser.Type == OrganizationUserType.Owner && deletingUserId.HasValue && !deletingUserIsOwner) + { + throw new BadRequestException(RemoveOwnerByNonOwnerErrorMessage); + } + + if (managementStatus.TryGetValue(orgUser.Id, out var isManaged) && isManaged) + { + throw new BadRequestException(RemoveClaimedAccountErrorMessage); + } + + result.Add((orgUser, string.Empty)); + } + catch (BadRequestException e) + { + result.Add((orgUser, e.Message)); + } + } + + var organizationUsersToRemove = result.Where(r => string.IsNullOrEmpty(r.ErrorMessage)).Select(r => r.OrganizationUser).ToList(); + if (organizationUsersToRemove.Any()) + { + await _organizationUserRepository.DeleteManyAsync(organizationUsersToRemove.Select(ou => ou.Id)); + foreach (var orgUser in organizationUsersToRemove.Where(ou => ou.UserId.HasValue)) + { + await DeleteAndPushUserRegistrationAsync(organizationId, orgUser.UserId.Value); + } + } + + return result; + } } diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommandTests.cs index 2d10ce626b21..61371b756e8c 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationUsers/RemoveOrganizationUserCommandTests.cs @@ -9,6 +9,7 @@ using Bit.Core.Test.AutoFixture.OrganizationUserFixtures; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; +using Microsoft.Extensions.Time.Testing; using NSubstitute; using Xunit; @@ -18,38 +19,93 @@ namespace Bit.Core.Test.AdminConsole.OrganizationFeatures.OrganizationUsers; public class RemoveOrganizationUserCommandTests { [Theory, BitAutoData] - public async Task RemoveUser_Success( + public async Task RemoveUser_WithDeletingUserId_Success( [OrganizationUser(type: OrganizationUserType.User)] OrganizationUser organizationUser, [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser deletingUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); + // Arrange + organizationUser.OrganizationId = deletingUser.OrganizationId; + + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + sutProvider.GetDependency() + .GetByIdAsync(deletingUser.Id) + .Returns(deletingUser); + sutProvider.GetDependency() + .OrganizationOwner(deletingUser.OrganizationId) + .Returns(true); + + // Act + await sutProvider.Sut.RemoveUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId); + // Assert + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetUsersOrganizationManagementStatusAsync(default, default); + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(organizationUser); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed); + } + + [Theory, BitAutoData] + public async Task RemoveUser_WithDeletingUserId_WithAccountDeprovisioningEnabled_Success( + [OrganizationUser(type: OrganizationUserType.User)] OrganizationUser organizationUser, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser deletingUser, + SutProvider sutProvider) + { + // Arrange organizationUser.OrganizationId = deletingUser.OrganizationId; - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); - currentContext.OrganizationOwner(deletingUser.OrganizationId).Returns(true); + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(true); + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + sutProvider.GetDependency() + .GetByIdAsync(deletingUser.Id) + .Returns(deletingUser); + sutProvider.GetDependency() + .OrganizationOwner(deletingUser.OrganizationId) + .Returns(true); + + // Act await sutProvider.Sut.RemoveUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId); - await organizationUserRepository.Received(1).DeleteAsync(organizationUser); - await sutProvider.GetDependency().Received(1).LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed); + // Assert + await sutProvider.GetDependency() + .Received(1) + .GetUsersOrganizationManagementStatusAsync( + organizationUser.OrganizationId, + Arg.Is>(i => i.Contains(organizationUser.Id))); + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(organizationUser); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed); } - [Theory] - [BitAutoData] - public async Task RemoveUser_NotFound_ThrowsException(SutProvider sutProvider, + [Theory, BitAutoData] + public async Task RemoveUser_WithDeletingUserId_NotFound_ThrowsException( + SutProvider sutProvider, Guid organizationId, Guid organizationUserId) { - await Assert.ThrowsAsync(async () => await sutProvider.Sut.RemoveUserAsync(organizationId, organizationUserId, null)); + // Act & Assert + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.RemoveUserAsync(organizationId, organizationUserId, null)); } - [Theory] - [BitAutoData] - public async Task RemoveUser_MismatchingOrganizationId_ThrowsException( + [Theory, BitAutoData] + public async Task RemoveUser_WithDeletingUserId_MismatchingOrganizationId_ThrowsException( SutProvider sutProvider, Guid organizationId, Guid organizationUserId) { + // Arrange sutProvider.GetDependency() .GetByIdAsync(organizationUserId) .Returns(new OrganizationUser @@ -58,92 +114,231 @@ public async Task RemoveUser_MismatchingOrganizationId_ThrowsException( OrganizationId = Guid.NewGuid() }); - await Assert.ThrowsAsync(async () => await sutProvider.Sut.RemoveUserAsync(organizationId, organizationUserId, null)); + // Act & Assert + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.RemoveUserAsync(organizationId, organizationUserId, null)); } [Theory, BitAutoData] - public async Task RemoveUser_InvalidUser_ThrowsException( - OrganizationUser organizationUser, OrganizationUser deletingUser, - SutProvider sutProvider) + public async Task RemoveUser_WithDeletingUserId_InvalidUser_ThrowsException( + OrganizationUser organizationUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); - - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + // Act & Assert var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RemoveUserAsync(Guid.NewGuid(), organizationUser.Id, deletingUser.UserId)); - Assert.Contains("User not found.", exception.Message); + () => sutProvider.Sut.RemoveUserAsync(Guid.NewGuid(), organizationUser.Id, null)); + Assert.Contains(RemoveOrganizationUserCommand.UserNotFoundErrorMessage, exception.Message); } [Theory, BitAutoData] - public async Task RemoveUser_RemoveYourself_ThrowsException(OrganizationUser deletingUser, SutProvider sutProvider) + public async Task RemoveUser_WithDeletingUserId_RemoveYourself_ThrowsException( + OrganizationUser deletingUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); - - organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(deletingUser.Id) + .Returns(deletingUser); + // Act & Assert var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RemoveUserAsync(deletingUser.OrganizationId, deletingUser.Id, deletingUser.UserId)); - Assert.Contains("You cannot remove yourself.", exception.Message); + Assert.Contains(RemoveOrganizationUserCommand.RemoveYourselfErrorMessage, exception.Message); } [Theory, BitAutoData] - public async Task RemoveUser_NonOwnerRemoveOwner_ThrowsException( + public async Task RemoveUser_WithDeletingUserId_NonOwnerRemoveOwner_ThrowsException( [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); - + // Arrange organizationUser.OrganizationId = deletingUser.OrganizationId; - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - currentContext.OrganizationAdmin(deletingUser.OrganizationId).Returns(true); + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + sutProvider.GetDependency() + .OrganizationAdmin(organizationUser.OrganizationId) + .Returns(true); + + // Act & Assert var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RemoveUserAsync(deletingUser.OrganizationId, organizationUser.Id, deletingUser.UserId)); - Assert.Contains("Only owners can delete other owners.", exception.Message); + () => sutProvider.Sut.RemoveUserAsync(organizationUser.OrganizationId, organizationUser.Id, deletingUser.UserId)); + Assert.Contains(RemoveOrganizationUserCommand.RemoveOwnerByNonOwnerErrorMessage, exception.Message); } [Theory, BitAutoData] - public async Task RemoveUser_RemovingLastOwner_ThrowsException( + public async Task RemoveUser_WithDeletingUserId_RemovingLastOwner_ThrowsException( [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, OrganizationUser deletingUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); - var hasConfirmedOwnersExceptQuery = sutProvider.GetDependency(); - + // Arrange organizationUser.OrganizationId = deletingUser.OrganizationId; - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); - hasConfirmedOwnersExceptQuery.HasConfirmedOwnersExceptAsync( - deletingUser.OrganizationId, - Arg.Is>(i => i.Contains(organizationUser.Id)), Arg.Any()) + + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync( + organizationUser.OrganizationId, + Arg.Is>(i => i.Contains(organizationUser.Id)), + Arg.Any()) .Returns(false); + sutProvider.GetDependency() + .OrganizationOwner(deletingUser.OrganizationId) + .Returns(true); + // Act & Assert var exception = await Assert.ThrowsAsync( - () => sutProvider.Sut.RemoveUserAsync(deletingUser.OrganizationId, organizationUser.Id, null)); - Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); - hasConfirmedOwnersExceptQuery + () => sutProvider.Sut.RemoveUserAsync(organizationUser.OrganizationId, organizationUser.Id, deletingUser.UserId)); + Assert.Contains(RemoveOrganizationUserCommand.RemoveLastConfirmedOwnerErrorMessage, exception.Message); + await sutProvider.GetDependency() .Received(1) .HasConfirmedOwnersExceptAsync( organizationUser.OrganizationId, Arg.Is>(i => i.Contains(organizationUser.Id)), true); } + [Theory, BitAutoData] + public async Task RemoveUserAsync_WithDeletingUserId_WithAccountDeprovisioningEnabled_WhenUserIsManaged_ThrowsException( + [OrganizationUser(status: OrganizationUserStatusType.Confirmed)] OrganizationUser orgUser, + Guid deletingUserId, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(true); + sutProvider.GetDependency() + .GetByIdAsync(orgUser.Id) + .Returns(orgUser); + sutProvider.GetDependency() + .GetUsersOrganizationManagementStatusAsync(orgUser.OrganizationId, Arg.Is>(i => i.Contains(orgUser.Id))) + .Returns(new Dictionary { { orgUser.Id, true } }); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveUserAsync(orgUser.OrganizationId, orgUser.Id, deletingUserId)); + Assert.Contains(RemoveOrganizationUserCommand.RemoveClaimedAccountErrorMessage, exception.Message); + await sutProvider.GetDependency() + .Received(1) + .GetUsersOrganizationManagementStatusAsync(orgUser.OrganizationId, Arg.Is>(i => i.Contains(orgUser.Id))); + } + [Theory, BitAutoData] public async Task RemoveUser_WithEventSystemUser_Success( [OrganizationUser(type: OrganizationUserType.User)] OrganizationUser organizationUser, - EventSystemUser eventSystemUser, - SutProvider sutProvider) + EventSystemUser eventSystemUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + + // Act + await sutProvider.Sut.RemoveUserAsync(organizationUser.OrganizationId, organizationUser.Id, eventSystemUser); - organizationUserRepository.GetByIdAsync(organizationUser.Id).Returns(organizationUser); + // Assert + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetUsersOrganizationManagementStatusAsync(default, default); + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(organizationUser); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed, eventSystemUser); + } + + [Theory, BitAutoData] + public async Task RemoveUser_WithEventSystemUser_WithAccountDeprovisioningEnabled_Success( + [OrganizationUser(type: OrganizationUserType.User)] OrganizationUser organizationUser, + EventSystemUser eventSystemUser, SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(true); + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + // Act await sutProvider.Sut.RemoveUserAsync(organizationUser.OrganizationId, organizationUser.Id, eventSystemUser); - await sutProvider.GetDependency().Received(1).LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed, eventSystemUser); + // Assert + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetUsersOrganizationManagementStatusAsync(default, default); + await sutProvider.GetDependency() + .Received(1) + .DeleteAsync(organizationUser); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventAsync(organizationUser, EventType.OrganizationUser_Removed, eventSystemUser); + } + + [Theory] + [BitAutoData] + public async Task RemoveUser_WithEventSystemUser_NotFound_ThrowsException( + SutProvider sutProvider, + Guid organizationId, Guid organizationUserId, EventSystemUser eventSystemUser) + { + // Act & Assert + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.RemoveUserAsync(organizationId, organizationUserId, eventSystemUser)); + } + + [Theory] + [BitAutoData] + public async Task RemoveUser_WithEventSystemUser_MismatchingOrganizationId_ThrowsException( + SutProvider sutProvider, Guid organizationId, Guid organizationUserId, EventSystemUser eventSystemUser) + { + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(organizationUserId) + .Returns(new OrganizationUser + { + Id = organizationUserId, + OrganizationId = Guid.NewGuid() + }); + + // Act & Assert + await Assert.ThrowsAsync(async () => + await sutProvider.Sut.RemoveUserAsync(organizationId, organizationUserId, eventSystemUser)); + } + + [Theory, BitAutoData] + public async Task RemoveUser_WithEventSystemUser_RemovingLastOwner_ThrowsException( + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, + EventSystemUser eventSystemUser, + SutProvider sutProvider) + { + // Arrange + sutProvider.GetDependency() + .GetByIdAsync(organizationUser.Id) + .Returns(organizationUser); + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync( + organizationUser.OrganizationId, + Arg.Is>(i => i.Contains(organizationUser.Id)), + Arg.Any()) + .Returns(false); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveUserAsync(organizationUser.OrganizationId, organizationUser.Id, eventSystemUser)); + Assert.Contains(RemoveOrganizationUserCommand.RemoveLastConfirmedOwnerErrorMessage, exception.Message); + await sutProvider.GetDependency() + .Received(1) + .HasConfirmedOwnersExceptAsync( + organizationUser.OrganizationId, + Arg.Is>(i => i.Contains(organizationUser.Id)), true); } [Theory, BitAutoData] @@ -170,23 +365,26 @@ public async Task RemoveUser_ByUserId_Success( } [Theory, BitAutoData] - public async Task RemoveUser_ByUserId_NotFound_ThrowsException(SutProvider sutProvider, - Guid organizationId, Guid userId) + public async Task RemoveUser_ByUserId_NotFound_ThrowsException( + SutProvider sutProvider, Guid organizationId, Guid userId) { + // Act & Assert await Assert.ThrowsAsync(async () => await sutProvider.Sut.RemoveUserAsync(organizationId, userId)); } [Theory, BitAutoData] - public async Task RemoveUser_ByUserId_InvalidUser_ThrowsException(OrganizationUser organizationUser, - SutProvider sutProvider) + public async Task RemoveUser_ByUserId_InvalidUser_ThrowsException( + OrganizationUser organizationUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); - - organizationUserRepository.GetByOrganizationAsync(organizationUser.OrganizationId, organizationUser.UserId!.Value).Returns(organizationUser); + // Arrange + sutProvider.GetDependency() + .GetByOrganizationAsync(organizationUser.OrganizationId, organizationUser.UserId!.Value) + .Returns(organizationUser); + // Act & Assert var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RemoveUserAsync(Guid.NewGuid(), organizationUser.UserId.Value)); - Assert.Contains("User not found.", exception.Message); + Assert.Contains(RemoveOrganizationUserCommand.UserNotFoundErrorMessage, exception.Message); } [Theory, BitAutoData] @@ -194,21 +392,22 @@ public async Task RemoveUser_ByUserId_RemovingLastOwner_ThrowsException( [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser organizationUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); - var hasConfirmedOwnersExceptQuery = sutProvider.GetDependency(); - - organizationUserRepository.GetByOrganizationAsync(organizationUser.OrganizationId, organizationUser.UserId!.Value).Returns(organizationUser); - hasConfirmedOwnersExceptQuery + // Arrange + sutProvider.GetDependency() + .GetByOrganizationAsync(organizationUser.OrganizationId, organizationUser.UserId!.Value) + .Returns(organizationUser); + sutProvider.GetDependency() .HasConfirmedOwnersExceptAsync( organizationUser.OrganizationId, Arg.Is>(i => i.Contains(organizationUser.Id)), Arg.Any()) .Returns(false); + // Act & Assert var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RemoveUserAsync(organizationUser.OrganizationId, organizationUser.UserId.Value)); - Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); - hasConfirmedOwnersExceptQuery + Assert.Contains(RemoveOrganizationUserCommand.RemoveLastConfirmedOwnerErrorMessage, exception.Message); + await sutProvider.GetDependency() .Received(1) .HasConfirmedOwnersExceptAsync( organizationUser.OrganizationId, @@ -217,93 +416,371 @@ public async Task RemoveUser_ByUserId_RemovingLastOwner_ThrowsException( } [Theory, BitAutoData] - public async Task RemoveUsers_FilterInvalid_ThrowsException(OrganizationUser organizationUser, OrganizationUser deletingUser, - SutProvider sutProvider) + public async Task RemoveUsers_WithDeletingUserId_Success( + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, OrganizationUser orgUser2) { - var organizationUserRepository = sutProvider.GetDependency(); + // Arrange + var sutProvider = SutProviderFactory(); + var eventDate = sutProvider.GetDependency().GetUtcNow().UtcDateTime; + orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; + var organizationUsers = new[] { orgUser1, orgUser2 }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + + sutProvider.GetDependency() + .GetManyAsync(default) + .ReturnsForAnyArgs(organizationUsers); + sutProvider.GetDependency() + .GetByIdAsync(deletingUser.Id) + .Returns(deletingUser); + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(deletingUser.OrganizationId, Arg.Any>()) + .Returns(true); + sutProvider.GetDependency() + .OrganizationOwner(deletingUser.OrganizationId) + .Returns(true); + sutProvider.GetDependency() + .GetUsersOrganizationManagementStatusAsync( + deletingUser.OrganizationId, + Arg.Is>(i => i.Contains(orgUser1.Id) && i.Contains(orgUser2.Id))) + .Returns(new Dictionary { { orgUser1.Id, false }, { orgUser2.Id, false } }); + + // Act + var result = await sutProvider.Sut.RemoveUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); + + // Assert + Assert.Equal(2, result.Count()); + Assert.All(result, r => Assert.Empty(r.ErrorMessage)); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetUsersOrganizationManagementStatusAsync(default, default); + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Is>(i => i.Contains(orgUser1.Id) && i.Contains(orgUser2.Id))); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventsAsync( + Arg.Is>(i => + i.First().OrganizationUser.Id == orgUser1.Id + && i.Last().OrganizationUser.Id == orgUser2.Id + && i.All(u => u.DateTime == eventDate))); + } + + [Theory, BitAutoData] + public async Task RemoveUsers_WithDeletingUserId_WithAccountDeprovisioningEnabled_Success( + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, OrganizationUser orgUser2) + { + // Arrange + var sutProvider = SutProviderFactory(); + var eventDate = sutProvider.GetDependency().GetUtcNow().UtcDateTime; + orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; + var organizationUsers = new[] { orgUser1, orgUser2 }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(true); + sutProvider.GetDependency() + .GetManyAsync(default) + .ReturnsForAnyArgs(organizationUsers); + sutProvider.GetDependency() + .GetByIdAsync(deletingUser.Id) + .Returns(deletingUser); + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(deletingUser.OrganizationId, Arg.Any>()) + .Returns(true); + sutProvider.GetDependency() + .OrganizationOwner(deletingUser.OrganizationId) + .Returns(true); + sutProvider.GetDependency() + .GetUsersOrganizationManagementStatusAsync( + deletingUser.OrganizationId, + Arg.Is>(i => i.Contains(orgUser1.Id) && i.Contains(orgUser2.Id))) + .Returns(new Dictionary { { orgUser1.Id, false }, { orgUser2.Id, false } }); + + // Act + var result = await sutProvider.Sut.RemoveUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); + + // Assert + Assert.Equal(2, result.Count()); + Assert.All(result, r => Assert.Empty(r.ErrorMessage)); + await sutProvider.GetDependency() + .Received(1) + .GetUsersOrganizationManagementStatusAsync( + deletingUser.OrganizationId, + Arg.Is>(i => i.Contains(orgUser1.Id) && i.Contains(orgUser2.Id))); + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Is>(i => i.Contains(orgUser1.Id) && i.Contains(orgUser2.Id))); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventsAsync( + Arg.Is>(i => + i.First().OrganizationUser.Id == orgUser1.Id + && i.Last().OrganizationUser.Id == orgUser2.Id + && i.All(u => u.DateTime == eventDate))); + } + + [Theory, BitAutoData] + public async Task RemoveUsers_WithDeletingUserId_WithMismatchingOrganizationId_ThrowsException(OrganizationUser organizationUser, + OrganizationUser deletingUser, SutProvider sutProvider) + { + // Arrange var organizationUsers = new[] { organizationUser }; var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + sutProvider.GetDependency() + .GetManyAsync(default) + .ReturnsForAnyArgs(organizationUsers); + // Act & Assert var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RemoveUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId)); - Assert.Contains("Users invalid.", exception.Message); + Assert.Contains(RemoveOrganizationUserCommand.UsersInvalidErrorMessage, exception.Message); } [Theory, BitAutoData] - public async Task RemoveUsers_RemoveYourself_ThrowsException( - OrganizationUser deletingUser, - SutProvider sutProvider) + public async Task RemoveUsers_WithDeletingUserId_RemoveYourself_ThrowsException( + OrganizationUser deletingUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); + // Arrange var organizationUsers = new[] { deletingUser }; var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + sutProvider.GetDependency() + .GetManyAsync(default) + .ReturnsForAnyArgs(organizationUsers); sutProvider.GetDependency() .HasConfirmedOwnersExceptAsync(deletingUser.OrganizationId, Arg.Any>()) .Returns(true); + // Act var result = await sutProvider.Sut.RemoveUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); - Assert.Contains("You cannot remove yourself.", result[0].Item2); + + // Assert + Assert.Contains(RemoveOrganizationUserCommand.RemoveYourselfErrorMessage, result.First().ErrorMessage); } [Theory, BitAutoData] - public async Task RemoveUsers_NonOwnerRemoveOwner_ThrowsException( - [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, + public async Task RemoveUsers_WithDeletingUserId_NonOwnerRemoveOwner_ThrowsException( [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, [OrganizationUser(OrganizationUserStatusType.Confirmed)] OrganizationUser orgUser2, + [OrganizationUser(type: OrganizationUserType.Admin)] OrganizationUser deletingUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); - + // Arrange orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; - var organizationUsers = new[] { orgUser1 }; + var organizationUsers = new[] { orgUser1, orgUser2 }; var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); + + sutProvider.GetDependency() + .GetManyAsync(default) + .ReturnsForAnyArgs(organizationUsers); sutProvider.GetDependency() .HasConfirmedOwnersExceptAsync(deletingUser.OrganizationId, Arg.Any>()) .Returns(true); + // Act var result = await sutProvider.Sut.RemoveUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); - Assert.Contains("Only owners can delete other owners.", result[0].Item2); + + // Assert + Assert.Contains(RemoveOrganizationUserCommand.RemoveOwnerByNonOwnerErrorMessage, result.First().ErrorMessage); } [Theory, BitAutoData] - public async Task RemoveUsers_LastOwner_ThrowsException( - [OrganizationUser(status: OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, + public async Task RemoveUsers_WithDeletingUserId_RemovingManagedUser_WithAccountDeprovisioningEnabled_ThrowsException( + [OrganizationUser(status: OrganizationUserStatusType.Confirmed, OrganizationUserType.User)] OrganizationUser orgUser, + OrganizationUser deletingUser, SutProvider sutProvider) { - var organizationUserRepository = sutProvider.GetDependency(); + // Arrange + orgUser.OrganizationId = deletingUser.OrganizationId; + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(true); + + sutProvider.GetDependency() + .GetManyAsync(Arg.Is>(i => i.Contains(orgUser.Id))) + .Returns(new[] { orgUser }); + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(orgUser.OrganizationId, Arg.Any>()) + .Returns(true); + + sutProvider.GetDependency() + .GetUsersOrganizationManagementStatusAsync(orgUser.OrganizationId, Arg.Is>(i => i.Contains(orgUser.Id))) + .Returns(new Dictionary { { orgUser.Id, true } }); + + // Act + var result = await sutProvider.Sut.RemoveUsersAsync(orgUser.OrganizationId, new[] { orgUser.Id }, deletingUser.UserId); + + // Assert + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .DeleteManyAsync(default); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .LogOrganizationUserEventsAsync(Arg.Any>()); + Assert.Contains(RemoveOrganizationUserCommand.RemoveClaimedAccountErrorMessage, result.First().ErrorMessage); + } + + [Theory, BitAutoData] + public async Task RemoveUsers_WithDeletingUserId_LastOwner_ThrowsException( + [OrganizationUser(status: OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, + SutProvider sutProvider) + { + // Arrange var organizationUsers = new[] { orgUser }; var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - organizationUserRepository.GetManyByOrganizationAsync(orgUser.OrganizationId, OrganizationUserType.Owner).Returns(organizationUsers); + sutProvider.GetDependency() + .GetManyAsync(default) + .ReturnsForAnyArgs(organizationUsers); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(orgUser.OrganizationId, OrganizationUserType.Owner) + .Returns(organizationUsers); + + // Act & Assert var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RemoveUsersAsync(orgUser.OrganizationId, organizationUserIds, null)); - Assert.Contains("Organization must have at least one confirmed owner.", exception.Message); + Assert.Contains(RemoveOrganizationUserCommand.RemoveLastConfirmedOwnerErrorMessage, exception.Message); } [Theory, BitAutoData] - public async Task RemoveUsers_Success( - [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser deletingUser, - [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, OrganizationUser orgUser2, - SutProvider sutProvider) + public async Task RemoveUsers_WithEventSystemUser_Success( + EventSystemUser eventSystemUser, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, + OrganizationUser orgUser2) { - var organizationUserRepository = sutProvider.GetDependency(); - var currentContext = sutProvider.GetDependency(); + // Arrange + var sutProvider = SutProviderFactory(); + var eventDate = sutProvider.GetDependency().GetUtcNow().UtcDateTime; + orgUser1.OrganizationId = orgUser2.OrganizationId; + var organizationUsers = new[] { orgUser1, orgUser2 }; + var organizationUserIds = organizationUsers.Select(u => u.Id); - orgUser1.OrganizationId = orgUser2.OrganizationId = deletingUser.OrganizationId; + sutProvider.GetDependency() + .GetManyAsync(default) + .ReturnsForAnyArgs(organizationUsers); + sutProvider.GetDependency() + .HasConfirmedOwnersExceptAsync(orgUser1.OrganizationId, Arg.Any>()) + .Returns(true); + + // Act + var result = await sutProvider.Sut.RemoveUsersAsync(orgUser1.OrganizationId, organizationUserIds, eventSystemUser); + + // Assert + Assert.Equal(2, result.Count()); + Assert.All(result, r => Assert.Empty(r.ErrorMessage)); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetUsersOrganizationManagementStatusAsync(default, default); + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Is>(i => i.Contains(orgUser1.Id) && i.Contains(orgUser2.Id))); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventsAsync( + Arg.Is>( + i => i.First().OrganizationUser.Id == orgUser1.Id + && i.Last().OrganizationUser.Id == orgUser2.Id + && i.All(u => u.EventSystemUser == eventSystemUser + && u.DateTime == eventDate))); + } + + [Theory, BitAutoData] + public async Task RemoveUsers_WithEventSystemUser_WithAccountDeprovisioningEnabled_Success( + EventSystemUser eventSystemUser, + [OrganizationUser(type: OrganizationUserType.Owner)] OrganizationUser orgUser1, + OrganizationUser orgUser2) + { + // Arrange + var sutProvider = SutProviderFactory(); + var eventDate = sutProvider.GetDependency().GetUtcNow().UtcDateTime; + orgUser1.OrganizationId = orgUser2.OrganizationId; var organizationUsers = new[] { orgUser1, orgUser2 }; var organizationUserIds = organizationUsers.Select(u => u.Id); - organizationUserRepository.GetManyAsync(default).ReturnsForAnyArgs(organizationUsers); - organizationUserRepository.GetByIdAsync(deletingUser.Id).Returns(deletingUser); + + sutProvider.GetDependency() + .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) + .Returns(true); + sutProvider.GetDependency() + .GetManyAsync(default) + .ReturnsForAnyArgs(organizationUsers); sutProvider.GetDependency() - .HasConfirmedOwnersExceptAsync(deletingUser.OrganizationId, Arg.Any>()) + .HasConfirmedOwnersExceptAsync(orgUser1.OrganizationId, Arg.Any>()) .Returns(true); - currentContext.OrganizationOwner(deletingUser.OrganizationId).Returns(true); - await sutProvider.Sut.RemoveUsersAsync(deletingUser.OrganizationId, organizationUserIds, deletingUser.UserId); + // Act + var result = await sutProvider.Sut.RemoveUsersAsync(orgUser1.OrganizationId, organizationUserIds, eventSystemUser); + + // Assert + Assert.Equal(2, result.Count()); + Assert.All(result, r => Assert.Empty(r.ErrorMessage)); + await sutProvider.GetDependency() + .DidNotReceiveWithAnyArgs() + .GetUsersOrganizationManagementStatusAsync(default, default); + await sutProvider.GetDependency() + .Received(1) + .DeleteManyAsync(Arg.Is>(i => i.Contains(orgUser1.Id) && i.Contains(orgUser2.Id))); + await sutProvider.GetDependency() + .Received(1) + .LogOrganizationUserEventsAsync( + Arg.Is>( + i => i.First().OrganizationUser.Id == orgUser1.Id + && i.Last().OrganizationUser.Id == orgUser2.Id + && i.All(u => u.EventSystemUser == eventSystemUser + && u.DateTime == eventDate))); + } + + [Theory, BitAutoData] + public async Task RemoveUsers_WithEventSystemUser_WithMismatchingOrganizationId_ThrowsException( + EventSystemUser eventSystemUser, + [OrganizationUser(type: OrganizationUserType.User)] OrganizationUser organizationUser, + SutProvider sutProvider) + { + // Arrange + var organizationUsers = new[] { organizationUser }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + sutProvider.GetDependency() + .GetManyAsync(default) + .ReturnsForAnyArgs(organizationUsers); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveUsersAsync(Guid.NewGuid(), organizationUserIds, eventSystemUser)); + Assert.Contains(RemoveOrganizationUserCommand.UsersInvalidErrorMessage, exception.Message); + } + + [Theory, BitAutoData] + public async Task RemoveUsers_WithEventSystemUser_LastOwner_ThrowsException( + [OrganizationUser(status: OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser orgUser, + EventSystemUser eventSystemUser, SutProvider sutProvider) + { + // Arrange + var organizationUsers = new[] { orgUser }; + var organizationUserIds = organizationUsers.Select(u => u.Id); + + sutProvider.GetDependency() + .GetManyAsync(default) + .ReturnsForAnyArgs(organizationUsers); + sutProvider.GetDependency() + .GetManyByOrganizationAsync(orgUser.OrganizationId, OrganizationUserType.Owner) + .Returns(organizationUsers); + + // Act & Assert + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RemoveUsersAsync(orgUser.OrganizationId, organizationUserIds, eventSystemUser)); + Assert.Contains(RemoveOrganizationUserCommand.RemoveLastConfirmedOwnerErrorMessage, exception.Message); + } + + /// + /// Returns a new SutProvider with a FakeTimeProvider registered in the Sut. + /// + private static SutProvider SutProviderFactory() + { + return new SutProvider() + .WithFakeTimeProvider() + .Create(); } } From c8930d44f26a1003cae1f803215ea61653ab2347 Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Wed, 27 Nov 2024 06:47:18 -0600 Subject: [PATCH 12/28] Swapping [] for Array.Empty (#5092) --- src/Core/Models/Commands/CommandResult.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Core/Models/Commands/CommandResult.cs b/src/Core/Models/Commands/CommandResult.cs index 4ac2d6249918..9e5d91e09cbb 100644 --- a/src/Core/Models/Commands/CommandResult.cs +++ b/src/Core/Models/Commands/CommandResult.cs @@ -8,5 +8,5 @@ public CommandResult(string error) : this([error]) { } public bool HasErrors => ErrorMessages.Count > 0; public List ErrorMessages { get; } = errors.ToList(); - public CommandResult() : this([]) { } + public CommandResult() : this(Array.Empty()) { } } From aa364cacef9cd0e3bc3746497d0ed6593143589b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Wed, 27 Nov 2024 15:49:20 +0000 Subject: [PATCH 13/28] [PM-14876] Update admin panel copy from 'Domain Verified' to 'Claimed Account' and rename associated ViewModel properties (#5058) --- src/Admin/Models/UserEditModel.cs | 10 +++------- src/Admin/Models/UserViewModel.cs | 14 +++++++------- src/Admin/Views/Users/_ViewInformation.cshtml | 7 ++++--- test/Admin.Test/Models/UserViewModelTests.cs | 14 ++++---------- 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/src/Admin/Models/UserEditModel.cs b/src/Admin/Models/UserEditModel.cs index 2ad0b27cbdd3..ed2d653246fc 100644 --- a/src/Admin/Models/UserEditModel.cs +++ b/src/Admin/Models/UserEditModel.cs @@ -9,10 +9,7 @@ namespace Bit.Admin.Models; public class UserEditModel { - public UserEditModel() - { - - } + public UserEditModel() { } public UserEditModel( User user, @@ -21,10 +18,9 @@ public UserEditModel( BillingInfo billingInfo, BillingHistoryInfo billingHistoryInfo, GlobalSettings globalSettings, - bool? domainVerified - ) + bool? claimedAccount) { - User = UserViewModel.MapViewModel(user, isTwoFactorEnabled, ciphers, domainVerified); + User = UserViewModel.MapViewModel(user, isTwoFactorEnabled, ciphers, claimedAccount); BillingInfo = billingInfo; BillingHistoryInfo = billingHistoryInfo; diff --git a/src/Admin/Models/UserViewModel.cs b/src/Admin/Models/UserViewModel.cs index 75c089ee5f61..7fddbc0f547f 100644 --- a/src/Admin/Models/UserViewModel.cs +++ b/src/Admin/Models/UserViewModel.cs @@ -14,7 +14,7 @@ public class UserViewModel public bool Premium { get; } public short? MaxStorageGb { get; } public bool EmailVerified { get; } - public bool? DomainVerified { get; } + public bool? ClaimedAccount { get; } public bool TwoFactorEnabled { get; } public DateTime AccountRevisionDate { get; } public DateTime RevisionDate { get; } @@ -36,7 +36,7 @@ public UserViewModel(Guid id, bool premium, short? maxStorageGb, bool emailVerified, - bool? domainVerified, + bool? claimedAccount, bool twoFactorEnabled, DateTime accountRevisionDate, DateTime revisionDate, @@ -58,7 +58,7 @@ public UserViewModel(Guid id, Premium = premium; MaxStorageGb = maxStorageGb; EmailVerified = emailVerified; - DomainVerified = domainVerified; + ClaimedAccount = claimedAccount; TwoFactorEnabled = twoFactorEnabled; AccountRevisionDate = accountRevisionDate; RevisionDate = revisionDate; @@ -79,7 +79,7 @@ public static IEnumerable MapViewModels( users.Select(user => MapViewModel(user, lookup, false)); public static UserViewModel MapViewModel(User user, - IEnumerable<(Guid userId, bool twoFactorIsEnabled)> lookup, bool? domainVerified) => + IEnumerable<(Guid userId, bool twoFactorIsEnabled)> lookup, bool? claimedAccount) => new( user.Id, user.Name, @@ -89,7 +89,7 @@ public static UserViewModel MapViewModel(User user, user.Premium, user.MaxStorageGb, user.EmailVerified, - domainVerified, + claimedAccount, IsTwoFactorEnabled(user, lookup), user.AccountRevisionDate, user.RevisionDate, @@ -106,7 +106,7 @@ public static UserViewModel MapViewModel(User user, public static UserViewModel MapViewModel(User user, bool isTwoFactorEnabled) => MapViewModel(user, isTwoFactorEnabled, Array.Empty(), false); - public static UserViewModel MapViewModel(User user, bool isTwoFactorEnabled, IEnumerable ciphers, bool? domainVerified) => + public static UserViewModel MapViewModel(User user, bool isTwoFactorEnabled, IEnumerable ciphers, bool? claimedAccount) => new( user.Id, user.Name, @@ -116,7 +116,7 @@ public static UserViewModel MapViewModel(User user, bool isTwoFactorEnabled, IEn user.Premium, user.MaxStorageGb, user.EmailVerified, - domainVerified, + claimedAccount, isTwoFactorEnabled, user.AccountRevisionDate, user.RevisionDate, diff --git a/src/Admin/Views/Users/_ViewInformation.cshtml b/src/Admin/Views/Users/_ViewInformation.cshtml index 00afcc19dfee..ae8bcb073731 100644 --- a/src/Admin/Views/Users/_ViewInformation.cshtml +++ b/src/Admin/Views/Users/_ViewInformation.cshtml @@ -12,9 +12,10 @@
Email Verified
@(Model.EmailVerified ? "Yes" : "No")
- @if(Model.DomainVerified.HasValue){ -
Domain Verified
-
@(Model.DomainVerified.Value == true ? "Yes" : "No")
+ @if(Model.ClaimedAccount.HasValue) + { +
Claimed Account
+
@(Model.ClaimedAccount.Value ? "Yes" : "No")
}
Using 2FA
diff --git a/test/Admin.Test/Models/UserViewModelTests.cs b/test/Admin.Test/Models/UserViewModelTests.cs index fac5d5f0eb10..d015b983286e 100644 --- a/test/Admin.Test/Models/UserViewModelTests.cs +++ b/test/Admin.Test/Models/UserViewModelTests.cs @@ -1,6 +1,4 @@ -#nullable enable - -using Bit.Admin.Models; +using Bit.Admin.Models; using Bit.Core.Entities; using Bit.Core.Vault.Entities; using Bit.Test.Common.AutoFixture.Attributes; @@ -116,30 +114,26 @@ public void MapUserViewModel_WithVerifiedDomain_ReturnsUserViewModel(User user) var actual = UserViewModel.MapViewModel(user, true, Array.Empty(), verifiedDomain); - Assert.True(actual.DomainVerified); + Assert.True(actual.ClaimedAccount); } [Theory] [BitAutoData] public void MapUserViewModel_WithoutVerifiedDomain_ReturnsUserViewModel(User user) { - var verifiedDomain = false; var actual = UserViewModel.MapViewModel(user, true, Array.Empty(), verifiedDomain); - Assert.False(actual.DomainVerified); + Assert.False(actual.ClaimedAccount); } [Theory] [BitAutoData] public void MapUserViewModel_WithNullVerifiedDomain_ReturnsUserViewModel(User user) { - var actual = UserViewModel.MapViewModel(user, true, Array.Empty(), null); - Assert.Null(actual.DomainVerified); + Assert.Null(actual.ClaimedAccount); } - - } From e9297f85e9593abafb870f833dfcfad4c67e21ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rui=20Tom=C3=A9?= <108268980+r-tome@users.noreply.github.com> Date: Wed, 27 Nov 2024 15:55:05 +0000 Subject: [PATCH 14/28] [PM-12684] Remove deprecated feature flag for Members TwoFA query optimization (#5076) --- src/Core/Constants.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 8fe388653923..8077494794c3 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -129,7 +129,6 @@ public static class FeatureFlagKeys public const string UnauthenticatedExtensionUIRefresh = "unauth-ui-refresh"; public const string GenerateIdentityFillScriptRefactor = "generate-identity-fill-script-refactor"; public const string DelayFido2PageScriptInitWithinMv2 = "delay-fido2-page-script-init-within-mv2"; - public const string MembersTwoFAQueryOptimization = "ac-1698-members-two-fa-query-optimization"; public const string NativeCarouselFlow = "native-carousel-flow"; public const string NativeCreateAccountFlow = "native-create-account-flow"; public const string AccountDeprovisioning = "pm-10308-account-deprovisioning"; From c703390ba2944621f541e4bd737f27a3fe0136be Mon Sep 17 00:00:00 2001 From: Andreas Coroiu Date: Thu, 28 Nov 2024 09:49:09 +0100 Subject: [PATCH 15/28] feat: add credential sync feature flag (#5052) --- src/Core/Constants.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 8077494794c3..561f5f27c86c 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -155,6 +155,7 @@ public static class FeatureFlagKeys public const string SecurityTasks = "security-tasks"; public const string PM14401_ScaleMSPOnClientOrganizationUpdate = "PM-14401-scale-msp-on-client-organization-update"; public const string DisableFreeFamiliesSponsorship = "PM-12274-disable-free-families-sponsorship"; + public const string MacOsNativeCredentialSync = "macos-native-credential-sync"; public static List GetAllKeys() { From 193f8d661216042b2ce423c6d3e5e0a9a1a96eac Mon Sep 17 00:00:00 2001 From: Addison Beck Date: Mon, 2 Dec 2024 06:42:17 -0500 Subject: [PATCH 16/28] Update version to 2024.12.0 (#5099) --- Directory.Build.props | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Directory.Build.props b/Directory.Build.props index 4e252c82edeb..8639ac4a0d9a 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -3,7 +3,7 @@ net8.0 - 2024.11.0 + 2024.12.0 Bit.$(MSBuildProjectName) enable @@ -64,4 +64,4 @@ - \ No newline at end of file + From ac42b81f7c426db11a8a7fb4632a1fc295cf830c Mon Sep 17 00:00:00 2001 From: Jimmy Vo Date: Mon, 2 Dec 2024 10:19:21 -0500 Subject: [PATCH 17/28] [PM-14862] Update documentation response type. (#5083) Update documentation to align with the code's response type. --- .../AdminConsole/Public/Controllers/OrganizationController.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Api/AdminConsole/Public/Controllers/OrganizationController.cs b/src/Api/AdminConsole/Public/Controllers/OrganizationController.cs index 9d0b902f89b1..c1715f471c1a 100644 --- a/src/Api/AdminConsole/Public/Controllers/OrganizationController.cs +++ b/src/Api/AdminConsole/Public/Controllers/OrganizationController.cs @@ -1,6 +1,5 @@ using System.Net; using Bit.Api.AdminConsole.Public.Models.Request; -using Bit.Api.AdminConsole.Public.Models.Response; using Bit.Api.Models.Public.Response; using Bit.Core.Context; using Bit.Core.Enums; @@ -38,7 +37,7 @@ public OrganizationController( /// /// The request model. [HttpPost("import")] - [ProducesResponseType(typeof(MemberResponseModel), (int)HttpStatusCode.OK)] + [ProducesResponseType(typeof(OkResult), (int)HttpStatusCode.OK)] [ProducesResponseType(typeof(ErrorResponseModel), (int)HttpStatusCode.BadRequest)] public async Task Import([FromBody] OrganizationImportRequestModel model) { From 6a77a6d8eebace51a8057b3f643135541cfbf589 Mon Sep 17 00:00:00 2001 From: Brandon Treston Date: Tue, 3 Dec 2024 09:58:46 -0500 Subject: [PATCH 18/28] [PM-14552] Update error messages copy (#5059) * update error messages * fix tests --- .../Implementations/OrganizationService.cs | 32 ++++++-- .../Services/OrganizationServiceTests.cs | 78 ++++++++++++++++--- 2 files changed, 94 insertions(+), 16 deletions(-) diff --git a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs index 47c79aa13e8a..e8756fb325d2 100644 --- a/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs +++ b/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs @@ -2332,10 +2332,13 @@ private async Task CheckPoliciesBeforeRestoreAsync(OrganizationUser orgUser, boo PolicyType.SingleOrg, OrganizationUserStatusType.Revoked); var singleOrgPolicyApplies = singleOrgPoliciesApplyingToRevokedUsers.Any(p => p.OrganizationId == orgUser.OrganizationId); + var singleOrgCompliant = true; + var belongsToOtherOrgCompliant = true; + var twoFactorCompliant = true; + if (hasOtherOrgs && singleOrgPolicyApplies) { - throw new BadRequestException("You cannot restore this user until " + - "they leave or remove all other organizations."); + singleOrgCompliant = false; } // Enforce Single Organization Policy of other organizations user is a member of @@ -2343,8 +2346,7 @@ private async Task CheckPoliciesBeforeRestoreAsync(OrganizationUser orgUser, boo PolicyType.SingleOrg); if (anySingleOrgPolicies) { - throw new BadRequestException("You cannot restore this user because they are a member of " + - "another organization which forbids it"); + belongsToOtherOrgCompliant = false; } // Enforce Two Factor Authentication Policy of organization user is trying to join @@ -2354,10 +2356,28 @@ private async Task CheckPoliciesBeforeRestoreAsync(OrganizationUser orgUser, boo PolicyType.TwoFactorAuthentication, OrganizationUserStatusType.Invited); if (invitedTwoFactorPolicies.Any(p => p.OrganizationId == orgUser.OrganizationId)) { - throw new BadRequestException("You cannot restore this user until they enable " + - "two-step login on their user account."); + twoFactorCompliant = false; } } + + var user = await _userRepository.GetByIdAsync(userId); + + if (!singleOrgCompliant && !twoFactorCompliant) + { + throw new BadRequestException(user.Email + " is not compliant with the single organization and two-step login polciy"); + } + else if (!singleOrgCompliant) + { + throw new BadRequestException(user.Email + " is not compliant with the single organization policy"); + } + else if (!belongsToOtherOrgCompliant) + { + throw new BadRequestException(user.Email + " belongs to an organization that doesn't allow them to join multiple organizations"); + } + else if (!twoFactorCompliant) + { + throw new BadRequestException(user.Email + " is not compliant with the two-step login policy"); + } } static OrganizationUserStatusType GetPriorActiveOrganizationUserStatusType(OrganizationUser organizationUser) diff --git a/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs b/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs index e09293f32d66..f0b7084fe985 100644 --- a/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/OrganizationServiceTests.cs @@ -1833,11 +1833,14 @@ public async Task RestoreUser_WithOtherOrganizationSingleOrgPolicyEnabled_Fails( .AnyPoliciesApplicableToUserAsync(organizationUser.UserId.Value, PolicyType.SingleOrg, Arg.Any()) .Returns(true); + var user = new User(); + user.Email = "test@bitwarden.com"; + sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); + var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); - Assert.Contains("you cannot restore this user because they are a member of " + - "another organization which forbids it", exception.Message.ToLowerInvariant()); + Assert.Contains("test@bitwarden.com belongs to an organization that doesn't allow them to join multiple organizations", exception.Message.ToLowerInvariant()); await organizationUserRepository.DidNotReceiveWithAnyArgs().RestoreAsync(Arg.Any(), Arg.Any()); await eventService.DidNotReceiveWithAnyArgs() @@ -1865,11 +1868,14 @@ public async Task RestoreUser_With2FAPolicyEnabled_WithoutUser2FAConfigured_Fail .GetPoliciesApplicableToUserAsync(organizationUser.UserId.Value, PolicyType.TwoFactorAuthentication, Arg.Any()) .Returns(new[] { new OrganizationUserPolicyDetails { OrganizationId = organizationUser.OrganizationId, PolicyType = PolicyType.TwoFactorAuthentication } }); + var user = new User(); + user.Email = "test@bitwarden.com"; + sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); + var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); - Assert.Contains("you cannot restore this user until they enable " + - "two-step login on their user account.", exception.Message.ToLowerInvariant()); + Assert.Contains("test@bitwarden.com is not compliant with the two-step login policy", exception.Message.ToLowerInvariant()); await organizationUserRepository.DidNotReceiveWithAnyArgs().RestoreAsync(Arg.Any(), Arg.Any()); await eventService.DidNotReceiveWithAnyArgs() @@ -1924,11 +1930,14 @@ public async Task RestoreUser_WithSingleOrgPolicyEnabled_Fails( new OrganizationUserPolicyDetails { OrganizationId = organizationUser.OrganizationId, PolicyType = PolicyType.SingleOrg, OrganizationUserStatus = OrganizationUserStatusType.Revoked } }); + var user = new User(); + user.Email = "test@bitwarden.com"; + sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); + var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); - Assert.Contains("you cannot restore this user until " + - "they leave or remove all other organizations.", exception.Message.ToLowerInvariant()); + Assert.Contains("test@bitwarden.com is not compliant with the single organization policy", exception.Message.ToLowerInvariant()); await organizationUserRepository.DidNotReceiveWithAnyArgs().RestoreAsync(Arg.Any(), Arg.Any()); await eventService.DidNotReceiveWithAnyArgs() @@ -1958,11 +1967,57 @@ public async Task RestoreUser_vNext_WithOtherOrganizationSingleOrgPolicyEnabled_ .AnyPoliciesApplicableToUserAsync(organizationUser.UserId.Value, PolicyType.SingleOrg, Arg.Any()) .Returns(true); + var user = new User(); + user.Email = "test@bitwarden.com"; + sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); + + var exception = await Assert.ThrowsAsync( + () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); + + Assert.Contains("test@bitwarden.com belongs to an organization that doesn't allow them to join multiple organizations", exception.Message.ToLowerInvariant()); + + await organizationUserRepository.DidNotReceiveWithAnyArgs().RestoreAsync(Arg.Any(), Arg.Any()); + await eventService.DidNotReceiveWithAnyArgs() + .LogOrganizationUserEventAsync(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Theory, BitAutoData] + public async Task RestoreUser_WithSingleOrgPolicyEnabled_And_2FA_Policy_Fails( + Organization organization, + [OrganizationUser(OrganizationUserStatusType.Confirmed, OrganizationUserType.Owner)] OrganizationUser owner, + [OrganizationUser(OrganizationUserStatusType.Revoked)] OrganizationUser organizationUser, + [OrganizationUser(OrganizationUserStatusType.Accepted)] OrganizationUser secondOrganizationUser, + SutProvider sutProvider) + { + organizationUser.Email = null; // this is required to mock that the user as had already been confirmed before the revoke + secondOrganizationUser.UserId = organizationUser.UserId; + RestoreRevokeUser_Setup(organization, owner, organizationUser, sutProvider); + var organizationUserRepository = sutProvider.GetDependency(); + var eventService = sutProvider.GetDependency(); + + organizationUserRepository.GetManyByUserAsync(organizationUser.UserId.Value).Returns(new[] { organizationUser, secondOrganizationUser }); + sutProvider.GetDependency() + .GetPoliciesApplicableToUserAsync(organizationUser.UserId.Value, PolicyType.SingleOrg, Arg.Any()) + .Returns(new[] + { + new OrganizationUserPolicyDetails { OrganizationId = organizationUser.OrganizationId, PolicyType = PolicyType.SingleOrg, OrganizationUserStatus = OrganizationUserStatusType.Revoked } + }); + + sutProvider.GetDependency() + .GetPoliciesApplicableToUserAsync(organizationUser.UserId.Value, PolicyType.TwoFactorAuthentication, Arg.Any()) + .Returns(new[] + { + new OrganizationUserPolicyDetails { OrganizationId = organizationUser.OrganizationId, PolicyType = PolicyType.TwoFactorAuthentication, OrganizationUserStatus = OrganizationUserStatusType.Revoked } + }); + + var user = new User(); + user.Email = "test@bitwarden.com"; + sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); + var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); - Assert.Contains("you cannot restore this user because they are a member of " + - "another organization which forbids it", exception.Message.ToLowerInvariant()); + Assert.Contains("test@bitwarden.com is not compliant with the single organization and two-step login polciy", exception.Message.ToLowerInvariant()); await organizationUserRepository.DidNotReceiveWithAnyArgs().RestoreAsync(Arg.Any(), Arg.Any()); await eventService.DidNotReceiveWithAnyArgs() @@ -1986,11 +2041,14 @@ public async Task RestoreUser_vNext_With2FAPolicyEnabled_WithoutUser2FAConfigure .GetPoliciesApplicableToUserAsync(organizationUser.UserId.Value, PolicyType.TwoFactorAuthentication, Arg.Any()) .Returns(new[] { new OrganizationUserPolicyDetails { OrganizationId = organizationUser.OrganizationId, PolicyType = PolicyType.TwoFactorAuthentication } }); + var user = new User(); + user.Email = "test@bitwarden.com"; + sutProvider.GetDependency().GetByIdAsync(organizationUser.UserId.Value).Returns(user); + var exception = await Assert.ThrowsAsync( () => sutProvider.Sut.RestoreUserAsync(organizationUser, owner.Id)); - Assert.Contains("you cannot restore this user until they enable " + - "two-step login on their user account.", exception.Message.ToLowerInvariant()); + Assert.Contains("test@bitwarden.com is not compliant with the two-step login policy", exception.Message.ToLowerInvariant()); await organizationUserRepository.DidNotReceiveWithAnyArgs().RestoreAsync(Arg.Any(), Arg.Any()); await eventService.DidNotReceiveWithAnyArgs() From 059e6816f2b24854663e9d0e7304072019cb80ce Mon Sep 17 00:00:00 2001 From: Jared McCannon Date: Tue, 3 Dec 2024 11:01:45 -0600 Subject: [PATCH 19/28] Fixing migration script. (#5093) --- util/Migrator/DbScripts/2024-11-26-00_OrgUserSetStatusBulk.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/Migrator/DbScripts/2024-11-26-00_OrgUserSetStatusBulk.sql b/util/Migrator/DbScripts/2024-11-26-00_OrgUserSetStatusBulk.sql index 95151f90ded5..5c51f9da40c7 100644 --- a/util/Migrator/DbScripts/2024-11-26-00_OrgUserSetStatusBulk.sql +++ b/util/Migrator/DbScripts/2024-11-26-00_OrgUserSetStatusBulk.sql @@ -1,4 +1,4 @@ -CREATE PROCEDURE [dbo].[OrganizationUser_SetStatusForUsersById] +CREATE OR ALTER PROCEDURE[dbo].[OrganizationUser_SetStatusForUsersById] @OrganizationUserIds AS NVARCHAR(MAX), @Status SMALLINT AS From b580d7c02210e9d18094f6f713ada52ce387198a Mon Sep 17 00:00:00 2001 From: Conner Turnbull <133619638+cturnbull-bitwarden@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:06:38 -0500 Subject: [PATCH 20/28] Automatically forwarding ports 1080 and 1433 in VS Code/Cursor (#5104) --- .devcontainer/internal_dev/devcontainer.json | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.devcontainer/internal_dev/devcontainer.json b/.devcontainer/internal_dev/devcontainer.json index ee9ab7a96da5..78a79180ee77 100644 --- a/.devcontainer/internal_dev/devcontainer.json +++ b/.devcontainer/internal_dev/devcontainer.json @@ -21,10 +21,15 @@ } }, "postCreateCommand": "bash .devcontainer/internal_dev/postCreateCommand.sh", + "forwardPorts": [1080, 1433], "portsAttributes": { "1080": { "label": "Mail Catcher", "onAutoForward": "notify" + }, + "1433": { + "label": "SQL Server", + "onAutoForward": "notify" } } } From c9aa61b0cf7382630553782d0301c0f9987fffe5 Mon Sep 17 00:00:00 2001 From: Conner Turnbull <133619638+cturnbull-bitwarden@users.noreply.github.com> Date: Tue, 3 Dec 2024 15:38:34 -0500 Subject: [PATCH 21/28] Updated dev container to give the option of installing the Stripe CLI (#5105) --- .devcontainer/internal_dev/postCreateCommand.sh | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.devcontainer/internal_dev/postCreateCommand.sh b/.devcontainer/internal_dev/postCreateCommand.sh index b013be1cec27..668b776447c9 100755 --- a/.devcontainer/internal_dev/postCreateCommand.sh +++ b/.devcontainer/internal_dev/postCreateCommand.sh @@ -70,6 +70,22 @@ Press to continue." sleep 5 # wait for DB container to start dotnet run --project ./util/MsSqlMigratorUtility "$SQL_CONNECTION_STRING" fi + read -r -p "Would you like to install the Stripe CLI? [y/N] " stripe_response + if [[ "$stripe_response" =~ ^([yY][eE][sS]|[yY])+$ ]]; then + install_stripe_cli + fi +} + +# Install Stripe CLI +install_stripe_cli() { + echo "Installing Stripe CLI..." + # Add Stripe CLI GPG key so that apt can verify the packages authenticity. + # If Stripe ever changes the key, we'll need to update this. Visit https://docs.stripe.com/stripe-cli?install-method=apt if so + curl -s https://packages.stripe.dev/api/security/keypair/stripe-cli-gpg/public | gpg --dearmor | sudo tee /usr/share/keyrings/stripe.gpg >/dev/null + # Add Stripe CLI repository to apt sources + echo "deb [signed-by=/usr/share/keyrings/stripe.gpg] https://packages.stripe.dev/stripe-cli-debian-local stable main" | sudo tee -a /etc/apt/sources.list.d/stripe.list >/dev/null + sudo apt update + sudo apt install -y stripe } # main From 44b687922d7756e4771e0572a19af596b447a9df Mon Sep 17 00:00:00 2001 From: Thomas Rittson <31796059+eliykat@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:50:47 +1000 Subject: [PATCH 22/28] [PM-14245] Remove policy definitions feature flag (#5095) * Remove PolicyService.SaveAsync and use command instead * Delete feature flag definition * Add public api integration tests --- .../Controllers/PoliciesController.cs | 29 +- .../Models/Request/PolicyRequestModel.cs | 25 +- .../Public/Controllers/PoliciesController.cs | 20 +- .../Request/PolicyUpdateRequestModel.cs | 27 +- .../Models/Response/PolicyResponseModel.cs | 6 +- .../VerifyOrganizationDomainCommand.cs | 21 +- .../Policies/ISavePolicyCommand.cs | 5 +- .../Implementations/SavePolicyCommand.cs | 4 +- .../AdminConsole/Services/IPolicyService.cs | 5 +- .../Services/Implementations/PolicyService.cs | 276 +------ .../Implementations/SsoConfigService.cs | 51 +- src/Core/Constants.cs | 1 - .../Controllers/PoliciesControllerTests.cs | 163 ++++ .../Controllers/PoliciesControllerTests.cs | 33 - .../VerifyOrganizationDomainCommandTests.cs | 28 +- .../Services/PolicyServiceTests.cs | 701 ------------------ .../Auth/Services/SsoConfigServiceTests.cs | 25 +- 17 files changed, 292 insertions(+), 1128 deletions(-) create mode 100644 test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs delete mode 100644 test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs diff --git a/src/Api/AdminConsole/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Controllers/PoliciesController.cs index ee48cdd5d470..1167d7a86ccc 100644 --- a/src/Api/AdminConsole/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Controllers/PoliciesController.cs @@ -6,8 +6,8 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Models.Business.Tokenables; using Bit.Core.Context; using Bit.Core.Enums; @@ -28,7 +28,6 @@ namespace Bit.Api.AdminConsole.Controllers; public class PoliciesController : Controller { private readonly IPolicyRepository _policyRepository; - private readonly IPolicyService _policyService; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IUserService _userService; private readonly ICurrentContext _currentContext; @@ -37,10 +36,10 @@ public class PoliciesController : Controller private readonly IDataProtectorTokenFactory _orgUserInviteTokenDataFactory; private readonly IFeatureService _featureService; private readonly IOrganizationHasVerifiedDomainsQuery _organizationHasVerifiedDomainsQuery; + private readonly ISavePolicyCommand _savePolicyCommand; public PoliciesController( IPolicyRepository policyRepository, - IPolicyService policyService, IOrganizationUserRepository organizationUserRepository, IUserService userService, ICurrentContext currentContext, @@ -48,10 +47,10 @@ public PoliciesController( IDataProtectionProvider dataProtectionProvider, IDataProtectorTokenFactory orgUserInviteTokenDataFactory, IFeatureService featureService, - IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery) + IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, + ISavePolicyCommand savePolicyCommand) { _policyRepository = policyRepository; - _policyService = policyService; _organizationUserRepository = organizationUserRepository; _userService = userService; _currentContext = currentContext; @@ -62,6 +61,7 @@ public PoliciesController( _orgUserInviteTokenDataFactory = orgUserInviteTokenDataFactory; _featureService = featureService; _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; + _savePolicyCommand = savePolicyCommand; } [HttpGet("{type}")] @@ -178,25 +178,20 @@ public async Task GetMasterPasswordPolicy(Guid orgId) } [HttpPut("{type}")] - public async Task Put(string orgId, int type, [FromBody] PolicyRequestModel model) + public async Task Put(Guid orgId, PolicyType type, [FromBody] PolicyRequestModel model) { - var orgIdGuid = new Guid(orgId); - if (!await _currentContext.ManagePolicies(orgIdGuid)) + if (!await _currentContext.ManagePolicies(orgId)) { throw new NotFoundException(); } - var policy = await _policyRepository.GetByOrganizationIdTypeAsync(new Guid(orgId), (PolicyType)type); - if (policy == null) - { - policy = model.ToPolicy(orgIdGuid); - } - else + + if (type != model.Type) { - policy = model.ToPolicy(policy); + throw new BadRequestException("Mismatched policy type"); } - var userId = _userService.GetProperUserId(User); - await _policyService.SaveAsync(policy, userId); + var policyUpdate = await model.ToPolicyUpdateAsync(orgId, _currentContext); + var policy = await _savePolicyCommand.SaveAsync(policyUpdate); return new PolicyResponseModel(policy); } } diff --git a/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs b/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs index db191194d733..a243f46b2eb8 100644 --- a/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs +++ b/src/Api/AdminConsole/Models/Request/PolicyRequestModel.cs @@ -1,7 +1,9 @@ using System.ComponentModel.DataAnnotations; using System.Text.Json; -using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.Context; namespace Bit.Api.AdminConsole.Models.Request; @@ -13,19 +15,12 @@ public class PolicyRequestModel public bool? Enabled { get; set; } public Dictionary Data { get; set; } - public Policy ToPolicy(Guid orgId) + public async Task ToPolicyUpdateAsync(Guid organizationId, ICurrentContext currentContext) => new() { - return ToPolicy(new Policy - { - Type = Type.Value, - OrganizationId = orgId - }); - } - - public Policy ToPolicy(Policy existingPolicy) - { - existingPolicy.Enabled = Enabled.GetValueOrDefault(); - existingPolicy.Data = Data != null ? JsonSerializer.Serialize(Data) : null; - return existingPolicy; - } + Type = Type!.Value, + OrganizationId = organizationId, + Data = Data != null ? JsonSerializer.Serialize(Data) : null, + Enabled = Enabled.GetValueOrDefault(), + PerformedBy = new StandardUser(currentContext.UserId!.Value, await currentContext.OrganizationOwner(organizationId)) + }; } diff --git a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs index f2e7c35d2466..a22c05ed629f 100644 --- a/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs +++ b/src/Api/AdminConsole/Public/Controllers/PoliciesController.cs @@ -3,6 +3,7 @@ using Bit.Api.AdminConsole.Public.Models.Response; using Bit.Api.Models.Public.Response; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services; using Bit.Core.Context; @@ -18,15 +19,18 @@ public class PoliciesController : Controller private readonly IPolicyRepository _policyRepository; private readonly IPolicyService _policyService; private readonly ICurrentContext _currentContext; + private readonly ISavePolicyCommand _savePolicyCommand; public PoliciesController( IPolicyRepository policyRepository, IPolicyService policyService, - ICurrentContext currentContext) + ICurrentContext currentContext, + ISavePolicyCommand savePolicyCommand) { _policyRepository = policyRepository; _policyService = policyService; _currentContext = currentContext; + _savePolicyCommand = savePolicyCommand; } /// @@ -80,17 +84,9 @@ public async Task List() [ProducesResponseType((int)HttpStatusCode.NotFound)] public async Task Put(PolicyType type, [FromBody] PolicyUpdateRequestModel model) { - var policy = await _policyRepository.GetByOrganizationIdTypeAsync( - _currentContext.OrganizationId.Value, type); - if (policy == null) - { - policy = model.ToPolicy(_currentContext.OrganizationId.Value, type); - } - else - { - policy = model.ToPolicy(policy); - } - await _policyService.SaveAsync(policy, null); + var policyUpdate = model.ToPolicyUpdate(_currentContext.OrganizationId!.Value, type); + var policy = await _savePolicyCommand.SaveAsync(policyUpdate); + var response = new PolicyResponseModel(policy); return new JsonResult(response); } diff --git a/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs b/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs index f859686b8133..eb5669046243 100644 --- a/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs +++ b/src/Api/AdminConsole/Public/Models/Request/PolicyUpdateRequestModel.cs @@ -1,26 +1,19 @@ using System.Text.Json; -using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.Enums; namespace Bit.Api.AdminConsole.Public.Models.Request; public class PolicyUpdateRequestModel : PolicyBaseModel { - public Policy ToPolicy(Guid orgId, PolicyType type) + public PolicyUpdate ToPolicyUpdate(Guid organizationId, PolicyType type) => new() { - return ToPolicy(new Policy - { - OrganizationId = orgId, - Enabled = Enabled.GetValueOrDefault(), - Data = Data != null ? JsonSerializer.Serialize(Data) : null, - Type = type - }); - } - - public virtual Policy ToPolicy(Policy existingPolicy) - { - existingPolicy.Enabled = Enabled.GetValueOrDefault(); - existingPolicy.Data = Data != null ? JsonSerializer.Serialize(Data) : null; - return existingPolicy; - } + Type = type, + OrganizationId = organizationId, + Data = Data != null ? JsonSerializer.Serialize(Data) : null, + Enabled = Enabled.GetValueOrDefault(), + PerformedBy = new SystemUser(EventSystemUser.PublicApi) + }; } diff --git a/src/Api/AdminConsole/Public/Models/Response/PolicyResponseModel.cs b/src/Api/AdminConsole/Public/Models/Response/PolicyResponseModel.cs index 27da5cc56120..8da7d93cf106 100644 --- a/src/Api/AdminConsole/Public/Models/Response/PolicyResponseModel.cs +++ b/src/Api/AdminConsole/Public/Models/Response/PolicyResponseModel.cs @@ -1,8 +1,9 @@ using System.ComponentModel.DataAnnotations; -using System.Text.Json; using Bit.Api.Models.Public.Response; using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; +using Newtonsoft.Json; +using JsonSerializer = System.Text.Json.JsonSerializer; namespace Bit.Api.AdminConsole.Public.Models.Response; @@ -11,6 +12,9 @@ namespace Bit.Api.AdminConsole.Public.Models.Response; /// public class PolicyResponseModel : PolicyBaseModel, IResponseModel { + [JsonConstructor] + public PolicyResponseModel() { } + public PolicyResponseModel(Policy policy) { if (policy == null) diff --git a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs index dd6add669f2d..375c6326ef27 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommand.cs @@ -1,8 +1,8 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; -using Bit.Core.AdminConsole.Services; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -19,9 +19,9 @@ public class VerifyOrganizationDomainCommand( IDnsResolverService dnsResolverService, IEventService eventService, IGlobalSettings globalSettings, - IPolicyService policyService, IFeatureService featureService, ICurrentContext currentContext, + ISavePolicyCommand savePolicyCommand, ILogger logger) : IVerifyOrganizationDomainCommand { @@ -125,10 +125,15 @@ private async Task EnableSingleOrganizationPolicyAsync(Guid organizationId, IAct { if (featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning)) { - await policyService.SaveAsync( - new Policy { OrganizationId = organizationId, Type = PolicyType.SingleOrg, Enabled = true }, - savingUserId: actingUser is StandardUser standardUser ? standardUser.UserId : null, - eventSystemUser: actingUser is SystemUser systemUser ? systemUser.SystemUserType : null); + var policyUpdate = new PolicyUpdate + { + OrganizationId = organizationId, + Type = PolicyType.SingleOrg, + Enabled = true, + PerformedBy = actingUser + }; + + await savePolicyCommand.SaveAsync(policyUpdate); } } } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/ISavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/ISavePolicyCommand.cs index 5bfdfc6aa7d2..6ca842686ecb 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/ISavePolicyCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/ISavePolicyCommand.cs @@ -1,8 +1,9 @@ -using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; namespace Bit.Core.AdminConsole.OrganizationFeatures.Policies; public interface ISavePolicyCommand { - Task SaveAsync(PolicyUpdate policy); + Task SaveAsync(PolicyUpdate policy); } diff --git a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs index f193aeabd1f6..cf332e689a62 100644 --- a/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs +++ b/src/Core/AdminConsole/OrganizationFeatures/Policies/Implementations/SavePolicyCommand.cs @@ -42,7 +42,7 @@ public SavePolicyCommand( _policyValidators = policyValidatorsDict; } - public async Task SaveAsync(PolicyUpdate policyUpdate) + public async Task SaveAsync(PolicyUpdate policyUpdate) { var org = await _applicationCacheService.GetOrganizationAbilityAsync(policyUpdate.OrganizationId); if (org == null) @@ -74,6 +74,8 @@ public async Task SaveAsync(PolicyUpdate policyUpdate) await _policyRepository.UpsertAsync(policy); await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated); + + return policy; } private async Task RunValidatorAsync(IPolicyValidator validator, PolicyUpdate policyUpdate) diff --git a/src/Core/AdminConsole/Services/IPolicyService.cs b/src/Core/AdminConsole/Services/IPolicyService.cs index 715c3a34d934..4f9a25f90455 100644 --- a/src/Core/AdminConsole/Services/IPolicyService.cs +++ b/src/Core/AdminConsole/Services/IPolicyService.cs @@ -1,5 +1,4 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; using Bit.Core.Entities; using Bit.Core.Enums; @@ -9,8 +8,6 @@ namespace Bit.Core.AdminConsole.Services; public interface IPolicyService { - Task SaveAsync(Policy policy, Guid? savingUserId, EventSystemUser? eventSystemUser = null); - /// /// Get the combined master password policy options for the specified user. /// diff --git a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs index 69ec27fd8a2a..c3eb2272d084 100644 --- a/src/Core/AdminConsole/Services/Implementations/PolicyService.cs +++ b/src/Core/AdminConsole/Services/Implementations/PolicyService.cs @@ -1,19 +1,8 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Enums; -using Bit.Core.AdminConsole.Models.Data; +using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies; -using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Auth.Enums; -using Bit.Core.Auth.Repositories; -using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; -using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; -using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; using Bit.Core.Services; @@ -24,107 +13,20 @@ namespace Bit.Core.AdminConsole.Services.Implementations; public class PolicyService : IPolicyService { private readonly IApplicationCacheService _applicationCacheService; - private readonly IEventService _eventService; - private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IPolicyRepository _policyRepository; - private readonly ISsoConfigRepository _ssoConfigRepository; - private readonly IMailService _mailService; private readonly GlobalSettings _globalSettings; - private readonly ITwoFactorIsEnabledQuery _twoFactorIsEnabledQuery; - private readonly IFeatureService _featureService; - private readonly ISavePolicyCommand _savePolicyCommand; - private readonly IRemoveOrganizationUserCommand _removeOrganizationUserCommand; - private readonly IOrganizationHasVerifiedDomainsQuery _organizationHasVerifiedDomainsQuery; - private readonly ICurrentContext _currentContext; public PolicyService( IApplicationCacheService applicationCacheService, - IEventService eventService, - IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, IPolicyRepository policyRepository, - ISsoConfigRepository ssoConfigRepository, - IMailService mailService, - GlobalSettings globalSettings, - ITwoFactorIsEnabledQuery twoFactorIsEnabledQuery, - IFeatureService featureService, - ISavePolicyCommand savePolicyCommand, - IRemoveOrganizationUserCommand removeOrganizationUserCommand, - IOrganizationHasVerifiedDomainsQuery organizationHasVerifiedDomainsQuery, - ICurrentContext currentContext) + GlobalSettings globalSettings) { _applicationCacheService = applicationCacheService; - _eventService = eventService; - _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _policyRepository = policyRepository; - _ssoConfigRepository = ssoConfigRepository; - _mailService = mailService; _globalSettings = globalSettings; - _twoFactorIsEnabledQuery = twoFactorIsEnabledQuery; - _featureService = featureService; - _savePolicyCommand = savePolicyCommand; - _removeOrganizationUserCommand = removeOrganizationUserCommand; - _organizationHasVerifiedDomainsQuery = organizationHasVerifiedDomainsQuery; - _currentContext = currentContext; - } - - public async Task SaveAsync(Policy policy, Guid? savingUserId, EventSystemUser? eventSystemUser = null) - { - if (_featureService.IsEnabled(FeatureFlagKeys.Pm13322AddPolicyDefinitions)) - { - // Transitional mapping - this will be moved to callers once the feature flag is removed - // TODO make sure to populate with SystemUser if not an actual user - var policyUpdate = new PolicyUpdate - { - OrganizationId = policy.OrganizationId, - Type = policy.Type, - Enabled = policy.Enabled, - Data = policy.Data, - PerformedBy = savingUserId.HasValue - ? new StandardUser(savingUserId.Value, await _currentContext.OrganizationOwner(policy.OrganizationId)) - : new SystemUser(eventSystemUser ?? EventSystemUser.Unknown) - }; - - await _savePolicyCommand.SaveAsync(policyUpdate); - return; - } - - var org = await _organizationRepository.GetByIdAsync(policy.OrganizationId); - if (org == null) - { - throw new BadRequestException("Organization not found"); - } - - if (!org.UsePolicies) - { - throw new BadRequestException("This organization cannot use policies."); - } - - // FIXME: This method will throw a bunch of errors based on if the - // policy that is being applied requires some other policy that is - // not enabled. It may be advisable to refactor this into a domain - // object and get this kind of stuff out of the service. - await HandleDependentPoliciesAsync(policy, org); - - var now = DateTime.UtcNow; - if (policy.Id == default(Guid)) - { - policy.CreationDate = now; - } - - policy.RevisionDate = now; - - // We can exit early for disable operations, because they are - // simpler. - if (!policy.Enabled) - { - await SetPolicyConfiguration(policy); - return; - } - - await EnablePolicyAsync(policy, org, savingUserId); } public async Task GetMasterPasswordPolicyForUserAsync(User user) @@ -190,178 +92,4 @@ private OrganizationUserType[] GetUserTypesExcludedFromPolicy(PolicyType policyT return new[] { OrganizationUserType.Owner, OrganizationUserType.Admin }; } - - private async Task DependsOnSingleOrgAsync(Organization org) - { - var singleOrg = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.SingleOrg); - if (singleOrg?.Enabled != true) - { - throw new BadRequestException("Single Organization policy not enabled."); - } - } - - private async Task RequiredBySsoAsync(Organization org) - { - var requireSso = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.RequireSso); - if (requireSso?.Enabled == true) - { - throw new BadRequestException("Single Sign-On Authentication policy is enabled."); - } - } - - private async Task RequiredByKeyConnectorAsync(Organization org) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(org.Id); - if (ssoConfig?.GetData()?.MemberDecryptionType == MemberDecryptionType.KeyConnector) - { - throw new BadRequestException("Key Connector is enabled."); - } - } - - private async Task RequiredByAccountRecoveryAsync(Organization org) - { - var requireSso = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.ResetPassword); - if (requireSso?.Enabled == true) - { - throw new BadRequestException("Account recovery policy is enabled."); - } - } - - private async Task RequiredByVaultTimeoutAsync(Organization org) - { - var vaultTimeout = await _policyRepository.GetByOrganizationIdTypeAsync(org.Id, PolicyType.MaximumVaultTimeout); - if (vaultTimeout?.Enabled == true) - { - throw new BadRequestException("Maximum Vault Timeout policy is enabled."); - } - } - - private async Task RequiredBySsoTrustedDeviceEncryptionAsync(Organization org) - { - var ssoConfig = await _ssoConfigRepository.GetByOrganizationIdAsync(org.Id); - if (ssoConfig?.GetData()?.MemberDecryptionType == MemberDecryptionType.TrustedDeviceEncryption) - { - throw new BadRequestException("Trusted device encryption is on and requires this policy."); - } - } - - private async Task HandleDependentPoliciesAsync(Policy policy, Organization org) - { - switch (policy.Type) - { - case PolicyType.SingleOrg: - if (!policy.Enabled) - { - await HasVerifiedDomainsAsync(org); - await RequiredBySsoAsync(org); - await RequiredByVaultTimeoutAsync(org); - await RequiredByKeyConnectorAsync(org); - await RequiredByAccountRecoveryAsync(org); - } - break; - - case PolicyType.RequireSso: - if (policy.Enabled) - { - await DependsOnSingleOrgAsync(org); - } - else - { - await RequiredByKeyConnectorAsync(org); - await RequiredBySsoTrustedDeviceEncryptionAsync(org); - } - break; - - case PolicyType.ResetPassword: - if (!policy.Enabled || policy.GetDataModel()?.AutoEnrollEnabled == false) - { - await RequiredBySsoTrustedDeviceEncryptionAsync(org); - } - - if (policy.Enabled) - { - await DependsOnSingleOrgAsync(org); - } - break; - - case PolicyType.MaximumVaultTimeout: - if (policy.Enabled) - { - await DependsOnSingleOrgAsync(org); - } - break; - } - } - - private async Task HasVerifiedDomainsAsync(Organization org) - { - if (_featureService.IsEnabled(FeatureFlagKeys.AccountDeprovisioning) - && await _organizationHasVerifiedDomainsQuery.HasVerifiedDomainsAsync(org.Id)) - { - throw new BadRequestException("The Single organization policy is required for organizations that have enabled domain verification."); - } - } - - private async Task SetPolicyConfiguration(Policy policy) - { - await _policyRepository.UpsertAsync(policy); - await _eventService.LogPolicyEventAsync(policy, EventType.Policy_Updated); - } - - private async Task EnablePolicyAsync(Policy policy, Organization org, Guid? savingUserId) - { - var currentPolicy = await _policyRepository.GetByIdAsync(policy.Id); - if (!currentPolicy?.Enabled ?? true) - { - var orgUsers = await _organizationUserRepository.GetManyDetailsByOrganizationAsync(policy.OrganizationId); - var organizationUsersTwoFactorEnabled = await _twoFactorIsEnabledQuery.TwoFactorIsEnabledAsync(orgUsers); - var removableOrgUsers = orgUsers.Where(ou => - ou.Status != OrganizationUserStatusType.Invited && ou.Status != OrganizationUserStatusType.Revoked && - ou.Type != OrganizationUserType.Owner && ou.Type != OrganizationUserType.Admin && - ou.UserId != savingUserId); - switch (policy.Type) - { - case PolicyType.TwoFactorAuthentication: - // Reorder by HasMasterPassword to prioritize checking users without a master if they have 2FA enabled - foreach (var orgUser in removableOrgUsers.OrderBy(ou => ou.HasMasterPassword)) - { - var userTwoFactorEnabled = organizationUsersTwoFactorEnabled.FirstOrDefault(u => u.user.Id == orgUser.Id).twoFactorIsEnabled; - if (!userTwoFactorEnabled) - { - if (!orgUser.HasMasterPassword) - { - throw new BadRequestException( - "Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page."); - } - - await _removeOrganizationUserCommand.RemoveUserAsync(policy.OrganizationId, orgUser.Id, - savingUserId); - await _mailService.SendOrganizationUserRemovedForPolicyTwoStepEmailAsync( - org.DisplayName(), orgUser.Email); - } - } - break; - case PolicyType.SingleOrg: - var userOrgs = await _organizationUserRepository.GetManyByManyUsersAsync( - removableOrgUsers.Select(ou => ou.UserId.Value)); - foreach (var orgUser in removableOrgUsers) - { - if (userOrgs.Any(ou => ou.UserId == orgUser.UserId - && ou.OrganizationId != org.Id - && ou.Status != OrganizationUserStatusType.Invited)) - { - await _removeOrganizationUserCommand.RemoveUserAsync(policy.OrganizationId, orgUser.Id, - savingUserId); - await _mailService.SendOrganizationUserRemovedForPolicySingleOrgEmailAsync( - org.DisplayName(), orgUser.Email); - } - } - break; - default: - break; - } - } - - await SetPolicyConfiguration(policy); - } } diff --git a/src/Core/Auth/Services/Implementations/SsoConfigService.cs b/src/Core/Auth/Services/Implementations/SsoConfigService.cs index 532f000394c7..bf7e2d56fe25 100644 --- a/src/Core/Auth/Services/Implementations/SsoConfigService.cs +++ b/src/Core/Auth/Services/Implementations/SsoConfigService.cs @@ -1,8 +1,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Repositories; @@ -17,25 +18,25 @@ public class SsoConfigService : ISsoConfigService { private readonly ISsoConfigRepository _ssoConfigRepository; private readonly IPolicyRepository _policyRepository; - private readonly IPolicyService _policyService; private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly IEventService _eventService; + private readonly ISavePolicyCommand _savePolicyCommand; public SsoConfigService( ISsoConfigRepository ssoConfigRepository, IPolicyRepository policyRepository, - IPolicyService policyService, IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, - IEventService eventService) + IEventService eventService, + ISavePolicyCommand savePolicyCommand) { _ssoConfigRepository = ssoConfigRepository; _policyRepository = policyRepository; - _policyService = policyService; _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _eventService = eventService; + _savePolicyCommand = savePolicyCommand; } public async Task SaveAsync(SsoConfig config, Organization organization) @@ -63,25 +64,29 @@ public async Task SaveAsync(SsoConfig config, Organization organization) // Automatically enable account recovery, SSO required, and single org policies if trusted device encryption is selected if (config.GetData().MemberDecryptionType == MemberDecryptionType.TrustedDeviceEncryption) { - var singleOrgPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.SingleOrg) ?? - new Policy { OrganizationId = config.OrganizationId, Type = PolicyType.SingleOrg }; - singleOrgPolicy.Enabled = true; - - await _policyService.SaveAsync(singleOrgPolicy, null); - - var resetPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.ResetPassword) ?? - new Policy { OrganizationId = config.OrganizationId, Type = PolicyType.ResetPassword, }; - - resetPolicy.Enabled = true; - resetPolicy.SetDataModel(new ResetPasswordDataModel { AutoEnrollEnabled = true }); - await _policyService.SaveAsync(resetPolicy, null); - - var ssoRequiredPolicy = await _policyRepository.GetByOrganizationIdTypeAsync(config.OrganizationId, PolicyType.RequireSso) ?? - new Policy { OrganizationId = config.OrganizationId, Type = PolicyType.RequireSso, }; - - ssoRequiredPolicy.Enabled = true; - await _policyService.SaveAsync(ssoRequiredPolicy, null); + await _savePolicyCommand.SaveAsync(new() + { + OrganizationId = config.OrganizationId, + Type = PolicyType.SingleOrg, + Enabled = true + }); + + var resetPasswordPolicy = new PolicyUpdate + { + OrganizationId = config.OrganizationId, + Type = PolicyType.ResetPassword, + Enabled = true, + }; + resetPasswordPolicy.SetDataModel(new ResetPasswordDataModel { AutoEnrollEnabled = true }); + await _savePolicyCommand.SaveAsync(resetPasswordPolicy); + + await _savePolicyCommand.SaveAsync(new() + { + OrganizationId = config.OrganizationId, + Type = PolicyType.RequireSso, + Enabled = true + }); } await LogEventsAsync(config, oldConfig); diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 561f5f27c86c..6efdedfeed92 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -144,7 +144,6 @@ public static class FeatureFlagKeys public const string AccessIntelligence = "pm-13227-access-intelligence"; public const string VerifiedSsoDomainEndpoint = "pm-12337-refactor-sso-details-endpoint"; public const string PM12275_MultiOrganizationEnterprises = "pm-12275-multi-organization-enterprises"; - public const string Pm13322AddPolicyDefinitions = "pm-13322-add-policy-definitions"; public const string LimitCollectionCreationDeletionSplit = "pm-10863-limit-collection-creation-deletion-split"; public const string GeneratorToolsModernization = "generator-tools-modernization"; public const string NewDeviceVerification = "new-device-verification"; diff --git a/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs new file mode 100644 index 000000000000..f034426f98ee --- /dev/null +++ b/test/Api.IntegrationTest/AdminConsole/Public/Controllers/PoliciesControllerTests.cs @@ -0,0 +1,163 @@ +using System.Net; +using System.Text.Json; +using Bit.Api.AdminConsole.Public.Models.Request; +using Bit.Api.AdminConsole.Public.Models.Response; +using Bit.Api.IntegrationTest.Factories; +using Bit.Api.IntegrationTest.Helpers; +using Bit.Core.AdminConsole.Entities; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.Repositories; +using Bit.Core.Billing.Enums; +using Bit.Core.Enums; +using Bit.Test.Common.Helpers; +using Xunit; + +namespace Bit.Api.IntegrationTest.AdminConsole.Public.Controllers; + +public class PoliciesControllerTests : IClassFixture, IAsyncLifetime +{ + private readonly HttpClient _client; + private readonly ApiApplicationFactory _factory; + private readonly LoginHelper _loginHelper; + + // These will get set in `InitializeAsync` which is ran before all tests + private Organization _organization = null!; + private string _ownerEmail = null!; + + public PoliciesControllerTests(ApiApplicationFactory factory) + { + _factory = factory; + _client = factory.CreateClient(); + _loginHelper = new LoginHelper(_factory, _client); + } + + public async Task InitializeAsync() + { + // Create the owner account + _ownerEmail = $"integration-test{Guid.NewGuid()}@bitwarden.com"; + await _factory.LoginWithNewAccount(_ownerEmail); + + // Create the organization + (_organization, _) = await OrganizationTestHelpers.SignUpAsync(_factory, plan: PlanType.EnterpriseAnnually2023, + ownerEmail: _ownerEmail, passwordManagerSeats: 10, paymentMethod: PaymentMethodType.Card); + + // Authorize with the organization api key + await _loginHelper.LoginWithOrganizationApiKeyAsync(_organization.Id); + } + + public Task DisposeAsync() + { + _client.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public async Task Post_NewPolicy() + { + var policyType = PolicyType.MasterPassword; + var request = new PolicyUpdateRequestModel + { + Enabled = true, + Data = new Dictionary + { + { "minComplexity", 15}, + { "requireLower", true} + } + }; + + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert against the response + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + var result = await response.Content.ReadFromJsonAsync(); + Assert.NotNull(result); + + Assert.True(result.Enabled); + Assert.Equal(policyType, result.Type); + Assert.IsType(result.Id); + Assert.NotEqual(default, result.Id); + Assert.NotNull(result.Data); + Assert.Equal(15, ((JsonElement)result.Data["minComplexity"]).GetInt32()); + Assert.True(((JsonElement)result.Data["requireLower"]).GetBoolean()); + + // Assert against the database values + var policyRepository = _factory.GetService(); + var policy = await policyRepository.GetByOrganizationIdTypeAsync(_organization.Id, policyType); + Assert.NotNull(policy); + + Assert.True(policy.Enabled); + Assert.Equal(policyType, policy.Type); + Assert.IsType(policy.Id); + Assert.NotEqual(default, policy.Id); + Assert.Equal(_organization.Id, policy.OrganizationId); + + Assert.NotNull(policy.Data); + var data = policy.GetDataModel(); + var expectedData = new MasterPasswordPolicyData { MinComplexity = 15, RequireLower = true }; + AssertHelper.AssertPropertyEqual(expectedData, data); + } + + [Fact] + public async Task Post_UpdatePolicy() + { + var policyType = PolicyType.MasterPassword; + var existingPolicy = new Policy + { + OrganizationId = _organization.Id, + Enabled = true, + Type = policyType + }; + existingPolicy.SetDataModel(new MasterPasswordPolicyData + { + EnforceOnLogin = true, + MinLength = 22, + RequireSpecial = true + }); + + var policyRepository = _factory.GetService(); + await policyRepository.UpsertAsync(existingPolicy); + + // The Id isn't set until it's created in the database, get it back out to get the id + var createdPolicy = await policyRepository.GetByOrganizationIdTypeAsync(_organization.Id, policyType); + var expectedId = createdPolicy!.Id; + + var request = new PolicyUpdateRequestModel + { + Enabled = false, + Data = new Dictionary + { + { "minLength", 15}, + { "requireUpper", true} + } + }; + + var response = await _client.PutAsync($"/public/policies/{policyType}", JsonContent.Create(request)); + + // Assert against the response + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + var result = await response.Content.ReadFromJsonAsync(); + Assert.NotNull(result); + + Assert.False(result.Enabled); + Assert.Equal(policyType, result.Type); + Assert.Equal(expectedId, result.Id); + Assert.NotNull(result.Data); + Assert.Equal(15, ((JsonElement)result.Data["minLength"]).GetInt32()); + Assert.True(((JsonElement)result.Data["requireUpper"]).GetBoolean()); + + // Assert against the database values + var policy = await policyRepository.GetByOrganizationIdTypeAsync(_organization.Id, policyType); + Assert.NotNull(policy); + + Assert.False(policy.Enabled); + Assert.Equal(policyType, policy.Type); + Assert.Equal(expectedId, policy.Id); + Assert.Equal(_organization.Id, policy.OrganizationId); + + Assert.NotNull(policy.Data); + var data = policy.GetDataModel(); + Assert.Equal(15, data.MinLength); + Assert.Equal(true, data.RequireUpper); + } +} diff --git a/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs b/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs deleted file mode 100644 index 71d04cae33a6..000000000000 --- a/test/Api.Test/AdminConsole/Public/Controllers/PoliciesControllerTests.cs +++ /dev/null @@ -1,33 +0,0 @@ -using Bit.Api.AdminConsole.Public.Controllers; -using Bit.Api.AdminConsole.Public.Models.Request; -using Bit.Api.AdminConsole.Public.Models.Response; -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Enums; -using Bit.Core.AdminConsole.Repositories; -using Bit.Core.Context; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Microsoft.AspNetCore.Mvc; -using NSubstitute; -using Xunit; - -namespace Bit.Api.Test.AdminConsole.Public.Controllers; - -[ControllerCustomize(typeof(PoliciesController))] -[SutProviderCustomize] -public class PoliciesControllerTests -{ - [Theory] - [BitAutoData] - [BitAutoData(PolicyType.SendOptions)] - public async Task Put_NewPolicy_AppliesCorrectType(PolicyType type, Organization organization, PolicyUpdateRequestModel model, SutProvider sutProvider) - { - sutProvider.GetDependency().OrganizationId.Returns(organization.Id); - sutProvider.GetDependency().GetByOrganizationIdTypeAsync(organization.Id, type).Returns((Policy)null); - - var response = await sutProvider.Sut.Put(type, model) as JsonResult; - var responseValue = response.Value as PolicyResponseModel; - - Assert.Equal(type, responseValue.Type); - } -} diff --git a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs index 8dbe5331317c..700df88d5482 100644 --- a/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs +++ b/test/Core.Test/AdminConsole/OrganizationFeatures/OrganizationDomains/VerifyOrganizationDomainCommandTests.cs @@ -1,7 +1,8 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Enums; +using Bit.Core.AdminConsole.Models.Data; using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains; -using Bit.Core.AdminConsole.Services; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.Context; using Bit.Core.Entities; using Bit.Core.Enums; @@ -182,9 +183,13 @@ public async Task UserVerifyOrganizationDomainAsync_GivenOrganizationDomainWithA _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .Received(1) - .SaveAsync(Arg.Is(x => x.Type == PolicyType.SingleOrg && x.OrganizationId == domain.OrganizationId && x.Enabled), userId); + .SaveAsync(Arg.Is(x => x.Type == PolicyType.SingleOrg && + x.OrganizationId == domain.OrganizationId && + x.Enabled && + x.PerformedBy is StandardUser && + x.PerformedBy.UserId == userId)); } [Theory, BitAutoData] @@ -208,9 +213,9 @@ public async Task UserVerifyOrganizationDomainAsync_GivenOrganizationDomainWithA _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceive() - .SaveAsync(Arg.Any(), null); + .SaveAsync(Arg.Any()); } [Theory, BitAutoData] @@ -234,10 +239,9 @@ public async Task UserVerifyOrganizationDomainAsync_GivenOrganizationDomainWithA _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceive() - .SaveAsync(Arg.Any(), null); - + .SaveAsync(Arg.Any()); } [Theory, BitAutoData] @@ -261,8 +265,8 @@ public async Task UserVerifyOrganizationDomainAsync_GivenOrganizationDomainWithA _ = await sutProvider.Sut.UserVerifyOrganizationDomainAsync(domain); - await sutProvider.GetDependency() + await sutProvider.GetDependency() .DidNotReceive() - .SaveAsync(Arg.Any(), null); + .SaveAsync(Arg.Any()); } } diff --git a/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs b/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs index 68f36e37ce83..62ab584c4bc6 100644 --- a/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs +++ b/test/Core.Test/AdminConsole/Services/PolicyServiceTests.cs @@ -1,25 +1,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; -using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationDomains.Interfaces; -using Bit.Core.AdminConsole.OrganizationFeatures.OrganizationUsers.Interfaces; -using Bit.Core.AdminConsole.Repositories; using Bit.Core.AdminConsole.Services.Implementations; -using Bit.Core.Auth.Entities; -using Bit.Core.Auth.Enums; -using Bit.Core.Auth.Models.Data; -using Bit.Core.Auth.Repositories; -using Bit.Core.Auth.UserFeatures.TwoFactorAuth.Interfaces; using Bit.Core.Enums; -using Bit.Core.Exceptions; using Bit.Core.Models.Data.Organizations.OrganizationUsers; using Bit.Core.Repositories; -using Bit.Core.Services; using Bit.Test.Common.AutoFixture; using Bit.Test.Common.AutoFixture.Attributes; using NSubstitute; using Xunit; -using AdminConsoleFixtures = Bit.Core.Test.AdminConsole.AutoFixture; using GlobalSettings = Bit.Core.Settings.GlobalSettings; namespace Bit.Core.Test.AdminConsole.Services; @@ -27,667 +15,6 @@ namespace Bit.Core.Test.AdminConsole.Services; [SutProviderCustomize] public class PolicyServiceTests { - [Theory, BitAutoData] - public async Task SaveAsync_OrganizationDoesNotExist_ThrowsBadRequest( - [AdminConsoleFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) - { - SetupOrg(sutProvider, policy.OrganizationId, null); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("Organization not found", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_OrganizationCannotUsePolicies_ThrowsBadRequest( - [AdminConsoleFixtures.Policy(PolicyType.DisableSend)] Policy policy, SutProvider sutProvider) - { - var orgId = Guid.NewGuid(); - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - UsePolicies = false, - }); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("cannot use policies", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_SingleOrg_RequireSsoEnabled_ThrowsBadRequest( - [AdminConsoleFixtures.Policy(PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = false; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.RequireSso) - .Returns(Task.FromResult(new Policy { Enabled = true })); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("Single Sign-On Authentication policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_SingleOrg_VaultTimeoutEnabled_ThrowsBadRequest([AdminConsoleFixtures.Policy(PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = false; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.MaximumVaultTimeout) - .Returns(new Policy { Enabled = true }); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("Maximum Vault Timeout policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory] - [BitAutoData(PolicyType.SingleOrg)] - [BitAutoData(PolicyType.RequireSso)] - public async Task SaveAsync_PolicyRequiredByKeyConnector_DisablePolicy_ThrowsBadRequest( - PolicyType policyType, - Policy policy, - SutProvider sutProvider) - { - policy.Enabled = false; - policy.Type = policyType; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - var ssoConfig = new SsoConfig { Enabled = true }; - var data = new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.KeyConnector }; - ssoConfig.SetData(data); - - sutProvider.GetDependency() - .GetByOrganizationIdAsync(policy.OrganizationId) - .Returns(ssoConfig); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("Key Connector is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_RequireSsoPolicy_NotEnabled_ThrowsBadRequestAsync( - [AdminConsoleFixtures.Policy(PolicyType.RequireSso)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = true; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.SingleOrg) - .Returns(Task.FromResult(new Policy { Enabled = false })); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_NewPolicy_Created( - [AdminConsoleFixtures.Policy(PolicyType.ResetPassword)] Policy policy, SutProvider sutProvider) - { - policy.Id = default; - policy.Data = null; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.SingleOrg) - .Returns(Task.FromResult(new Policy { Enabled = true })); - - var utcNow = DateTime.UtcNow; - - await sutProvider.Sut.SaveAsync(policy, Guid.NewGuid()); - - await sutProvider.GetDependency().Received() - .LogPolicyEventAsync(policy, EventType.Policy_Updated); - - await sutProvider.GetDependency().Received() - .UpsertAsync(policy); - - Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory, BitAutoData] - public async Task SaveAsync_VaultTimeoutPolicy_NotEnabled_ThrowsBadRequestAsync( - [AdminConsoleFixtures.Policy(PolicyType.MaximumVaultTimeout)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = true; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.SingleOrg) - .Returns(Task.FromResult(new Policy { Enabled = false })); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_ExistingPolicy_UpdateTwoFactor( - Organization organization, - [AdminConsoleFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, - SutProvider sutProvider) - { - // If the policy that this is updating isn't enabled then do some work now that the current one is enabled - - organization.UsePolicies = true; - policy.OrganizationId = organization.Id; - - SetupOrg(sutProvider, organization.Id, organization); - - sutProvider.GetDependency() - .GetByIdAsync(policy.Id) - .Returns(new Policy - { - Id = policy.Id, - Type = PolicyType.TwoFactorAuthentication, - Enabled = false - }); - - var orgUserDetailUserInvited = new OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Invited, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "user1@test.com", - Name = "TEST", - UserId = Guid.NewGuid(), - HasMasterPassword = false - }; - var orgUserDetailUserAcceptedWith2FA = new OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Accepted, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "user2@test.com", - Name = "TEST", - UserId = Guid.NewGuid(), - HasMasterPassword = true - }; - var orgUserDetailUserAcceptedWithout2FA = new OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Accepted, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "user3@test.com", - Name = "TEST", - UserId = Guid.NewGuid(), - HasMasterPassword = true - }; - var orgUserDetailAdmin = new OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.Admin, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "admin@test.com", - Name = "ADMIN", - UserId = Guid.NewGuid(), - HasMasterPassword = false - }; - - sutProvider.GetDependency() - .GetManyDetailsByOrganizationAsync(policy.OrganizationId) - .Returns(new List - { - orgUserDetailUserInvited, - orgUserDetailUserAcceptedWith2FA, - orgUserDetailUserAcceptedWithout2FA, - orgUserDetailAdmin - }); - - sutProvider.GetDependency() - .TwoFactorIsEnabledAsync(Arg.Any>()) - .Returns(new List<(OrganizationUserUserDetails user, bool hasTwoFactor)>() - { - (orgUserDetailUserInvited, false), - (orgUserDetailUserAcceptedWith2FA, true), - (orgUserDetailUserAcceptedWithout2FA, false), - (orgUserDetailAdmin, false), - }); - - var removeOrganizationUserCommand = sutProvider.GetDependency(); - - var utcNow = DateTime.UtcNow; - - var savingUserId = Guid.NewGuid(); - - await sutProvider.Sut.SaveAsync(policy, savingUserId); - - await removeOrganizationUserCommand.Received() - .RemoveUserAsync(policy.OrganizationId, orgUserDetailUserAcceptedWithout2FA.Id, savingUserId); - await sutProvider.GetDependency().Received() - .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailUserAcceptedWithout2FA.Email); - - await removeOrganizationUserCommand.DidNotReceive() - .RemoveUserAsync(policy.OrganizationId, orgUserDetailUserInvited.Id, savingUserId); - await sutProvider.GetDependency().DidNotReceive() - .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailUserInvited.Email); - await removeOrganizationUserCommand.DidNotReceive() - .RemoveUserAsync(policy.OrganizationId, orgUserDetailUserAcceptedWith2FA.Id, savingUserId); - await sutProvider.GetDependency().DidNotReceive() - .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailUserAcceptedWith2FA.Email); - await removeOrganizationUserCommand.DidNotReceive() - .RemoveUserAsync(policy.OrganizationId, orgUserDetailAdmin.Id, savingUserId); - await sutProvider.GetDependency().DidNotReceive() - .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(organization.DisplayName(), orgUserDetailAdmin.Email); - - await sutProvider.GetDependency().Received() - .LogPolicyEventAsync(policy, EventType.Policy_Updated); - - await sutProvider.GetDependency().Received() - .UpsertAsync(policy); - - Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory, BitAutoData] - public async Task SaveAsync_EnableTwoFactor_WithoutMasterPasswordOr2FA_ThrowsBadRequest( - Organization organization, - [AdminConsoleFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, - SutProvider sutProvider) - { - organization.UsePolicies = true; - policy.OrganizationId = organization.Id; - - SetupOrg(sutProvider, organization.Id, organization); - - var orgUserDetailUserWith2FAAndMP = new OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "user1@test.com", - Name = "TEST", - UserId = Guid.NewGuid(), - HasMasterPassword = true - }; - var orgUserDetailUserWith2FANoMP = new OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "user2@test.com", - Name = "TEST", - UserId = Guid.NewGuid(), - HasMasterPassword = false - }; - var orgUserDetailUserWithout2FA = new OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "user3@test.com", - Name = "TEST", - UserId = Guid.NewGuid(), - HasMasterPassword = false - }; - var orgUserDetailAdmin = new OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Confirmed, - Type = OrganizationUserType.Admin, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "admin@test.com", - Name = "ADMIN", - UserId = Guid.NewGuid(), - HasMasterPassword = false - }; - - sutProvider.GetDependency() - .GetManyDetailsByOrganizationAsync(policy.OrganizationId) - .Returns(new List - { - orgUserDetailUserWith2FAAndMP, - orgUserDetailUserWith2FANoMP, - orgUserDetailUserWithout2FA, - orgUserDetailAdmin - }); - - sutProvider.GetDependency() - .TwoFactorIsEnabledAsync(Arg.Is>(ids => - ids.Contains(orgUserDetailUserWith2FANoMP.UserId.Value) - && ids.Contains(orgUserDetailUserWithout2FA.UserId.Value) - && ids.Contains(orgUserDetailAdmin.UserId.Value))) - .Returns(new List<(Guid userId, bool hasTwoFactor)>() - { - (orgUserDetailUserWith2FANoMP.UserId.Value, true), - (orgUserDetailUserWithout2FA.UserId.Value, false), - (orgUserDetailAdmin.UserId.Value, false), - }); - - var removeOrganizationUserCommand = sutProvider.GetDependency(); - - var savingUserId = Guid.NewGuid(); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, savingUserId)); - - Assert.Contains("Policy could not be enabled. Non-compliant members will lose access to their accounts. Identify members without two-step login from the policies column in the members page.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await removeOrganizationUserCommand.DidNotReceiveWithAnyArgs() - .RemoveUserAsync(organizationId: default, organizationUserId: default, deletingUserId: default); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(default, default); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default); - - await sutProvider.GetDependency().DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_ExistingPolicy_UpdateSingleOrg( - [AdminConsoleFixtures.Policy(PolicyType.TwoFactorAuthentication)] Policy policy, SutProvider sutProvider) - { - // If the policy that this is updating isn't enabled then do some work now that the current one is enabled - - var org = new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - Name = "TEST", - }; - - SetupOrg(sutProvider, policy.OrganizationId, org); - - sutProvider.GetDependency() - .GetByIdAsync(policy.Id) - .Returns(new Policy - { - Id = policy.Id, - Type = PolicyType.SingleOrg, - Enabled = false, - }); - - var orgUserDetail = new Core.Models.Data.Organizations.OrganizationUsers.OrganizationUserUserDetails - { - Id = Guid.NewGuid(), - Status = OrganizationUserStatusType.Accepted, - Type = OrganizationUserType.User, - // Needs to be different from what is passed in as the savingUserId to Sut.SaveAsync - Email = "test@bitwarden.com", - Name = "TEST", - UserId = Guid.NewGuid(), - HasMasterPassword = true - }; - - sutProvider.GetDependency() - .GetManyDetailsByOrganizationAsync(policy.OrganizationId) - .Returns(new List - { - orgUserDetail, - }); - - sutProvider.GetDependency() - .TwoFactorIsEnabledAsync(Arg.Is>(ids => ids.Contains(orgUserDetail.UserId.Value))) - .Returns(new List<(Guid userId, bool hasTwoFactor)>() - { - (orgUserDetail.UserId.Value, false), - }); - - var utcNow = DateTime.UtcNow; - - var savingUserId = Guid.NewGuid(); - - await sutProvider.Sut.SaveAsync(policy, savingUserId); - - await sutProvider.GetDependency().Received() - .LogPolicyEventAsync(policy, EventType.Policy_Updated); - - await sutProvider.GetDependency().Received() - .UpsertAsync(policy); - - Assert.True(policy.CreationDate - utcNow < TimeSpan.FromSeconds(1)); - Assert.True(policy.RevisionDate - utcNow < TimeSpan.FromSeconds(1)); - } - - [Theory] - [BitAutoData(true, false)] - [BitAutoData(false, true)] - [BitAutoData(false, false)] - public async Task SaveAsync_ResetPasswordPolicyRequiredByTrustedDeviceEncryption_DisablePolicyOrDisableAutomaticEnrollment_ThrowsBadRequest( - bool policyEnabled, - bool autoEnrollEnabled, - [AdminConsoleFixtures.Policy(PolicyType.ResetPassword)] Policy policy, - SutProvider sutProvider) - { - policy.Enabled = policyEnabled; - policy.SetDataModel(new ResetPasswordDataModel - { - AutoEnrollEnabled = autoEnrollEnabled - }); - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - var ssoConfig = new SsoConfig { Enabled = true }; - ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption }); - - sutProvider.GetDependency() - .GetByOrganizationIdAsync(policy.OrganizationId) - .Returns(ssoConfig); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("Trusted device encryption is on and requires this policy.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_RequireSsoPolicyRequiredByTrustedDeviceEncryption_DisablePolicy_ThrowsBadRequest( - [AdminConsoleFixtures.Policy(PolicyType.RequireSso)] Policy policy, - SutProvider sutProvider) - { - policy.Enabled = false; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - var ssoConfig = new SsoConfig { Enabled = true }; - ssoConfig.SetData(new SsoConfigurationData { MemberDecryptionType = MemberDecryptionType.TrustedDeviceEncryption }); - - sutProvider.GetDependency() - .GetByOrganizationIdAsync(policy.OrganizationId) - .Returns(ssoConfig); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("Trusted device encryption is on and requires this policy.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - [Theory, BitAutoData] - public async Task SaveAsync_PolicyRequiredForAccountRecovery_NotEnabled_ThrowsBadRequestAsync( - [AdminConsoleFixtures.Policy(PolicyType.ResetPassword)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = true; - policy.SetDataModel(new ResetPasswordDataModel()); - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.SingleOrg) - .Returns(Task.FromResult(new Policy { Enabled = false })); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("Single Organization policy not enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .LogPolicyEventAsync(default, default, default); - } - - - [Theory, BitAutoData] - public async Task SaveAsync_SingleOrg_AccountRecoveryEnabled_ThrowsBadRequest( - [AdminConsoleFixtures.Policy(PolicyType.SingleOrg)] Policy policy, SutProvider sutProvider) - { - policy.Enabled = false; - - SetupOrg(sutProvider, policy.OrganizationId, new Organization - { - Id = policy.OrganizationId, - UsePolicies = true, - }); - - sutProvider.GetDependency() - .GetByOrganizationIdTypeAsync(policy.OrganizationId, PolicyType.ResetPassword) - .Returns(new Policy { Enabled = true }); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, - Guid.NewGuid())); - - Assert.Contains("Account recovery policy is enabled.", badRequestException.Message, StringComparison.OrdinalIgnoreCase); - - await sutProvider.GetDependency() - .DidNotReceiveWithAnyArgs() - .UpsertAsync(default); - } - [Theory, BitAutoData] public async Task GetPoliciesApplicableToUserAsync_WithRequireSsoTypeFilter_WithDefaultOrganizationUserStatusFilter_ReturnsNoPolicies(Guid userId, SutProvider sutProvider) { @@ -816,32 +143,4 @@ private static void SetupUserPolicies(Guid userId, SutProvider su new() { OrganizationId = Guid.NewGuid(), PolicyType = PolicyType.DisableSend, PolicyEnabled = true, OrganizationUserType = OrganizationUserType.User, OrganizationUserStatus = OrganizationUserStatusType.Invited, IsProvider = true } }); } - - - [Theory, BitAutoData] - public async Task SaveAsync_GivenOrganizationUsingPoliciesAndHasVerifiedDomains_WhenSingleOrgPolicyIsDisabled_ThenAnErrorShouldBeThrownOrganizationHasVerifiedDomains( - [AdminConsoleFixtures.Policy(PolicyType.SingleOrg)] Policy policy, Organization org, SutProvider sutProvider) - { - org.Id = policy.OrganizationId; - org.UsePolicies = true; - - policy.Enabled = false; - - sutProvider.GetDependency() - .IsEnabled(FeatureFlagKeys.AccountDeprovisioning) - .Returns(true); - - sutProvider.GetDependency() - .GetByIdAsync(policy.OrganizationId) - .Returns(org); - - sutProvider.GetDependency() - .HasVerifiedDomainsAsync(org.Id) - .Returns(true); - - var badRequestException = await Assert.ThrowsAsync( - () => sutProvider.Sut.SaveAsync(policy, null)); - - Assert.Equal("The Single organization policy is required for organizations that have enabled domain verification.", badRequestException.Message); - } } diff --git a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs index e397c838c68f..7beb772b9582 100644 --- a/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs +++ b/test/Core.Test/Auth/Services/SsoConfigServiceTests.cs @@ -1,8 +1,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Enums; using Bit.Core.AdminConsole.Models.Data.Organizations.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies; +using Bit.Core.AdminConsole.OrganizationFeatures.Policies.Models; using Bit.Core.AdminConsole.Repositories; -using Bit.Core.AdminConsole.Services; using Bit.Core.Auth.Entities; using Bit.Core.Auth.Enums; using Bit.Core.Auth.Models.Data; @@ -338,16 +339,26 @@ public async Task SaveAsync_Tde_Enable_Required_Policies(SutProvider().Received(1) + await sutProvider.GetDependency().Received(1) .SaveAsync( - Arg.Is(t => t.Type == PolicyType.SingleOrg), - null + Arg.Is(t => t.Type == PolicyType.SingleOrg && + t.OrganizationId == organization.Id && + t.Enabled) ); - await sutProvider.GetDependency().Received(1) + await sutProvider.GetDependency().Received(1) .SaveAsync( - Arg.Is(t => t.Type == PolicyType.ResetPassword && t.GetDataModel().AutoEnrollEnabled), - null + Arg.Is(t => t.Type == PolicyType.ResetPassword && + t.GetDataModel().AutoEnrollEnabled && + t.OrganizationId == organization.Id && + t.Enabled) + ); + + await sutProvider.GetDependency().Received(1) + .SaveAsync( + Arg.Is(t => t.Type == PolicyType.RequireSso && + t.OrganizationId == organization.Id && + t.Enabled) ); await sutProvider.GetDependency().ReceivedWithAnyArgs() From 94fdfa40e8af9c9b788aafe2cf89eacc2913eeea Mon Sep 17 00:00:00 2001 From: Jonas Hendrickx Date: Wed, 4 Dec 2024 11:45:11 +0100 Subject: [PATCH 23/28] [PM-13999] Show estimated tax for taxable countries (#5077) --- .../Billing/TaxServiceTests.cs | 151 +++ .../Controllers/AccountsBillingController.cs | 15 + .../Billing/Controllers/InvoicesController.cs | 42 + .../Controllers/ProviderBillingController.cs | 1 + .../Billing/Controllers/StripeController.cs | 14 +- .../Requests/TaxInformationRequestBody.cs | 2 + .../Billing/Extensions/CurrencyExtensions.cs | 33 + .../Extensions/ServiceCollectionExtensions.cs | 1 + .../PreviewIndividualInvoiceRequestModel.cs | 18 + .../PreviewOrganizationInvoiceRequestModel.cs | 37 + .../Requests/TaxInformationRequestModel.cs | 14 + .../Responses/PreviewInvoiceResponseModel.cs | 7 + src/Core/Billing/Models/PreviewInvoiceInfo.cs | 7 + .../Billing/Models/Sales/OrganizationSale.cs | 1 + src/Core/Billing/Models/TaxIdType.cs | 22 + src/Core/Billing/Models/TaxInformation.cs | 160 +--- src/Core/Billing/Services/ITaxService.cs | 22 + .../OrganizationBillingService.cs | 35 +- .../PremiumUserBillingService.cs | 37 +- .../Implementations/SubscriberService.cs | 49 +- src/Core/Billing/Services/TaxService.cs | 901 ++++++++++++++++++ src/Core/Billing/Utilities.cs | 1 + src/Core/Models/Business/TaxInfo.cs | 208 +--- src/Core/Services/IPaymentService.cs | 6 + src/Core/Services/IStripeAdapter.cs | 2 + .../Services/Implementations/StripeAdapter.cs | 12 + .../Implementations/StripePaymentService.cs | 387 +++++++- .../Services/SubscriberServiceTests.cs | 3 +- .../Core.Test/Models/Business/TaxInfoTests.cs | 114 --- .../Services/StripePaymentServiceTests.cs | 18 +- 30 files changed, 1791 insertions(+), 529 deletions(-) create mode 100644 bitwarden_license/test/Commercial.Core.Test/Billing/TaxServiceTests.cs create mode 100644 src/Api/Billing/Controllers/InvoicesController.cs create mode 100644 src/Core/Billing/Extensions/CurrencyExtensions.cs create mode 100644 src/Core/Billing/Models/Api/Requests/Accounts/PreviewIndividualInvoiceRequestModel.cs create mode 100644 src/Core/Billing/Models/Api/Requests/Organizations/PreviewOrganizationInvoiceRequestModel.cs create mode 100644 src/Core/Billing/Models/Api/Requests/TaxInformationRequestModel.cs create mode 100644 src/Core/Billing/Models/Api/Responses/PreviewInvoiceResponseModel.cs create mode 100644 src/Core/Billing/Models/PreviewInvoiceInfo.cs create mode 100644 src/Core/Billing/Models/TaxIdType.cs create mode 100644 src/Core/Billing/Services/ITaxService.cs create mode 100644 src/Core/Billing/Services/TaxService.cs delete mode 100644 test/Core.Test/Models/Business/TaxInfoTests.cs diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/TaxServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/TaxServiceTests.cs new file mode 100644 index 000000000000..3995fb9de6b1 --- /dev/null +++ b/bitwarden_license/test/Commercial.Core.Test/Billing/TaxServiceTests.cs @@ -0,0 +1,151 @@ +using Bit.Core.Billing.Services; +using Bit.Test.Common.AutoFixture; +using Bit.Test.Common.AutoFixture.Attributes; +using Xunit; + +namespace Bit.Commercial.Core.Test.Billing; + +[SutProviderCustomize] +public class TaxServiceTests +{ + [Theory] + [BitAutoData("AD", "A-123456-Z", "ad_nrt")] + [BitAutoData("AD", "A123456Z", "ad_nrt")] + [BitAutoData("AR", "20-12345678-9", "ar_cuit")] + [BitAutoData("AR", "20123456789", "ar_cuit")] + [BitAutoData("AU", "01259983598", "au_abn")] + [BitAutoData("AU", "123456789123", "au_arn")] + [BitAutoData("AT", "ATU12345678", "eu_vat")] + [BitAutoData("BH", "123456789012345", "bh_vat")] + [BitAutoData("BY", "123456789", "by_tin")] + [BitAutoData("BE", "BE0123456789", "eu_vat")] + [BitAutoData("BO", "123456789", "bo_tin")] + [BitAutoData("BR", "01.234.456/5432-10", "br_cnpj")] + [BitAutoData("BR", "01234456543210", "br_cnpj")] + [BitAutoData("BR", "123.456.789-87", "br_cpf")] + [BitAutoData("BR", "12345678987", "br_cpf")] + [BitAutoData("BG", "123456789", "bg_uic")] + [BitAutoData("BG", "BG012100705", "eu_vat")] + [BitAutoData("CA", "100728494", "ca_bn")] + [BitAutoData("CA", "123456789RT0001", "ca_gst_hst")] + [BitAutoData("CA", "PST-1234-1234", "ca_pst_bc")] + [BitAutoData("CA", "123456-7", "ca_pst_mb")] + [BitAutoData("CA", "1234567", "ca_pst_sk")] + [BitAutoData("CA", "1234567890TQ1234", "ca_qst")] + [BitAutoData("CL", "11.121.326-1", "cl_tin")] + [BitAutoData("CL", "11121326-1", "cl_tin")] + [BitAutoData("CL", "23.121.326-K", "cl_tin")] + [BitAutoData("CL", "43651326-K", "cl_tin")] + [BitAutoData("CN", "123456789012345678", "cn_tin")] + [BitAutoData("CN", "123456789012345", "cn_tin")] + [BitAutoData("CO", "123.456.789-0", "co_nit")] + [BitAutoData("CO", "1234567890", "co_nit")] + [BitAutoData("CR", "1-234-567890", "cr_tin")] + [BitAutoData("CR", "1234567890", "cr_tin")] + [BitAutoData("HR", "HR12345678912", "eu_vat")] + [BitAutoData("HR", "12345678901", "hr_oib")] + [BitAutoData("CY", "CY12345678X", "eu_vat")] + [BitAutoData("CZ", "CZ12345678", "eu_vat")] + [BitAutoData("DK", "DK12345678", "eu_vat")] + [BitAutoData("DO", "123-4567890-1", "do_rcn")] + [BitAutoData("DO", "12345678901", "do_rcn")] + [BitAutoData("EC", "1234567890001", "ec_ruc")] + [BitAutoData("EG", "123456789", "eg_tin")] + [BitAutoData("SV", "1234-567890-123-4", "sv_nit")] + [BitAutoData("SV", "12345678901234", "sv_nit")] + [BitAutoData("EE", "EE123456789", "eu_vat")] + [BitAutoData("EU", "EU123456789", "eu_oss_vat")] + [BitAutoData("FI", "FI12345678", "eu_vat")] + [BitAutoData("FR", "FR12345678901", "eu_vat")] + [BitAutoData("GE", "123456789", "ge_vat")] + [BitAutoData("DE", "1234567890", "de_stn")] + [BitAutoData("DE", "DE123456789", "eu_vat")] + [BitAutoData("GR", "EL123456789", "eu_vat")] + [BitAutoData("HK", "12345678", "hk_br")] + [BitAutoData("HU", "HU12345678", "eu_vat")] + [BitAutoData("HU", "12345678-1-23", "hu_tin")] + [BitAutoData("HU", "12345678123", "hu_tin")] + [BitAutoData("IS", "123456", "is_vat")] + [BitAutoData("IN", "12ABCDE1234F1Z5", "in_gst")] + [BitAutoData("IN", "12ABCDE3456FGZH", "in_gst")] + [BitAutoData("ID", "012.345.678.9-012.345", "id_npwp")] + [BitAutoData("ID", "0123456789012345", "id_npwp")] + [BitAutoData("IE", "IE1234567A", "eu_vat")] + [BitAutoData("IE", "IE1234567AB", "eu_vat")] + [BitAutoData("IL", "000012345", "il_vat")] + [BitAutoData("IL", "123456789", "il_vat")] + [BitAutoData("IT", "IT12345678901", "eu_vat")] + [BitAutoData("JP", "1234567890123", "jp_cn")] + [BitAutoData("JP", "12345", "jp_rn")] + [BitAutoData("KZ", "123456789012", "kz_bin")] + [BitAutoData("KE", "P000111111A", "ke_pin")] + [BitAutoData("LV", "LV12345678912", "eu_vat")] + [BitAutoData("LI", "CHE123456789", "li_uid")] + [BitAutoData("LI", "12345", "li_vat")] + [BitAutoData("LT", "LT123456789123", "eu_vat")] + [BitAutoData("LU", "LU12345678", "eu_vat")] + [BitAutoData("MY", "12345678", "my_frp")] + [BitAutoData("MY", "C 1234567890", "my_itn")] + [BitAutoData("MY", "C1234567890", "my_itn")] + [BitAutoData("MY", "A12-3456-78912345", "my_sst")] + [BitAutoData("MY", "A12345678912345", "my_sst")] + [BitAutoData("MT", "MT12345678", "eu_vat")] + [BitAutoData("MX", "ABC010203AB9", "mx_rfc")] + [BitAutoData("MD", "1003600", "md_vat")] + [BitAutoData("MA", "12345678", "ma_vat")] + [BitAutoData("NL", "NL123456789B12", "eu_vat")] + [BitAutoData("NZ", "123456789", "nz_gst")] + [BitAutoData("NG", "12345678-0001", "ng_tin")] + [BitAutoData("NO", "123456789MVA", "no_vat")] + [BitAutoData("NO", "1234567", "no_voec")] + [BitAutoData("OM", "OM1234567890", "om_vat")] + [BitAutoData("PE", "12345678901", "pe_ruc")] + [BitAutoData("PH", "123456789012", "ph_tin")] + [BitAutoData("PL", "PL1234567890", "eu_vat")] + [BitAutoData("PT", "PT123456789", "eu_vat")] + [BitAutoData("RO", "RO1234567891", "eu_vat")] + [BitAutoData("RO", "1234567890123", "ro_tin")] + [BitAutoData("RU", "1234567891", "ru_inn")] + [BitAutoData("RU", "123456789", "ru_kpp")] + [BitAutoData("SA", "123456789012345", "sa_vat")] + [BitAutoData("RS", "123456789", "rs_pib")] + [BitAutoData("SG", "M12345678X", "sg_gst")] + [BitAutoData("SG", "123456789F", "sg_uen")] + [BitAutoData("SK", "SK1234567891", "eu_vat")] + [BitAutoData("SI", "SI12345678", "eu_vat")] + [BitAutoData("SI", "12345678", "si_tin")] + [BitAutoData("ZA", "4123456789", "za_vat")] + [BitAutoData("KR", "123-45-67890", "kr_brn")] + [BitAutoData("KR", "1234567890", "kr_brn")] + [BitAutoData("ES", "A12345678", "es_cif")] + [BitAutoData("ES", "ESX1234567X", "eu_vat")] + [BitAutoData("SE", "SE123456789012", "eu_vat")] + [BitAutoData("CH", "CHE-123.456.789 HR", "ch_uid")] + [BitAutoData("CH", "CHE123456789HR", "ch_uid")] + [BitAutoData("CH", "CHE-123.456.789 MWST", "ch_vat")] + [BitAutoData("CH", "CHE123456789MWST", "ch_vat")] + [BitAutoData("TW", "12345678", "tw_vat")] + [BitAutoData("TH", "1234567890123", "th_vat")] + [BitAutoData("TR", "0123456789", "tr_tin")] + [BitAutoData("UA", "123456789", "ua_vat")] + [BitAutoData("AE", "123456789012345", "ae_trn")] + [BitAutoData("GB", "XI123456789", "eu_vat")] + [BitAutoData("GB", "GB123456789", "gb_vat")] + [BitAutoData("US", "12-3456789", "us_ein")] + [BitAutoData("UY", "123456789012", "uy_ruc")] + [BitAutoData("UZ", "123456789", "uz_tin")] + [BitAutoData("UZ", "123456789012", "uz_vat")] + [BitAutoData("VE", "A-12345678-9", "ve_rif")] + [BitAutoData("VE", "A123456789", "ve_rif")] + [BitAutoData("VN", "1234567890", "vn_tin")] + public void GetStripeTaxCode_WithValidCountryAndTaxId_ReturnsExpectedTaxIdType( + string country, + string taxId, + string expected, + SutProvider sutProvider) + { + var result = sutProvider.Sut.GetStripeTaxCode(country, taxId); + + Assert.Equal(expected, result); + } +} diff --git a/src/Api/Billing/Controllers/AccountsBillingController.cs b/src/Api/Billing/Controllers/AccountsBillingController.cs index 574ac3e65e26..fcb89226e75d 100644 --- a/src/Api/Billing/Controllers/AccountsBillingController.cs +++ b/src/Api/Billing/Controllers/AccountsBillingController.cs @@ -1,5 +1,6 @@ #nullable enable using Bit.Api.Billing.Models.Responses; +using Bit.Core.Billing.Models.Api.Requests.Accounts; using Bit.Core.Billing.Services; using Bit.Core.Services; using Bit.Core.Utilities; @@ -77,4 +78,18 @@ public async Task GetTransactionsAsync([FromQuery] DateTime? startAfter return TypedResults.Ok(transactions); } + + [HttpPost("preview-invoice")] + public async Task PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model) + { + var user = await userService.GetUserByPrincipalAsync(User); + if (user == null) + { + throw new UnauthorizedAccessException(); + } + + var invoice = await paymentService.PreviewInvoiceAsync(model, user.GatewayCustomerId, user.GatewaySubscriptionId); + + return TypedResults.Ok(invoice); + } } diff --git a/src/Api/Billing/Controllers/InvoicesController.cs b/src/Api/Billing/Controllers/InvoicesController.cs new file mode 100644 index 000000000000..686d9b964387 --- /dev/null +++ b/src/Api/Billing/Controllers/InvoicesController.cs @@ -0,0 +1,42 @@ +using Bit.Core.AdminConsole.Entities; +using Bit.Core.Billing.Models.Api.Requests.Organizations; +using Bit.Core.Context; +using Bit.Core.Repositories; +using Bit.Core.Services; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; + +namespace Bit.Api.Billing.Controllers; + +[Route("invoices")] +[Authorize("Application")] +public class InvoicesController : BaseBillingController +{ + [HttpPost("preview-organization")] + public async Task PreviewInvoiceAsync( + [FromBody] PreviewOrganizationInvoiceRequestBody model, + [FromServices] ICurrentContext currentContext, + [FromServices] IOrganizationRepository organizationRepository, + [FromServices] IPaymentService paymentService) + { + Organization organization = null; + if (model.OrganizationId != default) + { + if (!await currentContext.EditPaymentMethods(model.OrganizationId)) + { + return Error.Unauthorized(); + } + + organization = await organizationRepository.GetByIdAsync(model.OrganizationId); + if (organization == null) + { + return Error.NotFound(); + } + } + + var invoice = await paymentService.PreviewInvoiceAsync(model, organization?.GatewayCustomerId, + organization?.GatewaySubscriptionId); + + return TypedResults.Ok(invoice); + } +} diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index f7ddf0853e20..c5de63c69b9f 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -119,6 +119,7 @@ public async Task UpdateTaxInformationAsync( requestBody.Country, requestBody.PostalCode, requestBody.TaxId, + requestBody.TaxIdType, requestBody.Line1, requestBody.Line2, requestBody.City, diff --git a/src/Api/Billing/Controllers/StripeController.cs b/src/Api/Billing/Controllers/StripeController.cs index a4a974bb992c..f5e8253bfa49 100644 --- a/src/Api/Billing/Controllers/StripeController.cs +++ b/src/Api/Billing/Controllers/StripeController.cs @@ -1,4 +1,5 @@ -using Bit.Core.Services; +using Bit.Core.Billing.Services; +using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Mvc; @@ -46,4 +47,15 @@ public async Task> CreateSetupIntentForCardAsync() return TypedResults.Ok(setupIntent.ClientSecret); } + + [HttpGet] + [Route("~/tax/is-country-supported")] + public IResult IsCountrySupported( + [FromQuery] string country, + [FromServices] ITaxService taxService) + { + var isSupported = taxService.IsSupported(country); + + return TypedResults.Ok(isSupported); + } } diff --git a/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs b/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs index c5c0fde00b93..32ba2effb29b 100644 --- a/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs +++ b/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs @@ -10,6 +10,7 @@ public class TaxInformationRequestBody [Required] public string PostalCode { get; set; } public string TaxId { get; set; } + public string TaxIdType { get; set; } public string Line1 { get; set; } public string Line2 { get; set; } public string City { get; set; } @@ -19,6 +20,7 @@ public class TaxInformationRequestBody Country, PostalCode, TaxId, + TaxIdType, Line1, Line2, City, diff --git a/src/Core/Billing/Extensions/CurrencyExtensions.cs b/src/Core/Billing/Extensions/CurrencyExtensions.cs new file mode 100644 index 000000000000..cde1a7bea846 --- /dev/null +++ b/src/Core/Billing/Extensions/CurrencyExtensions.cs @@ -0,0 +1,33 @@ +namespace Bit.Core.Billing.Extensions; + +public static class CurrencyExtensions +{ + /// + /// Converts a currency amount in major units to minor units. + /// + /// 123.99 USD returns 12399 in minor units. + public static long ToMinor(this decimal amount) + { + return Convert.ToInt64(amount * 100); + } + + /// + /// Converts a currency amount in minor units to major units. + /// + /// + /// 12399 in minor units returns 123.99 USD. + public static decimal? ToMajor(this long? amount) + { + return amount?.ToMajor(); + } + + /// + /// Converts a currency amount in minor units to major units. + /// + /// + /// 12399 in minor units returns 123.99 USD. + public static decimal ToMajor(this long amount) + { + return Convert.ToDecimal(amount) / 100; + } +} diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index abfceac7366e..8f2803d920af 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -11,6 +11,7 @@ public static class ServiceCollectionExtensions { public static void AddBillingOperations(this IServiceCollection services) { + services.AddSingleton(); services.AddTransient(); services.AddTransient(); services.AddTransient(); diff --git a/src/Core/Billing/Models/Api/Requests/Accounts/PreviewIndividualInvoiceRequestModel.cs b/src/Core/Billing/Models/Api/Requests/Accounts/PreviewIndividualInvoiceRequestModel.cs new file mode 100644 index 000000000000..6dfb9894d50d --- /dev/null +++ b/src/Core/Billing/Models/Api/Requests/Accounts/PreviewIndividualInvoiceRequestModel.cs @@ -0,0 +1,18 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Core.Billing.Models.Api.Requests.Accounts; + +public class PreviewIndividualInvoiceRequestBody +{ + [Required] + public PasswordManagerRequestModel PasswordManager { get; set; } + + [Required] + public TaxInformationRequestModel TaxInformation { get; set; } +} + +public class PasswordManagerRequestModel +{ + [Range(0, int.MaxValue)] + public int AdditionalStorage { get; set; } +} diff --git a/src/Core/Billing/Models/Api/Requests/Organizations/PreviewOrganizationInvoiceRequestModel.cs b/src/Core/Billing/Models/Api/Requests/Organizations/PreviewOrganizationInvoiceRequestModel.cs new file mode 100644 index 000000000000..18d9c352d7c4 --- /dev/null +++ b/src/Core/Billing/Models/Api/Requests/Organizations/PreviewOrganizationInvoiceRequestModel.cs @@ -0,0 +1,37 @@ +using System.ComponentModel.DataAnnotations; +using Bit.Core.Billing.Enums; + +namespace Bit.Core.Billing.Models.Api.Requests.Organizations; + +public class PreviewOrganizationInvoiceRequestBody +{ + public Guid OrganizationId { get; set; } + + [Required] + public PasswordManagerRequestModel PasswordManager { get; set; } + + public SecretsManagerRequestModel SecretsManager { get; set; } + + [Required] + public TaxInformationRequestModel TaxInformation { get; set; } +} + +public class PasswordManagerRequestModel +{ + public PlanType Plan { get; set; } + + [Range(0, int.MaxValue)] + public int Seats { get; set; } + + [Range(0, int.MaxValue)] + public int AdditionalStorage { get; set; } +} + +public class SecretsManagerRequestModel +{ + [Range(0, int.MaxValue)] + public int Seats { get; set; } + + [Range(0, int.MaxValue)] + public int AdditionalMachineAccounts { get; set; } +} diff --git a/src/Core/Billing/Models/Api/Requests/TaxInformationRequestModel.cs b/src/Core/Billing/Models/Api/Requests/TaxInformationRequestModel.cs new file mode 100644 index 000000000000..9cb43645c6cd --- /dev/null +++ b/src/Core/Billing/Models/Api/Requests/TaxInformationRequestModel.cs @@ -0,0 +1,14 @@ +using System.ComponentModel.DataAnnotations; + +namespace Bit.Core.Billing.Models.Api.Requests; + +public class TaxInformationRequestModel +{ + [Length(2, 2), Required] + public string Country { get; set; } + + [Required] + public string PostalCode { get; set; } + + public string TaxId { get; set; } +} diff --git a/src/Core/Billing/Models/Api/Responses/PreviewInvoiceResponseModel.cs b/src/Core/Billing/Models/Api/Responses/PreviewInvoiceResponseModel.cs new file mode 100644 index 000000000000..fdde7dae1e48 --- /dev/null +++ b/src/Core/Billing/Models/Api/Responses/PreviewInvoiceResponseModel.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.Billing.Models.Api.Responses; + +public record PreviewInvoiceResponseModel( + decimal EffectiveTaxRate, + decimal TaxableBaseAmount, + decimal TaxAmount, + decimal TotalAmount); diff --git a/src/Core/Billing/Models/PreviewInvoiceInfo.cs b/src/Core/Billing/Models/PreviewInvoiceInfo.cs new file mode 100644 index 000000000000..16a2019c202b --- /dev/null +++ b/src/Core/Billing/Models/PreviewInvoiceInfo.cs @@ -0,0 +1,7 @@ +namespace Bit.Core.Billing.Models; + +public record PreviewInvoiceInfo( + decimal EffectiveTaxRate, + decimal TaxableBaseAmount, + decimal TaxAmount, + decimal TotalAmount); diff --git a/src/Core/Billing/Models/Sales/OrganizationSale.cs b/src/Core/Billing/Models/Sales/OrganizationSale.cs index a19c278c68f1..43852bb320bd 100644 --- a/src/Core/Billing/Models/Sales/OrganizationSale.cs +++ b/src/Core/Billing/Models/Sales/OrganizationSale.cs @@ -65,6 +65,7 @@ private static CustomerSetup GetCustomerSetup(OrganizationSignup signup) signup.TaxInfo.BillingAddressCountry, signup.TaxInfo.BillingAddressPostalCode, signup.TaxInfo.TaxIdNumber, + signup.TaxInfo.TaxIdType, signup.TaxInfo.BillingAddressLine1, signup.TaxInfo.BillingAddressLine2, signup.TaxInfo.BillingAddressCity, diff --git a/src/Core/Billing/Models/TaxIdType.cs b/src/Core/Billing/Models/TaxIdType.cs new file mode 100644 index 000000000000..3fc246d68bbc --- /dev/null +++ b/src/Core/Billing/Models/TaxIdType.cs @@ -0,0 +1,22 @@ +using System.Text.RegularExpressions; + +namespace Bit.Core.Billing.Models; + +public class TaxIdType +{ + /// + /// ISO-3166-2 code for the country. + /// + public string Country { get; set; } + + /// + /// The identifier in Stripe for the tax ID type. + /// + public string Code { get; set; } + + public Regex ValidationExpression { get; set; } + + public string Description { get; set; } + + public string Example { get; set; } +} diff --git a/src/Core/Billing/Models/TaxInformation.cs b/src/Core/Billing/Models/TaxInformation.cs index 5403f9469022..23ed3e5faa64 100644 --- a/src/Core/Billing/Models/TaxInformation.cs +++ b/src/Core/Billing/Models/TaxInformation.cs @@ -1,5 +1,4 @@ using Bit.Core.Models.Business; -using Stripe; namespace Bit.Core.Billing.Models; @@ -7,6 +6,7 @@ public record TaxInformation( string Country, string PostalCode, string TaxId, + string TaxIdType, string Line1, string Line2, string City, @@ -16,165 +16,9 @@ public record TaxInformation( taxInfo.BillingAddressCountry, taxInfo.BillingAddressPostalCode, taxInfo.TaxIdNumber, + taxInfo.TaxIdType, taxInfo.BillingAddressLine1, taxInfo.BillingAddressLine2, taxInfo.BillingAddressCity, taxInfo.BillingAddressState); - - public (AddressOptions, List) GetStripeOptions() - { - var address = new AddressOptions - { - Country = Country, - PostalCode = PostalCode, - Line1 = Line1, - Line2 = Line2, - City = City, - State = State - }; - - var customerTaxIdDataOptionsList = !string.IsNullOrEmpty(TaxId) - ? new List { new() { Type = GetTaxIdType(), Value = TaxId } } - : null; - - return (address, customerTaxIdDataOptionsList); - } - - public string GetTaxIdType() - { - if (string.IsNullOrEmpty(Country) || string.IsNullOrEmpty(TaxId)) - { - return null; - } - - switch (Country.ToUpper()) - { - case "AD": - return "ad_nrt"; - case "AE": - return "ae_trn"; - case "AR": - return "ar_cuit"; - case "AU": - return "au_abn"; - case "BO": - return "bo_tin"; - case "BR": - return "br_cnpj"; - case "CA": - // May break for those in Québec given the assumption of QST - if (State?.Contains("bec") ?? false) - { - return "ca_qst"; - } - return "ca_bn"; - case "CH": - return "ch_vat"; - case "CL": - return "cl_tin"; - case "CN": - return "cn_tin"; - case "CO": - return "co_nit"; - case "CR": - return "cr_tin"; - case "DO": - return "do_rcn"; - case "EC": - return "ec_ruc"; - case "EG": - return "eg_tin"; - case "GE": - return "ge_vat"; - case "ID": - return "id_npwp"; - case "IL": - return "il_vat"; - case "IS": - return "is_vat"; - case "KE": - return "ke_pin"; - case "AT": - case "BE": - case "BG": - case "CY": - case "CZ": - case "DE": - case "DK": - case "EE": - case "ES": - case "FI": - case "FR": - case "GB": - case "GR": - case "HR": - case "HU": - case "IE": - case "IT": - case "LT": - case "LU": - case "LV": - case "MT": - case "NL": - case "PL": - case "PT": - case "RO": - case "SE": - case "SI": - case "SK": - return "eu_vat"; - case "HK": - return "hk_br"; - case "IN": - return "in_gst"; - case "JP": - return "jp_cn"; - case "KR": - return "kr_brn"; - case "LI": - return "li_uid"; - case "MX": - return "mx_rfc"; - case "MY": - return "my_sst"; - case "NO": - return "no_vat"; - case "NZ": - return "nz_gst"; - case "PE": - return "pe_ruc"; - case "PH": - return "ph_tin"; - case "RS": - return "rs_pib"; - case "RU": - return "ru_inn"; - case "SA": - return "sa_vat"; - case "SG": - return "sg_gst"; - case "SV": - return "sv_nit"; - case "TH": - return "th_vat"; - case "TR": - return "tr_tin"; - case "TW": - return "tw_vat"; - case "UA": - return "ua_vat"; - case "US": - return "us_ein"; - case "UY": - return "uy_ruc"; - case "VE": - return "ve_rif"; - case "VN": - return "vn_tin"; - case "ZA": - return "za_vat"; - default: - return null; - } - } } diff --git a/src/Core/Billing/Services/ITaxService.cs b/src/Core/Billing/Services/ITaxService.cs new file mode 100644 index 000000000000..beee113d17bd --- /dev/null +++ b/src/Core/Billing/Services/ITaxService.cs @@ -0,0 +1,22 @@ +namespace Bit.Core.Billing.Services; + +public interface ITaxService +{ + /// + /// Retrieves the Stripe tax code for a given country and tax ID. + /// + /// + /// + /// + /// Returns the Stripe tax code if the tax ID is valid for the country. + /// Returns null if the tax ID is invalid or the country is not supported. + /// + string GetStripeTaxCode(string country, string taxId); + + /// + /// Returns true or false whether charging or storing tax is supported for the given country. + /// + /// + /// + bool IsSupported(string country); +} diff --git a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs index eadc5896258f..b186a99d93e7 100644 --- a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs +++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs @@ -28,7 +28,8 @@ public class OrganizationBillingService( IOrganizationRepository organizationRepository, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ISubscriberService subscriberService) : IOrganizationBillingService + ISubscriberService subscriberService, + ITaxService taxService) : IOrganizationBillingService { public async Task Finalize(OrganizationSale sale) { @@ -167,14 +168,38 @@ private async Task CreateCustomerAsync( throw new BillingException(); } - var (address, taxIdData) = customerSetup.TaxInformation.GetStripeOptions(); - - customerCreateOptions.Address = address; + customerCreateOptions.Address = new AddressOptions + { + Line1 = customerSetup.TaxInformation.Line1, + Line2 = customerSetup.TaxInformation.Line2, + City = customerSetup.TaxInformation.City, + PostalCode = customerSetup.TaxInformation.PostalCode, + State = customerSetup.TaxInformation.State, + Country = customerSetup.TaxInformation.Country, + }; customerCreateOptions.Tax = new CustomerTaxOptions { ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately }; - customerCreateOptions.TaxIdData = taxIdData; + + if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId)) + { + var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country, + customerSetup.TaxInformation.TaxId); + + if (taxIdType == null) + { + logger.LogWarning("Could not determine tax ID type for organization '{OrganizationID}' in country '{Country}' with tax ID '{TaxID}'.", + organization.Id, + customerSetup.TaxInformation.Country, + customerSetup.TaxInformation.TaxId); + } + + customerCreateOptions.TaxIdData = + [ + new() { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId } + ]; + } var (paymentMethodType, paymentMethodToken) = customerSetup.TokenizedPaymentSource; diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index 92c81dae1cf5..5351888ad9d3 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -24,7 +24,8 @@ public class PremiumUserBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - IUserRepository userRepository) : IPremiumUserBillingService + IUserRepository userRepository, + ITaxService taxService) : IPremiumUserBillingService { public async Task Finalize(PremiumUserSale sale) { @@ -82,13 +83,19 @@ private async Task CreateCustomerAsync( throw new BillingException(); } - var (address, taxIdData) = customerSetup.TaxInformation.GetStripeOptions(); - var subscriberName = user.SubscriberName(); var customerCreateOptions = new CustomerCreateOptions { - Address = address, + Address = new AddressOptions + { + Line1 = customerSetup.TaxInformation.Line1, + Line2 = customerSetup.TaxInformation.Line2, + City = customerSetup.TaxInformation.City, + PostalCode = customerSetup.TaxInformation.PostalCode, + State = customerSetup.TaxInformation.State, + Country = customerSetup.TaxInformation.Country, + }, Description = user.Name, Email = user.Email, Expand = ["tax"], @@ -113,10 +120,28 @@ private async Task CreateCustomerAsync( Tax = new CustomerTaxOptions { ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately - }, - TaxIdData = taxIdData + } }; + if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId)) + { + var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country, + customerSetup.TaxInformation.TaxId); + + if (taxIdType == null) + { + logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", + customerSetup.TaxInformation.Country, + customerSetup.TaxInformation.TaxId); + throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError"); + } + + customerCreateOptions.TaxIdData = + [ + new() { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId } + ]; + } + var (paymentMethodType, paymentMethodToken) = customerSetup.TokenizedPaymentSource; var braintreeCustomerId = ""; diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 9b8f64be82d7..6125d1541909 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -23,7 +23,8 @@ public class SubscriberService( IGlobalSettings globalSettings, ILogger logger, ISetupIntentCache setupIntentCache, - IStripeAdapter stripeAdapter) : ISubscriberService + IStripeAdapter stripeAdapter, + ITaxService taxService) : ISubscriberService { public async Task CancelSubscription( ISubscriber subscriber, @@ -618,16 +619,47 @@ public async Task UpdateTaxInformation( await stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); } - var taxIdType = taxInformation.GetTaxIdType(); + if (string.IsNullOrWhiteSpace(taxInformation.TaxId)) + { + return; + } - if (!string.IsNullOrWhiteSpace(taxInformation.TaxId) && - !string.IsNullOrWhiteSpace(taxIdType)) + var taxIdType = taxInformation.TaxIdType; + if (string.IsNullOrWhiteSpace(taxIdType)) { - await stripeAdapter.TaxIdCreateAsync(customer.Id, new TaxIdCreateOptions + taxIdType = taxService.GetStripeTaxCode(taxInformation.Country, + taxInformation.TaxId); + + if (taxIdType == null) { - Type = taxIdType, - Value = taxInformation.TaxId, - }); + logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", + taxInformation.Country, + taxInformation.TaxId); + throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError"); + } + } + + try + { + await stripeAdapter.TaxIdCreateAsync(customer.Id, + new TaxIdCreateOptions { Type = taxIdType, Value = taxInformation.TaxId }); + } + catch (StripeException e) + { + switch (e.StripeError.Code) + { + case StripeConstants.ErrorCodes.TaxIdInvalid: + logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", + taxInformation.TaxId, + taxInformation.Country); + throw new Exceptions.BadRequestException("billingInvalidTaxIdError"); + default: + logger.LogError(e, "Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.", + taxInformation.TaxId, + taxInformation.Country, + customer.Id); + throw new Exceptions.BadRequestException("billingTaxIdCreationError"); + } } } @@ -770,6 +802,7 @@ private static TaxInformation GetTaxInformation( customer.Address.Country, customer.Address.PostalCode, customer.TaxIds?.FirstOrDefault()?.Value, + customer.TaxIds?.FirstOrDefault()?.Type, customer.Address.Line1, customer.Address.Line2, customer.Address.City, diff --git a/src/Core/Billing/Services/TaxService.cs b/src/Core/Billing/Services/TaxService.cs new file mode 100644 index 000000000000..3066be92d1d8 --- /dev/null +++ b/src/Core/Billing/Services/TaxService.cs @@ -0,0 +1,901 @@ +using System.Text.RegularExpressions; +using Bit.Core.Billing.Models; + +namespace Bit.Core.Billing.Services; + +public class TaxService : ITaxService +{ + /// + /// Retrieves a list of supported tax ID types for customers. + /// + /// Compiled list from Stripe + private static readonly IEnumerable _taxIdTypes = + [ + new() + { + Country = "AD", + Code = "ad_nrt", + Description = "Andorran NRT number", + Example = "A-123456-Z", + ValidationExpression = new Regex("^([A-Z]{1})-?([0-9]{6})-?([A-Z]{1})$") + }, + new() + { + Country = "AR", + Code = "ar_cuit", + Description = "Argentinian tax ID number", + Example = "12-34567890-1", + ValidationExpression = new Regex("^([0-9]{2})-?([0-9]{8})-?([0-9]{1})$") + }, + new() + { + Country = "AU", + Code = "au_abn", + Description = "Australian Business Number (AU ABN)", + Example = "123456789012", + ValidationExpression = new Regex("^[0-9]{11}$") + }, + new() + { + Country = "AU", + Code = "au_arn", + Description = "Australian Taxation Office Reference Number", + Example = "123456789123", + ValidationExpression = new Regex("^[0-9]{12}$") + }, + new() + { + Country = "AT", + Code = "eu_vat", + Description = "European VAT number (Austria)", + Example = "ATU12345678", + ValidationExpression = new Regex("^ATU[0-9]{8}$") + }, + new() + { + Country = "BH", + Code = "bh_vat", + Description = "Bahraini VAT Number", + Example = "123456789012345", + ValidationExpression = new Regex("^[0-9]{15}$") + }, + new() + { + Country = "BY", + Code = "by_tin", + Description = "Belarus TIN Number", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "BE", + Code = "eu_vat", + Description = "European VAT number (Belgium)", + Example = "BE0123456789", + ValidationExpression = new Regex("^BE[0-9]{10}$") + }, + new() + { + Country = "BO", + Code = "bo_tin", + Description = "Bolivian tax ID", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "BR", + Code = "br_cnpj", + Description = "Brazilian CNPJ number", + Example = "01.234.456/5432-10", + ValidationExpression = new Regex("^[0-9]{2}.?[0-9]{3}.?[0-9]{3}/?[0-9]{4}-?[0-9]{2}$") + }, + new() + { + Country = "BR", + Code = "br_cpf", + Description = "Brazilian CPF number", + Example = "123.456.789-87", + ValidationExpression = new Regex("^[0-9]{3}.?[0-9]{3}.?[0-9]{3}-?[0-9]{2}$") + }, + new() + { + Country = "BG", + Code = "bg_uic", + Description = "Bulgaria Unified Identification Code", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "BG", + Code = "eu_vat", + Description = "European VAT number (Bulgaria)", + Example = "BG0123456789", + ValidationExpression = new Regex("^BG[0-9]{9,10}$") + }, + new() + { + Country = "CA", + Code = "ca_bn", + Description = "Canadian BN", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "CA", + Code = "ca_gst_hst", + Description = "Canadian GST/HST number", + Example = "123456789RT0002", + ValidationExpression = new Regex("^[0-9]{9}RT[0-9]{4}$") + }, + new() + { + Country = "CA", + Code = "ca_pst_bc", + Description = "Canadian PST number (British Columbia)", + Example = "PST-1234-5678", + ValidationExpression = new Regex("^PST-[0-9]{4}-[0-9]{4}$") + }, + new() + { + Country = "CA", + Code = "ca_pst_mb", + Description = "Canadian PST number (Manitoba)", + Example = "123456-7", + ValidationExpression = new Regex("^[0-9]{6}-[0-9]{1}$") + }, + new() + { + Country = "CA", + Code = "ca_pst_sk", + Description = "Canadian PST number (Saskatchewan)", + Example = "1234567", + ValidationExpression = new Regex("^[0-9]{7}$") + }, + new() + { + Country = "CA", + Code = "ca_qst", + Description = "Canadian QST number (Québec)", + Example = "1234567890TQ1234", + ValidationExpression = new Regex("^[0-9]{10}TQ[0-9]{4}$") + }, + new() + { + Country = "CL", + Code = "cl_tin", + Description = "Chilean TIN", + Example = "12.345.678-K", + ValidationExpression = new Regex("^[0-9]{2}.?[0-9]{3}.?[0-9]{3}-?[0-9A-Z]{1}$") + }, + new() + { + Country = "CN", + Code = "cn_tin", + Description = "Chinese tax ID", + Example = "123456789012345678", + ValidationExpression = new Regex("^[0-9]{15,18}$") + }, + new() + { + Country = "CO", + Code = "co_nit", + Description = "Colombian NIT number", + Example = "123.456.789-0", + ValidationExpression = new Regex("^[0-9]{3}.?[0-9]{3}.?[0-9]{3}-?[0-9]{1}$") + }, + new() + { + Country = "CR", + Code = "cr_tin", + Description = "Costa Rican tax ID", + Example = "1-234-567890", + ValidationExpression = new Regex("^[0-9]{1}-?[0-9]{3}-?[0-9]{6}$") + }, + new() + { + Country = "HR", + Code = "eu_vat", + Description = "European VAT number (Croatia)", + Example = "HR12345678912", + ValidationExpression = new Regex("^HR[0-9]{11}$") + }, + new() + { + Country = "HR", + Code = "hr_oib", + Description = "Croatian Personal Identification Number", + Example = "12345678901", + ValidationExpression = new Regex("^[0-9]{11}$") + }, + new() + { + Country = "CY", + Code = "eu_vat", + Description = "European VAT number (Cyprus)", + Example = "CY12345678X", + ValidationExpression = new Regex("^CY[0-9]{8}[A-Z]{1}$") + }, + new() + { + Country = "CZ", + Code = "eu_vat", + Description = "European VAT number (Czech Republic)", + Example = "CZ12345678", + ValidationExpression = new Regex("^CZ[0-9]{8,10}$") + }, + new() + { + Country = "DK", + Code = "eu_vat", + Description = "European VAT number (Denmark)", + Example = "DK12345678", + ValidationExpression = new Regex("^DK[0-9]{8}$") + }, + new() + { + Country = "DO", + Code = "do_rcn", + Description = "Dominican RCN number", + Example = "123-4567890-1", + ValidationExpression = new Regex("^[0-9]{3}-?[0-9]{7}-?[0-9]{1}$") + }, + new() + { + Country = "EC", + Code = "ec_ruc", + Description = "Ecuadorian RUC number", + Example = "1234567890001", + ValidationExpression = new Regex("^[0-9]{13}$") + }, + new() + { + Country = "EG", + Code = "eg_tin", + Description = "Egyptian Tax Identification Number", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + + new() + { + Country = "SV", + Code = "sv_nit", + Description = "El Salvadorian NIT number", + Example = "1234-567890-123-4", + ValidationExpression = new Regex("^[0-9]{4}-?[0-9]{6}-?[0-9]{3}-?[0-9]{1}$") + }, + + new() + { + Country = "EE", + Code = "eu_vat", + Description = "European VAT number (Estonia)", + Example = "EE123456789", + ValidationExpression = new Regex("^EE[0-9]{9}$") + }, + + new() + { + Country = "EU", + Code = "eu_oss_vat", + Description = "European One Stop Shop VAT number for non-Union scheme", + Example = "EU123456789", + ValidationExpression = new Regex("^EU[0-9]{9}$") + }, + new() + { + Country = "FI", + Code = "eu_vat", + Description = "European VAT number (Finland)", + Example = "FI12345678", + ValidationExpression = new Regex("^FI[0-9]{8}$") + }, + new() + { + Country = "FR", + Code = "eu_vat", + Description = "European VAT number (France)", + Example = "FR12345678901", + ValidationExpression = new Regex("^FR[0-9A-Z]{2}[0-9]{9}$") + }, + new() + { + Country = "GE", + Code = "ge_vat", + Description = "Georgian VAT", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "DE", + Code = "de_stn", + Description = "German Tax Number (Steuernummer)", + Example = "1234567890", + ValidationExpression = new Regex("^[0-9]{10}$") + }, + new() + { + Country = "DE", + Code = "eu_vat", + Description = "European VAT number (Germany)", + Example = "DE123456789", + ValidationExpression = new Regex("^DE[0-9]{9}$") + }, + new() + { + Country = "GR", + Code = "eu_vat", + Description = "European VAT number (Greece)", + Example = "EL123456789", + ValidationExpression = new Regex("^EL[0-9]{9}$") + }, + new() + { + Country = "HK", + Code = "hk_br", + Description = "Hong Kong BR number", + Example = "12345678", + ValidationExpression = new Regex("^[0-9]{8}$") + }, + new() + { + Country = "HU", + Code = "eu_vat", + Description = "European VAT number (Hungaria)", + Example = "HU12345678", + ValidationExpression = new Regex("^HU[0-9]{8}$") + }, + new() + { + Country = "HU", + Code = "hu_tin", + Description = "Hungary tax number (adószám)", + Example = "12345678-1-23", + ValidationExpression = new Regex("^[0-9]{8}-?[0-9]-?[0-9]{2}$") + }, + new() + { + Country = "IS", + Code = "is_vat", + Description = "Icelandic VAT", + Example = "123456", + ValidationExpression = new Regex("^[0-9]{6}$") + }, + new() + { + Country = "IN", + Code = "in_gst", + Description = "Indian GST number", + Example = "12ABCDE3456FGZH", + ValidationExpression = new Regex("^[0-9]{2}[A-Z]{5}[0-9]{4}[A-Z]{1}[1-9A-Z]{1}Z[0-9A-Z]{1}$") + }, + new() + { + Country = "ID", + Code = "id_npwp", + Description = "Indonesian NPWP number", + Example = "012.345.678.9-012.345", + ValidationExpression = new Regex("^[0-9]{3}.?[0-9]{3}.?[0-9]{3}.?[0-9]{1}-?[0-9]{3}.?[0-9]{3}$") + }, + new() + { + Country = "IE", + Code = "eu_vat", + Description = "European VAT number (Ireland)", + Example = "IE1234567AB", + ValidationExpression = new Regex("^IE[0-9]{7}[A-Z]{1,2}$") + }, + new() + { + Country = "IL", + Code = "il_vat", + Description = "Israel VAT", + Example = "000012345", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "IT", + Code = "eu_vat", + Description = "European VAT number (Italy)", + Example = "IT12345678912", + ValidationExpression = new Regex("^IT[0-9]{11}$") + }, + new() + { + Country = "JP", + Code = "jp_cn", + Description = "Japanese Corporate Number (*Hōjin Bangō*)", + Example = "1234567891234", + ValidationExpression = new Regex("^[0-9]{13}$") + }, + new() + { + Country = "JP", + Code = "jp_rn", + Description = + "Japanese Registered Foreign Businesses' Registration Number (*Tōroku Kokugai Jigyōsha no Tōroku Bangō*)", + Example = "12345", + ValidationExpression = new Regex("^[0-9]{5}$") + }, + new() + { + Country = "JP", + Code = "jp_trn", + Description = "Japanese Tax Registration Number (*Tōroku Bangō*)", + Example = "T1234567891234", + ValidationExpression = new Regex("^T[0-9]{13}$") + }, + new() + { + Country = "KZ", + Code = "kz_bin", + Description = "Kazakhstani Business Identification Number", + Example = "123456789012", + ValidationExpression = new Regex("^[0-9]{12}$") + }, + new() + { + Country = "KE", + Code = "ke_pin", + Description = "Kenya Revenue Authority Personal Identification Number", + Example = "P000111111A", + ValidationExpression = new Regex("^[A-Z]{1}[0-9]{9}[A-Z]{1}$") + }, + new() + { + Country = "LV", + Code = "eu_vat", + Description = "European VAT number", + Example = "LV12345678912", + ValidationExpression = new Regex("^LV[0-9]{11}$") + }, + new() + { + Country = "LI", + Code = "li_uid", + Description = "Liechtensteinian UID number", + Example = "CHE123456789", + ValidationExpression = new Regex("^CHE[0-9]{9}$") + }, + new() + { + Country = "LI", + Code = "li_vat", + Description = "Liechtensteinian VAT number", + Example = "12345", + ValidationExpression = new Regex("^[0-9]{5}$") + }, + new() + { + Country = "LT", + Code = "eu_vat", + Description = "European VAT number (Lithuania)", + Example = "LT123456789123", + ValidationExpression = new Regex("^LT[0-9]{9,12}$") + }, + new() + { + Country = "LU", + Code = "eu_vat", + Description = "European VAT number (Luxembourg)", + Example = "LU12345678", + ValidationExpression = new Regex("^LU[0-9]{8}$") + }, + new() + { + Country = "MY", + Code = "my_frp", + Description = "Malaysian FRP number", + Example = "12345678", + ValidationExpression = new Regex("^[0-9]{8}$") + }, + new() + { + Country = "MY", + Code = "my_itn", + Description = "Malaysian ITN", + Example = "C 1234567890", + ValidationExpression = new Regex("^[A-Z]{1} ?[0-9]{10}$") + }, + new() + { + Country = "MY", + Code = "my_sst", + Description = "Malaysian SST number", + Example = "A12-3456-78912345", + ValidationExpression = new Regex("^[A-Z]{1}[0-9]{2}-?[0-9]{4}-?[0-9]{8}$") + }, + new() + { + Country = "MT", + Code = "eu_vat", + Description = "European VAT number (Malta)", + Example = "MT12345678", + ValidationExpression = new Regex("^MT[0-9]{8}$") + }, + new() + { + Country = "MX", + Code = "mx_rfc", + Description = "Mexican RFC number", + Example = "ABC010203AB9", + ValidationExpression = new Regex("^[A-Z]{3}[0-9]{6}[A-Z0-9]{3}$") + }, + new() + { + Country = "MD", + Code = "md_vat", + Description = "Moldova VAT Number", + Example = "1234567", + ValidationExpression = new Regex("^[0-9]{7}$") + }, + new() + { + Country = "MA", + Code = "ma_vat", + Description = "Morocco VAT Number", + Example = "12345678", + ValidationExpression = new Regex("^[0-9]{8}$") + }, + new() + { + Country = "NL", + Code = "eu_vat", + Description = "European VAT number (Netherlands)", + Example = "NL123456789B12", + ValidationExpression = new Regex("^NL[0-9]{9}B[0-9]{2}$") + }, + new() + { + Country = "NZ", + Code = "nz_gst", + Description = "New Zealand GST number", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "NG", + Code = "ng_tin", + Description = "Nigerian TIN Number", + Example = "12345678-0001", + ValidationExpression = new Regex("^[0-9]{8}-[0-9]{4}$") + }, + new() + { + Country = "NO", + Code = "no_vat", + Description = "Norwegian VAT number", + Example = "123456789MVA", + ValidationExpression = new Regex("^[0-9]{9}MVA$") + }, + new() + { + Country = "NO", + Code = "no_voec", + Description = "Norwegian VAT on e-commerce number", + Example = "1234567", + ValidationExpression = new Regex("^[0-9]{7}$") + }, + new() + { + Country = "OM", + Code = "om_vat", + Description = "Omani VAT Number", + Example = "OM1234567890", + ValidationExpression = new Regex("^OM[0-9]{10}$") + }, + new() + { + Country = "PE", + Code = "pe_ruc", + Description = "Peruvian RUC number", + Example = "12345678901", + ValidationExpression = new Regex("^[0-9]{11}$") + }, + new() + { + Country = "PH", + Code = "ph_tin", + Description = "Philippines Tax Identification Number", + Example = "123456789012", + ValidationExpression = new Regex("^[0-9]{12}$") + }, + new() + { + Country = "PL", + Code = "eu_vat", + Description = "European VAT number (Poland)", + Example = "PL1234567890", + ValidationExpression = new Regex("^PL[0-9]{10}$") + }, + new() + { + Country = "PT", + Code = "eu_vat", + Description = "European VAT number (Portugal)", + Example = "PT123456789", + ValidationExpression = new Regex("^PT[0-9]{9}$") + }, + new() + { + Country = "RO", + Code = "eu_vat", + Description = "European VAT number (Romania)", + Example = "RO1234567891", + ValidationExpression = new Regex("^RO[0-9]{2,10}$") + }, + new() + { + Country = "RO", + Code = "ro_tin", + Description = "Romanian tax ID number", + Example = "1234567890123", + ValidationExpression = new Regex("^[0-9]{13}$") + }, + new() + { + Country = "RU", + Code = "ru_inn", + Description = "Russian INN", + Example = "1234567891", + ValidationExpression = new Regex("^[0-9]{10,12}$") + }, + new() + { + Country = "RU", + Code = "ru_kpp", + Description = "Russian KPP", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "SA", + Code = "sa_vat", + Description = "Saudi Arabia VAT", + Example = "123456789012345", + ValidationExpression = new Regex("^[0-9]{15}$") + }, + new() + { + Country = "RS", + Code = "rs_pib", + Description = "Serbian PIB number", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "SG", + Code = "sg_gst", + Description = "Singaporean GST", + Example = "M12345678X", + ValidationExpression = new Regex("^[A-Z]{1}[0-9]{8}[A-Z]{1}$") + }, + new() + { + Country = "SG", + Code = "sg_uen", + Description = "Singaporean UEN", + Example = "123456789F", + ValidationExpression = new Regex("^[0-9]{9}[A-Z]{1}$") + }, + new() + { + Country = "SK", + Code = "eu_vat", + Description = "European VAT number (Slovakia)", + Example = "SK1234567891", + ValidationExpression = new Regex("^SK[0-9]{10}$") + }, + new() + { + Country = "SI", + Code = "eu_vat", + Description = "European VAT number (Slovenia)", + Example = "SI12345678", + ValidationExpression = new Regex("^SI[0-9]{8}$") + }, + new() + { + Country = "SI", + Code = "si_tin", + Description = "Slovenia tax number (davčna številka)", + Example = "12345678", + ValidationExpression = new Regex("^[0-9]{8}$") + }, + new() + { + Country = "ZA", + Code = "za_vat", + Description = "South African VAT number", + Example = "4123456789", + ValidationExpression = new Regex("^[0-9]{10}$") + }, + new() + { + Country = "KR", + Code = "kr_brn", + Description = "Korean BRN", + Example = "123-45-67890", + ValidationExpression = new Regex("^[0-9]{3}-?[0-9]{2}-?[0-9]{5}$") + }, + new() + { + Country = "ES", + Code = "es_cif", + Description = "Spanish NIF/CIF number", + Example = "A12345678", + ValidationExpression = new Regex("^[A-Z]{1}[0-9]{8}$") + }, + new() + { + Country = "ES", + Code = "eu_vat", + Description = "European VAT number (Spain)", + Example = "ESA1234567Z", + ValidationExpression = new Regex("^ES[A-Z]{1}[0-9]{7}[A-Z]{1}$") + }, + new() + { + Country = "SE", + Code = "eu_vat", + Description = "European VAT number (Sweden)", + Example = "SE123456789123", + ValidationExpression = new Regex("^SE[0-9]{12}$") + }, + new() + { + Country = "CH", + Code = "ch_uid", + Description = "Switzerland UID number", + Example = "CHE-123.456.789 HR", + ValidationExpression = new Regex("^CHE-?[0-9]{3}.?[0-9]{3}.?[0-9]{3} ?HR$") + }, + new() + { + Country = "CH", + Code = "ch_vat", + Description = "Switzerland VAT number", + Example = "CHE-123.456.789 MWST", + ValidationExpression = new Regex("^CHE-?[0-9]{3}.?[0-9]{3}.?[0-9]{3} ?MWST$") + }, + new() + { + Country = "TW", + Code = "tw_vat", + Description = "Taiwanese VAT", + Example = "12345678", + ValidationExpression = new Regex("^[0-9]{8}$") + }, + new() + { + Country = "TZ", + Code = "tz_vat", + Description = "Tanzania VAT Number", + Example = "12345678A", + ValidationExpression = new Regex("^[0-9]{8}[A-Z]{1}$") + }, + new() + { + Country = "TH", + Code = "th_vat", + Description = "Thai VAT", + Example = "1234567891234", + ValidationExpression = new Regex("^[0-9]{13}$") + }, + new() + { + Country = "TR", + Code = "tr_tin", + Description = "Turkish TIN Number", + Example = "0123456789", + ValidationExpression = new Regex("^[0-9]{10}$") + }, + new() + { + Country = "UA", + Code = "ua_vat", + Description = "Ukrainian VAT", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "AE", + Code = "ae_trn", + Description = "United Arab Emirates TRN", + Example = "123456789012345", + ValidationExpression = new Regex("^[0-9]{15}$") + }, + new() + { + Country = "GB", + Code = "eu_vat", + Description = "Northern Ireland VAT number", + Example = "XI123456789", + ValidationExpression = new Regex("^XI[0-9]{9}$") + }, + new() + { + Country = "GB", + Code = "gb_vat", + Description = "United Kingdom VAT number", + Example = "GB123456789", + ValidationExpression = new Regex("^GB[0-9]{9}$") + }, + new() + { + Country = "US", + Code = "us_ein", + Description = "United States EIN", + Example = "12-3456789", + ValidationExpression = new Regex("^[0-9]{2}-?[0-9]{7}$") + }, + new() + { + Country = "UY", + Code = "uy_ruc", + Description = "Uruguayan RUC number", + Example = "123456789012", + ValidationExpression = new Regex("^[0-9]{12}$") + }, + new() + { + Country = "UZ", + Code = "uz_tin", + Description = "Uzbekistan TIN Number", + Example = "123456789", + ValidationExpression = new Regex("^[0-9]{9}$") + }, + new() + { + Country = "UZ", + Code = "uz_vat", + Description = "Uzbekistan VAT Number", + Example = "123456789012", + ValidationExpression = new Regex("^[0-9]{12}$") + }, + new() + { + Country = "VE", + Code = "ve_rif", + Description = "Venezuelan RIF number", + Example = "A-12345678-9", + ValidationExpression = new Regex("^[A-Z]{1}-?[0-9]{8}-?[0-9]{1}$") + }, + new() + { + Country = "VN", + Code = "vn_tin", + Description = "Vietnamese tax ID number", + Example = "1234567890", + ValidationExpression = new Regex("^[0-9]{10}$") + } + ]; + + public string GetStripeTaxCode(string country, string taxId) + { + foreach (var taxIdType in _taxIdTypes.Where(x => x.Country == country)) + { + if (taxIdType.ValidationExpression.IsMatch(taxId)) + { + return taxIdType.Code; + } + } + + return null; + } + + public bool IsSupported(string country) + { + return _taxIdTypes.Any(x => x.Country == country); + } +} diff --git a/src/Core/Billing/Utilities.cs b/src/Core/Billing/Utilities.cs index 28527af0c038..695a3b1bb485 100644 --- a/src/Core/Billing/Utilities.cs +++ b/src/Core/Billing/Utilities.cs @@ -83,6 +83,7 @@ public static TaxInformation GetTaxInformation(Customer customer) customer.Address.Country, customer.Address.PostalCode, customer.TaxIds?.FirstOrDefault()?.Value, + customer.TaxIds?.FirstOrDefault()?.Type, customer.Address.Line1, customer.Address.Line2, customer.Address.City, diff --git a/src/Core/Models/Business/TaxInfo.cs b/src/Core/Models/Business/TaxInfo.cs index 4424576ec91f..b12c5229b3e6 100644 --- a/src/Core/Models/Business/TaxInfo.cs +++ b/src/Core/Models/Business/TaxInfo.cs @@ -2,18 +2,9 @@ public class TaxInfo { - private string _taxIdNumber = null; - private string _taxIdType = null; + public string TaxIdNumber { get; set; } + public string TaxIdType { get; set; } - public string TaxIdNumber - { - get => _taxIdNumber; - set - { - _taxIdNumber = value; - _taxIdType = null; - } - } public string StripeTaxRateId { get; set; } public string BillingAddressLine1 { get; set; } public string BillingAddressLine2 { get; set; } @@ -21,201 +12,6 @@ public string TaxIdNumber public string BillingAddressState { get; set; } public string BillingAddressPostalCode { get; set; } public string BillingAddressCountry { get; set; } = "US"; - public string TaxIdType - { - get - { - if (string.IsNullOrWhiteSpace(BillingAddressCountry) || - string.IsNullOrWhiteSpace(TaxIdNumber)) - { - return null; - } - if (!string.IsNullOrWhiteSpace(_taxIdType)) - { - return _taxIdType; - } - - switch (BillingAddressCountry.ToUpper()) - { - case "AD": - _taxIdType = "ad_nrt"; - break; - case "AE": - _taxIdType = "ae_trn"; - break; - case "AR": - _taxIdType = "ar_cuit"; - break; - case "AU": - _taxIdType = "au_abn"; - break; - case "BO": - _taxIdType = "bo_tin"; - break; - case "BR": - _taxIdType = "br_cnpj"; - break; - case "CA": - // May break for those in Québec given the assumption of QST - if (BillingAddressState?.Contains("bec") ?? false) - { - _taxIdType = "ca_qst"; - break; - } - _taxIdType = "ca_bn"; - break; - case "CH": - _taxIdType = "ch_vat"; - break; - case "CL": - _taxIdType = "cl_tin"; - break; - case "CN": - _taxIdType = "cn_tin"; - break; - case "CO": - _taxIdType = "co_nit"; - break; - case "CR": - _taxIdType = "cr_tin"; - break; - case "DO": - _taxIdType = "do_rcn"; - break; - case "EC": - _taxIdType = "ec_ruc"; - break; - case "EG": - _taxIdType = "eg_tin"; - break; - case "GE": - _taxIdType = "ge_vat"; - break; - case "ID": - _taxIdType = "id_npwp"; - break; - case "IL": - _taxIdType = "il_vat"; - break; - case "IS": - _taxIdType = "is_vat"; - break; - case "KE": - _taxIdType = "ke_pin"; - break; - case "AT": - case "BE": - case "BG": - case "CY": - case "CZ": - case "DE": - case "DK": - case "EE": - case "ES": - case "FI": - case "FR": - case "GB": - case "GR": - case "HR": - case "HU": - case "IE": - case "IT": - case "LT": - case "LU": - case "LV": - case "MT": - case "NL": - case "PL": - case "PT": - case "RO": - case "SE": - case "SI": - case "SK": - _taxIdType = "eu_vat"; - break; - case "HK": - _taxIdType = "hk_br"; - break; - case "IN": - _taxIdType = "in_gst"; - break; - case "JP": - _taxIdType = "jp_cn"; - break; - case "KR": - _taxIdType = "kr_brn"; - break; - case "LI": - _taxIdType = "li_uid"; - break; - case "MX": - _taxIdType = "mx_rfc"; - break; - case "MY": - _taxIdType = "my_sst"; - break; - case "NO": - _taxIdType = "no_vat"; - break; - case "NZ": - _taxIdType = "nz_gst"; - break; - case "PE": - _taxIdType = "pe_ruc"; - break; - case "PH": - _taxIdType = "ph_tin"; - break; - case "RS": - _taxIdType = "rs_pib"; - break; - case "RU": - _taxIdType = "ru_inn"; - break; - case "SA": - _taxIdType = "sa_vat"; - break; - case "SG": - _taxIdType = "sg_gst"; - break; - case "SV": - _taxIdType = "sv_nit"; - break; - case "TH": - _taxIdType = "th_vat"; - break; - case "TR": - _taxIdType = "tr_tin"; - break; - case "TW": - _taxIdType = "tw_vat"; - break; - case "UA": - _taxIdType = "ua_vat"; - break; - case "US": - _taxIdType = "us_ein"; - break; - case "UY": - _taxIdType = "uy_ruc"; - break; - case "VE": - _taxIdType = "ve_rif"; - break; - case "VN": - _taxIdType = "vn_tin"; - break; - case "ZA": - _taxIdType = "za_vat"; - break; - default: - _taxIdType = null; - break; - } - - return _taxIdType; - } - } public bool HasTaxId { diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index bf9d047029b5..7d0f9d3c630b 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -1,6 +1,9 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Models.Api.Requests.Accounts; +using Bit.Core.Billing.Models.Api.Requests.Organizations; +using Bit.Core.Billing.Models.Api.Responses; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; @@ -59,4 +62,7 @@ Task AddSecretsManagerToSubscription(Organization org, Plan plan, int ad Task RisksSubscriptionFailure(Organization organization); Task HasSecretsManagerStandalone(Organization organization); Task<(DateTime?, DateTime?)> GetSuspensionDateAsync(Stripe.Subscription subscription); + Task PreviewInvoiceAsync(PreviewIndividualInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId); + Task PreviewInvoiceAsync(PreviewOrganizationInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId); + } diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs index 30583ef0b3e2..ef2e3ab76655 100644 --- a/src/Core/Services/IStripeAdapter.cs +++ b/src/Core/Services/IStripeAdapter.cs @@ -31,6 +31,7 @@ Task CustomerBalanceTransactionCreate(string custome Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options); Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options); Task> InvoiceListAsync(StripeInvoiceListOptions options); + Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options); Task> InvoiceSearchAsync(InvoiceSearchOptions options); Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options); Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options); @@ -42,6 +43,7 @@ Task CustomerBalanceTransactionCreate(string custome IAsyncEnumerable PaymentMethodListAutoPagingAsync(Stripe.PaymentMethodListOptions options); Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null); Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null); + Task PlanGetAsync(string id, Stripe.PlanGetOptions options = null); Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options); Task TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options); Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options); diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs index 8d183314567f..f4f8efe75fb1 100644 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Services/Implementations/StripeAdapter.cs @@ -15,6 +15,7 @@ public class StripeAdapter : IStripeAdapter private readonly Stripe.RefundService _refundService; private readonly Stripe.CardService _cardService; private readonly Stripe.BankAccountService _bankAccountService; + private readonly Stripe.PlanService _planService; private readonly Stripe.PriceService _priceService; private readonly Stripe.SetupIntentService _setupIntentService; private readonly Stripe.TestHelpers.TestClockService _testClockService; @@ -33,6 +34,7 @@ public StripeAdapter() _cardService = new Stripe.CardService(); _bankAccountService = new Stripe.BankAccountService(); _priceService = new Stripe.PriceService(); + _planService = new Stripe.PlanService(); _setupIntentService = new SetupIntentService(); _testClockService = new Stripe.TestHelpers.TestClockService(); _customerBalanceTransactionService = new CustomerBalanceTransactionService(); @@ -133,6 +135,11 @@ public async Task ProviderSubscriptionGetAsync( return invoices; } + public Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options) + { + return _invoiceService.CreatePreviewAsync(options); + } + public async Task> InvoiceSearchAsync(InvoiceSearchOptions options) => (await _invoiceService.SearchAsync(options)).Data; @@ -184,6 +191,11 @@ public async Task ProviderSubscriptionGetAsync( return _paymentMethodService.DetachAsync(id, options); } + public Task PlanGetAsync(string id, Stripe.PlanGetOptions options = null) + { + return _planService.GetAsync(id, options); + } + public Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options) { return _taxRateService.CreateAsync(options); diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index 259a4eb7572a..c32dcd43ade8 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1,8 +1,13 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Constants; +using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; +using Bit.Core.Billing.Models.Api.Requests.Accounts; +using Bit.Core.Billing.Models.Api.Requests.Organizations; +using Bit.Core.Billing.Models.Api.Responses; using Bit.Core.Billing.Models.Business; +using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -32,6 +37,7 @@ public class StripePaymentService : IPaymentService private readonly IStripeAdapter _stripeAdapter; private readonly IGlobalSettings _globalSettings; private readonly IFeatureService _featureService; + private readonly ITaxService _taxService; public StripePaymentService( ITransactionRepository transactionRepository, @@ -40,7 +46,8 @@ public StripePaymentService( IStripeAdapter stripeAdapter, Braintree.IBraintreeGateway braintreeGateway, IGlobalSettings globalSettings, - IFeatureService featureService) + IFeatureService featureService, + ITaxService taxService) { _transactionRepository = transactionRepository; _logger = logger; @@ -49,6 +56,7 @@ public StripePaymentService( _btGateway = braintreeGateway; _globalSettings = globalSettings; _featureService = featureService; + _taxService = taxService; } public async Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, @@ -112,6 +120,20 @@ public async Task PurchaseOrganizationAsync(Organization org, PaymentMet Subscription subscription; try { + if (taxInfo.TaxIdNumber != null && taxInfo.TaxIdType == null) + { + taxInfo.TaxIdType = _taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, + taxInfo.TaxIdNumber); + + if (taxInfo.TaxIdType == null) + { + _logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", + taxInfo.BillingAddressCountry, + taxInfo.TaxIdNumber); + throw new BadRequestException("billingTaxIdTypeInferenceError"); + } + } + var customerCreateOptions = new CustomerCreateOptions { Description = org.DisplayBusinessName(), @@ -146,12 +168,9 @@ public async Task PurchaseOrganizationAsync(Organization org, PaymentMet City = taxInfo?.BillingAddressCity, State = taxInfo?.BillingAddressState, }, - TaxIdData = taxInfo?.HasTaxId != true - ? null - : - [ - new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber, } - ], + TaxIdData = taxInfo.HasTaxId + ? [new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber }] + : null }; customerCreateOptions.AddExpand("tax"); @@ -1659,6 +1678,7 @@ public async Task GetTaxInfoAsync(ISubscriber subscriber) return new TaxInfo { TaxIdNumber = taxId?.Value, + TaxIdType = taxId?.Type, BillingAddressLine1 = address?.Line1, BillingAddressLine2 = address?.Line2, BillingAddressCity = address?.City, @@ -1670,9 +1690,13 @@ public async Task GetTaxInfoAsync(ISubscriber subscriber) public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) { - if (subscriber != null && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) + if (string.IsNullOrWhiteSpace(subscriber?.GatewayCustomerId) || subscriber.IsUser()) { - var customer = await _stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions + return; + } + + var customer = await _stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, + new CustomerUpdateOptions { Address = new AddressOptions { @@ -1686,23 +1710,59 @@ public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) Expand = ["tax_ids"] }); - if (!subscriber.IsUser() && customer != null) + if (customer == null) + { + return; + } + + var taxId = customer.TaxIds?.FirstOrDefault(); + + if (taxId != null) + { + await _stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); + } + + if (string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber)) + { + return; + } + + var taxIdType = taxInfo.TaxIdType; + + if (string.IsNullOrWhiteSpace(taxIdType)) + { + taxIdType = _taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber); + + if (taxIdType == null) { - var taxId = customer.TaxIds?.FirstOrDefault(); + _logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", + taxInfo.BillingAddressCountry, + taxInfo.TaxIdNumber); + throw new BadRequestException("billingTaxIdTypeInferenceError"); + } + } - if (taxId != null) - { - await _stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); - } - if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber) && - !string.IsNullOrWhiteSpace(taxInfo.TaxIdType)) - { - await _stripeAdapter.TaxIdCreateAsync(customer.Id, new TaxIdCreateOptions - { - Type = taxInfo.TaxIdType, - Value = taxInfo.TaxIdNumber, - }); - } + try + { + await _stripeAdapter.TaxIdCreateAsync(customer.Id, + new TaxIdCreateOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber, }); + } + catch (StripeException e) + { + switch (e.StripeError.Code) + { + case StripeConstants.ErrorCodes.TaxIdInvalid: + _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", + taxInfo.TaxIdNumber, + taxInfo.BillingAddressCountry); + throw new BadRequestException("billingInvalidTaxIdError"); + default: + _logger.LogError(e, + "Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.", + taxInfo.TaxIdNumber, + taxInfo.BillingAddressCountry, + customer.Id); + throw new BadRequestException("billingTaxIdCreationError"); } } } @@ -1835,6 +1895,285 @@ public async Task HasSecretsManagerStandalone(Organization organization) } } + public async Task PreviewInvoiceAsync( + PreviewIndividualInvoiceRequestBody parameters, + string gatewayCustomerId, + string gatewaySubscriptionId) + { + var options = new InvoiceCreatePreviewOptions + { + AutomaticTax = new InvoiceAutomaticTaxOptions + { + Enabled = true, + }, + Currency = "usd", + Discounts = new List(), + SubscriptionDetails = new InvoiceSubscriptionDetailsOptions + { + Items = + [ + new() + { + Quantity = 1, + Plan = "premium-annually" + }, + + new() + { + Quantity = parameters.PasswordManager.AdditionalStorage, + Plan = "storage-gb-annually" + } + ] + }, + CustomerDetails = new InvoiceCustomerDetailsOptions + { + Address = new AddressOptions + { + PostalCode = parameters.TaxInformation.PostalCode, + Country = parameters.TaxInformation.Country, + } + }, + }; + + if (!string.IsNullOrEmpty(parameters.TaxInformation.TaxId)) + { + var taxIdType = _taxService.GetStripeTaxCode( + options.CustomerDetails.Address.Country, + parameters.TaxInformation.TaxId); + + if (taxIdType == null) + { + _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", + parameters.TaxInformation.TaxId, + parameters.TaxInformation.Country); + throw new BadRequestException("billingPreviewInvalidTaxIdError"); + } + + options.CustomerDetails.TaxIds = [ + new InvoiceCustomerDetailsTaxIdOptions + { + Type = taxIdType, + Value = parameters.TaxInformation.TaxId + } + ]; + } + + if (gatewayCustomerId != null) + { + var gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + + if (gatewayCustomer.Discount != null) + { + options.Discounts.Add(new InvoiceDiscountOptions + { + Discount = gatewayCustomer.Discount.Id + }); + } + + if (gatewaySubscriptionId != null) + { + var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId); + + if (gatewaySubscription?.Discount != null) + { + options.Discounts.Add(new InvoiceDiscountOptions + { + Discount = gatewaySubscription.Discount.Id + }); + } + } + } + + try + { + var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); + + var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null + ? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() + : 0M; + + var result = new PreviewInvoiceResponseModel( + effectiveTaxRate, + invoice.TotalExcludingTax.ToMajor() ?? 0, + invoice.Tax.ToMajor() ?? 0, + invoice.Total.ToMajor()); + return result; + } + catch (StripeException e) + { + switch (e.StripeError.Code) + { + case StripeConstants.ErrorCodes.TaxIdInvalid: + _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", + parameters.TaxInformation.TaxId, + parameters.TaxInformation.Country); + throw new BadRequestException("billingPreviewInvalidTaxIdError"); + default: + _logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", + parameters.TaxInformation.TaxId, + parameters.TaxInformation.Country); + throw new BadRequestException("billingPreviewInvoiceError"); + } + } + } + + public async Task PreviewInvoiceAsync( + PreviewOrganizationInvoiceRequestBody parameters, + string gatewayCustomerId, + string gatewaySubscriptionId) + { + var plan = Utilities.StaticStore.GetPlan(parameters.PasswordManager.Plan); + + var options = new InvoiceCreatePreviewOptions + { + AutomaticTax = new InvoiceAutomaticTaxOptions + { + Enabled = true, + }, + Currency = "usd", + Discounts = new List(), + SubscriptionDetails = new InvoiceSubscriptionDetailsOptions + { + Items = + [ + new() + { + Quantity = parameters.PasswordManager.AdditionalStorage, + Plan = plan.PasswordManager.StripeStoragePlanId + } + ] + }, + CustomerDetails = new InvoiceCustomerDetailsOptions + { + Address = new AddressOptions + { + PostalCode = parameters.TaxInformation.PostalCode, + Country = parameters.TaxInformation.Country, + } + }, + }; + + if (plan.PasswordManager.HasAdditionalSeatsOption) + { + options.SubscriptionDetails.Items.Add( + new() + { + Quantity = parameters.PasswordManager.Seats, + Plan = plan.PasswordManager.StripeSeatPlanId + } + ); + } + else + { + options.SubscriptionDetails.Items.Add( + new() + { + Quantity = 1, + Plan = plan.PasswordManager.StripePlanId + } + ); + } + + if (plan.SupportsSecretsManager) + { + if (plan.SecretsManager.HasAdditionalSeatsOption) + { + options.SubscriptionDetails.Items.Add(new() + { + Quantity = parameters.SecretsManager?.Seats ?? 0, + Plan = plan.SecretsManager.StripeSeatPlanId + }); + } + + if (plan.SecretsManager.HasAdditionalServiceAccountOption) + { + options.SubscriptionDetails.Items.Add(new() + { + Quantity = parameters.SecretsManager?.AdditionalMachineAccounts ?? 0, + Plan = plan.SecretsManager.StripeServiceAccountPlanId + }); + } + } + + if (!string.IsNullOrEmpty(parameters.TaxInformation.TaxId)) + { + var taxIdType = _taxService.GetStripeTaxCode( + options.CustomerDetails.Address.Country, + parameters.TaxInformation.TaxId); + + if (taxIdType == null) + { + _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", + parameters.TaxInformation.TaxId, + parameters.TaxInformation.Country); + throw new BadRequestException("billingTaxIdTypeInferenceError"); + } + + options.CustomerDetails.TaxIds = [ + new InvoiceCustomerDetailsTaxIdOptions + { + Type = taxIdType, + Value = parameters.TaxInformation.TaxId + } + ]; + } + + if (gatewayCustomerId != null) + { + var gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); + + if (gatewayCustomer.Discount != null) + { + options.Discounts.Add(new InvoiceDiscountOptions + { + Discount = gatewayCustomer.Discount.Id + }); + } + + var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId); + + if (gatewaySubscription?.Discount != null) + { + options.Discounts.Add(new InvoiceDiscountOptions + { + Discount = gatewaySubscription.Discount.Id + }); + } + } + + try + { + var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); + + var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null + ? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() + : 0M; + + var result = new PreviewInvoiceResponseModel( + effectiveTaxRate, + invoice.TotalExcludingTax.ToMajor() ?? 0, + invoice.Tax.ToMajor() ?? 0, + invoice.Total.ToMajor()); + return result; + } + catch (StripeException e) + { + switch (e.StripeError.Code) + { + case StripeConstants.ErrorCodes.TaxIdInvalid: + _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", + parameters.TaxInformation.TaxId, + parameters.TaxInformation.Country); + throw new BadRequestException("billingPreviewInvalidTaxIdError"); + default: + _logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", + parameters.TaxInformation.TaxId, + parameters.TaxInformation.Country); + throw new BadRequestException("billingPreviewInvoiceError"); + } + } + } + private PaymentMethod GetLatestCardPaymentMethod(string customerId) { var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging( diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 385b185ffe1b..9c25ffdc5574 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -1545,7 +1545,7 @@ public async Task UpdateTaxInformation_NonUser_MakesCorrectInvocations( { var stripeAdapter = sutProvider.GetDependency(); - var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1" }] } }; + var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } }; stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("tax_ids"))).Returns(customer); @@ -1554,6 +1554,7 @@ public async Task UpdateTaxInformation_NonUser_MakesCorrectInvocations( "US", "12345", "123456789", + "us_ein", "123 Example St.", null, "Example Town", diff --git a/test/Core.Test/Models/Business/TaxInfoTests.cs b/test/Core.Test/Models/Business/TaxInfoTests.cs deleted file mode 100644 index 197948006e97..000000000000 --- a/test/Core.Test/Models/Business/TaxInfoTests.cs +++ /dev/null @@ -1,114 +0,0 @@ -using Bit.Core.Models.Business; -using Xunit; - -namespace Bit.Core.Test.Models.Business; - -public class TaxInfoTests -{ - // PH = Placeholder - [Theory] - [InlineData(null, null, null, null)] - [InlineData("", "", null, null)] - [InlineData("PH", "", null, null)] - [InlineData("", "PH", null, null)] - [InlineData("AE", "PH", null, "ae_trn")] - [InlineData("AU", "PH", null, "au_abn")] - [InlineData("BR", "PH", null, "br_cnpj")] - [InlineData("CA", "PH", "bec", "ca_qst")] - [InlineData("CA", "PH", null, "ca_bn")] - [InlineData("CL", "PH", null, "cl_tin")] - [InlineData("AT", "PH", null, "eu_vat")] - [InlineData("BE", "PH", null, "eu_vat")] - [InlineData("BG", "PH", null, "eu_vat")] - [InlineData("CY", "PH", null, "eu_vat")] - [InlineData("CZ", "PH", null, "eu_vat")] - [InlineData("DE", "PH", null, "eu_vat")] - [InlineData("DK", "PH", null, "eu_vat")] - [InlineData("EE", "PH", null, "eu_vat")] - [InlineData("ES", "PH", null, "eu_vat")] - [InlineData("FI", "PH", null, "eu_vat")] - [InlineData("FR", "PH", null, "eu_vat")] - [InlineData("GB", "PH", null, "eu_vat")] - [InlineData("GR", "PH", null, "eu_vat")] - [InlineData("HR", "PH", null, "eu_vat")] - [InlineData("HU", "PH", null, "eu_vat")] - [InlineData("IE", "PH", null, "eu_vat")] - [InlineData("IT", "PH", null, "eu_vat")] - [InlineData("LT", "PH", null, "eu_vat")] - [InlineData("LU", "PH", null, "eu_vat")] - [InlineData("LV", "PH", null, "eu_vat")] - [InlineData("MT", "PH", null, "eu_vat")] - [InlineData("NL", "PH", null, "eu_vat")] - [InlineData("PL", "PH", null, "eu_vat")] - [InlineData("PT", "PH", null, "eu_vat")] - [InlineData("RO", "PH", null, "eu_vat")] - [InlineData("SE", "PH", null, "eu_vat")] - [InlineData("SI", "PH", null, "eu_vat")] - [InlineData("SK", "PH", null, "eu_vat")] - [InlineData("HK", "PH", null, "hk_br")] - [InlineData("IN", "PH", null, "in_gst")] - [InlineData("JP", "PH", null, "jp_cn")] - [InlineData("KR", "PH", null, "kr_brn")] - [InlineData("LI", "PH", null, "li_uid")] - [InlineData("MX", "PH", null, "mx_rfc")] - [InlineData("MY", "PH", null, "my_sst")] - [InlineData("NO", "PH", null, "no_vat")] - [InlineData("NZ", "PH", null, "nz_gst")] - [InlineData("RU", "PH", null, "ru_inn")] - [InlineData("SA", "PH", null, "sa_vat")] - [InlineData("SG", "PH", null, "sg_gst")] - [InlineData("TH", "PH", null, "th_vat")] - [InlineData("TW", "PH", null, "tw_vat")] - [InlineData("US", "PH", null, "us_ein")] - [InlineData("ZA", "PH", null, "za_vat")] - [InlineData("ABCDEF", "PH", null, null)] - public void GetTaxIdType_Success(string billingAddressCountry, - string taxIdNumber, - string billingAddressState, - string expectedTaxIdType) - { - var taxInfo = new TaxInfo - { - BillingAddressCountry = billingAddressCountry, - TaxIdNumber = taxIdNumber, - BillingAddressState = billingAddressState, - }; - - Assert.Equal(expectedTaxIdType, taxInfo.TaxIdType); - } - - [Fact] - public void GetTaxIdType_CreateOnce_ReturnCacheSecondTime() - { - var taxInfo = new TaxInfo - { - BillingAddressCountry = "US", - TaxIdNumber = "PH", - BillingAddressState = null, - }; - - Assert.Equal("us_ein", taxInfo.TaxIdType); - - // Per the current spec even if the values change to something other than null it - // will return the cached version of TaxIdType. - taxInfo.BillingAddressCountry = "ZA"; - - Assert.Equal("us_ein", taxInfo.TaxIdType); - } - - [Theory] - [InlineData(null, null, false)] - [InlineData("123", "US", true)] - [InlineData("123", "ZQ12", false)] - [InlineData(" ", "US", false)] - public void HasTaxId_ReturnsExpected(string taxIdNumber, string billingAddressCountry, bool expected) - { - var taxInfo = new TaxInfo - { - TaxIdNumber = taxIdNumber, - BillingAddressCountry = billingAddressCountry, - }; - - Assert.Equal(expected, taxInfo.HasTaxId); - } -} diff --git a/test/Core.Test/Services/StripePaymentServiceTests.cs b/test/Core.Test/Services/StripePaymentServiceTests.cs index e15f07b11384..35e1901a2fca 100644 --- a/test/Core.Test/Services/StripePaymentServiceTests.cs +++ b/test/Core.Test/Services/StripePaymentServiceTests.cs @@ -77,7 +77,8 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => @@ -134,7 +135,8 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => @@ -190,7 +192,8 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => @@ -247,7 +250,8 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => @@ -441,7 +445,8 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => @@ -510,7 +515,8 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => From 470a12640ee9f42b725cdf4d2587a59ec4dc121b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Ch=C4=99ci=C5=84ski?= Date: Wed, 4 Dec 2024 14:18:58 +0100 Subject: [PATCH 24/28] Trigger unified build on rc and hotfix-rc branches (#5108) --- .github/workflows/build.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1d092d8b4dcb..0fa03312b879 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -530,7 +530,9 @@ jobs: self-host-build: name: Trigger self-host build - if: github.event_name != 'pull_request_target' && github.ref == 'refs/heads/main' + if: | + github.event_name != 'pull_request_target' + && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/rc' || github.ref == 'refs/heads/hotfix-rc') runs-on: ubuntu-22.04 needs: - build-docker From 90a9473a5ecf4e0d2e912bfb7dd7209b310a9673 Mon Sep 17 00:00:00 2001 From: Jonas Hendrickx Date: Wed, 4 Dec 2024 15:36:11 +0100 Subject: [PATCH 25/28] Revert "[PM-13999] Show estimated tax for taxable countries (#5077)" (#5109) This reverts commit 94fdfa40e8af9c9b788aafe2cf89eacc2913eeea. Co-authored-by: Conner Turnbull <133619638+cturnbull-bitwarden@users.noreply.github.com> --- .../Billing/TaxServiceTests.cs | 151 --- .../Controllers/AccountsBillingController.cs | 15 - .../Billing/Controllers/InvoicesController.cs | 42 - .../Controllers/ProviderBillingController.cs | 1 - .../Billing/Controllers/StripeController.cs | 14 +- .../Requests/TaxInformationRequestBody.cs | 2 - .../Billing/Extensions/CurrencyExtensions.cs | 33 - .../Extensions/ServiceCollectionExtensions.cs | 1 - .../PreviewIndividualInvoiceRequestModel.cs | 18 - .../PreviewOrganizationInvoiceRequestModel.cs | 37 - .../Requests/TaxInformationRequestModel.cs | 14 - .../Responses/PreviewInvoiceResponseModel.cs | 7 - src/Core/Billing/Models/PreviewInvoiceInfo.cs | 7 - .../Billing/Models/Sales/OrganizationSale.cs | 1 - src/Core/Billing/Models/TaxIdType.cs | 22 - src/Core/Billing/Models/TaxInformation.cs | 160 +++- src/Core/Billing/Services/ITaxService.cs | 22 - .../OrganizationBillingService.cs | 35 +- .../PremiumUserBillingService.cs | 37 +- .../Implementations/SubscriberService.cs | 49 +- src/Core/Billing/Services/TaxService.cs | 901 ------------------ src/Core/Billing/Utilities.cs | 1 - src/Core/Models/Business/TaxInfo.cs | 208 +++- src/Core/Services/IPaymentService.cs | 6 - src/Core/Services/IStripeAdapter.cs | 2 - .../Services/Implementations/StripeAdapter.cs | 12 - .../Implementations/StripePaymentService.cs | 387 +------- .../Services/SubscriberServiceTests.cs | 3 +- .../Core.Test/Models/Business/TaxInfoTests.cs | 114 +++ .../Services/StripePaymentServiceTests.cs | 18 +- 30 files changed, 529 insertions(+), 1791 deletions(-) delete mode 100644 bitwarden_license/test/Commercial.Core.Test/Billing/TaxServiceTests.cs delete mode 100644 src/Api/Billing/Controllers/InvoicesController.cs delete mode 100644 src/Core/Billing/Extensions/CurrencyExtensions.cs delete mode 100644 src/Core/Billing/Models/Api/Requests/Accounts/PreviewIndividualInvoiceRequestModel.cs delete mode 100644 src/Core/Billing/Models/Api/Requests/Organizations/PreviewOrganizationInvoiceRequestModel.cs delete mode 100644 src/Core/Billing/Models/Api/Requests/TaxInformationRequestModel.cs delete mode 100644 src/Core/Billing/Models/Api/Responses/PreviewInvoiceResponseModel.cs delete mode 100644 src/Core/Billing/Models/PreviewInvoiceInfo.cs delete mode 100644 src/Core/Billing/Models/TaxIdType.cs delete mode 100644 src/Core/Billing/Services/ITaxService.cs delete mode 100644 src/Core/Billing/Services/TaxService.cs create mode 100644 test/Core.Test/Models/Business/TaxInfoTests.cs diff --git a/bitwarden_license/test/Commercial.Core.Test/Billing/TaxServiceTests.cs b/bitwarden_license/test/Commercial.Core.Test/Billing/TaxServiceTests.cs deleted file mode 100644 index 3995fb9de6b1..000000000000 --- a/bitwarden_license/test/Commercial.Core.Test/Billing/TaxServiceTests.cs +++ /dev/null @@ -1,151 +0,0 @@ -using Bit.Core.Billing.Services; -using Bit.Test.Common.AutoFixture; -using Bit.Test.Common.AutoFixture.Attributes; -using Xunit; - -namespace Bit.Commercial.Core.Test.Billing; - -[SutProviderCustomize] -public class TaxServiceTests -{ - [Theory] - [BitAutoData("AD", "A-123456-Z", "ad_nrt")] - [BitAutoData("AD", "A123456Z", "ad_nrt")] - [BitAutoData("AR", "20-12345678-9", "ar_cuit")] - [BitAutoData("AR", "20123456789", "ar_cuit")] - [BitAutoData("AU", "01259983598", "au_abn")] - [BitAutoData("AU", "123456789123", "au_arn")] - [BitAutoData("AT", "ATU12345678", "eu_vat")] - [BitAutoData("BH", "123456789012345", "bh_vat")] - [BitAutoData("BY", "123456789", "by_tin")] - [BitAutoData("BE", "BE0123456789", "eu_vat")] - [BitAutoData("BO", "123456789", "bo_tin")] - [BitAutoData("BR", "01.234.456/5432-10", "br_cnpj")] - [BitAutoData("BR", "01234456543210", "br_cnpj")] - [BitAutoData("BR", "123.456.789-87", "br_cpf")] - [BitAutoData("BR", "12345678987", "br_cpf")] - [BitAutoData("BG", "123456789", "bg_uic")] - [BitAutoData("BG", "BG012100705", "eu_vat")] - [BitAutoData("CA", "100728494", "ca_bn")] - [BitAutoData("CA", "123456789RT0001", "ca_gst_hst")] - [BitAutoData("CA", "PST-1234-1234", "ca_pst_bc")] - [BitAutoData("CA", "123456-7", "ca_pst_mb")] - [BitAutoData("CA", "1234567", "ca_pst_sk")] - [BitAutoData("CA", "1234567890TQ1234", "ca_qst")] - [BitAutoData("CL", "11.121.326-1", "cl_tin")] - [BitAutoData("CL", "11121326-1", "cl_tin")] - [BitAutoData("CL", "23.121.326-K", "cl_tin")] - [BitAutoData("CL", "43651326-K", "cl_tin")] - [BitAutoData("CN", "123456789012345678", "cn_tin")] - [BitAutoData("CN", "123456789012345", "cn_tin")] - [BitAutoData("CO", "123.456.789-0", "co_nit")] - [BitAutoData("CO", "1234567890", "co_nit")] - [BitAutoData("CR", "1-234-567890", "cr_tin")] - [BitAutoData("CR", "1234567890", "cr_tin")] - [BitAutoData("HR", "HR12345678912", "eu_vat")] - [BitAutoData("HR", "12345678901", "hr_oib")] - [BitAutoData("CY", "CY12345678X", "eu_vat")] - [BitAutoData("CZ", "CZ12345678", "eu_vat")] - [BitAutoData("DK", "DK12345678", "eu_vat")] - [BitAutoData("DO", "123-4567890-1", "do_rcn")] - [BitAutoData("DO", "12345678901", "do_rcn")] - [BitAutoData("EC", "1234567890001", "ec_ruc")] - [BitAutoData("EG", "123456789", "eg_tin")] - [BitAutoData("SV", "1234-567890-123-4", "sv_nit")] - [BitAutoData("SV", "12345678901234", "sv_nit")] - [BitAutoData("EE", "EE123456789", "eu_vat")] - [BitAutoData("EU", "EU123456789", "eu_oss_vat")] - [BitAutoData("FI", "FI12345678", "eu_vat")] - [BitAutoData("FR", "FR12345678901", "eu_vat")] - [BitAutoData("GE", "123456789", "ge_vat")] - [BitAutoData("DE", "1234567890", "de_stn")] - [BitAutoData("DE", "DE123456789", "eu_vat")] - [BitAutoData("GR", "EL123456789", "eu_vat")] - [BitAutoData("HK", "12345678", "hk_br")] - [BitAutoData("HU", "HU12345678", "eu_vat")] - [BitAutoData("HU", "12345678-1-23", "hu_tin")] - [BitAutoData("HU", "12345678123", "hu_tin")] - [BitAutoData("IS", "123456", "is_vat")] - [BitAutoData("IN", "12ABCDE1234F1Z5", "in_gst")] - [BitAutoData("IN", "12ABCDE3456FGZH", "in_gst")] - [BitAutoData("ID", "012.345.678.9-012.345", "id_npwp")] - [BitAutoData("ID", "0123456789012345", "id_npwp")] - [BitAutoData("IE", "IE1234567A", "eu_vat")] - [BitAutoData("IE", "IE1234567AB", "eu_vat")] - [BitAutoData("IL", "000012345", "il_vat")] - [BitAutoData("IL", "123456789", "il_vat")] - [BitAutoData("IT", "IT12345678901", "eu_vat")] - [BitAutoData("JP", "1234567890123", "jp_cn")] - [BitAutoData("JP", "12345", "jp_rn")] - [BitAutoData("KZ", "123456789012", "kz_bin")] - [BitAutoData("KE", "P000111111A", "ke_pin")] - [BitAutoData("LV", "LV12345678912", "eu_vat")] - [BitAutoData("LI", "CHE123456789", "li_uid")] - [BitAutoData("LI", "12345", "li_vat")] - [BitAutoData("LT", "LT123456789123", "eu_vat")] - [BitAutoData("LU", "LU12345678", "eu_vat")] - [BitAutoData("MY", "12345678", "my_frp")] - [BitAutoData("MY", "C 1234567890", "my_itn")] - [BitAutoData("MY", "C1234567890", "my_itn")] - [BitAutoData("MY", "A12-3456-78912345", "my_sst")] - [BitAutoData("MY", "A12345678912345", "my_sst")] - [BitAutoData("MT", "MT12345678", "eu_vat")] - [BitAutoData("MX", "ABC010203AB9", "mx_rfc")] - [BitAutoData("MD", "1003600", "md_vat")] - [BitAutoData("MA", "12345678", "ma_vat")] - [BitAutoData("NL", "NL123456789B12", "eu_vat")] - [BitAutoData("NZ", "123456789", "nz_gst")] - [BitAutoData("NG", "12345678-0001", "ng_tin")] - [BitAutoData("NO", "123456789MVA", "no_vat")] - [BitAutoData("NO", "1234567", "no_voec")] - [BitAutoData("OM", "OM1234567890", "om_vat")] - [BitAutoData("PE", "12345678901", "pe_ruc")] - [BitAutoData("PH", "123456789012", "ph_tin")] - [BitAutoData("PL", "PL1234567890", "eu_vat")] - [BitAutoData("PT", "PT123456789", "eu_vat")] - [BitAutoData("RO", "RO1234567891", "eu_vat")] - [BitAutoData("RO", "1234567890123", "ro_tin")] - [BitAutoData("RU", "1234567891", "ru_inn")] - [BitAutoData("RU", "123456789", "ru_kpp")] - [BitAutoData("SA", "123456789012345", "sa_vat")] - [BitAutoData("RS", "123456789", "rs_pib")] - [BitAutoData("SG", "M12345678X", "sg_gst")] - [BitAutoData("SG", "123456789F", "sg_uen")] - [BitAutoData("SK", "SK1234567891", "eu_vat")] - [BitAutoData("SI", "SI12345678", "eu_vat")] - [BitAutoData("SI", "12345678", "si_tin")] - [BitAutoData("ZA", "4123456789", "za_vat")] - [BitAutoData("KR", "123-45-67890", "kr_brn")] - [BitAutoData("KR", "1234567890", "kr_brn")] - [BitAutoData("ES", "A12345678", "es_cif")] - [BitAutoData("ES", "ESX1234567X", "eu_vat")] - [BitAutoData("SE", "SE123456789012", "eu_vat")] - [BitAutoData("CH", "CHE-123.456.789 HR", "ch_uid")] - [BitAutoData("CH", "CHE123456789HR", "ch_uid")] - [BitAutoData("CH", "CHE-123.456.789 MWST", "ch_vat")] - [BitAutoData("CH", "CHE123456789MWST", "ch_vat")] - [BitAutoData("TW", "12345678", "tw_vat")] - [BitAutoData("TH", "1234567890123", "th_vat")] - [BitAutoData("TR", "0123456789", "tr_tin")] - [BitAutoData("UA", "123456789", "ua_vat")] - [BitAutoData("AE", "123456789012345", "ae_trn")] - [BitAutoData("GB", "XI123456789", "eu_vat")] - [BitAutoData("GB", "GB123456789", "gb_vat")] - [BitAutoData("US", "12-3456789", "us_ein")] - [BitAutoData("UY", "123456789012", "uy_ruc")] - [BitAutoData("UZ", "123456789", "uz_tin")] - [BitAutoData("UZ", "123456789012", "uz_vat")] - [BitAutoData("VE", "A-12345678-9", "ve_rif")] - [BitAutoData("VE", "A123456789", "ve_rif")] - [BitAutoData("VN", "1234567890", "vn_tin")] - public void GetStripeTaxCode_WithValidCountryAndTaxId_ReturnsExpectedTaxIdType( - string country, - string taxId, - string expected, - SutProvider sutProvider) - { - var result = sutProvider.Sut.GetStripeTaxCode(country, taxId); - - Assert.Equal(expected, result); - } -} diff --git a/src/Api/Billing/Controllers/AccountsBillingController.cs b/src/Api/Billing/Controllers/AccountsBillingController.cs index fcb89226e75d..574ac3e65e26 100644 --- a/src/Api/Billing/Controllers/AccountsBillingController.cs +++ b/src/Api/Billing/Controllers/AccountsBillingController.cs @@ -1,6 +1,5 @@ #nullable enable using Bit.Api.Billing.Models.Responses; -using Bit.Core.Billing.Models.Api.Requests.Accounts; using Bit.Core.Billing.Services; using Bit.Core.Services; using Bit.Core.Utilities; @@ -78,18 +77,4 @@ public async Task GetTransactionsAsync([FromQuery] DateTime? startAfter return TypedResults.Ok(transactions); } - - [HttpPost("preview-invoice")] - public async Task PreviewInvoiceAsync([FromBody] PreviewIndividualInvoiceRequestBody model) - { - var user = await userService.GetUserByPrincipalAsync(User); - if (user == null) - { - throw new UnauthorizedAccessException(); - } - - var invoice = await paymentService.PreviewInvoiceAsync(model, user.GatewayCustomerId, user.GatewaySubscriptionId); - - return TypedResults.Ok(invoice); - } } diff --git a/src/Api/Billing/Controllers/InvoicesController.cs b/src/Api/Billing/Controllers/InvoicesController.cs deleted file mode 100644 index 686d9b964387..000000000000 --- a/src/Api/Billing/Controllers/InvoicesController.cs +++ /dev/null @@ -1,42 +0,0 @@ -using Bit.Core.AdminConsole.Entities; -using Bit.Core.Billing.Models.Api.Requests.Organizations; -using Bit.Core.Context; -using Bit.Core.Repositories; -using Bit.Core.Services; -using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Mvc; - -namespace Bit.Api.Billing.Controllers; - -[Route("invoices")] -[Authorize("Application")] -public class InvoicesController : BaseBillingController -{ - [HttpPost("preview-organization")] - public async Task PreviewInvoiceAsync( - [FromBody] PreviewOrganizationInvoiceRequestBody model, - [FromServices] ICurrentContext currentContext, - [FromServices] IOrganizationRepository organizationRepository, - [FromServices] IPaymentService paymentService) - { - Organization organization = null; - if (model.OrganizationId != default) - { - if (!await currentContext.EditPaymentMethods(model.OrganizationId)) - { - return Error.Unauthorized(); - } - - organization = await organizationRepository.GetByIdAsync(model.OrganizationId); - if (organization == null) - { - return Error.NotFound(); - } - } - - var invoice = await paymentService.PreviewInvoiceAsync(model, organization?.GatewayCustomerId, - organization?.GatewaySubscriptionId); - - return TypedResults.Ok(invoice); - } -} diff --git a/src/Api/Billing/Controllers/ProviderBillingController.cs b/src/Api/Billing/Controllers/ProviderBillingController.cs index c5de63c69b9f..f7ddf0853e20 100644 --- a/src/Api/Billing/Controllers/ProviderBillingController.cs +++ b/src/Api/Billing/Controllers/ProviderBillingController.cs @@ -119,7 +119,6 @@ public async Task UpdateTaxInformationAsync( requestBody.Country, requestBody.PostalCode, requestBody.TaxId, - requestBody.TaxIdType, requestBody.Line1, requestBody.Line2, requestBody.City, diff --git a/src/Api/Billing/Controllers/StripeController.cs b/src/Api/Billing/Controllers/StripeController.cs index f5e8253bfa49..a4a974bb992c 100644 --- a/src/Api/Billing/Controllers/StripeController.cs +++ b/src/Api/Billing/Controllers/StripeController.cs @@ -1,5 +1,4 @@ -using Bit.Core.Billing.Services; -using Bit.Core.Services; +using Bit.Core.Services; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.AspNetCore.Mvc; @@ -47,15 +46,4 @@ public async Task> CreateSetupIntentForCardAsync() return TypedResults.Ok(setupIntent.ClientSecret); } - - [HttpGet] - [Route("~/tax/is-country-supported")] - public IResult IsCountrySupported( - [FromQuery] string country, - [FromServices] ITaxService taxService) - { - var isSupported = taxService.IsSupported(country); - - return TypedResults.Ok(isSupported); - } } diff --git a/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs b/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs index 32ba2effb29b..c5c0fde00b93 100644 --- a/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs +++ b/src/Api/Billing/Models/Requests/TaxInformationRequestBody.cs @@ -10,7 +10,6 @@ public class TaxInformationRequestBody [Required] public string PostalCode { get; set; } public string TaxId { get; set; } - public string TaxIdType { get; set; } public string Line1 { get; set; } public string Line2 { get; set; } public string City { get; set; } @@ -20,7 +19,6 @@ public class TaxInformationRequestBody Country, PostalCode, TaxId, - TaxIdType, Line1, Line2, City, diff --git a/src/Core/Billing/Extensions/CurrencyExtensions.cs b/src/Core/Billing/Extensions/CurrencyExtensions.cs deleted file mode 100644 index cde1a7bea846..000000000000 --- a/src/Core/Billing/Extensions/CurrencyExtensions.cs +++ /dev/null @@ -1,33 +0,0 @@ -namespace Bit.Core.Billing.Extensions; - -public static class CurrencyExtensions -{ - /// - /// Converts a currency amount in major units to minor units. - /// - /// 123.99 USD returns 12399 in minor units. - public static long ToMinor(this decimal amount) - { - return Convert.ToInt64(amount * 100); - } - - /// - /// Converts a currency amount in minor units to major units. - /// - /// - /// 12399 in minor units returns 123.99 USD. - public static decimal? ToMajor(this long? amount) - { - return amount?.ToMajor(); - } - - /// - /// Converts a currency amount in minor units to major units. - /// - /// - /// 12399 in minor units returns 123.99 USD. - public static decimal ToMajor(this long amount) - { - return Convert.ToDecimal(amount) / 100; - } -} diff --git a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs index 8f2803d920af..abfceac7366e 100644 --- a/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs +++ b/src/Core/Billing/Extensions/ServiceCollectionExtensions.cs @@ -11,7 +11,6 @@ public static class ServiceCollectionExtensions { public static void AddBillingOperations(this IServiceCollection services) { - services.AddSingleton(); services.AddTransient(); services.AddTransient(); services.AddTransient(); diff --git a/src/Core/Billing/Models/Api/Requests/Accounts/PreviewIndividualInvoiceRequestModel.cs b/src/Core/Billing/Models/Api/Requests/Accounts/PreviewIndividualInvoiceRequestModel.cs deleted file mode 100644 index 6dfb9894d50d..000000000000 --- a/src/Core/Billing/Models/Api/Requests/Accounts/PreviewIndividualInvoiceRequestModel.cs +++ /dev/null @@ -1,18 +0,0 @@ -using System.ComponentModel.DataAnnotations; - -namespace Bit.Core.Billing.Models.Api.Requests.Accounts; - -public class PreviewIndividualInvoiceRequestBody -{ - [Required] - public PasswordManagerRequestModel PasswordManager { get; set; } - - [Required] - public TaxInformationRequestModel TaxInformation { get; set; } -} - -public class PasswordManagerRequestModel -{ - [Range(0, int.MaxValue)] - public int AdditionalStorage { get; set; } -} diff --git a/src/Core/Billing/Models/Api/Requests/Organizations/PreviewOrganizationInvoiceRequestModel.cs b/src/Core/Billing/Models/Api/Requests/Organizations/PreviewOrganizationInvoiceRequestModel.cs deleted file mode 100644 index 18d9c352d7c4..000000000000 --- a/src/Core/Billing/Models/Api/Requests/Organizations/PreviewOrganizationInvoiceRequestModel.cs +++ /dev/null @@ -1,37 +0,0 @@ -using System.ComponentModel.DataAnnotations; -using Bit.Core.Billing.Enums; - -namespace Bit.Core.Billing.Models.Api.Requests.Organizations; - -public class PreviewOrganizationInvoiceRequestBody -{ - public Guid OrganizationId { get; set; } - - [Required] - public PasswordManagerRequestModel PasswordManager { get; set; } - - public SecretsManagerRequestModel SecretsManager { get; set; } - - [Required] - public TaxInformationRequestModel TaxInformation { get; set; } -} - -public class PasswordManagerRequestModel -{ - public PlanType Plan { get; set; } - - [Range(0, int.MaxValue)] - public int Seats { get; set; } - - [Range(0, int.MaxValue)] - public int AdditionalStorage { get; set; } -} - -public class SecretsManagerRequestModel -{ - [Range(0, int.MaxValue)] - public int Seats { get; set; } - - [Range(0, int.MaxValue)] - public int AdditionalMachineAccounts { get; set; } -} diff --git a/src/Core/Billing/Models/Api/Requests/TaxInformationRequestModel.cs b/src/Core/Billing/Models/Api/Requests/TaxInformationRequestModel.cs deleted file mode 100644 index 9cb43645c6cd..000000000000 --- a/src/Core/Billing/Models/Api/Requests/TaxInformationRequestModel.cs +++ /dev/null @@ -1,14 +0,0 @@ -using System.ComponentModel.DataAnnotations; - -namespace Bit.Core.Billing.Models.Api.Requests; - -public class TaxInformationRequestModel -{ - [Length(2, 2), Required] - public string Country { get; set; } - - [Required] - public string PostalCode { get; set; } - - public string TaxId { get; set; } -} diff --git a/src/Core/Billing/Models/Api/Responses/PreviewInvoiceResponseModel.cs b/src/Core/Billing/Models/Api/Responses/PreviewInvoiceResponseModel.cs deleted file mode 100644 index fdde7dae1e48..000000000000 --- a/src/Core/Billing/Models/Api/Responses/PreviewInvoiceResponseModel.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Bit.Core.Billing.Models.Api.Responses; - -public record PreviewInvoiceResponseModel( - decimal EffectiveTaxRate, - decimal TaxableBaseAmount, - decimal TaxAmount, - decimal TotalAmount); diff --git a/src/Core/Billing/Models/PreviewInvoiceInfo.cs b/src/Core/Billing/Models/PreviewInvoiceInfo.cs deleted file mode 100644 index 16a2019c202b..000000000000 --- a/src/Core/Billing/Models/PreviewInvoiceInfo.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Bit.Core.Billing.Models; - -public record PreviewInvoiceInfo( - decimal EffectiveTaxRate, - decimal TaxableBaseAmount, - decimal TaxAmount, - decimal TotalAmount); diff --git a/src/Core/Billing/Models/Sales/OrganizationSale.cs b/src/Core/Billing/Models/Sales/OrganizationSale.cs index 43852bb320bd..a19c278c68f1 100644 --- a/src/Core/Billing/Models/Sales/OrganizationSale.cs +++ b/src/Core/Billing/Models/Sales/OrganizationSale.cs @@ -65,7 +65,6 @@ private static CustomerSetup GetCustomerSetup(OrganizationSignup signup) signup.TaxInfo.BillingAddressCountry, signup.TaxInfo.BillingAddressPostalCode, signup.TaxInfo.TaxIdNumber, - signup.TaxInfo.TaxIdType, signup.TaxInfo.BillingAddressLine1, signup.TaxInfo.BillingAddressLine2, signup.TaxInfo.BillingAddressCity, diff --git a/src/Core/Billing/Models/TaxIdType.cs b/src/Core/Billing/Models/TaxIdType.cs deleted file mode 100644 index 3fc246d68bbc..000000000000 --- a/src/Core/Billing/Models/TaxIdType.cs +++ /dev/null @@ -1,22 +0,0 @@ -using System.Text.RegularExpressions; - -namespace Bit.Core.Billing.Models; - -public class TaxIdType -{ - /// - /// ISO-3166-2 code for the country. - /// - public string Country { get; set; } - - /// - /// The identifier in Stripe for the tax ID type. - /// - public string Code { get; set; } - - public Regex ValidationExpression { get; set; } - - public string Description { get; set; } - - public string Example { get; set; } -} diff --git a/src/Core/Billing/Models/TaxInformation.cs b/src/Core/Billing/Models/TaxInformation.cs index 23ed3e5faa64..5403f9469022 100644 --- a/src/Core/Billing/Models/TaxInformation.cs +++ b/src/Core/Billing/Models/TaxInformation.cs @@ -1,4 +1,5 @@ using Bit.Core.Models.Business; +using Stripe; namespace Bit.Core.Billing.Models; @@ -6,7 +7,6 @@ public record TaxInformation( string Country, string PostalCode, string TaxId, - string TaxIdType, string Line1, string Line2, string City, @@ -16,9 +16,165 @@ public record TaxInformation( taxInfo.BillingAddressCountry, taxInfo.BillingAddressPostalCode, taxInfo.TaxIdNumber, - taxInfo.TaxIdType, taxInfo.BillingAddressLine1, taxInfo.BillingAddressLine2, taxInfo.BillingAddressCity, taxInfo.BillingAddressState); + + public (AddressOptions, List) GetStripeOptions() + { + var address = new AddressOptions + { + Country = Country, + PostalCode = PostalCode, + Line1 = Line1, + Line2 = Line2, + City = City, + State = State + }; + + var customerTaxIdDataOptionsList = !string.IsNullOrEmpty(TaxId) + ? new List { new() { Type = GetTaxIdType(), Value = TaxId } } + : null; + + return (address, customerTaxIdDataOptionsList); + } + + public string GetTaxIdType() + { + if (string.IsNullOrEmpty(Country) || string.IsNullOrEmpty(TaxId)) + { + return null; + } + + switch (Country.ToUpper()) + { + case "AD": + return "ad_nrt"; + case "AE": + return "ae_trn"; + case "AR": + return "ar_cuit"; + case "AU": + return "au_abn"; + case "BO": + return "bo_tin"; + case "BR": + return "br_cnpj"; + case "CA": + // May break for those in Québec given the assumption of QST + if (State?.Contains("bec") ?? false) + { + return "ca_qst"; + } + return "ca_bn"; + case "CH": + return "ch_vat"; + case "CL": + return "cl_tin"; + case "CN": + return "cn_tin"; + case "CO": + return "co_nit"; + case "CR": + return "cr_tin"; + case "DO": + return "do_rcn"; + case "EC": + return "ec_ruc"; + case "EG": + return "eg_tin"; + case "GE": + return "ge_vat"; + case "ID": + return "id_npwp"; + case "IL": + return "il_vat"; + case "IS": + return "is_vat"; + case "KE": + return "ke_pin"; + case "AT": + case "BE": + case "BG": + case "CY": + case "CZ": + case "DE": + case "DK": + case "EE": + case "ES": + case "FI": + case "FR": + case "GB": + case "GR": + case "HR": + case "HU": + case "IE": + case "IT": + case "LT": + case "LU": + case "LV": + case "MT": + case "NL": + case "PL": + case "PT": + case "RO": + case "SE": + case "SI": + case "SK": + return "eu_vat"; + case "HK": + return "hk_br"; + case "IN": + return "in_gst"; + case "JP": + return "jp_cn"; + case "KR": + return "kr_brn"; + case "LI": + return "li_uid"; + case "MX": + return "mx_rfc"; + case "MY": + return "my_sst"; + case "NO": + return "no_vat"; + case "NZ": + return "nz_gst"; + case "PE": + return "pe_ruc"; + case "PH": + return "ph_tin"; + case "RS": + return "rs_pib"; + case "RU": + return "ru_inn"; + case "SA": + return "sa_vat"; + case "SG": + return "sg_gst"; + case "SV": + return "sv_nit"; + case "TH": + return "th_vat"; + case "TR": + return "tr_tin"; + case "TW": + return "tw_vat"; + case "UA": + return "ua_vat"; + case "US": + return "us_ein"; + case "UY": + return "uy_ruc"; + case "VE": + return "ve_rif"; + case "VN": + return "vn_tin"; + case "ZA": + return "za_vat"; + default: + return null; + } + } } diff --git a/src/Core/Billing/Services/ITaxService.cs b/src/Core/Billing/Services/ITaxService.cs deleted file mode 100644 index beee113d17bd..000000000000 --- a/src/Core/Billing/Services/ITaxService.cs +++ /dev/null @@ -1,22 +0,0 @@ -namespace Bit.Core.Billing.Services; - -public interface ITaxService -{ - /// - /// Retrieves the Stripe tax code for a given country and tax ID. - /// - /// - /// - /// - /// Returns the Stripe tax code if the tax ID is valid for the country. - /// Returns null if the tax ID is invalid or the country is not supported. - /// - string GetStripeTaxCode(string country, string taxId); - - /// - /// Returns true or false whether charging or storing tax is supported for the given country. - /// - /// - /// - bool IsSupported(string country); -} diff --git a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs index b186a99d93e7..eadc5896258f 100644 --- a/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs +++ b/src/Core/Billing/Services/Implementations/OrganizationBillingService.cs @@ -28,8 +28,7 @@ public class OrganizationBillingService( IOrganizationRepository organizationRepository, ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, - ISubscriberService subscriberService, - ITaxService taxService) : IOrganizationBillingService + ISubscriberService subscriberService) : IOrganizationBillingService { public async Task Finalize(OrganizationSale sale) { @@ -168,38 +167,14 @@ private async Task CreateCustomerAsync( throw new BillingException(); } - customerCreateOptions.Address = new AddressOptions - { - Line1 = customerSetup.TaxInformation.Line1, - Line2 = customerSetup.TaxInformation.Line2, - City = customerSetup.TaxInformation.City, - PostalCode = customerSetup.TaxInformation.PostalCode, - State = customerSetup.TaxInformation.State, - Country = customerSetup.TaxInformation.Country, - }; + var (address, taxIdData) = customerSetup.TaxInformation.GetStripeOptions(); + + customerCreateOptions.Address = address; customerCreateOptions.Tax = new CustomerTaxOptions { ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately }; - - if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId)) - { - var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country, - customerSetup.TaxInformation.TaxId); - - if (taxIdType == null) - { - logger.LogWarning("Could not determine tax ID type for organization '{OrganizationID}' in country '{Country}' with tax ID '{TaxID}'.", - organization.Id, - customerSetup.TaxInformation.Country, - customerSetup.TaxInformation.TaxId); - } - - customerCreateOptions.TaxIdData = - [ - new() { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId } - ]; - } + customerCreateOptions.TaxIdData = taxIdData; var (paymentMethodType, paymentMethodToken) = customerSetup.TokenizedPaymentSource; diff --git a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs index 5351888ad9d3..92c81dae1cf5 100644 --- a/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs +++ b/src/Core/Billing/Services/Implementations/PremiumUserBillingService.cs @@ -24,8 +24,7 @@ public class PremiumUserBillingService( ISetupIntentCache setupIntentCache, IStripeAdapter stripeAdapter, ISubscriberService subscriberService, - IUserRepository userRepository, - ITaxService taxService) : IPremiumUserBillingService + IUserRepository userRepository) : IPremiumUserBillingService { public async Task Finalize(PremiumUserSale sale) { @@ -83,19 +82,13 @@ private async Task CreateCustomerAsync( throw new BillingException(); } + var (address, taxIdData) = customerSetup.TaxInformation.GetStripeOptions(); + var subscriberName = user.SubscriberName(); var customerCreateOptions = new CustomerCreateOptions { - Address = new AddressOptions - { - Line1 = customerSetup.TaxInformation.Line1, - Line2 = customerSetup.TaxInformation.Line2, - City = customerSetup.TaxInformation.City, - PostalCode = customerSetup.TaxInformation.PostalCode, - State = customerSetup.TaxInformation.State, - Country = customerSetup.TaxInformation.Country, - }, + Address = address, Description = user.Name, Email = user.Email, Expand = ["tax"], @@ -120,28 +113,10 @@ private async Task CreateCustomerAsync( Tax = new CustomerTaxOptions { ValidateLocation = StripeConstants.ValidateTaxLocationTiming.Immediately - } + }, + TaxIdData = taxIdData }; - if (!string.IsNullOrEmpty(customerSetup.TaxInformation.TaxId)) - { - var taxIdType = taxService.GetStripeTaxCode(customerSetup.TaxInformation.Country, - customerSetup.TaxInformation.TaxId); - - if (taxIdType == null) - { - logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", - customerSetup.TaxInformation.Country, - customerSetup.TaxInformation.TaxId); - throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError"); - } - - customerCreateOptions.TaxIdData = - [ - new() { Type = taxIdType, Value = customerSetup.TaxInformation.TaxId } - ]; - } - var (paymentMethodType, paymentMethodToken) = customerSetup.TokenizedPaymentSource; var braintreeCustomerId = ""; diff --git a/src/Core/Billing/Services/Implementations/SubscriberService.cs b/src/Core/Billing/Services/Implementations/SubscriberService.cs index 6125d1541909..9b8f64be82d7 100644 --- a/src/Core/Billing/Services/Implementations/SubscriberService.cs +++ b/src/Core/Billing/Services/Implementations/SubscriberService.cs @@ -23,8 +23,7 @@ public class SubscriberService( IGlobalSettings globalSettings, ILogger logger, ISetupIntentCache setupIntentCache, - IStripeAdapter stripeAdapter, - ITaxService taxService) : ISubscriberService + IStripeAdapter stripeAdapter) : ISubscriberService { public async Task CancelSubscription( ISubscriber subscriber, @@ -619,47 +618,16 @@ public async Task UpdateTaxInformation( await stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); } - if (string.IsNullOrWhiteSpace(taxInformation.TaxId)) - { - return; - } - - var taxIdType = taxInformation.TaxIdType; - if (string.IsNullOrWhiteSpace(taxIdType)) - { - taxIdType = taxService.GetStripeTaxCode(taxInformation.Country, - taxInformation.TaxId); - - if (taxIdType == null) - { - logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", - taxInformation.Country, - taxInformation.TaxId); - throw new Exceptions.BadRequestException("billingTaxIdTypeInferenceError"); - } - } + var taxIdType = taxInformation.GetTaxIdType(); - try - { - await stripeAdapter.TaxIdCreateAsync(customer.Id, - new TaxIdCreateOptions { Type = taxIdType, Value = taxInformation.TaxId }); - } - catch (StripeException e) + if (!string.IsNullOrWhiteSpace(taxInformation.TaxId) && + !string.IsNullOrWhiteSpace(taxIdType)) { - switch (e.StripeError.Code) + await stripeAdapter.TaxIdCreateAsync(customer.Id, new TaxIdCreateOptions { - case StripeConstants.ErrorCodes.TaxIdInvalid: - logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - taxInformation.TaxId, - taxInformation.Country); - throw new Exceptions.BadRequestException("billingInvalidTaxIdError"); - default: - logger.LogError(e, "Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.", - taxInformation.TaxId, - taxInformation.Country, - customer.Id); - throw new Exceptions.BadRequestException("billingTaxIdCreationError"); - } + Type = taxIdType, + Value = taxInformation.TaxId, + }); } } @@ -802,7 +770,6 @@ private static TaxInformation GetTaxInformation( customer.Address.Country, customer.Address.PostalCode, customer.TaxIds?.FirstOrDefault()?.Value, - customer.TaxIds?.FirstOrDefault()?.Type, customer.Address.Line1, customer.Address.Line2, customer.Address.City, diff --git a/src/Core/Billing/Services/TaxService.cs b/src/Core/Billing/Services/TaxService.cs deleted file mode 100644 index 3066be92d1d8..000000000000 --- a/src/Core/Billing/Services/TaxService.cs +++ /dev/null @@ -1,901 +0,0 @@ -using System.Text.RegularExpressions; -using Bit.Core.Billing.Models; - -namespace Bit.Core.Billing.Services; - -public class TaxService : ITaxService -{ - /// - /// Retrieves a list of supported tax ID types for customers. - /// - /// Compiled list from Stripe - private static readonly IEnumerable _taxIdTypes = - [ - new() - { - Country = "AD", - Code = "ad_nrt", - Description = "Andorran NRT number", - Example = "A-123456-Z", - ValidationExpression = new Regex("^([A-Z]{1})-?([0-9]{6})-?([A-Z]{1})$") - }, - new() - { - Country = "AR", - Code = "ar_cuit", - Description = "Argentinian tax ID number", - Example = "12-34567890-1", - ValidationExpression = new Regex("^([0-9]{2})-?([0-9]{8})-?([0-9]{1})$") - }, - new() - { - Country = "AU", - Code = "au_abn", - Description = "Australian Business Number (AU ABN)", - Example = "123456789012", - ValidationExpression = new Regex("^[0-9]{11}$") - }, - new() - { - Country = "AU", - Code = "au_arn", - Description = "Australian Taxation Office Reference Number", - Example = "123456789123", - ValidationExpression = new Regex("^[0-9]{12}$") - }, - new() - { - Country = "AT", - Code = "eu_vat", - Description = "European VAT number (Austria)", - Example = "ATU12345678", - ValidationExpression = new Regex("^ATU[0-9]{8}$") - }, - new() - { - Country = "BH", - Code = "bh_vat", - Description = "Bahraini VAT Number", - Example = "123456789012345", - ValidationExpression = new Regex("^[0-9]{15}$") - }, - new() - { - Country = "BY", - Code = "by_tin", - Description = "Belarus TIN Number", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "BE", - Code = "eu_vat", - Description = "European VAT number (Belgium)", - Example = "BE0123456789", - ValidationExpression = new Regex("^BE[0-9]{10}$") - }, - new() - { - Country = "BO", - Code = "bo_tin", - Description = "Bolivian tax ID", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "BR", - Code = "br_cnpj", - Description = "Brazilian CNPJ number", - Example = "01.234.456/5432-10", - ValidationExpression = new Regex("^[0-9]{2}.?[0-9]{3}.?[0-9]{3}/?[0-9]{4}-?[0-9]{2}$") - }, - new() - { - Country = "BR", - Code = "br_cpf", - Description = "Brazilian CPF number", - Example = "123.456.789-87", - ValidationExpression = new Regex("^[0-9]{3}.?[0-9]{3}.?[0-9]{3}-?[0-9]{2}$") - }, - new() - { - Country = "BG", - Code = "bg_uic", - Description = "Bulgaria Unified Identification Code", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "BG", - Code = "eu_vat", - Description = "European VAT number (Bulgaria)", - Example = "BG0123456789", - ValidationExpression = new Regex("^BG[0-9]{9,10}$") - }, - new() - { - Country = "CA", - Code = "ca_bn", - Description = "Canadian BN", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "CA", - Code = "ca_gst_hst", - Description = "Canadian GST/HST number", - Example = "123456789RT0002", - ValidationExpression = new Regex("^[0-9]{9}RT[0-9]{4}$") - }, - new() - { - Country = "CA", - Code = "ca_pst_bc", - Description = "Canadian PST number (British Columbia)", - Example = "PST-1234-5678", - ValidationExpression = new Regex("^PST-[0-9]{4}-[0-9]{4}$") - }, - new() - { - Country = "CA", - Code = "ca_pst_mb", - Description = "Canadian PST number (Manitoba)", - Example = "123456-7", - ValidationExpression = new Regex("^[0-9]{6}-[0-9]{1}$") - }, - new() - { - Country = "CA", - Code = "ca_pst_sk", - Description = "Canadian PST number (Saskatchewan)", - Example = "1234567", - ValidationExpression = new Regex("^[0-9]{7}$") - }, - new() - { - Country = "CA", - Code = "ca_qst", - Description = "Canadian QST number (Québec)", - Example = "1234567890TQ1234", - ValidationExpression = new Regex("^[0-9]{10}TQ[0-9]{4}$") - }, - new() - { - Country = "CL", - Code = "cl_tin", - Description = "Chilean TIN", - Example = "12.345.678-K", - ValidationExpression = new Regex("^[0-9]{2}.?[0-9]{3}.?[0-9]{3}-?[0-9A-Z]{1}$") - }, - new() - { - Country = "CN", - Code = "cn_tin", - Description = "Chinese tax ID", - Example = "123456789012345678", - ValidationExpression = new Regex("^[0-9]{15,18}$") - }, - new() - { - Country = "CO", - Code = "co_nit", - Description = "Colombian NIT number", - Example = "123.456.789-0", - ValidationExpression = new Regex("^[0-9]{3}.?[0-9]{3}.?[0-9]{3}-?[0-9]{1}$") - }, - new() - { - Country = "CR", - Code = "cr_tin", - Description = "Costa Rican tax ID", - Example = "1-234-567890", - ValidationExpression = new Regex("^[0-9]{1}-?[0-9]{3}-?[0-9]{6}$") - }, - new() - { - Country = "HR", - Code = "eu_vat", - Description = "European VAT number (Croatia)", - Example = "HR12345678912", - ValidationExpression = new Regex("^HR[0-9]{11}$") - }, - new() - { - Country = "HR", - Code = "hr_oib", - Description = "Croatian Personal Identification Number", - Example = "12345678901", - ValidationExpression = new Regex("^[0-9]{11}$") - }, - new() - { - Country = "CY", - Code = "eu_vat", - Description = "European VAT number (Cyprus)", - Example = "CY12345678X", - ValidationExpression = new Regex("^CY[0-9]{8}[A-Z]{1}$") - }, - new() - { - Country = "CZ", - Code = "eu_vat", - Description = "European VAT number (Czech Republic)", - Example = "CZ12345678", - ValidationExpression = new Regex("^CZ[0-9]{8,10}$") - }, - new() - { - Country = "DK", - Code = "eu_vat", - Description = "European VAT number (Denmark)", - Example = "DK12345678", - ValidationExpression = new Regex("^DK[0-9]{8}$") - }, - new() - { - Country = "DO", - Code = "do_rcn", - Description = "Dominican RCN number", - Example = "123-4567890-1", - ValidationExpression = new Regex("^[0-9]{3}-?[0-9]{7}-?[0-9]{1}$") - }, - new() - { - Country = "EC", - Code = "ec_ruc", - Description = "Ecuadorian RUC number", - Example = "1234567890001", - ValidationExpression = new Regex("^[0-9]{13}$") - }, - new() - { - Country = "EG", - Code = "eg_tin", - Description = "Egyptian Tax Identification Number", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - - new() - { - Country = "SV", - Code = "sv_nit", - Description = "El Salvadorian NIT number", - Example = "1234-567890-123-4", - ValidationExpression = new Regex("^[0-9]{4}-?[0-9]{6}-?[0-9]{3}-?[0-9]{1}$") - }, - - new() - { - Country = "EE", - Code = "eu_vat", - Description = "European VAT number (Estonia)", - Example = "EE123456789", - ValidationExpression = new Regex("^EE[0-9]{9}$") - }, - - new() - { - Country = "EU", - Code = "eu_oss_vat", - Description = "European One Stop Shop VAT number for non-Union scheme", - Example = "EU123456789", - ValidationExpression = new Regex("^EU[0-9]{9}$") - }, - new() - { - Country = "FI", - Code = "eu_vat", - Description = "European VAT number (Finland)", - Example = "FI12345678", - ValidationExpression = new Regex("^FI[0-9]{8}$") - }, - new() - { - Country = "FR", - Code = "eu_vat", - Description = "European VAT number (France)", - Example = "FR12345678901", - ValidationExpression = new Regex("^FR[0-9A-Z]{2}[0-9]{9}$") - }, - new() - { - Country = "GE", - Code = "ge_vat", - Description = "Georgian VAT", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "DE", - Code = "de_stn", - Description = "German Tax Number (Steuernummer)", - Example = "1234567890", - ValidationExpression = new Regex("^[0-9]{10}$") - }, - new() - { - Country = "DE", - Code = "eu_vat", - Description = "European VAT number (Germany)", - Example = "DE123456789", - ValidationExpression = new Regex("^DE[0-9]{9}$") - }, - new() - { - Country = "GR", - Code = "eu_vat", - Description = "European VAT number (Greece)", - Example = "EL123456789", - ValidationExpression = new Regex("^EL[0-9]{9}$") - }, - new() - { - Country = "HK", - Code = "hk_br", - Description = "Hong Kong BR number", - Example = "12345678", - ValidationExpression = new Regex("^[0-9]{8}$") - }, - new() - { - Country = "HU", - Code = "eu_vat", - Description = "European VAT number (Hungaria)", - Example = "HU12345678", - ValidationExpression = new Regex("^HU[0-9]{8}$") - }, - new() - { - Country = "HU", - Code = "hu_tin", - Description = "Hungary tax number (adószám)", - Example = "12345678-1-23", - ValidationExpression = new Regex("^[0-9]{8}-?[0-9]-?[0-9]{2}$") - }, - new() - { - Country = "IS", - Code = "is_vat", - Description = "Icelandic VAT", - Example = "123456", - ValidationExpression = new Regex("^[0-9]{6}$") - }, - new() - { - Country = "IN", - Code = "in_gst", - Description = "Indian GST number", - Example = "12ABCDE3456FGZH", - ValidationExpression = new Regex("^[0-9]{2}[A-Z]{5}[0-9]{4}[A-Z]{1}[1-9A-Z]{1}Z[0-9A-Z]{1}$") - }, - new() - { - Country = "ID", - Code = "id_npwp", - Description = "Indonesian NPWP number", - Example = "012.345.678.9-012.345", - ValidationExpression = new Regex("^[0-9]{3}.?[0-9]{3}.?[0-9]{3}.?[0-9]{1}-?[0-9]{3}.?[0-9]{3}$") - }, - new() - { - Country = "IE", - Code = "eu_vat", - Description = "European VAT number (Ireland)", - Example = "IE1234567AB", - ValidationExpression = new Regex("^IE[0-9]{7}[A-Z]{1,2}$") - }, - new() - { - Country = "IL", - Code = "il_vat", - Description = "Israel VAT", - Example = "000012345", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "IT", - Code = "eu_vat", - Description = "European VAT number (Italy)", - Example = "IT12345678912", - ValidationExpression = new Regex("^IT[0-9]{11}$") - }, - new() - { - Country = "JP", - Code = "jp_cn", - Description = "Japanese Corporate Number (*Hōjin Bangō*)", - Example = "1234567891234", - ValidationExpression = new Regex("^[0-9]{13}$") - }, - new() - { - Country = "JP", - Code = "jp_rn", - Description = - "Japanese Registered Foreign Businesses' Registration Number (*Tōroku Kokugai Jigyōsha no Tōroku Bangō*)", - Example = "12345", - ValidationExpression = new Regex("^[0-9]{5}$") - }, - new() - { - Country = "JP", - Code = "jp_trn", - Description = "Japanese Tax Registration Number (*Tōroku Bangō*)", - Example = "T1234567891234", - ValidationExpression = new Regex("^T[0-9]{13}$") - }, - new() - { - Country = "KZ", - Code = "kz_bin", - Description = "Kazakhstani Business Identification Number", - Example = "123456789012", - ValidationExpression = new Regex("^[0-9]{12}$") - }, - new() - { - Country = "KE", - Code = "ke_pin", - Description = "Kenya Revenue Authority Personal Identification Number", - Example = "P000111111A", - ValidationExpression = new Regex("^[A-Z]{1}[0-9]{9}[A-Z]{1}$") - }, - new() - { - Country = "LV", - Code = "eu_vat", - Description = "European VAT number", - Example = "LV12345678912", - ValidationExpression = new Regex("^LV[0-9]{11}$") - }, - new() - { - Country = "LI", - Code = "li_uid", - Description = "Liechtensteinian UID number", - Example = "CHE123456789", - ValidationExpression = new Regex("^CHE[0-9]{9}$") - }, - new() - { - Country = "LI", - Code = "li_vat", - Description = "Liechtensteinian VAT number", - Example = "12345", - ValidationExpression = new Regex("^[0-9]{5}$") - }, - new() - { - Country = "LT", - Code = "eu_vat", - Description = "European VAT number (Lithuania)", - Example = "LT123456789123", - ValidationExpression = new Regex("^LT[0-9]{9,12}$") - }, - new() - { - Country = "LU", - Code = "eu_vat", - Description = "European VAT number (Luxembourg)", - Example = "LU12345678", - ValidationExpression = new Regex("^LU[0-9]{8}$") - }, - new() - { - Country = "MY", - Code = "my_frp", - Description = "Malaysian FRP number", - Example = "12345678", - ValidationExpression = new Regex("^[0-9]{8}$") - }, - new() - { - Country = "MY", - Code = "my_itn", - Description = "Malaysian ITN", - Example = "C 1234567890", - ValidationExpression = new Regex("^[A-Z]{1} ?[0-9]{10}$") - }, - new() - { - Country = "MY", - Code = "my_sst", - Description = "Malaysian SST number", - Example = "A12-3456-78912345", - ValidationExpression = new Regex("^[A-Z]{1}[0-9]{2}-?[0-9]{4}-?[0-9]{8}$") - }, - new() - { - Country = "MT", - Code = "eu_vat", - Description = "European VAT number (Malta)", - Example = "MT12345678", - ValidationExpression = new Regex("^MT[0-9]{8}$") - }, - new() - { - Country = "MX", - Code = "mx_rfc", - Description = "Mexican RFC number", - Example = "ABC010203AB9", - ValidationExpression = new Regex("^[A-Z]{3}[0-9]{6}[A-Z0-9]{3}$") - }, - new() - { - Country = "MD", - Code = "md_vat", - Description = "Moldova VAT Number", - Example = "1234567", - ValidationExpression = new Regex("^[0-9]{7}$") - }, - new() - { - Country = "MA", - Code = "ma_vat", - Description = "Morocco VAT Number", - Example = "12345678", - ValidationExpression = new Regex("^[0-9]{8}$") - }, - new() - { - Country = "NL", - Code = "eu_vat", - Description = "European VAT number (Netherlands)", - Example = "NL123456789B12", - ValidationExpression = new Regex("^NL[0-9]{9}B[0-9]{2}$") - }, - new() - { - Country = "NZ", - Code = "nz_gst", - Description = "New Zealand GST number", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "NG", - Code = "ng_tin", - Description = "Nigerian TIN Number", - Example = "12345678-0001", - ValidationExpression = new Regex("^[0-9]{8}-[0-9]{4}$") - }, - new() - { - Country = "NO", - Code = "no_vat", - Description = "Norwegian VAT number", - Example = "123456789MVA", - ValidationExpression = new Regex("^[0-9]{9}MVA$") - }, - new() - { - Country = "NO", - Code = "no_voec", - Description = "Norwegian VAT on e-commerce number", - Example = "1234567", - ValidationExpression = new Regex("^[0-9]{7}$") - }, - new() - { - Country = "OM", - Code = "om_vat", - Description = "Omani VAT Number", - Example = "OM1234567890", - ValidationExpression = new Regex("^OM[0-9]{10}$") - }, - new() - { - Country = "PE", - Code = "pe_ruc", - Description = "Peruvian RUC number", - Example = "12345678901", - ValidationExpression = new Regex("^[0-9]{11}$") - }, - new() - { - Country = "PH", - Code = "ph_tin", - Description = "Philippines Tax Identification Number", - Example = "123456789012", - ValidationExpression = new Regex("^[0-9]{12}$") - }, - new() - { - Country = "PL", - Code = "eu_vat", - Description = "European VAT number (Poland)", - Example = "PL1234567890", - ValidationExpression = new Regex("^PL[0-9]{10}$") - }, - new() - { - Country = "PT", - Code = "eu_vat", - Description = "European VAT number (Portugal)", - Example = "PT123456789", - ValidationExpression = new Regex("^PT[0-9]{9}$") - }, - new() - { - Country = "RO", - Code = "eu_vat", - Description = "European VAT number (Romania)", - Example = "RO1234567891", - ValidationExpression = new Regex("^RO[0-9]{2,10}$") - }, - new() - { - Country = "RO", - Code = "ro_tin", - Description = "Romanian tax ID number", - Example = "1234567890123", - ValidationExpression = new Regex("^[0-9]{13}$") - }, - new() - { - Country = "RU", - Code = "ru_inn", - Description = "Russian INN", - Example = "1234567891", - ValidationExpression = new Regex("^[0-9]{10,12}$") - }, - new() - { - Country = "RU", - Code = "ru_kpp", - Description = "Russian KPP", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "SA", - Code = "sa_vat", - Description = "Saudi Arabia VAT", - Example = "123456789012345", - ValidationExpression = new Regex("^[0-9]{15}$") - }, - new() - { - Country = "RS", - Code = "rs_pib", - Description = "Serbian PIB number", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "SG", - Code = "sg_gst", - Description = "Singaporean GST", - Example = "M12345678X", - ValidationExpression = new Regex("^[A-Z]{1}[0-9]{8}[A-Z]{1}$") - }, - new() - { - Country = "SG", - Code = "sg_uen", - Description = "Singaporean UEN", - Example = "123456789F", - ValidationExpression = new Regex("^[0-9]{9}[A-Z]{1}$") - }, - new() - { - Country = "SK", - Code = "eu_vat", - Description = "European VAT number (Slovakia)", - Example = "SK1234567891", - ValidationExpression = new Regex("^SK[0-9]{10}$") - }, - new() - { - Country = "SI", - Code = "eu_vat", - Description = "European VAT number (Slovenia)", - Example = "SI12345678", - ValidationExpression = new Regex("^SI[0-9]{8}$") - }, - new() - { - Country = "SI", - Code = "si_tin", - Description = "Slovenia tax number (davčna številka)", - Example = "12345678", - ValidationExpression = new Regex("^[0-9]{8}$") - }, - new() - { - Country = "ZA", - Code = "za_vat", - Description = "South African VAT number", - Example = "4123456789", - ValidationExpression = new Regex("^[0-9]{10}$") - }, - new() - { - Country = "KR", - Code = "kr_brn", - Description = "Korean BRN", - Example = "123-45-67890", - ValidationExpression = new Regex("^[0-9]{3}-?[0-9]{2}-?[0-9]{5}$") - }, - new() - { - Country = "ES", - Code = "es_cif", - Description = "Spanish NIF/CIF number", - Example = "A12345678", - ValidationExpression = new Regex("^[A-Z]{1}[0-9]{8}$") - }, - new() - { - Country = "ES", - Code = "eu_vat", - Description = "European VAT number (Spain)", - Example = "ESA1234567Z", - ValidationExpression = new Regex("^ES[A-Z]{1}[0-9]{7}[A-Z]{1}$") - }, - new() - { - Country = "SE", - Code = "eu_vat", - Description = "European VAT number (Sweden)", - Example = "SE123456789123", - ValidationExpression = new Regex("^SE[0-9]{12}$") - }, - new() - { - Country = "CH", - Code = "ch_uid", - Description = "Switzerland UID number", - Example = "CHE-123.456.789 HR", - ValidationExpression = new Regex("^CHE-?[0-9]{3}.?[0-9]{3}.?[0-9]{3} ?HR$") - }, - new() - { - Country = "CH", - Code = "ch_vat", - Description = "Switzerland VAT number", - Example = "CHE-123.456.789 MWST", - ValidationExpression = new Regex("^CHE-?[0-9]{3}.?[0-9]{3}.?[0-9]{3} ?MWST$") - }, - new() - { - Country = "TW", - Code = "tw_vat", - Description = "Taiwanese VAT", - Example = "12345678", - ValidationExpression = new Regex("^[0-9]{8}$") - }, - new() - { - Country = "TZ", - Code = "tz_vat", - Description = "Tanzania VAT Number", - Example = "12345678A", - ValidationExpression = new Regex("^[0-9]{8}[A-Z]{1}$") - }, - new() - { - Country = "TH", - Code = "th_vat", - Description = "Thai VAT", - Example = "1234567891234", - ValidationExpression = new Regex("^[0-9]{13}$") - }, - new() - { - Country = "TR", - Code = "tr_tin", - Description = "Turkish TIN Number", - Example = "0123456789", - ValidationExpression = new Regex("^[0-9]{10}$") - }, - new() - { - Country = "UA", - Code = "ua_vat", - Description = "Ukrainian VAT", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "AE", - Code = "ae_trn", - Description = "United Arab Emirates TRN", - Example = "123456789012345", - ValidationExpression = new Regex("^[0-9]{15}$") - }, - new() - { - Country = "GB", - Code = "eu_vat", - Description = "Northern Ireland VAT number", - Example = "XI123456789", - ValidationExpression = new Regex("^XI[0-9]{9}$") - }, - new() - { - Country = "GB", - Code = "gb_vat", - Description = "United Kingdom VAT number", - Example = "GB123456789", - ValidationExpression = new Regex("^GB[0-9]{9}$") - }, - new() - { - Country = "US", - Code = "us_ein", - Description = "United States EIN", - Example = "12-3456789", - ValidationExpression = new Regex("^[0-9]{2}-?[0-9]{7}$") - }, - new() - { - Country = "UY", - Code = "uy_ruc", - Description = "Uruguayan RUC number", - Example = "123456789012", - ValidationExpression = new Regex("^[0-9]{12}$") - }, - new() - { - Country = "UZ", - Code = "uz_tin", - Description = "Uzbekistan TIN Number", - Example = "123456789", - ValidationExpression = new Regex("^[0-9]{9}$") - }, - new() - { - Country = "UZ", - Code = "uz_vat", - Description = "Uzbekistan VAT Number", - Example = "123456789012", - ValidationExpression = new Regex("^[0-9]{12}$") - }, - new() - { - Country = "VE", - Code = "ve_rif", - Description = "Venezuelan RIF number", - Example = "A-12345678-9", - ValidationExpression = new Regex("^[A-Z]{1}-?[0-9]{8}-?[0-9]{1}$") - }, - new() - { - Country = "VN", - Code = "vn_tin", - Description = "Vietnamese tax ID number", - Example = "1234567890", - ValidationExpression = new Regex("^[0-9]{10}$") - } - ]; - - public string GetStripeTaxCode(string country, string taxId) - { - foreach (var taxIdType in _taxIdTypes.Where(x => x.Country == country)) - { - if (taxIdType.ValidationExpression.IsMatch(taxId)) - { - return taxIdType.Code; - } - } - - return null; - } - - public bool IsSupported(string country) - { - return _taxIdTypes.Any(x => x.Country == country); - } -} diff --git a/src/Core/Billing/Utilities.cs b/src/Core/Billing/Utilities.cs index 695a3b1bb485..28527af0c038 100644 --- a/src/Core/Billing/Utilities.cs +++ b/src/Core/Billing/Utilities.cs @@ -83,7 +83,6 @@ public static TaxInformation GetTaxInformation(Customer customer) customer.Address.Country, customer.Address.PostalCode, customer.TaxIds?.FirstOrDefault()?.Value, - customer.TaxIds?.FirstOrDefault()?.Type, customer.Address.Line1, customer.Address.Line2, customer.Address.City, diff --git a/src/Core/Models/Business/TaxInfo.cs b/src/Core/Models/Business/TaxInfo.cs index b12c5229b3e6..4424576ec91f 100644 --- a/src/Core/Models/Business/TaxInfo.cs +++ b/src/Core/Models/Business/TaxInfo.cs @@ -2,9 +2,18 @@ public class TaxInfo { - public string TaxIdNumber { get; set; } - public string TaxIdType { get; set; } + private string _taxIdNumber = null; + private string _taxIdType = null; + public string TaxIdNumber + { + get => _taxIdNumber; + set + { + _taxIdNumber = value; + _taxIdType = null; + } + } public string StripeTaxRateId { get; set; } public string BillingAddressLine1 { get; set; } public string BillingAddressLine2 { get; set; } @@ -12,6 +21,201 @@ public class TaxInfo public string BillingAddressState { get; set; } public string BillingAddressPostalCode { get; set; } public string BillingAddressCountry { get; set; } = "US"; + public string TaxIdType + { + get + { + if (string.IsNullOrWhiteSpace(BillingAddressCountry) || + string.IsNullOrWhiteSpace(TaxIdNumber)) + { + return null; + } + if (!string.IsNullOrWhiteSpace(_taxIdType)) + { + return _taxIdType; + } + + switch (BillingAddressCountry.ToUpper()) + { + case "AD": + _taxIdType = "ad_nrt"; + break; + case "AE": + _taxIdType = "ae_trn"; + break; + case "AR": + _taxIdType = "ar_cuit"; + break; + case "AU": + _taxIdType = "au_abn"; + break; + case "BO": + _taxIdType = "bo_tin"; + break; + case "BR": + _taxIdType = "br_cnpj"; + break; + case "CA": + // May break for those in Québec given the assumption of QST + if (BillingAddressState?.Contains("bec") ?? false) + { + _taxIdType = "ca_qst"; + break; + } + _taxIdType = "ca_bn"; + break; + case "CH": + _taxIdType = "ch_vat"; + break; + case "CL": + _taxIdType = "cl_tin"; + break; + case "CN": + _taxIdType = "cn_tin"; + break; + case "CO": + _taxIdType = "co_nit"; + break; + case "CR": + _taxIdType = "cr_tin"; + break; + case "DO": + _taxIdType = "do_rcn"; + break; + case "EC": + _taxIdType = "ec_ruc"; + break; + case "EG": + _taxIdType = "eg_tin"; + break; + case "GE": + _taxIdType = "ge_vat"; + break; + case "ID": + _taxIdType = "id_npwp"; + break; + case "IL": + _taxIdType = "il_vat"; + break; + case "IS": + _taxIdType = "is_vat"; + break; + case "KE": + _taxIdType = "ke_pin"; + break; + case "AT": + case "BE": + case "BG": + case "CY": + case "CZ": + case "DE": + case "DK": + case "EE": + case "ES": + case "FI": + case "FR": + case "GB": + case "GR": + case "HR": + case "HU": + case "IE": + case "IT": + case "LT": + case "LU": + case "LV": + case "MT": + case "NL": + case "PL": + case "PT": + case "RO": + case "SE": + case "SI": + case "SK": + _taxIdType = "eu_vat"; + break; + case "HK": + _taxIdType = "hk_br"; + break; + case "IN": + _taxIdType = "in_gst"; + break; + case "JP": + _taxIdType = "jp_cn"; + break; + case "KR": + _taxIdType = "kr_brn"; + break; + case "LI": + _taxIdType = "li_uid"; + break; + case "MX": + _taxIdType = "mx_rfc"; + break; + case "MY": + _taxIdType = "my_sst"; + break; + case "NO": + _taxIdType = "no_vat"; + break; + case "NZ": + _taxIdType = "nz_gst"; + break; + case "PE": + _taxIdType = "pe_ruc"; + break; + case "PH": + _taxIdType = "ph_tin"; + break; + case "RS": + _taxIdType = "rs_pib"; + break; + case "RU": + _taxIdType = "ru_inn"; + break; + case "SA": + _taxIdType = "sa_vat"; + break; + case "SG": + _taxIdType = "sg_gst"; + break; + case "SV": + _taxIdType = "sv_nit"; + break; + case "TH": + _taxIdType = "th_vat"; + break; + case "TR": + _taxIdType = "tr_tin"; + break; + case "TW": + _taxIdType = "tw_vat"; + break; + case "UA": + _taxIdType = "ua_vat"; + break; + case "US": + _taxIdType = "us_ein"; + break; + case "UY": + _taxIdType = "uy_ruc"; + break; + case "VE": + _taxIdType = "ve_rif"; + break; + case "VN": + _taxIdType = "vn_tin"; + break; + case "ZA": + _taxIdType = "za_vat"; + break; + default: + _taxIdType = null; + break; + } + + return _taxIdType; + } + } public bool HasTaxId { diff --git a/src/Core/Services/IPaymentService.cs b/src/Core/Services/IPaymentService.cs index 7d0f9d3c630b..bf9d047029b5 100644 --- a/src/Core/Services/IPaymentService.cs +++ b/src/Core/Services/IPaymentService.cs @@ -1,9 +1,6 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Models.Api.Requests.Accounts; -using Bit.Core.Billing.Models.Api.Requests.Organizations; -using Bit.Core.Billing.Models.Api.Responses; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Models.Business; @@ -62,7 +59,4 @@ Task AddSecretsManagerToSubscription(Organization org, Plan plan, int ad Task RisksSubscriptionFailure(Organization organization); Task HasSecretsManagerStandalone(Organization organization); Task<(DateTime?, DateTime?)> GetSuspensionDateAsync(Stripe.Subscription subscription); - Task PreviewInvoiceAsync(PreviewIndividualInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId); - Task PreviewInvoiceAsync(PreviewOrganizationInvoiceRequestBody parameters, string gatewayCustomerId, string gatewaySubscriptionId); - } diff --git a/src/Core/Services/IStripeAdapter.cs b/src/Core/Services/IStripeAdapter.cs index ef2e3ab76655..30583ef0b3e2 100644 --- a/src/Core/Services/IStripeAdapter.cs +++ b/src/Core/Services/IStripeAdapter.cs @@ -31,7 +31,6 @@ Task CustomerBalanceTransactionCreate(string custome Task InvoiceUpcomingAsync(Stripe.UpcomingInvoiceOptions options); Task InvoiceGetAsync(string id, Stripe.InvoiceGetOptions options); Task> InvoiceListAsync(StripeInvoiceListOptions options); - Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options); Task> InvoiceSearchAsync(InvoiceSearchOptions options); Task InvoiceUpdateAsync(string id, Stripe.InvoiceUpdateOptions options); Task InvoiceFinalizeInvoiceAsync(string id, Stripe.InvoiceFinalizeOptions options); @@ -43,7 +42,6 @@ Task CustomerBalanceTransactionCreate(string custome IAsyncEnumerable PaymentMethodListAutoPagingAsync(Stripe.PaymentMethodListOptions options); Task PaymentMethodAttachAsync(string id, Stripe.PaymentMethodAttachOptions options = null); Task PaymentMethodDetachAsync(string id, Stripe.PaymentMethodDetachOptions options = null); - Task PlanGetAsync(string id, Stripe.PlanGetOptions options = null); Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options); Task TaxRateUpdateAsync(string id, Stripe.TaxRateUpdateOptions options); Task TaxIdCreateAsync(string id, Stripe.TaxIdCreateOptions options); diff --git a/src/Core/Services/Implementations/StripeAdapter.cs b/src/Core/Services/Implementations/StripeAdapter.cs index f4f8efe75fb1..8d183314567f 100644 --- a/src/Core/Services/Implementations/StripeAdapter.cs +++ b/src/Core/Services/Implementations/StripeAdapter.cs @@ -15,7 +15,6 @@ public class StripeAdapter : IStripeAdapter private readonly Stripe.RefundService _refundService; private readonly Stripe.CardService _cardService; private readonly Stripe.BankAccountService _bankAccountService; - private readonly Stripe.PlanService _planService; private readonly Stripe.PriceService _priceService; private readonly Stripe.SetupIntentService _setupIntentService; private readonly Stripe.TestHelpers.TestClockService _testClockService; @@ -34,7 +33,6 @@ public StripeAdapter() _cardService = new Stripe.CardService(); _bankAccountService = new Stripe.BankAccountService(); _priceService = new Stripe.PriceService(); - _planService = new Stripe.PlanService(); _setupIntentService = new SetupIntentService(); _testClockService = new Stripe.TestHelpers.TestClockService(); _customerBalanceTransactionService = new CustomerBalanceTransactionService(); @@ -135,11 +133,6 @@ public async Task ProviderSubscriptionGetAsync( return invoices; } - public Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions options) - { - return _invoiceService.CreatePreviewAsync(options); - } - public async Task> InvoiceSearchAsync(InvoiceSearchOptions options) => (await _invoiceService.SearchAsync(options)).Data; @@ -191,11 +184,6 @@ public Task InvoiceCreatePreviewAsync(InvoiceCreatePreviewOptions optio return _paymentMethodService.DetachAsync(id, options); } - public Task PlanGetAsync(string id, Stripe.PlanGetOptions options = null) - { - return _planService.GetAsync(id, options); - } - public Task TaxRateCreateAsync(Stripe.TaxRateCreateOptions options) { return _taxRateService.CreateAsync(options); diff --git a/src/Core/Services/Implementations/StripePaymentService.cs b/src/Core/Services/Implementations/StripePaymentService.cs index c32dcd43ade8..259a4eb7572a 100644 --- a/src/Core/Services/Implementations/StripePaymentService.cs +++ b/src/Core/Services/Implementations/StripePaymentService.cs @@ -1,13 +1,8 @@ using Bit.Core.AdminConsole.Entities; using Bit.Core.AdminConsole.Entities.Provider; using Bit.Core.Billing.Constants; -using Bit.Core.Billing.Extensions; using Bit.Core.Billing.Models; -using Bit.Core.Billing.Models.Api.Requests.Accounts; -using Bit.Core.Billing.Models.Api.Requests.Organizations; -using Bit.Core.Billing.Models.Api.Responses; using Bit.Core.Billing.Models.Business; -using Bit.Core.Billing.Services; using Bit.Core.Entities; using Bit.Core.Enums; using Bit.Core.Exceptions; @@ -37,7 +32,6 @@ public class StripePaymentService : IPaymentService private readonly IStripeAdapter _stripeAdapter; private readonly IGlobalSettings _globalSettings; private readonly IFeatureService _featureService; - private readonly ITaxService _taxService; public StripePaymentService( ITransactionRepository transactionRepository, @@ -46,8 +40,7 @@ public StripePaymentService( IStripeAdapter stripeAdapter, Braintree.IBraintreeGateway braintreeGateway, IGlobalSettings globalSettings, - IFeatureService featureService, - ITaxService taxService) + IFeatureService featureService) { _transactionRepository = transactionRepository; _logger = logger; @@ -56,7 +49,6 @@ public StripePaymentService( _btGateway = braintreeGateway; _globalSettings = globalSettings; _featureService = featureService; - _taxService = taxService; } public async Task PurchaseOrganizationAsync(Organization org, PaymentMethodType paymentMethodType, @@ -120,20 +112,6 @@ public async Task PurchaseOrganizationAsync(Organization org, PaymentMet Subscription subscription; try { - if (taxInfo.TaxIdNumber != null && taxInfo.TaxIdType == null) - { - taxInfo.TaxIdType = _taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, - taxInfo.TaxIdNumber); - - if (taxInfo.TaxIdType == null) - { - _logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", - taxInfo.BillingAddressCountry, - taxInfo.TaxIdNumber); - throw new BadRequestException("billingTaxIdTypeInferenceError"); - } - } - var customerCreateOptions = new CustomerCreateOptions { Description = org.DisplayBusinessName(), @@ -168,9 +146,12 @@ public async Task PurchaseOrganizationAsync(Organization org, PaymentMet City = taxInfo?.BillingAddressCity, State = taxInfo?.BillingAddressState, }, - TaxIdData = taxInfo.HasTaxId - ? [new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber }] - : null + TaxIdData = taxInfo?.HasTaxId != true + ? null + : + [ + new CustomerTaxIdDataOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber, } + ], }; customerCreateOptions.AddExpand("tax"); @@ -1678,7 +1659,6 @@ public async Task GetTaxInfoAsync(ISubscriber subscriber) return new TaxInfo { TaxIdNumber = taxId?.Value, - TaxIdType = taxId?.Type, BillingAddressLine1 = address?.Line1, BillingAddressLine2 = address?.Line2, BillingAddressCity = address?.City, @@ -1690,13 +1670,9 @@ public async Task GetTaxInfoAsync(ISubscriber subscriber) public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) { - if (string.IsNullOrWhiteSpace(subscriber?.GatewayCustomerId) || subscriber.IsUser()) + if (subscriber != null && !string.IsNullOrWhiteSpace(subscriber.GatewayCustomerId)) { - return; - } - - var customer = await _stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, - new CustomerUpdateOptions + var customer = await _stripeAdapter.CustomerUpdateAsync(subscriber.GatewayCustomerId, new CustomerUpdateOptions { Address = new AddressOptions { @@ -1710,59 +1686,23 @@ public async Task SaveTaxInfoAsync(ISubscriber subscriber, TaxInfo taxInfo) Expand = ["tax_ids"] }); - if (customer == null) - { - return; - } - - var taxId = customer.TaxIds?.FirstOrDefault(); - - if (taxId != null) - { - await _stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); - } - - if (string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber)) - { - return; - } - - var taxIdType = taxInfo.TaxIdType; - - if (string.IsNullOrWhiteSpace(taxIdType)) - { - taxIdType = _taxService.GetStripeTaxCode(taxInfo.BillingAddressCountry, taxInfo.TaxIdNumber); - - if (taxIdType == null) + if (!subscriber.IsUser() && customer != null) { - _logger.LogWarning("Could not infer tax ID type in country '{Country}' with tax ID '{TaxID}'.", - taxInfo.BillingAddressCountry, - taxInfo.TaxIdNumber); - throw new BadRequestException("billingTaxIdTypeInferenceError"); - } - } + var taxId = customer.TaxIds?.FirstOrDefault(); - try - { - await _stripeAdapter.TaxIdCreateAsync(customer.Id, - new TaxIdCreateOptions { Type = taxInfo.TaxIdType, Value = taxInfo.TaxIdNumber, }); - } - catch (StripeException e) - { - switch (e.StripeError.Code) - { - case StripeConstants.ErrorCodes.TaxIdInvalid: - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - taxInfo.TaxIdNumber, - taxInfo.BillingAddressCountry); - throw new BadRequestException("billingInvalidTaxIdError"); - default: - _logger.LogError(e, - "Error creating tax ID '{TaxId}' in country '{Country}' for customer '{CustomerID}'.", - taxInfo.TaxIdNumber, - taxInfo.BillingAddressCountry, - customer.Id); - throw new BadRequestException("billingTaxIdCreationError"); + if (taxId != null) + { + await _stripeAdapter.TaxIdDeleteAsync(customer.Id, taxId.Id); + } + if (!string.IsNullOrWhiteSpace(taxInfo.TaxIdNumber) && + !string.IsNullOrWhiteSpace(taxInfo.TaxIdType)) + { + await _stripeAdapter.TaxIdCreateAsync(customer.Id, new TaxIdCreateOptions + { + Type = taxInfo.TaxIdType, + Value = taxInfo.TaxIdNumber, + }); + } } } } @@ -1895,285 +1835,6 @@ public async Task HasSecretsManagerStandalone(Organization organization) } } - public async Task PreviewInvoiceAsync( - PreviewIndividualInvoiceRequestBody parameters, - string gatewayCustomerId, - string gatewaySubscriptionId) - { - var options = new InvoiceCreatePreviewOptions - { - AutomaticTax = new InvoiceAutomaticTaxOptions - { - Enabled = true, - }, - Currency = "usd", - Discounts = new List(), - SubscriptionDetails = new InvoiceSubscriptionDetailsOptions - { - Items = - [ - new() - { - Quantity = 1, - Plan = "premium-annually" - }, - - new() - { - Quantity = parameters.PasswordManager.AdditionalStorage, - Plan = "storage-gb-annually" - } - ] - }, - CustomerDetails = new InvoiceCustomerDetailsOptions - { - Address = new AddressOptions - { - PostalCode = parameters.TaxInformation.PostalCode, - Country = parameters.TaxInformation.Country, - } - }, - }; - - if (!string.IsNullOrEmpty(parameters.TaxInformation.TaxId)) - { - var taxIdType = _taxService.GetStripeTaxCode( - options.CustomerDetails.Address.Country, - parameters.TaxInformation.TaxId); - - if (taxIdType == null) - { - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvalidTaxIdError"); - } - - options.CustomerDetails.TaxIds = [ - new InvoiceCustomerDetailsTaxIdOptions - { - Type = taxIdType, - Value = parameters.TaxInformation.TaxId - } - ]; - } - - if (gatewayCustomerId != null) - { - var gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); - - if (gatewayCustomer.Discount != null) - { - options.Discounts.Add(new InvoiceDiscountOptions - { - Discount = gatewayCustomer.Discount.Id - }); - } - - if (gatewaySubscriptionId != null) - { - var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId); - - if (gatewaySubscription?.Discount != null) - { - options.Discounts.Add(new InvoiceDiscountOptions - { - Discount = gatewaySubscription.Discount.Id - }); - } - } - } - - try - { - var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); - - var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null - ? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() - : 0M; - - var result = new PreviewInvoiceResponseModel( - effectiveTaxRate, - invoice.TotalExcludingTax.ToMajor() ?? 0, - invoice.Tax.ToMajor() ?? 0, - invoice.Total.ToMajor()); - return result; - } - catch (StripeException e) - { - switch (e.StripeError.Code) - { - case StripeConstants.ErrorCodes.TaxIdInvalid: - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvalidTaxIdError"); - default: - _logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvoiceError"); - } - } - } - - public async Task PreviewInvoiceAsync( - PreviewOrganizationInvoiceRequestBody parameters, - string gatewayCustomerId, - string gatewaySubscriptionId) - { - var plan = Utilities.StaticStore.GetPlan(parameters.PasswordManager.Plan); - - var options = new InvoiceCreatePreviewOptions - { - AutomaticTax = new InvoiceAutomaticTaxOptions - { - Enabled = true, - }, - Currency = "usd", - Discounts = new List(), - SubscriptionDetails = new InvoiceSubscriptionDetailsOptions - { - Items = - [ - new() - { - Quantity = parameters.PasswordManager.AdditionalStorage, - Plan = plan.PasswordManager.StripeStoragePlanId - } - ] - }, - CustomerDetails = new InvoiceCustomerDetailsOptions - { - Address = new AddressOptions - { - PostalCode = parameters.TaxInformation.PostalCode, - Country = parameters.TaxInformation.Country, - } - }, - }; - - if (plan.PasswordManager.HasAdditionalSeatsOption) - { - options.SubscriptionDetails.Items.Add( - new() - { - Quantity = parameters.PasswordManager.Seats, - Plan = plan.PasswordManager.StripeSeatPlanId - } - ); - } - else - { - options.SubscriptionDetails.Items.Add( - new() - { - Quantity = 1, - Plan = plan.PasswordManager.StripePlanId - } - ); - } - - if (plan.SupportsSecretsManager) - { - if (plan.SecretsManager.HasAdditionalSeatsOption) - { - options.SubscriptionDetails.Items.Add(new() - { - Quantity = parameters.SecretsManager?.Seats ?? 0, - Plan = plan.SecretsManager.StripeSeatPlanId - }); - } - - if (plan.SecretsManager.HasAdditionalServiceAccountOption) - { - options.SubscriptionDetails.Items.Add(new() - { - Quantity = parameters.SecretsManager?.AdditionalMachineAccounts ?? 0, - Plan = plan.SecretsManager.StripeServiceAccountPlanId - }); - } - } - - if (!string.IsNullOrEmpty(parameters.TaxInformation.TaxId)) - { - var taxIdType = _taxService.GetStripeTaxCode( - options.CustomerDetails.Address.Country, - parameters.TaxInformation.TaxId); - - if (taxIdType == null) - { - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingTaxIdTypeInferenceError"); - } - - options.CustomerDetails.TaxIds = [ - new InvoiceCustomerDetailsTaxIdOptions - { - Type = taxIdType, - Value = parameters.TaxInformation.TaxId - } - ]; - } - - if (gatewayCustomerId != null) - { - var gatewayCustomer = await _stripeAdapter.CustomerGetAsync(gatewayCustomerId); - - if (gatewayCustomer.Discount != null) - { - options.Discounts.Add(new InvoiceDiscountOptions - { - Discount = gatewayCustomer.Discount.Id - }); - } - - var gatewaySubscription = await _stripeAdapter.SubscriptionGetAsync(gatewaySubscriptionId); - - if (gatewaySubscription?.Discount != null) - { - options.Discounts.Add(new InvoiceDiscountOptions - { - Discount = gatewaySubscription.Discount.Id - }); - } - } - - try - { - var invoice = await _stripeAdapter.InvoiceCreatePreviewAsync(options); - - var effectiveTaxRate = invoice.Tax != null && invoice.TotalExcludingTax != null - ? invoice.Tax.Value.ToMajor() / invoice.TotalExcludingTax.Value.ToMajor() - : 0M; - - var result = new PreviewInvoiceResponseModel( - effectiveTaxRate, - invoice.TotalExcludingTax.ToMajor() ?? 0, - invoice.Tax.ToMajor() ?? 0, - invoice.Total.ToMajor()); - return result; - } - catch (StripeException e) - { - switch (e.StripeError.Code) - { - case StripeConstants.ErrorCodes.TaxIdInvalid: - _logger.LogWarning("Invalid tax ID '{TaxID}' for country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvalidTaxIdError"); - default: - _logger.LogError(e, "Unexpected error previewing invoice with tax ID '{TaxId}' in country '{Country}'.", - parameters.TaxInformation.TaxId, - parameters.TaxInformation.Country); - throw new BadRequestException("billingPreviewInvoiceError"); - } - } - } - private PaymentMethod GetLatestCardPaymentMethod(string customerId) { var cardPaymentMethods = _stripeAdapter.PaymentMethodListAutoPaging( diff --git a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs index 9c25ffdc5574..385b185ffe1b 100644 --- a/test/Core.Test/Billing/Services/SubscriberServiceTests.cs +++ b/test/Core.Test/Billing/Services/SubscriberServiceTests.cs @@ -1545,7 +1545,7 @@ public async Task UpdateTaxInformation_NonUser_MakesCorrectInvocations( { var stripeAdapter = sutProvider.GetDependency(); - var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1", Type = "us_ein" }] } }; + var customer = new Customer { Id = provider.GatewayCustomerId, TaxIds = new StripeList { Data = [new TaxId { Id = "tax_id_1" }] } }; stripeAdapter.CustomerGetAsync(provider.GatewayCustomerId, Arg.Is( options => options.Expand.Contains("tax_ids"))).Returns(customer); @@ -1554,7 +1554,6 @@ public async Task UpdateTaxInformation_NonUser_MakesCorrectInvocations( "US", "12345", "123456789", - "us_ein", "123 Example St.", null, "Example Town", diff --git a/test/Core.Test/Models/Business/TaxInfoTests.cs b/test/Core.Test/Models/Business/TaxInfoTests.cs new file mode 100644 index 000000000000..197948006e97 --- /dev/null +++ b/test/Core.Test/Models/Business/TaxInfoTests.cs @@ -0,0 +1,114 @@ +using Bit.Core.Models.Business; +using Xunit; + +namespace Bit.Core.Test.Models.Business; + +public class TaxInfoTests +{ + // PH = Placeholder + [Theory] + [InlineData(null, null, null, null)] + [InlineData("", "", null, null)] + [InlineData("PH", "", null, null)] + [InlineData("", "PH", null, null)] + [InlineData("AE", "PH", null, "ae_trn")] + [InlineData("AU", "PH", null, "au_abn")] + [InlineData("BR", "PH", null, "br_cnpj")] + [InlineData("CA", "PH", "bec", "ca_qst")] + [InlineData("CA", "PH", null, "ca_bn")] + [InlineData("CL", "PH", null, "cl_tin")] + [InlineData("AT", "PH", null, "eu_vat")] + [InlineData("BE", "PH", null, "eu_vat")] + [InlineData("BG", "PH", null, "eu_vat")] + [InlineData("CY", "PH", null, "eu_vat")] + [InlineData("CZ", "PH", null, "eu_vat")] + [InlineData("DE", "PH", null, "eu_vat")] + [InlineData("DK", "PH", null, "eu_vat")] + [InlineData("EE", "PH", null, "eu_vat")] + [InlineData("ES", "PH", null, "eu_vat")] + [InlineData("FI", "PH", null, "eu_vat")] + [InlineData("FR", "PH", null, "eu_vat")] + [InlineData("GB", "PH", null, "eu_vat")] + [InlineData("GR", "PH", null, "eu_vat")] + [InlineData("HR", "PH", null, "eu_vat")] + [InlineData("HU", "PH", null, "eu_vat")] + [InlineData("IE", "PH", null, "eu_vat")] + [InlineData("IT", "PH", null, "eu_vat")] + [InlineData("LT", "PH", null, "eu_vat")] + [InlineData("LU", "PH", null, "eu_vat")] + [InlineData("LV", "PH", null, "eu_vat")] + [InlineData("MT", "PH", null, "eu_vat")] + [InlineData("NL", "PH", null, "eu_vat")] + [InlineData("PL", "PH", null, "eu_vat")] + [InlineData("PT", "PH", null, "eu_vat")] + [InlineData("RO", "PH", null, "eu_vat")] + [InlineData("SE", "PH", null, "eu_vat")] + [InlineData("SI", "PH", null, "eu_vat")] + [InlineData("SK", "PH", null, "eu_vat")] + [InlineData("HK", "PH", null, "hk_br")] + [InlineData("IN", "PH", null, "in_gst")] + [InlineData("JP", "PH", null, "jp_cn")] + [InlineData("KR", "PH", null, "kr_brn")] + [InlineData("LI", "PH", null, "li_uid")] + [InlineData("MX", "PH", null, "mx_rfc")] + [InlineData("MY", "PH", null, "my_sst")] + [InlineData("NO", "PH", null, "no_vat")] + [InlineData("NZ", "PH", null, "nz_gst")] + [InlineData("RU", "PH", null, "ru_inn")] + [InlineData("SA", "PH", null, "sa_vat")] + [InlineData("SG", "PH", null, "sg_gst")] + [InlineData("TH", "PH", null, "th_vat")] + [InlineData("TW", "PH", null, "tw_vat")] + [InlineData("US", "PH", null, "us_ein")] + [InlineData("ZA", "PH", null, "za_vat")] + [InlineData("ABCDEF", "PH", null, null)] + public void GetTaxIdType_Success(string billingAddressCountry, + string taxIdNumber, + string billingAddressState, + string expectedTaxIdType) + { + var taxInfo = new TaxInfo + { + BillingAddressCountry = billingAddressCountry, + TaxIdNumber = taxIdNumber, + BillingAddressState = billingAddressState, + }; + + Assert.Equal(expectedTaxIdType, taxInfo.TaxIdType); + } + + [Fact] + public void GetTaxIdType_CreateOnce_ReturnCacheSecondTime() + { + var taxInfo = new TaxInfo + { + BillingAddressCountry = "US", + TaxIdNumber = "PH", + BillingAddressState = null, + }; + + Assert.Equal("us_ein", taxInfo.TaxIdType); + + // Per the current spec even if the values change to something other than null it + // will return the cached version of TaxIdType. + taxInfo.BillingAddressCountry = "ZA"; + + Assert.Equal("us_ein", taxInfo.TaxIdType); + } + + [Theory] + [InlineData(null, null, false)] + [InlineData("123", "US", true)] + [InlineData("123", "ZQ12", false)] + [InlineData(" ", "US", false)] + public void HasTaxId_ReturnsExpected(string taxIdNumber, string billingAddressCountry, bool expected) + { + var taxInfo = new TaxInfo + { + TaxIdNumber = taxIdNumber, + BillingAddressCountry = billingAddressCountry, + }; + + Assert.Equal(expected, taxInfo.HasTaxId); + } +} diff --git a/test/Core.Test/Services/StripePaymentServiceTests.cs b/test/Core.Test/Services/StripePaymentServiceTests.cs index 35e1901a2fca..e15f07b11384 100644 --- a/test/Core.Test/Services/StripePaymentServiceTests.cs +++ b/test/Core.Test/Services/StripePaymentServiceTests.cs @@ -77,8 +77,7 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => @@ -135,8 +134,7 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => @@ -192,8 +190,7 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => @@ -250,8 +247,7 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => @@ -445,8 +441,7 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => @@ -515,8 +510,7 @@ await stripeAdapter.Received().CustomerCreateAsync(Arg.Is(s => From 3c75ff335b3faee4f924aa97ea089a8454cc4698 Mon Sep 17 00:00:00 2001 From: Alex Morask <144709477+amorask-bitwarden@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:36:37 -0500 Subject: [PATCH 26/28] [PM-15536] Allow reseller to add organization (#5111) * Allow reseller to add organization * Run dotnet format --- .../AdminConsole/Services/ProviderService.cs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs index e384d71df904..864466ad4536 100644 --- a/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs +++ b/bitwarden_license/src/Commercial.Core/AdminConsole/Services/ProviderService.cs @@ -27,7 +27,11 @@ namespace Bit.Commercial.Core.AdminConsole.Services; public class ProviderService : IProviderService { - public static PlanType[] ProviderDisallowedOrganizationTypes = new[] { PlanType.Free, PlanType.FamiliesAnnually, PlanType.FamiliesAnnually2019 }; + private static readonly PlanType[] _resellerDisallowedOrganizationTypes = [ + PlanType.Free, + PlanType.FamiliesAnnually, + PlanType.FamiliesAnnually2019 + ]; private readonly IDataProtector _dataProtector; private readonly IMailService _mailService; @@ -690,13 +694,14 @@ private void ThrowOnInvalidPlanType(ProviderType providerType, PlanType requeste throw new BadRequestException($"Multi-organization Enterprise Providers cannot manage organizations with the plan type {requestedType}. Only Enterprise (Monthly) and Enterprise (Annually) are allowed."); } break; + case ProviderType.Reseller: + if (_resellerDisallowedOrganizationTypes.Contains(requestedType)) + { + throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed."); + } + break; default: throw new BadRequestException($"Unsupported provider type {providerType}."); } - - if (ProviderDisallowedOrganizationTypes.Contains(requestedType)) - { - throw new BadRequestException($"Providers cannot manage organizations with the requested plan type ({requestedType}). Only Teams and Enterprise accounts are allowed."); - } } } From 74e86935a45dee256154a6fad92f80f54f3a4887 Mon Sep 17 00:00:00 2001 From: Nick Krantz <125900171+nick-livefront@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:19:10 -0600 Subject: [PATCH 27/28] add `PM9111ExtensionPersistAddEditForm` feature flag (#5106) --- src/Core/Constants.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 6efdedfeed92..1d807149a924 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -155,6 +155,7 @@ public static class FeatureFlagKeys public const string PM14401_ScaleMSPOnClientOrganizationUpdate = "PM-14401-scale-msp-on-client-organization-update"; public const string DisableFreeFamiliesSponsorship = "PM-12274-disable-free-families-sponsorship"; public const string MacOsNativeCredentialSync = "macos-native-credential-sync"; + public const string PM9111ExtensionPersistAddEditForm = "pm-9111-extension-persist-add-edit-form"; public static List GetAllKeys() { From 0e32dcccadd759e938542dc8069bc7be9b4e297b Mon Sep 17 00:00:00 2001 From: Robyn MacCallum Date: Wed, 4 Dec 2024 14:42:12 -0500 Subject: [PATCH 28/28] Update Constants.cs (#5112) --- src/Core/Constants.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Core/Constants.cs b/src/Core/Constants.cs index 1d807149a924..46ed985cb8b0 100644 --- a/src/Core/Constants.cs +++ b/src/Core/Constants.cs @@ -156,6 +156,7 @@ public static class FeatureFlagKeys public const string DisableFreeFamiliesSponsorship = "PM-12274-disable-free-families-sponsorship"; public const string MacOsNativeCredentialSync = "macos-native-credential-sync"; public const string PM9111ExtensionPersistAddEditForm = "pm-9111-extension-persist-add-edit-form"; + public const string InlineMenuTotp = "inline-menu-totp"; public static List GetAllKeys() {