diff --git a/src/Billing/Controllers/StripeController.cs b/src/Billing/Controllers/StripeController.cs index fb63e1993e36..e71e025dffc5 100644 --- a/src/Billing/Controllers/StripeController.cs +++ b/src/Billing/Controllers/StripeController.cs @@ -52,6 +52,7 @@ public class StripeController : Controller private readonly ICurrentContext _currentContext; private readonly GlobalSettings _globalSettings; private readonly IStripeEventService _stripeEventService; + private readonly IStripeFacade _stripeFacade; public StripeController( GlobalSettings globalSettings, @@ -70,7 +71,8 @@ public StripeController( ITaxRateRepository taxRateRepository, IUserRepository userRepository, ICurrentContext currentContext, - IStripeEventService stripeEventService) + IStripeEventService stripeEventService, + IStripeFacade stripeFacade) { _billingSettings = billingSettings?.Value; _hostingEnvironment = hostingEnvironment; @@ -97,6 +99,7 @@ public StripeController( _currentContext = currentContext; _globalSettings = globalSettings; _stripeEventService = stripeEventService; + _stripeFacade = stripeFacade; } [HttpPost("webhook")] @@ -209,48 +212,71 @@ await _userService.UpdatePremiumExpirationAsync(userId, else if (parsedEvent.Type.Equals(HandledStripeWebhook.UpcomingInvoice)) { var invoice = await _stripeEventService.GetInvoice(parsedEvent); - var subscriptionService = new SubscriptionService(); - var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId); + + if (string.IsNullOrEmpty(invoice.SubscriptionId)) + { + _logger.LogWarning("Received 'invoice.upcoming' Event with ID '{eventId}' that did not include a Subscription ID", parsedEvent.Id); + return new OkResult(); + } + + var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId); + if (subscription == null) { - throw new Exception("Invoice subscription is null. " + invoice.Id); + throw new Exception( + $"Received null Subscription from Stripe for ID '{invoice.SubscriptionId}' while processing Event with ID '{parsedEvent.Id}'"); } - subscription = await VerifyCorrectTaxRateForCharge(invoice, subscription); + var updatedSubscription = await VerifyCorrectTaxRateForCharge(invoice, subscription); - string email = null; - var ids = GetIdsFromMetaData(subscription.Metadata); - // org - if (ids.Item1.HasValue) + var (organizationId, userId) = GetIdsFromMetaData(updatedSubscription.Metadata); + + var invoiceLineItemDescriptions = invoice.Lines.Select(i => i.Description).ToList(); + + async Task SendEmails(IEnumerable emails) { - // sponsored org - if (IsSponsoredSubscription(subscription)) + var validEmails = emails.Where(e => !string.IsNullOrEmpty(e)); + + if (invoice.NextPaymentAttempt.HasValue) { - await _validateSponsorshipCommand.ValidateSponsorshipAsync(ids.Item1.Value); + await _mailService.SendInvoiceUpcoming( + validEmails, + invoice.AmountDue / 100M, + invoice.NextPaymentAttempt.Value, + invoiceLineItemDescriptions, + true); } + } - var org = await _organizationRepository.GetByIdAsync(ids.Item1.Value); - if (org != null && OrgPlanForInvoiceNotifications(org)) + if (organizationId.HasValue) + { + if (IsSponsoredSubscription(updatedSubscription)) + { + await _validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId.Value); + } + + var organization = await _organizationRepository.GetByIdAsync(organizationId.Value); + + if (organization == null || !OrgPlanForInvoiceNotifications(organization)) { - email = org.BillingEmail; + return new OkResult(); } + + await SendEmails(new List { organization.BillingEmail }); + + var ownerEmails = await _organizationRepository.GetOwnerEmailAddressesById(organization.Id); + + await SendEmails(ownerEmails); } - // user - else if (ids.Item2.HasValue) + else if (userId.HasValue) { - var user = await _userService.GetUserByIdAsync(ids.Item2.Value); + var user = await _userService.GetUserByIdAsync(userId.Value); + if (user.Premium) { - email = user.Email; + await SendEmails(new List { user.Email }); } } - - if (!string.IsNullOrWhiteSpace(email) && invoice.NextPaymentAttempt.HasValue) - { - var items = invoice.Lines.Select(i => i.Description).ToList(); - await _mailService.SendInvoiceUpcomingAsync(email, invoice.AmountDue / 100M, - invoice.NextPaymentAttempt.Value, items, true); - } } else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeSucceeded)) { diff --git a/src/Billing/Services/Implementations/StripeEventService.cs b/src/Billing/Services/Implementations/StripeEventService.cs index 076602e3d209..ce7ab311ffd0 100644 --- a/src/Billing/Services/Implementations/StripeEventService.cs +++ b/src/Billing/Services/Implementations/StripeEventService.cs @@ -7,13 +7,16 @@ namespace Bit.Billing.Services.Implementations; public class StripeEventService : IStripeEventService { private readonly GlobalSettings _globalSettings; + private readonly ILogger _logger; private readonly IStripeFacade _stripeFacade; public StripeEventService( GlobalSettings globalSettings, + ILogger logger, IStripeFacade stripeFacade) { _globalSettings = globalSettings; + _logger = logger; _stripeFacade = stripeFacade; } @@ -26,6 +29,12 @@ public async Task GetCharge(Event stripeEvent, bool fresh = false, List< return eventCharge; } + if (string.IsNullOrEmpty(eventCharge.Id)) + { + _logger.LogWarning("Cannot retrieve up-to-date Charge for Event with ID '{eventId}' because no Charge ID was included in the Event.", stripeEvent.Id); + return eventCharge; + } + var charge = await _stripeFacade.GetCharge(eventCharge.Id, new ChargeGetOptions { Expand = expand }); if (charge == null) @@ -46,6 +55,12 @@ public async Task GetCustomer(Event stripeEvent, bool fresh = false, L return eventCustomer; } + if (string.IsNullOrEmpty(eventCustomer.Id)) + { + _logger.LogWarning("Cannot retrieve up-to-date Customer for Event with ID '{eventId}' because no Customer ID was included in the Event.", stripeEvent.Id); + return eventCustomer; + } + var customer = await _stripeFacade.GetCustomer(eventCustomer.Id, new CustomerGetOptions { Expand = expand }); if (customer == null) @@ -66,6 +81,12 @@ public async Task GetInvoice(Event stripeEvent, bool fresh = false, Lis return eventInvoice; } + if (string.IsNullOrEmpty(eventInvoice.Id)) + { + _logger.LogWarning("Cannot retrieve up-to-date Invoice for Event with ID '{eventId}' because no Invoice ID was included in the Event.", stripeEvent.Id); + return eventInvoice; + } + var invoice = await _stripeFacade.GetInvoice(eventInvoice.Id, new InvoiceGetOptions { Expand = expand }); if (invoice == null) @@ -86,6 +107,12 @@ public async Task GetPaymentMethod(Event stripeEvent, bool fresh return eventPaymentMethod; } + if (string.IsNullOrEmpty(eventPaymentMethod.Id)) + { + _logger.LogWarning("Cannot retrieve up-to-date Payment Method for Event with ID '{eventId}' because no Payment Method ID was included in the Event.", stripeEvent.Id); + return eventPaymentMethod; + } + var paymentMethod = await _stripeFacade.GetPaymentMethod(eventPaymentMethod.Id, new PaymentMethodGetOptions { Expand = expand }); if (paymentMethod == null) @@ -106,6 +133,12 @@ public async Task GetSubscription(Event stripeEvent, bool fresh = return eventSubscription; } + if (string.IsNullOrEmpty(eventSubscription.Id)) + { + _logger.LogWarning("Cannot retrieve up-to-date Subscription for Event with ID '{eventId}' because no Subscription ID was included in the Event.", stripeEvent.Id); + return eventSubscription; + } + var subscription = await _stripeFacade.GetSubscription(eventSubscription.Id, new SubscriptionGetOptions { Expand = expand }); if (subscription == null) @@ -132,7 +165,7 @@ public async Task ValidateCloudRegion(Event stripeEvent) (await GetCharge(stripeEvent, true, customerExpansion))?.Customer?.Metadata, HandledStripeWebhook.UpcomingInvoice => - (await GetInvoice(stripeEvent, true, customerExpansion))?.Customer?.Metadata, + await GetCustomerMetadataFromUpcomingInvoiceEvent(stripeEvent), HandledStripeWebhook.PaymentSucceeded or HandledStripeWebhook.PaymentFailed or HandledStripeWebhook.InvoiceCreated => (await GetInvoice(stripeEvent, true, customerExpansion))?.Customer?.Metadata, @@ -154,6 +187,20 @@ public async Task ValidateCloudRegion(Event stripeEvent) var customerRegion = GetCustomerRegion(customerMetadata); return customerRegion == serverRegion; + + /* In Stripe, when we receive an invoice.upcoming event, the event does not include an Invoice ID because + the invoice hasn't been created yet. Therefore, rather than retrieving the fresh Invoice with a 'customer' + expansion, we need to use the Customer ID on the event to retrieve the metadata. */ + async Task> GetCustomerMetadataFromUpcomingInvoiceEvent(Event localStripeEvent) + { + var invoice = await GetInvoice(localStripeEvent); + + var customer = !string.IsNullOrEmpty(invoice.CustomerId) + ? await _stripeFacade.GetCustomer(invoice.CustomerId) + : null; + + return customer?.Metadata; + } } private static T Extract(Event stripeEvent) diff --git a/src/Core/Repositories/IOrganizationRepository.cs b/src/Core/Repositories/IOrganizationRepository.cs index 14126adb0aad..4ac518489bde 100644 --- a/src/Core/Repositories/IOrganizationRepository.cs +++ b/src/Core/Repositories/IOrganizationRepository.cs @@ -14,4 +14,5 @@ public interface IOrganizationRepository : IRepository Task GetByLicenseKeyAsync(string licenseKey); Task GetSelfHostedOrganizationDetailsById(Guid id); Task> SearchUnassignedToProviderAsync(string name, string ownerEmail, int skip, int take); + Task> GetOwnerEmailAddressesById(Guid organizationId); } diff --git a/src/Core/Services/IMailService.cs b/src/Core/Services/IMailService.cs index 0e5831082f8a..6350a0e46183 100644 --- a/src/Core/Services/IMailService.cs +++ b/src/Core/Services/IMailService.cs @@ -24,7 +24,17 @@ public interface IMailService Task SendOrganizationConfirmedEmailAsync(string organizationName, string email); Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email); Task SendPasswordlessSignInAsync(string returnUrl, string token, string email); - Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, List items, + Task SendInvoiceUpcoming( + string email, + decimal amount, + DateTime dueDate, + List items, + bool mentionInvoices); + Task SendInvoiceUpcoming( + IEnumerable email, + decimal amount, + DateTime dueDate, + List items, bool mentionInvoices); Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices); Task SendAddedCreditAsync(string email, decimal amount); diff --git a/src/Core/Services/Implementations/HandlebarsMailService.cs b/src/Core/Services/Implementations/HandlebarsMailService.cs index 98ff7df07bd7..24974f7ff061 100644 --- a/src/Core/Services/Implementations/HandlebarsMailService.cs +++ b/src/Core/Services/Implementations/HandlebarsMailService.cs @@ -285,10 +285,21 @@ public async Task SendPasswordlessSignInAsync(string returnUrl, string token, st await _mailDeliveryService.SendEmailAsync(message); } - public async Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, - List items, bool mentionInvoices) - { - var message = CreateDefaultMessage("Your Subscription Will Renew Soon", email); + public async Task SendInvoiceUpcoming( + string email, + decimal amount, + DateTime dueDate, + List items, + bool mentionInvoices) => await SendInvoiceUpcoming(new List { email }, amount, dueDate, items, mentionInvoices); + + public async Task SendInvoiceUpcoming( + IEnumerable emails, + decimal amount, + DateTime dueDate, + List items, + bool mentionInvoices) + { + var message = CreateDefaultMessage("Your Subscription Will Renew Soon", emails); var model = new InvoiceUpcomingViewModel { WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash, diff --git a/src/Core/Services/NoopImplementations/NoopMailService.cs b/src/Core/Services/NoopImplementations/NoopMailService.cs index 97d69cfa48bf..089ae18f181c 100644 --- a/src/Core/Services/NoopImplementations/NoopMailService.cs +++ b/src/Core/Services/NoopImplementations/NoopMailService.cs @@ -88,11 +88,19 @@ public Task SendPasswordlessSignInAsync(string returnUrl, string token, string e return Task.FromResult(0); } - public Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, - List items, bool mentionInvoices) - { - return Task.FromResult(0); - } + public Task SendInvoiceUpcoming( + string email, + decimal amount, + DateTime dueDate, + List items, + bool mentionInvoices) => Task.FromResult(0); + + public Task SendInvoiceUpcoming( + IEnumerable emails, + decimal amount, + DateTime dueDate, + List items, + bool mentionInvoices) => Task.FromResult(0); public Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices) { diff --git a/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs b/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs index 9d8cad0f9cef..9329e23790e1 100644 --- a/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs @@ -149,4 +149,14 @@ public async Task> SearchUnassignedToProviderAsync(str return results.ToList(); } } + + public async Task> GetOwnerEmailAddressesById(Guid organizationId) + { + await using var connection = new SqlConnection(ConnectionString); + + return await connection.QueryAsync( + $"[{Schema}].[{Table}_ReadOwnerEmailAddressesById]", + new { OrganizationId = organizationId }, + commandType: CommandType.StoredProcedure); + } } diff --git a/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs b/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs index b7ffb9978a49..62f4df63e366 100644 --- a/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs +++ b/src/Infrastructure.EntityFramework/Repositories/OrganizationRepository.cs @@ -224,4 +224,24 @@ public async Task GetSelfHostedOrganizationDetail return selfHostedOrganization; } } + + public async Task> GetOwnerEmailAddressesById(Guid organizationId) + { + using var scope = ServiceScopeFactory.CreateScope(); + + var dbContext = GetDatabaseContext(scope); + + var query = + from u in dbContext.Users + join ou in dbContext.OrganizationUsers on u.Id equals ou.UserId + where + ou.OrganizationId == organizationId && + ou.Type == OrganizationUserType.Owner && + ou.Status == OrganizationUserStatusType.Confirmed + group u by u.Email + into grouped + select grouped.Key; + + return await query.ToListAsync(); + } } diff --git a/test/Billing.Test/Services/StripeEventServiceTests.cs b/test/Billing.Test/Services/StripeEventServiceTests.cs index 5b1642413d9f..f2fe1c8d19fb 100644 --- a/test/Billing.Test/Services/StripeEventServiceTests.cs +++ b/test/Billing.Test/Services/StripeEventServiceTests.cs @@ -3,6 +3,7 @@ using Bit.Billing.Test.Utilities; using Bit.Core.Settings; using FluentAssertions; +using Microsoft.Extensions.Logging; using NSubstitute; using Stripe; using Xunit; @@ -21,7 +22,7 @@ public StripeEventServiceTests() globalSettings.BaseServiceUri = baseServiceUriSettings; _stripeFacade = Substitute.For(); - _stripeEventService = new StripeEventService(globalSettings, _stripeFacade); + _stripeEventService = new StripeEventService(globalSettings, Substitute.For>(), _stripeFacade); } #region GetCharge diff --git a/util/Migrator/DbScripts/2023-10-03_00_OrganizationReadOwnerEmailAddresses.sql b/util/Migrator/DbScripts/2023-10-03_00_OrganizationReadOwnerEmailAddresses.sql new file mode 100644 index 000000000000..c88b12af0348 --- /dev/null +++ b/util/Migrator/DbScripts/2023-10-03_00_OrganizationReadOwnerEmailAddresses.sql @@ -0,0 +1,17 @@ +CREATE OR ALTER PROCEDURE [dbo].[Organization_ReadOwnerEmailAddressesById] + @OrganizationId UNIQUEIDENTIFIER +AS +BEGIN + SET NOCOUNT ON + + SELECT + [U].[Email] + FROM [User] AS [U] + INNER JOIN [OrganizationUser] AS [OU] ON [U].[Id] = [OU].[UserId] + WHERE + [OU].[OrganizationId] = @OrganizationId AND + [OU].[Type] = 0 AND -- Owner + [OU].[Status] = 2 -- Confirmed + GROUP BY [U].[Email] +END +GO