Skip to content

Commit 3f1fc4f

Browse files
authored
Improve google refund polling handling. (#982)
General improvements around IAP validation.
1 parent ada6f94 commit 3f1fc4f

File tree

4 files changed

+54
-60
lines changed

4 files changed

+54
-60
lines changed

iap/iap.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,11 @@ func init() {
9292

9393
type ValidateReceiptAppleResponseReceiptInApp struct {
9494
OriginalTransactionID string `json:"original_transaction_id"`
95-
TransactionId string `json:"transaction_id"` // Different than OriginalTransactionId if the user Auto-renews subscription or restores a purchase.
95+
TransactionId string `json:"transaction_id"` // Different from OriginalTransactionId if the user Auto-renews subscription or restores a purchase.
9696
ProductID string `json:"product_id"`
9797
ExpiresDateMs string `json:"expires_date_ms"` // Subscription expiration or renewal date.
9898
PurchaseDateMs string `json:"purchase_date_ms"`
99+
CancellationDateMs string `json:"cancellation_date_ms"`
99100
}
100101

101102
type ValidateReceiptAppleResponseReceipt struct {

server/core_purchase.go

+26-41
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ func ValidatePurchaseGoogle(ctx context.Context, logger *zap.Logger, db *sql.DB,
172172
if !persist {
173173
validatedPurchases := []*api.ValidatedPurchase{
174174
{
175+
UserId: userID.String(),
175176
ProductId: sPurchase.productId,
176177
TransactionId: sPurchase.transactionId,
177178
Store: sPurchase.store,
@@ -446,6 +447,7 @@ func ListPurchases(ctx context.Context, logger *zap.Logger, db *sql.DB, userID s
446447
purchase_time,
447448
create_time,
448449
update_time,
450+
refund_time,
449451
environment
450452
FROM
451453
purchase
@@ -472,9 +474,10 @@ func ListPurchases(ctx context.Context, logger *zap.Logger, db *sql.DB, userID s
472474
var purchaseTime pgtype.Timestamptz
473475
var createTime pgtype.Timestamptz
474476
var updateTime pgtype.Timestamptz
477+
var refundTime pgtype.Timestamptz
475478
var environment api.StoreEnvironment
476479

477-
if err = rows.Scan(&dbUserID, &transactionId, &productId, &store, &rawResponse, &purchaseTime, &createTime, &updateTime, &environment); err != nil {
480+
if err = rows.Scan(&dbUserID, &transactionId, &productId, &store, &rawResponse, &purchaseTime, &createTime, &updateTime, &refundTime, &environment); err != nil {
478481
logger.Error("Error retrieving purchases.", zap.Error(err))
479482
return nil, err
480483
}
@@ -500,6 +503,9 @@ func ListPurchases(ctx context.Context, logger *zap.Logger, db *sql.DB, userID s
500503
ProviderResponse: rawResponse,
501504
Environment: environment,
502505
}
506+
if refundTime.Time.Unix() != 0 {
507+
purchase.RefundTime = timestamppb.New(purchase.RefundTime.AsTime())
508+
}
503509

504510
purchases = append(purchases, purchase)
505511

@@ -575,6 +581,8 @@ func upsertPurchases(ctx context.Context, db *sql.DB, purchases []*storagePurcha
575581
return nil, errors.New("expects at least one receipt")
576582
}
577583

584+
userIDIn := purchases[0].userID
585+
578586
statements := make([]string, 0, len(purchases))
579587
params := make([]interface{}, 0, len(purchases)*8)
580588
transactionIDsToPurchase := make(map[string]*storagePurchase)
@@ -613,72 +621,49 @@ VALUES
613621
ON CONFLICT
614622
(transaction_id)
615623
DO UPDATE SET
616-
refund_time = $8, update_time = now()
624+
refund_time = $8,
625+
update_time = now()
617626
RETURNING
618-
transaction_id, create_time, update_time, refund_time
627+
user_id,
628+
transaction_id,
629+
create_time,
630+
update_time,
631+
refund_time
619632
`
620-
insertedTransactionIDs := make(map[string]struct{})
621633
rows, err := db.QueryContext(ctx, query, params...)
622634
if err != nil {
623635
return nil, err
624636
}
625637
for rows.Next() {
626638
// Newly inserted purchases
639+
var dbUserID uuid.UUID
627640
var transactionId string
628641
var createTime pgtype.Timestamptz
629642
var updateTime pgtype.Timestamptz
630643
var refundTime pgtype.Timestamptz
631-
if err = rows.Scan(&transactionId, &createTime, &updateTime, &refundTime); err != nil {
644+
if err = rows.Scan(&dbUserID, &transactionId, &createTime, &updateTime, &refundTime); err != nil {
632645
rows.Close()
633646
return nil, err
634647
}
635648
storedPurchase, _ := transactionIDsToPurchase[transactionId]
636649
storedPurchase.createTime = createTime.Time
637650
storedPurchase.updateTime = updateTime.Time
638-
storedPurchase.refundTime = refundTime.Time
639-
storedPurchase.seenBefore = false
640-
insertedTransactionIDs[storedPurchase.transactionId] = struct{}{}
651+
storedPurchase.seenBefore = updateTime.Time.After(createTime.Time)
652+
if refundTime.Time.Unix() != 0 {
653+
storedPurchase.refundTime = refundTime.Time
654+
}
641655
}
642656
rows.Close()
643657
if err := rows.Err(); err != nil {
644658
return nil, err
645659
}
646660

647-
// Go over purchases that have not been inserted (already exist in the DB) and fetch createTime and updateTime
648-
if len(transactionIDsToPurchase) > len(insertedTransactionIDs) {
649-
seenIDs := make([]string, 0, len(transactionIDsToPurchase))
650-
for tID, _ := range transactionIDsToPurchase {
651-
if _, ok := insertedTransactionIDs[tID]; !ok {
652-
seenIDs = append(seenIDs, tID)
653-
}
654-
}
655-
656-
rows, err = db.QueryContext(ctx, "SELECT transaction_id, create_time, update_time FROM purchase WHERE transaction_id IN ($1)", strings.Join(seenIDs, ", "))
657-
if err != nil {
658-
return nil, err
659-
}
660-
for rows.Next() {
661-
// Already seen purchases
662-
var transactionId string
663-
var createTime pgtype.Timestamptz
664-
var updateTime pgtype.Timestamptz
665-
if err = rows.Scan(&transactionId, &createTime, &updateTime); err != nil {
666-
rows.Close()
667-
return nil, err
668-
}
669-
storedPurchase, _ := transactionIDsToPurchase[transactionId]
670-
storedPurchase.createTime = createTime.Time
671-
storedPurchase.updateTime = updateTime.Time
672-
storedPurchase.seenBefore = true
673-
}
674-
rows.Close()
675-
if err := rows.Err(); err != nil {
676-
return nil, err
677-
}
678-
}
679-
680661
storedPurchases := make([]*storagePurchase, 0, len(transactionIDsToPurchase))
681662
for _, purchase := range transactionIDsToPurchase {
663+
if purchase.seenBefore && purchase.userID != userIDIn {
664+
// Mismatch between userID requesting validation and existing receipt userID, return error.
665+
return nil, status.Error(codes.FailedPrecondition, "Invalid receipt for userID.")
666+
}
682667
storedPurchases = append(storedPurchases, purchase)
683668
}
684669

server/core_subscription.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -581,20 +581,26 @@ DO
581581
raw_notification = coalesce(to_jsonb(nullif($9, '')), subscription.raw_notification::jsonb),
582582
refund_time = coalesce($10, subscription.refund_time)
583583
RETURNING
584-
create_time, update_time, expire_time, refund_time, raw_response, raw_notification
584+
user_id, create_time, update_time, expire_time, refund_time, raw_response, raw_notification
585585
`
586586
var (
587+
userID uuid.UUID
587588
createTime pgtype.Timestamptz
588589
updateTime pgtype.Timestamptz
589590
expireTime pgtype.Timestamptz
590591
refundTime pgtype.Timestamptz
591592
rawResponse string
592593
rawNotification string
593594
)
594-
if err := db.QueryRowContext(ctx, query, sub.userID, sub.store, sub.originalTransactionId, sub.productId, sub.purchaseTime, sub.environment, sub.expireTime, sub.rawResponse, sub.rawNotification, sub.refundTime).Scan(&createTime, &updateTime, &expireTime, &refundTime, &rawResponse, &rawNotification); err != nil {
595+
if err := db.QueryRowContext(ctx, query, sub.userID, sub.store, sub.originalTransactionId, sub.productId, sub.purchaseTime, sub.environment, sub.expireTime, sub.rawResponse, sub.rawNotification, sub.refundTime).Scan(&userID, &createTime, &updateTime, &expireTime, &refundTime, &rawResponse, &rawNotification); err != nil {
595596
return err
596597
}
597598

599+
if sub.userID != userID {
600+
// Subscription receipt has been seen before for a different user.
601+
return status.Error(codes.FailedPrecondition, "Invalid receipt for userID")
602+
}
603+
598604
sub.createTime = createTime.Time
599605
sub.updateTime = updateTime.Time
600606
sub.expireTime = expireTime.Time

server/google_refund_scheduler.go

+18-16
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,14 @@ func (g *LocalGoogleRefundScheduler) Start(runtime *Runtime) {
9696
}
9797

9898
for _, vr := range voidedReceipts {
99-
switch vr.Kind {
100-
case "androidpublisher#productPurchase":
101-
purchase, err := getPurchaseByTransactionId(g.ctx, g.db, vr.PurchaseToken)
102-
if err != nil && err != sql.ErrNoRows {
103-
g.logger.Warn("Failed to find purchase for Google refund callback", zap.Error(err), zap.String("purchase_token", vr.PurchaseToken))
104-
continue
105-
}
99+
purchase, err := getPurchaseByTransactionId(g.ctx, g.db, vr.PurchaseToken)
100+
if err != nil && err != sql.ErrNoRows {
101+
g.logger.Error("Failed to get purchase by transaction_id", zap.Error(err), zap.String("purchase_token", vr.PurchaseToken))
102+
continue
103+
}
106104

105+
if purchase != nil {
106+
// Refunded purchase.
107107
if purchase.RefundTime.Seconds != 0 {
108108
// Purchase refund already handled, skip it.
109109
continue
@@ -144,8 +144,9 @@ func (g *LocalGoogleRefundScheduler) Start(runtime *Runtime) {
144144
PurchaseTime: timestamppb.New(dbPurchase.purchaseTime),
145145
CreateTime: timestamppb.New(dbPurchase.createTime),
146146
UpdateTime: timestamppb.New(dbPurchase.updateTime),
147-
RefundTime: timestamppb.New(refundTime),
147+
RefundTime: timestamppb.New(dbPurchase.refundTime),
148148
Environment: purchase.Environment,
149+
SeenBefore: dbPurchase.seenBefore,
149150
}
150151

151152
json, err := json.Marshal(vr)
@@ -159,16 +160,19 @@ func (g *LocalGoogleRefundScheduler) Start(runtime *Runtime) {
159160
g.logger.Warn("Failed to invoke Google purchase refund hook", zap.Error(err))
160161
}
161162
}
162-
163-
case "androidpublisher#subscriptionPurchase":
163+
} else {
164164
subscription, err := getSubscriptionByOriginalTransactionId(g.ctx, g.db, vr.PurchaseToken)
165-
if err != nil {
166-
if err != sql.ErrNoRows {
167-
g.logger.Error("Failed to find subscription for Google refund callback", zap.Error(err), zap.String("transaction_id", vr.PurchaseToken))
168-
}
165+
if err != nil && err != sql.ErrNoRows {
166+
g.logger.Error("Failed to get subscription by original_transaction_id", zap.Error(err), zap.String("original_transaction_id", vr.PurchaseToken))
167+
continue
168+
}
169+
170+
if subscription == nil {
171+
// No subscription was found.
169172
continue
170173
}
171174

175+
// Refunded subscription.
172176
if subscription.RefundTime.Seconds != 0 {
173177
// Subscription refund already handled, skip it.
174178
continue
@@ -231,8 +235,6 @@ func (g *LocalGoogleRefundScheduler) Start(runtime *Runtime) {
231235
g.logger.Warn("Failed to invoke Google subscription refund hook", zap.Error(err))
232236
}
233237
}
234-
default:
235-
g.logger.Warn("Unhandled IAP Google voided receipt kind", zap.String("kind", vr.Kind))
236238
}
237239
}
238240
}

0 commit comments

Comments
 (0)