Skip to content

Commit

Permalink
fix(balances): missing delimiters for regexp (#469)
Browse files Browse the repository at this point in the history
* fix(balances): missing delimiters for regexp
  • Loading branch information
paul-nicolas committed Oct 17, 2023
1 parent 0a8144b commit 93ed88f
Show file tree
Hide file tree
Showing 6 changed files with 470 additions and 272 deletions.
6 changes: 3 additions & 3 deletions pkg/ledgertesting/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ func StorageDriver(multipleInstance bool) (*sqlstorage.Driver, func(), error) {
return nil, nil, errors.New("not found driver")
}

func ProvideStorageDriver() fx.Option {
func ProvideStorageDriver(withMultipleInstance bool) fx.Option {
return fx.Provide(func(lc fx.Lifecycle) (*sqlstorage.Driver, error) {
driver, stopFn, err := StorageDriver(false)
driver, stopFn, err := StorageDriver(withMultipleInstance)
if err != nil {
return nil, err
}
Expand All @@ -66,7 +66,7 @@ func ProvideStorageDriver() fx.Option {

func ProvideLedgerStorageDriver() fx.Option {
return fx.Options(
ProvideStorageDriver(),
ProvideStorageDriver(false),
fx.Provide(
fx.Annotate(sqlstorage.NewLedgerStorageDriverFromRawDriver,
fx.As(new(storage.Driver[ledger.Store]))),
Expand Down
67 changes: 21 additions & 46 deletions pkg/storage/sqlstorage/accounts_test.go
Original file line number Diff line number Diff line change
@@ -1,58 +1,16 @@
package sqlstorage
package sqlstorage_test

import (
"context"
"os"
"testing"

"github.com/numary/ledger/pkg/core"
"github.com/numary/ledger/pkg/ledger"
"github.com/pborman/uuid"
"github.com/numary/ledger/pkg/storage/sqlstorage"
"github.com/stretchr/testify/assert"
)

func TestAccounts(t *testing.T) {
d := NewDriver("sqlite", &sqliteDB{
directory: os.TempDir(),
dbName: uuid.New(),
}, false)

assert.NoError(t, d.Initialize(context.Background()))

defer func(d *Driver, ctx context.Context) {
assert.NoError(t, d.Close(ctx))
}(d, context.Background())

store, _, err := d.GetLedgerStore(context.Background(), "foo", true)
assert.NoError(t, err)

_, err = store.Initialize(context.Background())
assert.NoError(t, err)

accountTests(t, store)
}

func TestAccountsMultipleInstance(t *testing.T) {
d := NewDriver("sqlite", &sqliteDB{
directory: os.TempDir(),
dbName: uuid.New(),
}, true)

assert.NoError(t, d.Initialize(context.Background()))

defer func(d *Driver, ctx context.Context) {
assert.NoError(t, d.Close(ctx))
}(d, context.Background())

store, _, err := d.GetLedgerStore(context.Background(), "foo", true)
assert.NoError(t, err)

_, err = store.Initialize(context.Background())
assert.NoError(t, err)

accountTests(t, store)
}

func accountTests(t *testing.T, store *Store) {
func testAccounts(t *testing.T, store *sqlstorage.Store) {
t.Run("success balance", func(t *testing.T) {
q := ledger.AccountsQuery{
PageSize: 10,
Expand Down Expand Up @@ -107,4 +65,21 @@ func accountTests(t *testing.T, store *Store) {
_, err := store.GetAccounts(context.Background(), q)
assert.NoError(t, err, "balance operator filter should not fail")
})

t.Run("success get accounts with address filters", func(t *testing.T) {
err := store.Commit(context.Background(), tx1, tx2, tx3, tx4)
assert.NoError(t, err)

q := ledger.AccountsQuery{
PageSize: 10,
Filters: ledger.AccountsQueryFilters{
Address: "users:1",
},
}

accounts, err := store.GetAccounts(context.Background(), q)
assert.NoError(t, err, "balance operator filter should not fail")
assert.Equal(t, len(accounts.Data), 1)
assert.Equal(t, accounts.Data[0].Address, core.AccountAddress("users:1"))
})
}
2 changes: 1 addition & 1 deletion pkg/storage/sqlstorage/balances.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func (s *Store) GetBalances(ctx context.Context, q ledger.BalancesQuery) (api.Cu
continue
}

arg := sb.Args.Add(strings.ReplaceAll(segment, "\\", "\\\\"))
arg := sb.Args.Add("^" + strings.ReplaceAll(segment, "\\", "\\\\") + "$")
sb.Where(fmt.Sprintf("account_json @@ ('$[%d] like_regex \"' || %s::text || '\"')::jsonpath", i, arg))
}
} else {
Expand Down
47 changes: 46 additions & 1 deletion pkg/storage/sqlstorage/balances_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package sqlstorage_test

import (
"context"
"github.com/numary/ledger/pkg/ledgertesting"
"os"
"testing"
"time"

"github.com/numary/ledger/pkg/ledgertesting"

"github.com/numary/ledger/pkg/core"
"github.com/numary/ledger/pkg/ledger"
"github.com/numary/ledger/pkg/storage/sqlstorage"
Expand Down Expand Up @@ -160,6 +161,33 @@ func testGetBalances(t *testing.T, store *sqlstorage.Store) {
})
}

func testGetBalancesOn1Account(t *testing.T, store *sqlstorage.Store) {
err := store.Commit(context.Background(), tx1, tx2, tx3, tx4)
assert.NoError(t, err)

t.Run("on 1 accounts", func(t *testing.T) {
cursor, err := store.GetBalances(context.Background(),
ledger.BalancesQuery{
Filters: ledger.BalancesQueryFilters{
AddressRegexp: []string{"users:1"},
},
PageSize: 10,
})
assert.NoError(t, err)
assert.Equal(t, 10, cursor.PageSize)
assert.Equal(t, false, cursor.HasMore)
assert.Equal(t, "", cursor.Previous)
assert.Equal(t, "", cursor.Next)
assert.Equal(t, []core.AccountsBalances{
{
"users:1": core.AssetsBalances{
"USD": core.NewMonetaryInt(1),
},
},
}, cursor.Data)
})
}

func testGetBalancesAggregated(t *testing.T, store *sqlstorage.Store) {
err := store.Commit(context.Background(), tx1, tx2, tx3)
assert.NoError(t, err)
Expand All @@ -174,6 +202,23 @@ func testGetBalancesAggregated(t *testing.T, store *sqlstorage.Store) {
}, cursor)
}

func testGetBalancesAggregatedByAccount(t *testing.T, store *sqlstorage.Store) {
err := store.Commit(context.Background(), tx1, tx2, tx3, tx4)
assert.NoError(t, err)

q := ledger.AggregatedBalancesQuery{
PageSize: 10,
Filters: ledger.AggregatedBalancesQueryFilters{
AddressRegexp: "users:1",
},
}
cursor, err := store.GetBalancesAggregated(context.Background(), q)
assert.NoError(t, err)
assert.Equal(t, core.AssetsBalances{
"USD": core.NewMonetaryInt(1),
}, cursor)
}

func testGetBalancesBigInts(t *testing.T, store *sqlstorage.Store) {

if os.Getenv("NUMARY_STORAGE_POSTGRES_CONN_STRING") != "" ||
Expand Down
Loading

0 comments on commit 93ed88f

Please sign in to comment.