Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AC-1693] Send InvoiceUpcoming Notification to Client Owners #3319

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions src/Billing/Controllers/StripeController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class StripeController : Controller
private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings;
private readonly IStripeEventService _stripeEventService;
private readonly IStripeFacade _stripeFacade;

public StripeController(
GlobalSettings globalSettings,
Expand All @@ -70,7 +71,8 @@ public StripeController(
ITaxRateRepository taxRateRepository,
IUserRepository userRepository,
ICurrentContext currentContext,
IStripeEventService stripeEventService)
IStripeEventService stripeEventService,
IStripeFacade stripeFacade)
{
_billingSettings = billingSettings?.Value;
_hostingEnvironment = hostingEnvironment;
Expand All @@ -97,6 +99,7 @@ public StripeController(
_currentContext = currentContext;
_globalSettings = globalSettings;
_stripeEventService = stripeEventService;
_stripeFacade = stripeFacade;
}

[HttpPost("webhook")]
Expand Down Expand Up @@ -209,48 +212,71 @@ await _userService.UpdatePremiumExpirationAsync(userId,
else if (parsedEvent.Type.Equals(HandledStripeWebhook.UpcomingInvoice))
{
var invoice = await _stripeEventService.GetInvoice(parsedEvent);
var subscriptionService = new SubscriptionService();
var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId);

if (string.IsNullOrEmpty(invoice.SubscriptionId))
{
_logger.LogWarning("Received 'invoice.upcoming' Event with ID '{eventId}' that did not include a Subscription ID", parsedEvent.Id);
return new OkResult();
}

var subscription = await _stripeFacade.GetSubscription(invoice.SubscriptionId);

if (subscription == null)
{
throw new Exception("Invoice subscription is null. " + invoice.Id);
throw new Exception(
$"Received null Subscription from Stripe for ID '{invoice.SubscriptionId}' while processing Event with ID '{parsedEvent.Id}'");
}

subscription = await VerifyCorrectTaxRateForCharge(invoice, subscription);
var updatedSubscription = await VerifyCorrectTaxRateForCharge(invoice, subscription);

string email = null;
var ids = GetIdsFromMetaData(subscription.Metadata);
// org
if (ids.Item1.HasValue)
var (organizationId, userId) = GetIdsFromMetaData(updatedSubscription.Metadata);

var invoiceLineItemDescriptions = invoice.Lines.Select(i => i.Description).ToList();

async Task SendEmails(IEnumerable<string> emails)
{
// sponsored org
if (IsSponsoredSubscription(subscription))
var validEmails = emails.Where(e => !string.IsNullOrEmpty(e));

if (invoice.NextPaymentAttempt.HasValue)
{
await _validateSponsorshipCommand.ValidateSponsorshipAsync(ids.Item1.Value);
await _mailService.SendInvoiceUpcoming(
validEmails,
invoice.AmountDue / 100M,
invoice.NextPaymentAttempt.Value,
invoiceLineItemDescriptions,
true);
}
}

var org = await _organizationRepository.GetByIdAsync(ids.Item1.Value);
if (org != null && OrgPlanForInvoiceNotifications(org))
if (organizationId.HasValue)
{
if (IsSponsoredSubscription(updatedSubscription))
{
await _validateSponsorshipCommand.ValidateSponsorshipAsync(organizationId.Value);
}

var organization = await _organizationRepository.GetByIdAsync(organizationId.Value);

if (organization == null || !OrgPlanForInvoiceNotifications(organization))
{
email = org.BillingEmail;
return new OkResult();
}

await SendEmails(new List<string> { organization.BillingEmail });

var ownerEmails = await _organizationRepository.GetOwnerEmailAddressesById(organization.Id);

await SendEmails(ownerEmails);
}
// user
else if (ids.Item2.HasValue)
else if (userId.HasValue)
{
var user = await _userService.GetUserByIdAsync(ids.Item2.Value);
var user = await _userService.GetUserByIdAsync(userId.Value);

if (user.Premium)
{
email = user.Email;
await SendEmails(new List<string> { user.Email });
}
}

if (!string.IsNullOrWhiteSpace(email) && invoice.NextPaymentAttempt.HasValue)
{
var items = invoice.Lines.Select(i => i.Description).ToList();
await _mailService.SendInvoiceUpcomingAsync(email, invoice.AmountDue / 100M,
invoice.NextPaymentAttempt.Value, items, true);
}
}
else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeSucceeded))
{
Expand Down
49 changes: 48 additions & 1 deletion src/Billing/Services/Implementations/StripeEventService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ namespace Bit.Billing.Services.Implementations;
public class StripeEventService : IStripeEventService
{
private readonly GlobalSettings _globalSettings;
private readonly ILogger<StripeEventService> _logger;
private readonly IStripeFacade _stripeFacade;

public StripeEventService(
GlobalSettings globalSettings,
ILogger<StripeEventService> logger,
IStripeFacade stripeFacade)
{
_globalSettings = globalSettings;
_logger = logger;
_stripeFacade = stripeFacade;
}

Expand All @@ -26,6 +29,12 @@ public async Task<Charge> GetCharge(Event stripeEvent, bool fresh = false, List<
return eventCharge;
}

if (string.IsNullOrEmpty(eventCharge.Id))
{
_logger.LogWarning("Cannot retrieve up-to-date Charge for Event with ID '{eventId}' because no Charge ID was included in the Event.", stripeEvent.Id);
return eventCharge;
}

var charge = await _stripeFacade.GetCharge(eventCharge.Id, new ChargeGetOptions { Expand = expand });

if (charge == null)
Expand All @@ -46,6 +55,12 @@ public async Task<Customer> GetCustomer(Event stripeEvent, bool fresh = false, L
return eventCustomer;
}

if (string.IsNullOrEmpty(eventCustomer.Id))
{
_logger.LogWarning("Cannot retrieve up-to-date Customer for Event with ID '{eventId}' because no Customer ID was included in the Event.", stripeEvent.Id);
return eventCustomer;
}

var customer = await _stripeFacade.GetCustomer(eventCustomer.Id, new CustomerGetOptions { Expand = expand });

if (customer == null)
Expand All @@ -66,6 +81,12 @@ public async Task<Invoice> GetInvoice(Event stripeEvent, bool fresh = false, Lis
return eventInvoice;
}

if (string.IsNullOrEmpty(eventInvoice.Id))
{
_logger.LogWarning("Cannot retrieve up-to-date Invoice for Event with ID '{eventId}' because no Invoice ID was included in the Event.", stripeEvent.Id);
return eventInvoice;
}

var invoice = await _stripeFacade.GetInvoice(eventInvoice.Id, new InvoiceGetOptions { Expand = expand });

if (invoice == null)
Expand All @@ -86,6 +107,12 @@ public async Task<PaymentMethod> GetPaymentMethod(Event stripeEvent, bool fresh
return eventPaymentMethod;
}

if (string.IsNullOrEmpty(eventPaymentMethod.Id))
{
_logger.LogWarning("Cannot retrieve up-to-date Payment Method for Event with ID '{eventId}' because no Payment Method ID was included in the Event.", stripeEvent.Id);
return eventPaymentMethod;
}

var paymentMethod = await _stripeFacade.GetPaymentMethod(eventPaymentMethod.Id, new PaymentMethodGetOptions { Expand = expand });

if (paymentMethod == null)
Expand All @@ -106,6 +133,12 @@ public async Task<Subscription> GetSubscription(Event stripeEvent, bool fresh =
return eventSubscription;
}

if (string.IsNullOrEmpty(eventSubscription.Id))
{
_logger.LogWarning("Cannot retrieve up-to-date Subscription for Event with ID '{eventId}' because no Subscription ID was included in the Event.", stripeEvent.Id);
return eventSubscription;
}

var subscription = await _stripeFacade.GetSubscription(eventSubscription.Id, new SubscriptionGetOptions { Expand = expand });

if (subscription == null)
Expand All @@ -132,7 +165,7 @@ public async Task<bool> ValidateCloudRegion(Event stripeEvent)
(await GetCharge(stripeEvent, true, customerExpansion))?.Customer?.Metadata,

HandledStripeWebhook.UpcomingInvoice =>
(await GetInvoice(stripeEvent, true, customerExpansion))?.Customer?.Metadata,
await GetCustomerMetadataFromUpcomingInvoiceEvent(stripeEvent),

HandledStripeWebhook.PaymentSucceeded or HandledStripeWebhook.PaymentFailed or HandledStripeWebhook.InvoiceCreated =>
(await GetInvoice(stripeEvent, true, customerExpansion))?.Customer?.Metadata,
Expand All @@ -154,6 +187,20 @@ public async Task<bool> ValidateCloudRegion(Event stripeEvent)
var customerRegion = GetCustomerRegion(customerMetadata);

return customerRegion == serverRegion;

/* In Stripe, when we receive an invoice.upcoming event, the event does not include an Invoice ID because
the invoice hasn't been created yet. Therefore, rather than retrieving the fresh Invoice with a 'customer'
expansion, we need to use the Customer ID on the event to retrieve the metadata. */
async Task<Dictionary<string, string>> GetCustomerMetadataFromUpcomingInvoiceEvent(Event localStripeEvent)
{
var invoice = await GetInvoice(localStripeEvent);

var customer = !string.IsNullOrEmpty(invoice.CustomerId)
? await _stripeFacade.GetCustomer(invoice.CustomerId)
: null;

return customer?.Metadata;
}
}

private static T Extract<T>(Event stripeEvent)
Expand Down
1 change: 1 addition & 0 deletions src/Core/Repositories/IOrganizationRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ public interface IOrganizationRepository : IRepository<Organization, Guid>
Task<Organization> GetByLicenseKeyAsync(string licenseKey);
Task<SelfHostedOrganizationDetails> GetSelfHostedOrganizationDetailsById(Guid id);
Task<ICollection<Organization>> SearchUnassignedToProviderAsync(string name, string ownerEmail, int skip, int take);
Task<IEnumerable<string>> GetOwnerEmailAddressesById(Guid organizationId);
}
12 changes: 11 additions & 1 deletion src/Core/Services/IMailService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@ public interface IMailService
Task SendOrganizationConfirmedEmailAsync(string organizationName, string email);
Task SendOrganizationUserRemovedForPolicyTwoStepEmailAsync(string organizationName, string email);
Task SendPasswordlessSignInAsync(string returnUrl, string token, string email);
Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate, List<string> items,
Task SendInvoiceUpcoming(
string email,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices);
Task SendInvoiceUpcoming(
IEnumerable<string> email,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices);
Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices);
Task SendAddedCreditAsync(string email, decimal amount);
Expand Down
19 changes: 15 additions & 4 deletions src/Core/Services/Implementations/HandlebarsMailService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,21 @@ public async Task SendPasswordlessSignInAsync(string returnUrl, string token, st
await _mailDeliveryService.SendEmailAsync(message);
}

