Skip to content

Commit

Permalink
fix(payments): pools fixes (#952)
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-nicolas authored Dec 5, 2023
1 parent 29f6584 commit 6539dd0
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 4 deletions.
6 changes: 4 additions & 2 deletions components/payments/cmd/api/internal/storage/pools.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ func (s *Storage) AddAccountToPool(ctx context.Context, poolAccount *models.Pool
func (s *Storage) RemoveAccountFromPool(ctx context.Context, poolAccount *models.PoolAccounts) error {
_, err := s.db.NewDelete().
Model(poolAccount).
Where("pool_id = ?", poolAccount.PoolID).
Where("account_id = ?", poolAccount.AccountID).
Exec(ctx)
if err != nil {
return e("failed to remove account from pool", err)
Expand All @@ -63,7 +65,7 @@ func (s *Storage) ListPools(ctx context.Context, pagination PaginatorQuery) ([]*
Model(&pools).
Relation("PoolAccounts")

query = pagination.apply(query, "pools.created_at")
query = pagination.apply(query, "pool.created_at")

err := query.Scan(ctx)
if err != nil {
Expand Down Expand Up @@ -94,7 +96,7 @@ func (s *Storage) ListPools(ctx context.Context, pagination PaginatorQuery) ([]*

query = s.db.NewSelect().Model(&pools)

hasPrevious, err = pagination.hasPrevious(ctx, query, "pools.created_at", firstReference)
hasPrevious, err = pagination.hasPrevious(ctx, query, "pool.created_at", firstReference)
if err != nil {
return nil, PaginationDetails{}, e("failed to check if there is a previous page", err)
}
Expand Down
197 changes: 197 additions & 0 deletions components/payments/cmd/api/internal/storage/pools_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package storage_test

import (
"context"
"testing"
"time"

"github.com/formancehq/payments/cmd/api/internal/storage"
"github.com/formancehq/payments/internal/models"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

func insertPools(t *testing.T, store *storage.Storage, accountIDs []models.AccountID) []uuid.UUID {
pool1 := models.Pool{
Name: "test",
CreatedAt: time.Date(2023, 11, 14, 8, 0, 0, 0, time.UTC),
}
var uuid1 uuid.UUID
err := store.DB().NewInsert().
Model(&pool1).
Returning("id").
Scan(context.Background(), &uuid1)
require.NoError(t, err)

poolAccounts1 := models.PoolAccounts{
PoolID: uuid1,
AccountID: accountIDs[0],
}
_, err = store.DB().NewInsert().
Model(&poolAccounts1).
Exec(context.Background())

var uuid2 uuid.UUID
pool2 := models.Pool{
Name: "test2",
CreatedAt: time.Date(2023, 11, 14, 9, 0, 0, 0, time.UTC),
}
err = store.DB().NewInsert().
Model(&pool2).
Returning("id").
Scan(context.Background(), &uuid2)
require.NoError(t, err)

poolAccounts2 := []*models.PoolAccounts{
{
PoolID: uuid2,
AccountID: accountIDs[0],
},
{
PoolID: uuid2,
AccountID: accountIDs[1],
},
}
_, err = store.DB().NewInsert().
Model(&poolAccounts2).
Exec(context.Background())
require.NoError(t, err)

return []uuid.UUID{uuid1, uuid2}
}

func TestCreatePools(t *testing.T) {
t.Parallel()

store := newStore(t)

connectorID := installConnector(t, store)
accounts := insertAccounts(t, store, connectorID)

pool := &models.Pool{
Name: "test",
CreatedAt: time.Date(2023, 11, 14, 8, 0, 0, 0, time.UTC),
PoolAccounts: []*models.PoolAccounts{},
}
for _, account := range accounts {
pool.PoolAccounts = append(pool.PoolAccounts, &models.PoolAccounts{
AccountID: account,
})
}

err := store.CreatePool(context.Background(), pool)
require.NoError(t, err)
require.NotEqual(t, uuid.Nil, pool.ID)
}

func TestAddAccountsToPool(t *testing.T) {
t.Parallel()

store := newStore(t)

connectorID := installConnector(t, store)
accounts := insertAccounts(t, store, connectorID)
poolIDs := insertPools(t, store, accounts)

poolAccounts := []*models.PoolAccounts{
{
PoolID: poolIDs[0],
AccountID: accounts[1],
},
}

err := store.AddAccountsToPool(context.Background(), poolAccounts)
require.NoError(t, err)

pool, err := store.GetPool(context.Background(), poolIDs[0])
require.NoError(t, err)
require.Equal(t, 2, len(pool.PoolAccounts))
require.Equal(t, accounts[0], pool.PoolAccounts[0].AccountID)
require.Equal(t, accounts[1], pool.PoolAccounts[1].AccountID)
}

func TestRemoveAccoutsToPool(t *testing.T) {
t.Parallel()

store := newStore(t)

connectorID := installConnector(t, store)
accounts := insertAccounts(t, store, connectorID)
poolIDs := insertPools(t, store, accounts)

poolAccounts := []*models.PoolAccounts{
{
PoolID: poolIDs[0],
AccountID: accounts[0],
},
}

err := store.RemoveAccountFromPool(context.Background(), poolAccounts[0])
require.NoError(t, err)

pool, err := store.GetPool(context.Background(), poolIDs[0])
require.NoError(t, err)
require.Equal(t, 0, len(pool.PoolAccounts))
}

func TestListPools(t *testing.T) {
t.Parallel()

store := newStore(t)

connectorID := installConnector(t, store)
accounts := insertAccounts(t, store, connectorID)
insertedPools := insertPools(t, store, accounts)

t.Run("list all pools", func(t *testing.T) {
t.Parallel()

pools, paginationDetails, err := store.ListPools(context.Background(), storage.NewPaginatorQuery(15, nil, nil))
require.NoError(t, err)
require.Equal(t, 2, len(pools))
require.Equal(t, 15, paginationDetails.PageSize)
require.Equal(t, false, paginationDetails.HasMore)
require.Equal(t, "", paginationDetails.PreviousPage)
require.Equal(t, "", paginationDetails.NextPage)
require.Equal(t, insertedPools[1], pools[0].ID)
require.Equal(t, 2, len(pools[0].PoolAccounts))
require.Equal(t, insertedPools[0], pools[1].ID)
require.Equal(t, 1, len(pools[1].PoolAccounts))
})

t.Run("list all pools with page size 1", func(t *testing.T) {
t.Parallel()

query, err := storage.Paginate(1, "", nil, nil)
require.NoError(t, err)

pools, paginationDetails, err := store.ListPools(context.Background(), query)
require.NoError(t, err)
require.Equal(t, 1, len(pools))
require.Equal(t, 1, paginationDetails.PageSize)
require.Equal(t, true, paginationDetails.HasMore)
require.Equal(t, "", paginationDetails.PreviousPage)
require.Equal(t, insertedPools[1], pools[0].ID)
require.Equal(t, 2, len(pools[0].PoolAccounts))

query, err = storage.Paginate(1, paginationDetails.NextPage, nil, nil)
require.NoError(t, err)
pools, paginationDetails, err = store.ListPools(context.Background(), query)
require.NoError(t, err)
require.Equal(t, 1, len(pools))
require.Equal(t, 1, paginationDetails.PageSize)
require.Equal(t, false, paginationDetails.HasMore)
require.Equal(t, insertedPools[0], pools[0].ID)
require.Equal(t, 1, len(pools[0].PoolAccounts))

query, err = storage.Paginate(1, paginationDetails.PreviousPage, nil, nil)
require.NoError(t, err)
pools, paginationDetails, err = store.ListPools(context.Background(), query)
require.NoError(t, err)
require.Equal(t, 1, len(pools))
require.Equal(t, 1, paginationDetails.PageSize)
require.Equal(t, true, paginationDetails.HasMore)
require.Equal(t, insertedPools[1], pools[0].ID)
require.Equal(t, 2, len(pools[0].PoolAccounts))
})
}
2 changes: 1 addition & 1 deletion components/payments/internal/models/pools.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type PoolAccounts struct {
type Pool struct {
bun.BaseModel `bun:"accounts.pools"`

ID uuid.UUID `bun:",pk,notnull"`
ID uuid.UUID `bun:",pk,nullzero"`
Name string
CreatedAt time.Time

Expand Down
14 changes: 13 additions & 1 deletion components/payments/internal/storage/migrations_v1.x.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func registerMigrationsV1(ctx context.Context, migrator *migrations.Migrator) {
Up: func(tx bun.Tx) error {
_, err := tx.Exec(`
CREATE TABLE IF NOT EXISTS accounts.pool_accounts (
pool_id uuid NOT NULL DEFAULT gen_random_uuid(),
pool_id uuid NOT NULL,
account_id CHARACTER VARYING NOT NULL,
CONSTRAINT pool_accounts_pk PRIMARY KEY (pool_id, account_id)
);
Expand Down Expand Up @@ -100,6 +100,18 @@ func registerMigrationsV1(ctx context.Context, migrator *migrations.Migrator) {
return err
}

return nil
},
},
migrations.Migration{
Up: func(tx bun.Tx) error {
_, err := tx.Exec(`
ALTER TABLE accounts.pools ALTER COLUMN id SET DEFAULT gen_random_uuid();
`)
if err != nil {
return err
}

return nil
},
},
Expand Down

0 comments on commit 6539dd0

Please sign in to comment.