Skip to content

Commit

Permalink
Handle customer.updated event in StripeController
Browse files Browse the repository at this point in the history
  • Loading branch information
amorask-bitwarden committed Oct 2, 2023
1 parent da43a16 commit 715bcff
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 207 deletions.
1 change: 1 addition & 0 deletions src/Billing/Constants/HandledStripeWebhook.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ public static class HandledStripeWebhook
public const string PaymentFailed = "invoice.payment_failed";
public const string InvoiceCreated = "invoice.created";
public const string PaymentMethodAttached = "payment_method.attached";
public const string CustomerUpdated = "customer.updated";
}
243 changes: 36 additions & 207 deletions src/Billing/Controllers/StripeController.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Bit.Billing.Constants;
using Bit.Billing.Services;
using Bit.Core.Context;
using Bit.Core.Entities;
using Bit.Core.Enums;
Expand Down Expand Up @@ -44,12 +45,13 @@ public class StripeController : Controller
private readonly IAppleIapService _appleIapService;
private readonly IMailService _mailService;
private readonly ILogger<StripeController> _logger;
private readonly Braintree.BraintreeGateway _btGateway;
private readonly BraintreeGateway _btGateway;
private readonly IReferenceEventService _referenceEventService;
private readonly ITaxRateRepository _taxRateRepository;
private readonly IUserRepository _userRepository;
private readonly ICurrentContext _currentContext;
private readonly GlobalSettings _globalSettings;
private readonly IStripeEventService _stripeEventService;

