Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove IncreaseAllowance & add shared nonce #1

Merged
merged 4 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions channel/conclude.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const (
// - it searches for a past concluded event by calling `isConcluded`
// - if found, channel is already concluded and success is returned
// - if none found, conclude/concludeFinal is called on the adjudicator
//
// - it waits for a Concluded event from the blockchain.
func (a *Adjudicator) ensureConcluded(ctx context.Context, req channel.AdjudicatorReq, subStates channel.StateMap) error {
// Check whether it is already concluded.
Expand Down
34 changes: 27 additions & 7 deletions channel/contractbackend.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ const (
// create a TxTimedoutError with additional context.
var errTxTimedOut = errors.New("")

// SharedExpected Nonce is a map of each expected next nonce of all clients.
var (
SharedExpectedNonces map[ChainID]map[common.Address]uint64
SharedNonceMtx map[ChainID]map[common.Address]*sync.Mutex
)

// SharedMutex controls the reads and writes on the nonceMtx and ecpectedNextNonce of the ContractBackend.
var SharedMutex = &sync.Mutex{}

// ContractInterface provides all functions needed by an ethereum backend.
// Both test.SimulatedBackend and ethclient.Client implement this interface.
type ContractInterface interface {
Expand All @@ -63,7 +72,7 @@ type Transactor interface {
type ContractBackend struct {
ContractInterface
tr Transactor
nonceMtx *sync.Mutex
nonceMtx map[common.Address]*sync.Mutex
expectedNextNonce map[common.Address]uint64
txFinalityDepth uint64
chainID ChainID
Expand All @@ -73,11 +82,24 @@ type ContractBackend struct {
// txFinalityDepth defines in how many consecutive blocks a TX has to be
// included to be considered final. Must be at least 1.
func NewContractBackend(cf ContractInterface, chainID ChainID, tr Transactor, txFinalityDepth uint64) ContractBackend {
// Check if the shared maps are initialized, if not, initialize them.
if SharedExpectedNonces == nil {
SharedExpectedNonces = make(map[ChainID]map[common.Address]uint64)
}
if SharedNonceMtx == nil {
SharedNonceMtx = make(map[ChainID]map[common.Address]*sync.Mutex)
}

// Check if the specific chainID entry exists in the shared maps, if not, create it.
if _, exists := SharedExpectedNonces[chainID]; !exists {
SharedExpectedNonces[chainID] = make(map[common.Address]uint64)
SharedNonceMtx[chainID] = make(map[common.Address]*sync.Mutex)
}
return ContractBackend{
ContractInterface: cf,
tr: tr,
expectedNextNonce: make(map[common.Address]uint64),
nonceMtx: &sync.Mutex{},
expectedNextNonce: SharedExpectedNonces[chainID],
nonceMtx: SharedNonceMtx[chainID],
txFinalityDepth: txFinalityDepth,
chainID: chainID,
}
Expand Down Expand Up @@ -165,10 +187,7 @@ func (c *ContractBackend) nonce(ctx context.Context, sender common.Address) (uin
err = cherrors.CheckIsChainNotReachableError(err)
return 0, errors.WithMessage(err, "fetching nonce")
}

// Look up expected next nonce locally.
c.nonceMtx.Lock()
defer c.nonceMtx.Unlock()
SharedMutex.Lock()
expectedNextNonce, found := c.expectedNextNonce[sender]
if !found {
c.expectedNextNonce[sender] = 0
Expand All @@ -181,6 +200,7 @@ func (c *ContractBackend) nonce(ctx context.Context, sender common.Address) (uin

// Update local expectation.
c.expectedNextNonce[sender] = nonce + 1
SharedMutex.Unlock()
return nonce, nil
}

Expand Down
135 changes: 118 additions & 17 deletions channel/erc20_depositor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@ package channel

import (
"context"
"fmt"
"math/big"
"sync"

"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/pkg/errors"
"perun.network/go-perun/log"

"github.com/perun-network/perun-eth-backend/bindings/assetholdererc20"
"github.com/perun-network/perun-eth-backend/bindings/peruntoken"
Expand All @@ -40,42 +45,138 @@ const ERC20DepositorTXGasLimit = 100000
// Return value of ERC20Depositor.NumTx.
const erc20DepositorNumTx = 2

// Keep track of the increase allowance and deposit processes.
var mu sync.Mutex
var locks = make(map[string]*sync.Mutex)

// DepositResult is created to keep track of the returned values.
type DepositResult struct {
Transactions types.Transactions
Error error
}

// Create key from account address and asset to only lock the process when hub deposits the same asset at the same time.
func lockKey(account common.Address, asset common.Address) string {
return fmt.Sprintf("%s-%s", account.Hex(), asset.Hex())
}

// Retrieves Lock for specific key.
func handleLock(lockKey string) *sync.Mutex {
mu.Lock()
defer mu.Unlock()

if lock, exists := locks[lockKey]; exists {
return lock
}

lock := &sync.Mutex{}
locks[lockKey] = lock
return lock
}

// Locks the lock argument, runs the given function and then unlocks the lock argument.
func lockAndUnlock(lock *sync.Mutex, fn func()) {
mu.Lock()
defer mu.Unlock()
lock.Lock()
defer lock.Unlock()
fn()
}

// NewERC20Depositor creates a new ERC20Depositor.
func NewERC20Depositor(token common.Address) *ERC20Depositor {
return &ERC20Depositor{Token: token}
}

// Deposit deposits ERC20 tokens into the ERC20 AssetHolder specified at the
// request's asset address.
// Deposit approves the value to be swapped and calls DepositOnly.
//
//nolint:funlen
func (d *ERC20Depositor) Deposit(ctx context.Context, req DepositReq) (types.Transactions, error) {
// Bind a `AssetHolderERC20` instance.
assetholder, err := assetholdererc20.NewAssetholdererc20(req.Asset.EthAddress(), req.CB)
if err != nil {
return nil, errors.Wrapf(err, "binding AssetHolderERC20 contract at: %x", req.Asset)
}
lockKey := lockKey(req.Account.Address, req.Asset.EthAddress())
lock := handleLock(lockKey)

// Bind an `ERC20` instance.
token, err := peruntoken.NewPeruntoken(d.Token, req.CB)
if err != nil {
return nil, errors.Wrapf(err, "binding ERC20 contract at: %x", d.Token)
}
// Increase the allowance.
opts, err := req.CB.NewTransactor(ctx, ERC20DepositorTXGasLimit, req.Account)
if err != nil {
return nil, errors.WithMessagef(err, "creating transactor for asset: %x", req.Asset)
callOpts := bind.CallOpts{
Pending: false,
Context: ctx,
}
tx1, err := token.IncreaseAllowance(opts, req.Asset.EthAddress(), req.Balance)
// variables for the return value.
var depResult DepositResult
var approvalReceived bool
var tx1 *types.Transaction
var err1 error
lockAndUnlock(lock, func() {
allowance, err := token.Allowance(&callOpts, req.Account.Address, req.Asset.EthAddress())
if err != nil {
depResult.Transactions = nil
depResult.Error = errors.WithMessagef(err, "could not get Allowance for asset: %x", req.Asset)
}
result := new(big.Int).Add(req.Balance, allowance)

// Increase the allowance.
opts, err := req.CB.NewTransactor(ctx, ERC20DepositorTXGasLimit, req.Account)
if err != nil {
depResult.Transactions = nil
depResult.Error = errors.WithMessagef(err, "creating transactor for asset: %x", req.Asset)
}
// Create a channel for receiving PeruntokenApproval events
eventSink := make(chan *peruntoken.PeruntokenApproval)

// Create a channel for receiving the Approval event
eventReceived := make(chan bool)

// Watch for Approval events and send them to the eventSink
subscription, err := token.WatchApproval(&bind.WatchOpts{Start: nil, Context: ctx}, eventSink, []common.Address{req.Account.Address}, []common.Address{req.Asset.EthAddress()})
if err != nil {
depResult.Transactions = nil
depResult.Error = errors.WithMessagef(err, "Cannot listen for event")
}
tx1, err1 = token.Approve(opts, req.Asset.EthAddress(), result)
if err1 != nil {
err = cherrors.CheckIsChainNotReachableError(err)
depResult.Transactions = nil
depResult.Error = errors.WithMessagef(err, "increasing allowance for asset: %x", req.Asset)
}

go func() {
select {
case event := <-eventSink:
log.Printf("Received Approval event: Owner: %s, Spender: %s, Value: %s\n", event.Owner.Hex(), event.Spender.Hex(), event.Value.String())
eventReceived <- true
case err := <-subscription.Err():
log.Println("Subscription error:", err)
}
}()
approvalReceived = <-eventReceived
})
if approvalReceived {
tx2, err := d.DepositOnly(ctx, req)
depResult.Transactions = []*types.Transaction{tx1, tx2}
depResult.Error = errors.WithMessage(err, "AssetHolderERC20 depositing")
}
return depResult.Transactions, depResult.Error
}

// DepositOnly deposits ERC20 tokens into the ERC20 AssetHolder specified at the
// requests asset address.
func (d *ERC20Depositor) DepositOnly(ctx context.Context, req DepositReq) (*types.Transaction, error) {
// Bind a `AssetHolderERC20` instance.
assetholder, err := assetholdererc20.NewAssetholdererc20(req.Asset.EthAddress(), req.CB)
if err != nil {
err = cherrors.CheckIsChainNotReachableError(err)
return nil, errors.WithMessagef(err, "increasing allowance for asset: %x", req.Asset)
return nil, errors.Wrapf(err, "binding AssetHolderERC20 contract at: %x", req.Asset)
}
// Deposit.
opts, err = req.CB.NewTransactor(ctx, ERC20DepositorTXGasLimit, req.Account)
opts, err := req.CB.NewTransactor(ctx, ERC20DepositorTXGasLimit, req.Account)
if err != nil {
return nil, errors.WithMessagef(err, "creating transactor for asset: %x", req.Asset)
}

tx2, err := assetholder.Deposit(opts, req.FundingID, req.Balance)
err = cherrors.CheckIsChainNotReachableError(err)
return []*types.Transaction{tx1, tx2}, errors.WithMessage(err, "AssetHolderERC20 depositing")
return tx2, err
}

// NumTX returns 2 since it does IncreaseAllowance and Deposit.
Expand Down
4 changes: 2 additions & 2 deletions channel/funder.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ func (f *Funder) Fund(ctx context.Context, request channel.FundingReq) error {
nonFundingErrg := perror.NewGatherer()
for _, err := range perror.Causes(errg.Wait()) {
if channel.IsAssetFundingError(err) && err != nil {
fudingErr, ok := err.(*channel.AssetFundingError)
fundingErr, ok := err.(*channel.AssetFundingError)
if !ok {
return fmt.Errorf("wrong type: expected %T, got %T", &channel.AssetFundingError{}, err)
}
fundingErrs = append(fundingErrs, fudingErr)
fundingErrs = append(fundingErrs, fundingErr)
} else if err != nil {
nonFundingErrg.Add(err)
}
Expand Down
13 changes: 5 additions & 8 deletions channel/withdraw.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,19 @@ import (
"perun.network/go-perun/log"
)

// Withdraw ensures that a channel has been concluded and the final outcome
// Withdraw ensures that a channel has been concluded and the final outcome.
// withdrawn from the asset holders.
func (a *Adjudicator) Withdraw(ctx context.Context, req channel.AdjudicatorReq, subStates channel.StateMap) error {
if err := a.ensureConcluded(ctx, req, subStates); err != nil {
return errors.WithMessage(err, "ensure Concluded")
}

if err := a.checkConcludedState(ctx, req, subStates); err != nil {
return errors.WithMessage(err, "check concluded state")
}

return errors.WithMessage(a.ensureWithdrawn(ctx, req), "ensure Withdrawn")
}

// ensureWithdrawn ensures that the channel has been withdrawn from the asset.
func (a *Adjudicator) ensureWithdrawn(ctx context.Context, req channel.AdjudicatorReq) error {
g, ctx := errgroup.WithContext(ctx)

Expand All @@ -75,7 +74,7 @@ func (a *Adjudicator) ensureWithdrawn(ctx context.Context, req channel.Adjudicat
}
defer sub.Close()

// Check for past event
// Check for past event.
if err := sub.ReadPast(ctx, events); err != nil {
return errors.WithMessage(err, "reading past events")
}
Expand All @@ -90,7 +89,7 @@ func (a *Adjudicator) ensureWithdrawn(ctx context.Context, req channel.Adjudicat
return errors.WithMessage(err, "withdrawing assets failed")
}

// Wait for event
// Wait for event.
go func() {
subErr <- sub.Read(ctx, events)
}()
Expand Down Expand Up @@ -146,13 +145,11 @@ func (a *Adjudicator) callAssetWithdraw(ctx context.Context, request channel.Adj
if err != nil {
return nil, errors.WithMessagef(err, "creating transactor for asset %d", asset.assetIndex)
}

tx, err := asset.Withdraw(trans, auth, sig)
if err != nil {
err = cherrors.CheckIsChainNotReachableError(err)
return nil, errors.WithMessagef(err, "withdrawing asset %d", asset.assetIndex)
return nil, errors.WithMessagef(err, "withdrawing asset %d with transaction nonce %d", asset.assetIndex, trans.Nonce)
}
log.Debugf("Sent transaction %v", tx.Hash().Hex())
return tx, nil
}()
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion client/fund_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
func TestFundRecovery(t *testing.T) {
rng := test.Prng(t)

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

ctest.TestFundRecovery(
Expand Down
2 changes: 1 addition & 1 deletion client/payment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import (
)

const (
twoPartyTestTimeout = 10 * time.Second
twoPartyTestTimeout = 60 * time.Second
TxFinalityDepth = 3
)

Expand Down
6 changes: 3 additions & 3 deletions client/test/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import (

const (
// DefaultTimeout is the default timeout for client tests.
DefaultTimeout = 5 * time.Second
DefaultTimeout = 20 * time.Second
// BlockInterval is the default block interval for the simulated chain.
BlockInterval = 100 * time.Millisecond
BlockInterval = 200 * time.Millisecond
// challenge duration in blocks that is used by MakeRoleSetups.
challengeDurationBlocks = 60
challengeDurationBlocks = 90
)

// MakeRoleSetups creates a two party client test setup with the provided names.
Expand Down
Loading