public async Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate,
List<string> items, bool mentionInvoices)
{
var message = CreateDefaultMessage("Your Subscription Will Renew Soon", email);
public async Task SendInvoiceUpcoming(
string email,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices) => await SendInvoiceUpcoming(new List<string> { email }, amount, dueDate, items, mentionInvoices);

public async Task SendInvoiceUpcoming(
IEnumerable<string> emails,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices)
{
var message = CreateDefaultMessage("Your Subscription Will Renew Soon", emails);
var model = new InvoiceUpcomingViewModel
{
WebVaultUrl = _globalSettings.BaseServiceUri.VaultWithHash,
Expand Down
18 changes: 13 additions & 5 deletions src/Core/Services/NoopImplementations/NoopMailService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,19 @@ public Task SendPasswordlessSignInAsync(string returnUrl, string token, string e
return Task.FromResult(0);
}

public Task SendInvoiceUpcomingAsync(string email, decimal amount, DateTime dueDate,
List<string> items, bool mentionInvoices)
{
return Task.FromResult(0);
}
public Task SendInvoiceUpcoming(
string email,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices) => Task.FromResult(0);

public Task SendInvoiceUpcoming(
IEnumerable<string> emails,
decimal amount,
DateTime dueDate,
List<string> items,
bool mentionInvoices) => Task.FromResult(0);

public Task SendPaymentFailedAsync(string email, decimal amount, bool mentionInvoices)
{
Expand Down
10 changes: 10 additions & 0 deletions src/Infrastructure.Dapper/Repositories/OrganizationRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,14 @@ public async Task<ICollection<Organization>> SearchUnassignedToProviderAsync(str
return results.ToList();
}
}

public async Task<IEnumerable<string>> GetOwnerEmailAddressesById(Guid organizationId)
{
await using var connection = new SqlConnection(ConnectionString);

return await connection.QueryAsync<string>(
$"[{Schema}].[{Table}_ReadOwnerEmailAddressesById]",
new { OrganizationId = organizationId },
commandType: CommandType.StoredProcedure);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,24 @@ public async Task<SelfHostedOrganizationDetails> GetSelfHostedOrganizationDetail
return selfHostedOrganization;
}
}

public async Task<IEnumerable<string>> GetOwnerEmailAddressesById(Guid organizationId)
{
using var scope = ServiceScopeFactory.CreateScope();

var dbContext = GetDatabaseContext(scope);

cyprain-okeke marked this conversation as resolved.
Show resolved Hide resolved
var query =
from u in dbContext.Users
join ou in dbContext.OrganizationUsers on u.Id equals ou.UserId
where
ou.OrganizationId == organizationId &&
ou.Type == OrganizationUserType.Owner &&
ou.Status == OrganizationUserStatusType.Confirmed
group u by u.Email
into grouped
select grouped.Key;

return await query.ToListAsync();
}
}
3 changes: 2 additions & 1 deletion test/Billing.Test/Services/StripeEventServiceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Bit.Billing.Test.Utilities;
using Bit.Core.Settings;
using FluentAssertions;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Stripe;
using Xunit;
Expand All @@ -21,7 +22,7 @@
globalSettings.BaseServiceUri = baseServiceUriSettings;

_stripeFacade = Substitute.For<IStripeFacade>();
_stripeEventService = new StripeEventService(globalSettings, _stripeFacade);
_stripeEventService = new StripeEventService(globalSettings, Substitute.For<ILogger<StripeEventService>>(), _stripeFacade);
}