public StripeController(
GlobalSettings globalSettings,
Expand All @@ -67,7 +69,8 @@ public StripeController(
ILogger<StripeController> logger,
ITaxRateRepository taxRateRepository,
IUserRepository userRepository,
ICurrentContext currentContext)
ICurrentContext currentContext,
IStripeEventService stripeEventService)
{
_billingSettings = billingSettings?.Value;
_hostingEnvironment = hostingEnvironment;
Expand All @@ -83,7 +86,7 @@ public StripeController(
_taxRateRepository = taxRateRepository;
_userRepository = userRepository;
_logger = logger;
_btGateway = new Braintree.BraintreeGateway
_btGateway = new BraintreeGateway
{
Environment = globalSettings.Braintree.Production ?
Braintree.Environment.PRODUCTION : Braintree.Environment.SANDBOX,
Expand All @@ -93,6 +96,7 @@ public StripeController(
};
_currentContext = currentContext;
_globalSettings = globalSettings;
_stripeEventService = stripeEventService;
}

[HttpPost("webhook")]
Expand All @@ -103,7 +107,7 @@ public async Task<IActionResult> PostWebhook([FromQuery] string key)
return new BadRequestResult();
}

Stripe.Event parsedEvent;
Event parsedEvent;
using (var sr = new StreamReader(HttpContext.Request.Body))
{
var json = await sr.ReadToEndAsync();
Expand All @@ -125,7 +129,7 @@ public async Task<IActionResult> PostWebhook([FromQuery] string key)
}

// If the customer and server cloud regions don't match, early return 200 to avoid unnecessary errors
if (!await ValidateCloudRegionAsync(parsedEvent))
if (!await _stripeEventService.ValidateCloudRegion(parsedEvent))
{
return new OkResult();
}
Expand All @@ -135,7 +139,7 @@ public async Task<IActionResult> PostWebhook([FromQuery] string key)

if (subDeleted || subUpdated)
{
var subscription = await GetSubscriptionAsync(parsedEvent, true);
var subscription = await _stripeEventService.GetSubscription(parsedEvent, true);
var ids = GetIdsFromMetaData(subscription.Metadata);
var organizationId = ids.Item1 ?? Guid.Empty;
var userId = ids.Item2 ?? Guid.Empty;
Expand Down Expand Up @@ -204,7 +208,7 @@ await _userService.UpdatePremiumExpirationAsync(userId,
}
else if (parsedEvent.Type.Equals(HandledStripeWebhook.UpcomingInvoice))
{
var invoice = await GetInvoiceAsync(parsedEvent);
var invoice = await _stripeEventService.GetInvoice(parsedEvent);
var subscriptionService = new SubscriptionService();
var subscription = await subscriptionService.GetAsync(invoice.SubscriptionId);
if (subscription == null)
Expand Down Expand Up @@ -250,7 +254,7 @@ await _mailService.SendInvoiceUpcomingAsync(email, invoice.AmountDue / 100M,
}
else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeSucceeded))
{
var charge = await GetChargeAsync(parsedEvent);
var charge = await _stripeEventService.GetCharge(parsedEvent);
var chargeTransaction = await _transactionRepository.GetByGatewayIdAsync(
GatewayType.Stripe, charge.Id);
if (chargeTransaction != null)
Expand Down Expand Up @@ -377,7 +381,7 @@ await _mailService.SendInvoiceUpcomingAsync(email, invoice.AmountDue / 100M,
}
else if (parsedEvent.Type.Equals(HandledStripeWebhook.ChargeRefunded))
{
var charge = await GetChargeAsync(parsedEvent);
var charge = await _stripeEventService.GetCharge(parsedEvent);
var chargeTransaction = await _transactionRepository.GetByGatewayIdAsync(
GatewayType.Stripe, charge.Id);
if (chargeTransaction == null)
Expand Down Expand Up @@ -427,7 +431,7 @@ await _transactionRepository.CreateAsync(new Transaction
}
else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentSucceeded))
{
var invoice = await GetInvoiceAsync(parsedEvent, true);
var invoice = await _stripeEventService.GetInvoice(parsedEvent, true);
if (invoice.Paid && invoice.BillingReason == "subscription_create")
{
var subscriptionService = new SubscriptionService();
Expand Down Expand Up @@ -479,125 +483,53 @@ await _referenceEventService.RaiseEventAsync(
}
else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentFailed))
{
await HandlePaymentFailed(await GetInvoiceAsync(parsedEvent, true));
await HandlePaymentFailed(await _stripeEventService.GetInvoice(parsedEvent, true));
}
else if (parsedEvent.Type.Equals(HandledStripeWebhook.InvoiceCreated))
{
var invoice = await GetInvoiceAsync(parsedEvent, true);
var invoice = await _stripeEventService.GetInvoice(parsedEvent, true);
if (!invoice.Paid && UnpaidAutoChargeInvoiceForSubscriptionCycle(invoice))
{
await AttemptToPayInvoiceAsync(invoice);
}
}
else if (parsedEvent.Type.Equals(HandledStripeWebhook.PaymentMethodAttached))
{
var paymentMethod = await GetPaymentMethodAsync(parsedEvent);
var paymentMethod = await _stripeEventService.GetPaymentMethod(parsedEvent);
await HandlePaymentMethodAttachedAsync(paymentMethod);
}
else
{
_logger.LogWarning("Unsupported event received. " + parsedEvent.Type);
}

return new OkResult();
}

