diff --git a/components/ledger/internal/engine/command/commander.go b/components/ledger/internal/engine/command/commander.go index 13fc302ac0..2f41c8c897 100644 --- a/components/ledger/internal/engine/command/commander.go +++ b/components/ledger/internal/engine/command/commander.go @@ -17,7 +17,6 @@ import ( "github.com/formancehq/ledger/internal/bus" "github.com/formancehq/ledger/internal/engine/utils/batching" "github.com/formancehq/ledger/internal/machine/vm" - "github.com/formancehq/stack/libs/go-libs/collectionutils" "github.com/formancehq/stack/libs/go-libs/metadata" "github.com/pkg/errors" ) @@ -126,22 +125,13 @@ func (commander *Commander) exec(ctx context.Context, parameters Parameters, scr return nil, NewErrCompilationFailed(err) } - involvedAccounts, involvedSources, err := func() ([]string, []string, error) { - involvedAccounts, involvedSources, err := m.ResolveResources(ctx, commander.store) - if err != nil { - return nil, nil, NewErrCompilationFailed(err) - } - - return involvedAccounts, involvedSources, nil - }() + readLockAccounts, writeLockAccounts, err := m.ResolveResources(ctx, commander.store) if err != nil { - return nil, err + return nil, NewErrCompilationFailed(err) } - - worldFilter := collectionutils.FilterNot(collectionutils.FilterEq("world")) lockAccounts := Accounts{ - Read: collectionutils.Filter(involvedAccounts, worldFilter), - Write: collectionutils.Filter(involvedSources, worldFilter), + Read: readLockAccounts, + Write: writeLockAccounts, } unlock, err := func() (Unlock, error) { diff --git a/components/ledger/internal/machine/script/compiler/compiler.go b/components/ledger/internal/machine/script/compiler/compiler.go index 207673918c..7507963233 100644 --- a/components/ledger/internal/machine/script/compiler/compiler.go +++ b/components/ledger/internal/machine/script/compiler/compiler.go @@ -26,6 +26,15 @@ type parseVisitor struct { varIdx map[string]machine.Address // needBalances store for each account, the set of assets needed neededBalances map[machine.Address]map[machine.Address]struct{} + + // The sources accounts that aren't unbounded + // that is, @world or sources that appear within a + // '.. allowing unboundeed overdraft' clause + writeLockAccounts map[machine.Address]struct{} + + // all the accounts that appear in either the destination + // or in the balance() function + readLockAccounts map[machine.Address]struct{} } // Allocates constants if it hasn't already been, @@ -580,6 +589,7 @@ func (p *parseVisitor) VisitVars(c *parser.VarListDeclContext) *CompileError { Account: *accAddr, Asset: *assAddr, }) + p.readLockAccounts[*accAddr] = struct{}{} if err != nil { return LogicError(c, err) } @@ -674,12 +684,14 @@ func CompileFull(input string) CompileArtifacts { } visitor := parseVisitor{ - errListener: errListener, - instructions: make([]byte, 0), - resources: make([]program.Resource, 0), - varIdx: make(map[string]machine.Address), - neededBalances: make(map[machine.Address]map[machine.Address]struct{}), - sources: map[machine.Address]struct{}{}, + errListener: errListener, + instructions: make([]byte, 0), + resources: make([]program.Resource, 0), + varIdx: make(map[string]machine.Address), + neededBalances: make(map[machine.Address]map[machine.Address]struct{}), + sources: map[machine.Address]struct{}{}, + writeLockAccounts: map[machine.Address]struct{}{}, + readLockAccounts: map[machine.Address]struct{}{}, } err := visitor.VisitScript(tree) @@ -688,17 +700,24 @@ func CompileFull(input string) CompileArtifacts { return artifacts } - sources := make(machine.Addresses, 0) - for address := range visitor.sources { - sources = append(sources, address) + readLockAccounts := make(machine.Addresses, 0) + for address := range visitor.readLockAccounts { + readLockAccounts = append(readLockAccounts, address) + } + sort.Stable(readLockAccounts) + + writeLockAccounts := make(machine.Addresses, 0) + for address := range visitor.writeLockAccounts { + writeLockAccounts = append(writeLockAccounts, address) } - sort.Stable(sources) + sort.Stable(writeLockAccounts) artifacts.Program = &program.Program{ - Instructions: visitor.instructions, - Resources: visitor.resources, - NeededBalances: visitor.neededBalances, - Sources: sources, + Instructions: visitor.instructions, + Resources: visitor.resources, + NeededBalances: visitor.neededBalances, + ReadLockAccounts: readLockAccounts, + WriteLockAccounts: writeLockAccounts, } return artifacts diff --git a/components/ledger/internal/machine/script/compiler/compiler_test.go b/components/ledger/internal/machine/script/compiler/compiler_test.go index b8b02870d8..a3b694022d 100644 --- a/components/ledger/internal/machine/script/compiler/compiler_test.go +++ b/components/ledger/internal/machine/script/compiler/compiler_test.go @@ -1045,7 +1045,7 @@ func TestSetAccountMeta(t *testing.T) { program2.OP_ASSET, program2.OP_APUSH, 04, 00, program2.OP_MONETARY_NEW, - program2.OP_TAKE_ALL, + program2.OP_TAKE_ALWAYS, program2.OP_APUSH, 02, 00, program2.OP_TAKE_MAX, program2.OP_APUSH, 05, 00, @@ -1177,7 +1177,7 @@ func TestVariableBalance(t *testing.T) { program2.OP_ASSET, program2.OP_APUSH, 04, 00, program2.OP_MONETARY_NEW, - program2.OP_TAKE_ALL, + program2.OP_TAKE_ALWAYS, program2.OP_APUSH, 02, 00, program2.OP_TAKE_MAX, program2.OP_APUSH, 05, 00, diff --git a/components/ledger/internal/machine/script/compiler/destination.go b/components/ledger/internal/machine/script/compiler/destination.go index 6b941143ee..1b454c13fa 100644 --- a/components/ledger/internal/machine/script/compiler/destination.go +++ b/components/ledger/internal/machine/script/compiler/destination.go @@ -23,7 +23,7 @@ func (p *parseVisitor) VisitDestinationRecursive(c parser.IDestinationContext) * case *parser.DestAccountContext: p.AppendInstruction(program.OP_FUNDING_SUM) p.AppendInstruction(program.OP_TAKE) - ty, _, err := p.VisitExpr(c.Expression(), true) + ty, destAddr, err := p.VisitExpr(c.Expression(), true) if err != nil { return err } @@ -32,6 +32,9 @@ func (p *parseVisitor) VisitDestinationRecursive(c parser.IDestinationContext) * errors.New("wrong type: expected account as destination"), ) } + if !p.isWorld(*destAddr) { + p.readLockAccounts[*destAddr] = struct{}{} + } p.AppendInstruction(program.OP_SEND) return nil case *parser.DestInOrderContext: diff --git a/components/ledger/internal/machine/script/compiler/source.go b/components/ledger/internal/machine/script/compiler/source.go index 9db97a8d47..a3d3dc0fb3 100644 --- a/components/ledger/internal/machine/script/compiler/source.go +++ b/components/ledger/internal/machine/script/compiler/source.go @@ -99,6 +99,25 @@ func (p *parseVisitor) TakeFromSource(fallback *FallbackAccount) error { return nil } +func (p parseVisitor) isOverdraftUnbounded(overdraftCtx parser.ISourceAccountOverdraftContext) bool { + if overdraftCtx == nil { + return false + } + + switch overdraftCtx.(type) { + case *parser.SrcAccountOverdraftUnboundedContext: + return true + case *parser.SrcAccountOverdraftSpecificContext: + return false + + default: + // even though this branch should be unreachable, + // we default to `false` instead of panicking + // in order to have a more conservative behaviour + return false + } +} + // VisitSource returns the resource addresses of all the accounts, // the addresses of accounts already emptied, // and possibly a fallback account if the source has an unbounded overdraft allowance or contains @world @@ -129,7 +148,11 @@ func (p *parseVisitor) VisitSource(c parser.ISourceContext, pushAsset func(), is return nil, nil, nil, LogicError(c, err) } p.AppendInstruction(program.OP_MONETARY_NEW) - p.AppendInstruction(program.OP_TAKE_ALL) + if p.isWorld(*accAddr) { + p.AppendInstruction(program.OP_TAKE_ALWAYS) + } else { + p.AppendInstruction(program.OP_TAKE_ALL) + } } else { if p.isWorld(*accAddr) { return nil, nil, nil, LogicError(c, errors.New("@world is already set to an unbounded overdraft")) @@ -151,12 +174,18 @@ func (p *parseVisitor) VisitSource(c parser.ISourceContext, pushAsset func(), is return nil, nil, nil, LogicError(c, err) } p.AppendInstruction(program.OP_MONETARY_NEW) - p.AppendInstruction(program.OP_TAKE_ALL) + p.AppendInstruction(program.OP_TAKE_ALWAYS) f := FallbackAccount(*accAddr) fallback = &f } } - neededAccounts[*accAddr] = struct{}{} + + isUnboundedOverdraft := p.isWorld(*accAddr) || p.isOverdraftUnbounded(overdraft) + if !isUnboundedOverdraft { + p.writeLockAccounts[*accAddr] = struct{}{} + neededAccounts[*accAddr] = struct{}{} + } + emptiedAccounts[*accAddr] = struct{}{} if fallback != nil && isAll { diff --git a/components/ledger/internal/machine/vm/machine.go b/components/ledger/internal/machine/vm/machine.go index 276fe94496..c6b6123ec5 100644 --- a/components/ledger/internal/machine/vm/machine.go +++ b/components/ledger/internal/machine/vm/machine.go @@ -13,6 +13,7 @@ import ( "encoding/binary" "fmt" "math/big" + "slices" "github.com/formancehq/ledger/internal/machine" @@ -139,16 +140,16 @@ func (m *Machine) withdrawAlways(account machine.AccountAddress, mon machine.Mon if accBalance, ok := m.Balances[account]; ok { if balance, ok := accBalance[mon.Asset]; ok { accBalance[mon.Asset] = balance.Sub(mon.Amount) - return &machine.Funding{ - Asset: mon.Asset, - Parts: []machine.FundingPart{{ - Account: account, - Amount: mon.Amount, - }}, - }, nil } } - return nil, fmt.Errorf("missing %v balance from %v", mon.Asset, account) + + return &machine.Funding{ + Asset: mon.Asset, + Parts: []machine.FundingPart{{ + Account: account, + Amount: mon.Amount, + }}, + }, nil } func (m *Machine) credit(account machine.AccountAddress, funding machine.Funding) { @@ -170,8 +171,16 @@ func (m *Machine) repay(funding machine.Funding) { if part.Account == "world" { continue } - balance := m.Balances[part.Account][funding.Asset] - m.Balances[part.Account][funding.Asset] = balance.Add(part.Amount) + accountBalance, ok := m.Balances[part.Account] + if !ok { + // no asset: the source has to be an unbounded source + // which NEVER appears as bounded + // this means we don't need to track it's balance + continue + } + + balance := accountBalance[funding.Asset] + accountBalance[funding.Asset] = balance.Add(part.Amount) } } @@ -574,6 +583,9 @@ func (m *Machine) ResolveResources(ctx context.Context, store Store) ([]string, if err != nil { return nil, nil, err } + if val.GetType() == machine.TypeAccount { + involvedAccountsMap[machine.Address(idx)] = string(val.(machine.AccountAddress)) + } case program.VariableAccountBalance: acc, _ := m.getResource(res.Account) address := string((*acc).(machine.AccountAddress)) @@ -607,16 +619,19 @@ func (m *Machine) ResolveResources(ctx context.Context, store Store) ([]string, m.Resources = append(m.Resources, val) } - involvedAccounts := make([]string, 0) - involvedSources := make([]string, 0) - for _, accountAddress := range involvedAccountsMap { - involvedAccounts = append(involvedAccounts, accountAddress) + readLockAccounts := make([]string, 0) + for _, accountAddress := range m.Program.ReadLockAccounts { + readLockAccounts = append(readLockAccounts, involvedAccountsMap[accountAddress]) } - for _, machineAddress := range m.Program.Sources { - involvedSources = append(involvedSources, involvedAccountsMap[machineAddress]) + + writeLockAccounts := make([]string, 0) + for _, machineAddress := range m.Program.WriteLockAccounts { + writeLockAccounts = append(writeLockAccounts, involvedAccountsMap[machineAddress]) } - return involvedAccounts, involvedSources, nil + slices.Sort(readLockAccounts) + slices.Sort(writeLockAccounts) + return readLockAccounts, writeLockAccounts, nil } func (m *Machine) SetVarsFromJSON(vars map[string]string) error { diff --git a/components/ledger/internal/machine/vm/machine_test.go b/components/ledger/internal/machine/vm/machine_test.go index c60b73d30f..e2de91469c 100644 --- a/components/ledger/internal/machine/vm/machine_test.go +++ b/components/ledger/internal/machine/vm/machine_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math/big" + "slices" "sync" "testing" @@ -528,6 +529,59 @@ func TestWorldSource(t *testing.T) { test(t, tc) } +func TestUnboundedSourceSimple(t *testing.T) { + tc := NewTestCase() + tc.compile(t, `send [GEM 15] ( + source = @unbounded allowing unbounded overdraft + destination = @b + )`) + tc.expected = CaseResult{ + Printed: []machine.Value{}, + Postings: []Posting{ + + { + Asset: "GEM", + Amount: machine.NewMonetaryInt(15), + Source: "unbounded", + Destination: "b", + }, + }, + Error: nil, + } + test(t, tc) +} + +func TestUnboundedSourceInorder(t *testing.T) { + tc := NewTestCase() + tc.compile(t, `send [GEM 15] ( + source = { + @a + @unbounded allowing unbounded overdraft + } + destination = @b + )`) + tc.setBalance("a", "GEM", 1) + tc.expected = CaseResult{ + Printed: []machine.Value{}, + Postings: []Posting{ + { + Asset: "GEM", + Amount: machine.NewMonetaryInt(1), + Source: "a", + Destination: "b", + }, + { + Asset: "GEM", + Amount: machine.NewMonetaryInt(14), + Source: "unbounded", + Destination: "b", + }, + }, + Error: nil, + } + test(t, tc) +} + func TestNoEmptyPostings(t *testing.T) { tc := NewTestCase() tc.compile(t, `send [GEM 2] ( @@ -924,11 +978,23 @@ func TestNeededBalances(t *testing.T) { } send [GEM 15] ( source = { + // normal accounts are tracked $a @b - @world + + // we don't want to track world, as it is an unbounded account + max [GEM 1] from @world + + // we want to lock bounded overdrafts account + @bounded allowing overdraft up to [GEM 1] + + // we don't want to lock unbounded overdrafts account + @unb allowing unbounded overdraft + } + destination = { + max [GEM 1] to @c + remaining to @world } - destination = @c )`) if err != nil { @@ -943,11 +1009,102 @@ func TestNeededBalances(t *testing.T) { if err != nil { t.Fatalf("did not expect error on SetVars, got: %v", err) } - _, _, err = m.ResolveResources(context.Background(), EmptyStore) + readLockAccounts, writeLockAccounts, err := m.ResolveResources(context.Background(), EmptyStore) require.NoError(t, err) + require.Equalf(t, []string{"c"}, readLockAccounts, "readlock") + require.Equalf(t, []string{"a", "b", "bounded"}, writeLockAccounts, "writelock") - err = m.ResolveBalances(context.Background(), EmptyStore) + store := mockStore{} + err = m.ResolveBalances(context.Background(), &store) require.NoError(t, err) + + require.Equal(t, []string{"a", "b", "bounded"}, store.GetRequestedAccounts()) +} + +func TestNeededBalances2(t *testing.T) { + p, err := compiler.Compile(` + send [GEM 15] ( + source = { + // we want to track a balance even if it appears later on + // as an unbounded overdraft + max [GEM 1] from @a + @a allowing unbounded overdraft + } + destination = @c + )`) + + if err != nil { + t.Fatalf("did not expect error on Compile, got: %v", err) + } + + m := NewMachine(*p) + _, involvedSources, err := m.ResolveResources(context.Background(), EmptyStore) + require.NoError(t, err) + require.Equal(t, []string{"a"}, involvedSources) + +} + +func TestNeededBalancesBalanceFn(t *testing.T) { + p, err := compiler.Compile(`vars { + monetary $balance = balance(@acc, COIN) +} + +send $balance ( + source = @a + destination = @b +)`) + + if err != nil { + t.Fatalf("did not expect error on Compile, got: %v", err) + } + + m := NewMachine(*p) + rlAccounts, wlAccounts, err := m.ResolveResources(context.Background(), EmptyStore) + require.NoError(t, err) + require.Equal(t, []string{"a"}, wlAccounts) + require.Equal(t, []string{"acc", "b"}, rlAccounts) + + store := mockStore{} + err = m.ResolveBalances(context.Background(), &store) + require.NoError(t, err) + require.Equal(t, []string{"a", "acc"}, store.GetRequestedAccounts()) +} + +func TestNeededBalancesBalanceOfMeta(t *testing.T) { + p, err := compiler.Compile(`vars { + account $src = meta(@x, "k") +} + +send [COIN 1] ( + source = $src + destination = @dest +)`) + + if err != nil { + t.Fatalf("did not expect error on Compile, got: %v", err) + } + m := NewMachine(*p) + + staticStore := StaticStore{ + "x": &AccountWithBalances{ + Account: ledger.Account{ + Address: "x", + Metadata: metadata.Metadata{ + "k": "src", + }, + }, + Balances: map[string]*big.Int{}, + }, + } + rlAccounts, wlAccounts, err := m.ResolveResources(context.Background(), staticStore) + require.NoError(t, err) + require.Equal(t, []string{"src"}, wlAccounts) + require.Equal(t, []string{"dest"}, rlAccounts) + + store := mockStore{} + err = m.ResolveBalances(context.Background(), &store) + require.NoError(t, err) + require.Equal(t, []string{"src"}, store.GetRequestedAccounts()) } func TestSetTxMeta(t *testing.T) { @@ -2181,3 +2338,80 @@ send [COIN 100] ( } test(t, tc) } + +func TestRepayUnboundedMinimal(t *testing.T) { + tc := NewTestCase() + + tc.compile(t, ` +send [COIN 100]( + source = @src allowing unbounded overdraft + destination = { + max [COIN 1] to @d1 + remaining to @d2 + } + ) +`) + + tc.expected = CaseResult{ + Printed: []machine.Value{}, + Postings: []Posting{ + {"src", "d1", machine.NewMonetaryInt(1), "COIN"}, + {"src", "d2", machine.NewMonetaryInt(99), "COIN"}, + }, + } + test(t, tc) +} + +func TestRepayUnboundedComplex(t *testing.T) { + tc := NewTestCase() + + tc.compile(t, ` +send [EGP 86640]( + source = { + max [EGP 86640] from @asset:current_assets allowing unbounded overdraft + } + destination = { + max [EGP 86466] to @liability:client_balances + max [EGP 9] to @liability:current_liabilities:1 + max [EGP 9] to @liability:current_liabilities:2 + max [EGP 100] to @liability:current_liabilities:3 + max [EGP 4] to @liability:current_liabilities:4 + max [EGP 43] to @liability:current_liabilities:checks:5 + remaining to @liability:current_liabilities:6 + } + ) + +`) + + tc.expected = CaseResult{ + Printed: []machine.Value{}, + Postings: []Posting{ + {"asset:current_assets", "liability:client_balances", machine.NewMonetaryInt(86466), "EGP"}, + {"asset:current_assets", "liability:current_liabilities:1", machine.NewMonetaryInt(9), "EGP"}, + {"asset:current_assets", "liability:current_liabilities:2", machine.NewMonetaryInt(9), "EGP"}, + {"asset:current_assets", "liability:current_liabilities:3", machine.NewMonetaryInt(100), "EGP"}, + {"asset:current_assets", "liability:current_liabilities:4", machine.NewMonetaryInt(4), "EGP"}, + {"asset:current_assets", "liability:current_liabilities:checks:5", machine.NewMonetaryInt(43), "EGP"}, + {"asset:current_assets", "liability:current_liabilities:6", machine.NewMonetaryInt(9), "EGP"}, + }, + } + test(t, tc) +} + +type mockStore struct { + requestedAccounts []string +} + +func (s *mockStore) GetRequestedAccounts() []string { + slices.Sort(s.requestedAccounts) + return s.requestedAccounts +} + +func (s *mockStore) GetBalance(ctx context.Context, address, asset string) (*big.Int, error) { + s.requestedAccounts = append(s.requestedAccounts, address) + return big.NewInt(0), nil +} + +func (s *mockStore) GetAccount(ctx context.Context, address string) (*ledger.Account, error) { + panic("not implemented") +} diff --git a/components/ledger/internal/machine/vm/program/program.go b/components/ledger/internal/machine/vm/program/program.go index 882b8202bb..fb93e8f9ae 100644 --- a/components/ledger/internal/machine/vm/program/program.go +++ b/components/ledger/internal/machine/vm/program/program.go @@ -12,8 +12,10 @@ import ( type Program struct { Instructions []byte Resources []Resource - Sources []machine.Address NeededBalances map[machine.Address]map[machine.Address]struct{} + + ReadLockAccounts []machine.Address + WriteLockAccounts []machine.Address } func (p Program) String() string {