#region GetCharge
Expand Down Expand Up @@ -476,7 +477,7 @@
var cloudRegionValid = await _stripeEventService.ValidateCloudRegion(stripeEvent);

// Assert
cloudRegionValid.Should().BeTrue();

Check failure on line 480 in test/Billing.Test/Services/StripeEventServiceTests.cs

View workflow job for this annotation

GitHub Actions / Test Results

Bit.Billing.Test.Services.StripeEventServiceTests ► ValidateCloudRegion_UpcomingInvoice_Success

Failed test found in: test/Billing.Test/TestResults/oss-test-results.trx Error: Expected cloudRegionValid to be true, but found False.
Raw output
Expected cloudRegionValid to be true, but found False.
   at FluentAssertions.Execution.XUnit2TestFramework.Throw(String message)
   at FluentAssertions.Execution.TestFrameworkProvider.Throw(String message)
   at FluentAssertions.Execution.DefaultAssertionStrategy.HandleFailure(String message)
   at FluentAssertions.Execution.AssertionScope.FailWith(Func`1 failReasonFunc)
   at FluentAssertions.Execution.AssertionScope.FailWith(Func`1 failReasonFunc)
   at FluentAssertions.Execution.AssertionScope.FailWith(String message, Object[] args)
   at FluentAssertions.Primitives.BooleanAssertions`1.BeTrue(String because, Object[] becauseArgs)
   at Bit.Billing.Test.Services.StripeEventServiceTests.ValidateCloudRegion_UpcomingInvoice_Success() in /home/runner/work/server/server/test/Billing.Test/Services/StripeEventServiceTests.cs:line 480
--- End of stack trace from previous location ---

await _stripeFacade.Received(1).GetInvoice(
invoice.Id,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
CREATE OR ALTER PROCEDURE [dbo].[Organization_ReadOwnerEmailAddressesById]
@OrganizationId UNIQUEIDENTIFIER
AS
BEGIN
SET NOCOUNT ON

SELECT
cyprain-okeke marked this conversation as resolved.
Show resolved Hide resolved
[U].[Email]
FROM [User] AS [U]
INNER JOIN [OrganizationUser] AS [OU] ON [U].[Id] = [OU].[UserId]
WHERE
[OU].[OrganizationId] = @OrganizationId AND
[OU].[Type] = 0 AND -- Owner
[OU].[Status] = 2 -- Confirmed
GROUP BY [U].[Email]
END
GO
Loading