/// <summary>
/// Ensures that the customer associated with the parsed event's data is in the correct region for this server.
/// We use the customer instead of the subscription given that all subscriptions have customers, but not all
/// customers have subscriptions
/// </summary>
/// <param name="parsedEvent"></param>
/// <returns>true if the customer's region and the server's region match, otherwise false</returns>
/// <exception cref="Exception"></exception>
private async Task<bool> ValidateCloudRegionAsync(Event parsedEvent)
{
var serverRegion = _globalSettings.BaseServiceUri.CloudRegion;
var eventType = parsedEvent.Type;
var expandOptions = new List<string> { "customer" };

try
else if (parsedEvent.Type.Equals(HandledStripeWebhook.CustomerUpdated))
{
Dictionary<string, string> customerMetadata;
switch (eventType)
{
case HandledStripeWebhook.SubscriptionDeleted:
case HandledStripeWebhook.SubscriptionUpdated:
customerMetadata = (await GetSubscriptionAsync(parsedEvent, true, expandOptions))?.Customer
?.Metadata;
break;
case HandledStripeWebhook.ChargeSucceeded:
case HandledStripeWebhook.ChargeRefunded:
customerMetadata = (await GetChargeAsync(parsedEvent, true, expandOptions))?.Customer?.Metadata;
break;
case HandledStripeWebhook.UpcomingInvoice:
customerMetadata = (await GetInvoiceAsync(parsedEvent))?.Customer?.Metadata;
break;
case HandledStripeWebhook.PaymentSucceeded:
case HandledStripeWebhook.PaymentFailed:
case HandledStripeWebhook.InvoiceCreated:
customerMetadata = (await GetInvoiceAsync(parsedEvent, true, expandOptions))?.Customer?.Metadata;
break;
case HandledStripeWebhook.PaymentMethodAttached:
customerMetadata = (await GetPaymentMethodAsync(parsedEvent, true, expandOptions))
?.Customer
?.Metadata;
break;
default:
customerMetadata = null;
break;
}
var customer =
await _stripeEventService.GetCustomer(parsedEvent, true, new List<string> { "subscriptions" });

if (customerMetadata is null)
if (customer.Subscriptions == null || !customer.Subscriptions.Any())
{
return false;
return new OkResult();
}

var customerRegion = GetCustomerRegionFromMetadata(customerMetadata);
var subscription = customer.Subscriptions.First();

return customerRegion == serverRegion;
}
catch (Exception e)
{
_logger.LogError(e, "Encountered unexpected error while validating cloud region");
throw;
}
}
var (organizationId, _) = GetIdsFromMetaData(subscription.Metadata);

/// <summary>
/// Gets the customer's region from the metadata.
/// </summary>
/// <param name="customerMetadata">The metadata of the customer.</param>
/// <returns>The region of the customer. If the region is not specified, it returns "US", if metadata is null,
/// it returns null. It is case insensitive.</returns>
private static string GetCustomerRegionFromMetadata(IDictionary<string, string> customerMetadata)
{
const string defaultRegion = "US";
if (!organizationId.HasValue)
{
return new OkResult();
}

if (customerMetadata is null)
{
return null;
}
var organization = await _organizationRepository.GetByIdAsync(organizationId.Value);
organization.BillingEmail = customer.Email;
await _organizationRepository.ReplaceAsync(organization);

if (customerMetadata.TryGetValue("region", out var value))
{
return value;
await _referenceEventService.RaiseEventAsync(
new ReferenceEvent(ReferenceEventType.OrganizationEditedInStripe, organization, _currentContext));
}

var miscasedRegionKey = customerMetadata.Keys
.FirstOrDefault(key =>
key.Equals("region", StringComparison.OrdinalIgnoreCase));

if (miscasedRegionKey is null)
else
{
return defaultRegion;
_logger.LogWarning("Unsupported event received. " + parsedEvent.Type);
}

_ = customerMetadata.TryGetValue(miscasedRegionKey, out var regionValue);

return !string.IsNullOrWhiteSpace(regionValue)
? regionValue
: defaultRegion;
return new OkResult();
}

private async Task HandlePaymentMethodAttachedAsync(PaymentMethod paymentMethod)
Expand Down Expand Up @@ -975,109 +907,6 @@ private bool UnpaidAutoChargeInvoiceForSubscriptionCycle(Invoice invoice)
invoice.BillingReason == "subscription_cycle" && invoice.SubscriptionId != null;
}

