diff --git a/pkg/api/controllers/account_controller.go b/pkg/api/controllers/account_controller.go index bb75e1e37..1330c01e0 100644 --- a/pkg/api/controllers/account_controller.go +++ b/pkg/api/controllers/account_controller.go @@ -142,6 +142,7 @@ func (ctl *AccountController) GetAccounts(c *gin.Context) { WithAddressFilter(c.Query("address")). WithBalanceFilter(balance). WithBalanceOperatorFilter(balanceOperator). + WithBalanceAssetFilter(c.Query("balanceAsset")). WithMetadataFilter(c.QueryMap("metadata")). WithPageSize(pageSize) } diff --git a/pkg/api/controllers/account_controller_test.go b/pkg/api/controllers/account_controller_test.go index 2dd6008a5..74259182f 100644 --- a/pkg/api/controllers/account_controller_test.go +++ b/pkg/api/controllers/account_controller_test.go @@ -52,6 +52,18 @@ func TestGetAccounts(t *testing.T) { }, false) require.Equal(t, http.StatusOK, rsp.Result().StatusCode) + rsp = internal.PostTransaction(t, api, controllers.PostTransaction{ + Postings: core.Postings{ + { + Source: "world", + Destination: "fred", + Amount: core.NewMonetaryInt(10), + Asset: "EUR", + }, + }, + }, false) + require.Equal(t, http.StatusOK, rsp.Result().StatusCode) + meta := core.Metadata{ "roles": "admin", "accountId": float64(3), @@ -67,16 +79,17 @@ func TestGetAccounts(t *testing.T) { rsp = internal.CountAccounts(api, url.Values{}) require.Equal(t, http.StatusOK, rsp.Result().StatusCode) - require.Equal(t, "3", rsp.Header().Get("Count")) + require.Equal(t, "4", rsp.Header().Get("Count")) t.Run("all", func(t *testing.T) { rsp = internal.GetAccounts(api, url.Values{}) assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) // 3 accounts: world, bob, alice - assert.Len(t, cursor.Data, 3) + assert.Len(t, cursor.Data, 4) assert.Equal(t, []core.Account{ {Address: "world", Metadata: core.Metadata{}}, + {Address: "fred", Metadata: core.Metadata{}}, {Address: "bob", Metadata: meta}, {Address: "alice", Metadata: core.Metadata{}}, }, cursor.Data) @@ -261,6 +274,18 @@ func TestGetAccounts(t *testing.T) { assert.Equal(t, "alice", string(cursor.Data[0].Address)) }) + t.Run("filter by balance >= 0 and asset specified", func(t *testing.T) { + rsp = internal.GetAccounts(api, url.Values{ + "balanceAsset": []string{"EUR"}, + "balance": []string{"0"}, + controllers.QueryKeyBalanceOperator: []string{"gte"}, + }) + assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) + cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) + assert.Len(t, cursor.Data, 1) + assert.Equal(t, "fred", string(cursor.Data[0].Address)) + }) + t.Run("filter by balance > 120", func(t *testing.T) { rsp = internal.GetAccounts(api, url.Values{ "balance": []string{"120"}, @@ -290,8 +315,9 @@ func TestGetAccounts(t *testing.T) { }) assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - assert.Len(t, cursor.Data, 1) + assert.Len(t, cursor.Data, 2) assert.Equal(t, "world", string(cursor.Data[0].Address)) + assert.Equal(t, "fred", string(cursor.Data[1].Address)) }) t.Run("filter by balance <= 100", func(t *testing.T) { @@ -301,9 +327,10 @@ func TestGetAccounts(t *testing.T) { }) assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - assert.Len(t, cursor.Data, 2) + assert.Len(t, cursor.Data, 3) assert.Equal(t, "world", string(cursor.Data[0].Address)) - assert.Equal(t, "bob", string(cursor.Data[1].Address)) + assert.Equal(t, "fred", string(cursor.Data[1].Address)) + assert.Equal(t, "bob", string(cursor.Data[2].Address)) }) t.Run("filter by balance = 100", func(t *testing.T) { @@ -325,9 +352,10 @@ func TestGetAccounts(t *testing.T) { }) assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) cursor := internal.DecodeCursorResponse[core.Account](t, rsp.Body) - assert.Len(t, cursor.Data, 2) + assert.Len(t, cursor.Data, 3) assert.Equal(t, "world", string(cursor.Data[0].Address)) - assert.Equal(t, "alice", string(cursor.Data[1].Address)) + assert.Equal(t, "fred", string(cursor.Data[1].Address)) + assert.Equal(t, "alice", string(cursor.Data[2].Address)) }) t.Run("invalid balance", func(t *testing.T) { diff --git a/pkg/api/controllers/swagger.yaml b/pkg/api/controllers/swagger.yaml index 2b1117b0c..25f55eb65 100644 --- a/pkg/api/controllers/swagger.yaml +++ b/pkg/api/controllers/swagger.yaml @@ -168,6 +168,11 @@ paths: type: integer format: int64 example: 2400 + - name: balanceAsset + in: query + description: Filter accounts by their balance asset + schema: + type: string - name: balanceOperator x-speakeasy-ignore: true in: query diff --git a/pkg/ledger/storage.go b/pkg/ledger/storage.go index 28674d602..f85ad030c 100644 --- a/pkg/ledger/storage.go +++ b/pkg/ledger/storage.go @@ -133,6 +133,7 @@ type AccountsQueryFilters struct { Balance string BalanceOperator BalanceOperator Metadata map[string]string + BalanceAsset string } type BalanceOperator string @@ -220,6 +221,12 @@ func (a *AccountsQuery) WithMetadataFilter(metadata map[string]string) *Accounts return a } +func (a *AccountsQuery) WithBalanceAssetFilter(value string) *AccountsQuery { + a.Filters.BalanceAsset = value + + return a +} + type BalancesQuery struct { PageSize uint Offset uint diff --git a/pkg/storage/sqlstorage/accounts.go b/pkg/storage/sqlstorage/accounts.go index c2769e1f1..c1bc47d1c 100644 --- a/pkg/storage/sqlstorage/accounts.go +++ b/pkg/storage/sqlstorage/accounts.go @@ -27,11 +27,13 @@ func (s *Store) buildAccountsQuery(p ledger.AccountsQuery) (*sqlbuilder.SelectBu sb := sqlbuilder.NewSelectBuilder() t := AccPaginationToken{} sb.From(s.schema.Table("accounts")) + sb.Distinct() var ( address = p.Filters.Address metadata = p.Filters.Metadata balance = p.Filters.Balance + balanceAsset = p.Filters.BalanceAsset balanceOperator = p.Filters.BalanceOperator ) @@ -78,6 +80,9 @@ func (s *Store) buildAccountsQuery(p ledger.AccountsQuery) (*sqlbuilder.SelectBu if balance != "" { sb.Join(s.schema.Table("volumes"), "accounts.address = volumes.account") + if balanceAsset != "" { + sb = sb.Where(sb.E("volumes.asset", balanceAsset)) + } balanceOperation := "volumes.input - volumes.output" balanceValue, err := strconv.ParseInt(balance, 10, 0) diff --git a/pkg/storage/sqlstorage/accounts_test.go b/pkg/storage/sqlstorage/accounts_test.go index a746531d6..4c93a90fc 100644 --- a/pkg/storage/sqlstorage/accounts_test.go +++ b/pkg/storage/sqlstorage/accounts_test.go @@ -2,15 +2,72 @@ package sqlstorage_test import ( "context" - "testing" - "github.com/numary/ledger/pkg/core" "github.com/numary/ledger/pkg/ledger" "github.com/numary/ledger/pkg/storage/sqlstorage" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" ) func testAccounts(t *testing.T, store *sqlstorage.Store) { + + err := store.Commit(context.Background(), + core.ExpandedTransaction{ + Transaction: core.Transaction{ + TransactionData: core.TransactionData{ + Postings: []core.Posting{ + { + Source: "world", + Destination: "us_bank", + Amount: core.NewMonetaryInt(100), + Asset: "USD/2", + }, + { + Source: "world", + Destination: "eu_bank", + Amount: core.NewMonetaryInt(100), + Asset: "EUR/2", + }, + }, + }, + }, + PreCommitVolumes: map[string]core.AssetsVolumes{ + "world": map[string]core.Volumes{ + "USD/2": {}, + "EUR/2": {}, + }, + "us_bank": map[string]core.Volumes{ + "USD/2": {}, + }, + "eu_bank": map[string]core.Volumes{ + "EUR/2": {}, + }, + }, + PostCommitVolumes: map[string]core.AssetsVolumes{ + "world": map[string]core.Volumes{ + "USD/2": { + Output: core.NewMonetaryInt(100), + }, + "EUR/2": { + Output: core.NewMonetaryInt(100), + }, + }, + "us_bank": map[string]core.Volumes{ + "USD/2": { + Input: core.NewMonetaryInt(100), + }, + }, + "eu_bank": map[string]core.Volumes{ + "EUR/2": { + Input: core.NewMonetaryInt(100), + }, + }, + }, + }, + ) + require.NoError(t, err) + t.Run("success balance", func(t *testing.T) { q := ledger.AccountsQuery{ PageSize: 10, @@ -22,6 +79,35 @@ func testAccounts(t *testing.T, store *sqlstorage.Store) { _, err := store.GetAccounts(context.Background(), q) assert.NoError(t, err, "balance filter should not fail") }) + t.Run("filter balance when multiple assets match", func(t *testing.T) { + q := ledger.AccountsQuery{ + PageSize: 10, + Filters: ledger.AccountsQueryFilters{ + Balance: "0", + BalanceOperator: "lt", + }, + } + + accounts, err := store.GetAccounts(context.Background(), q) + require.NoError(t, err, "balance filter should not fail") + require.Len(t, accounts.Data, 1) + require.EqualValues(t, "world", accounts.Data[0].Address) + }) + t.Run("filter balance when specifying asset", func(t *testing.T) { + q := ledger.AccountsQuery{ + PageSize: 10, + Filters: ledger.AccountsQueryFilters{ + Balance: "0", + BalanceOperator: "gt", + BalanceAsset: "USD/2", + }, + } + + accounts, err := store.GetAccounts(context.Background(), q) + require.NoError(t, err, "balance filter should not fail") + require.Len(t, accounts.Data, 1) + require.EqualValues(t, "us_bank", accounts.Data[0].Address) + }) t.Run("panic invalid balance", func(t *testing.T) { q := ledger.AccountsQuery{ @@ -67,19 +153,16 @@ func testAccounts(t *testing.T, store *sqlstorage.Store) { }) t.Run("success get accounts with address filters", func(t *testing.T) { - err := store.Commit(context.Background(), tx1, tx2, tx3, tx4) - assert.NoError(t, err) - q := ledger.AccountsQuery{ PageSize: 10, Filters: ledger.AccountsQueryFilters{ - Address: "users:1", + Address: "us_bank", }, } accounts, err := store.GetAccounts(context.Background(), q) assert.NoError(t, err, "balance operator filter should not fail") assert.Equal(t, len(accounts.Data), 1) - assert.Equal(t, accounts.Data[0].Address, core.AccountAddress("users:1")) + assert.Equal(t, accounts.Data[0].Address, core.AccountAddress("us_bank")) }) }