private async Task<Charge> GetChargeAsync(Event parsedEvent, bool fresh = false, List<string> expandOptions = null)
{
if (!(parsedEvent.Data.Object is Charge eventCharge))
{
throw new Exception("Charge is null (from parsed event). " + parsedEvent.Id);
}
if (!fresh)
{
return eventCharge;
}
var chargeService = new ChargeService();
var chargeGetOptions = new ChargeGetOptions { Expand = expandOptions };
var charge = await chargeService.GetAsync(eventCharge.Id, chargeGetOptions);
if (charge == null)
{
throw new Exception("Charge is null. " + eventCharge.Id);
}
return charge;
}

private async Task<Invoice> GetInvoiceAsync(Stripe.Event parsedEvent, bool fresh = false, List<string> expandOptions = null)
{
if (!(parsedEvent.Data.Object is Invoice eventInvoice))
{
throw new Exception("Invoice is null (from parsed event). " + parsedEvent.Id);
}
if (!fresh)
{
return eventInvoice;
}
var invoiceService = new InvoiceService();
var invoiceGetOptions = new InvoiceGetOptions { Expand = expandOptions };
var invoice = await invoiceService.GetAsync(eventInvoice.Id, invoiceGetOptions);
if (invoice == null)
{
throw new Exception("Invoice is null. " + eventInvoice.Id);
}
return invoice;
}

private async Task<Subscription> GetSubscriptionAsync(Stripe.Event parsedEvent, bool fresh = false,
List<string> expandOptions = null)
{
if (parsedEvent.Data.Object is not Subscription eventSubscription)
{
throw new Exception("Subscription is null (from parsed event). " + parsedEvent.Id);
}
if (!fresh)
{
return eventSubscription;
}
var subscriptionService = new SubscriptionService();
var subscriptionGetOptions = new SubscriptionGetOptions { Expand = expandOptions };
var subscription = await subscriptionService.GetAsync(eventSubscription.Id, subscriptionGetOptions);
if (subscription == null)
{
throw new Exception("Subscription is null. " + eventSubscription.Id);
}
return subscription;
}

private async Task<Customer> GetCustomerAsync(string customerId)
{
if (string.IsNullOrWhiteSpace(customerId))
{
throw new Exception("Customer ID cannot be empty when attempting to get a customer from Stripe");
}

var customerService = new CustomerService();
var customer = await customerService.GetAsync(customerId);
if (customer == null)
{
throw new Exception($"Customer is null. {customerId}");
}

return customer;
}

private async Task<PaymentMethod> GetPaymentMethodAsync(Event parsedEvent, bool fresh = false,
List<string> expandOptions = null)
{
if (parsedEvent.Data.Object is not PaymentMethod eventPaymentMethod)
{
throw new Exception("Invoice is null (from parsed event). " + parsedEvent.Id);
}

if (!fresh)
{
return eventPaymentMethod;
}

var paymentMethodService = new PaymentMethodService();
var paymentMethodGetOptions = new PaymentMethodGetOptions { Expand = expandOptions };
var paymentMethod = await paymentMethodService.GetAsync(eventPaymentMethod.Id, paymentMethodGetOptions);

if (paymentMethod == null)
{
throw new Exception($"Payment method is null. {eventPaymentMethod.Id}");
}

return paymentMethod;
}

private async Task<Subscription> VerifyCorrectTaxRateForCharge(Invoice invoice, Subscription subscription)
{
if (!string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.Country) && !string.IsNullOrWhiteSpace(invoice?.CustomerAddress?.PostalCode))
Expand Down
2 changes: 2 additions & 0 deletions src/Core/Tools/Enums/ReferenceEventType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public enum ReferenceEventType
OrganizationEditedByAdmin,
[EnumMember(Value = "organization-created-by-admin")]
OrganizationCreatedByAdmin,
[EnumMember(Value = "organization-edited-in-stripe")]
OrganizationEditedInStripe,
[EnumMember(Value = "sm-service-account-accessed-secret")]
SmServiceAccountAccessedSecret,
}

0 comments on commit 715bcff

Please sign in to comment.