From 1775ebd3bfa90a29653c16d99d4a0253ee86f3b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Thu, 2 Mar 2023 11:36:16 -0800 Subject: [PATCH] enforce account links to be private change the linkAccount function parameter type to PrivatePath --- docs/language/accounts.mdx | 1 + runtime/account_test.go | 105 ------------------ runtime/sema/authaccount_type.go | 2 +- runtime/tests/checker/account_test.go | 32 +++--- runtime/tests/interpreter/capability_test.go | 76 ++++--------- .../tests/interpreter/memory_metering_test.go | 10 +- 6 files changed, 48 insertions(+), 178 deletions(-) diff --git a/docs/language/accounts.mdx b/docs/language/accounts.mdx index d3cd026453..5e5ad01a98 100644 --- a/docs/language/accounts.mdx +++ b/docs/language/accounts.mdx @@ -137,6 +137,7 @@ to the `prepare` phase of the transaction. fun borrow(from: StoragePath): T? fun link(_ newCapabilityPath: CapabilityPath, target: Path): Capability? + fun linkAccount(_ newCapabilityPath: PrivatePath): Capability<&AuthAccount>? fun getCapability(_ path: CapabilityPath): Capability fun getLinkTarget(_ path: CapabilityPath): Path? fun unlink(_ path: CapabilityPath) diff --git a/runtime/account_test.go b/runtime/account_test.go index a5b82036fd..e952340444 100644 --- a/runtime/account_test.go +++ b/runtime/account_test.go @@ -2595,111 +2595,6 @@ func TestRuntimeAccountLink(t *testing.T) { assert.ErrorContains(t, err, "value of type `AuthAccount` has no member `linkAccount`") }) - t.Run("enabled, pragma", func(t *testing.T) { - - t.Parallel() - - runtime := NewInterpreterRuntime(Config{ - AtreeValidationEnabled: true, - AccountLinkingEnabled: true, - }) - - address1 := common.MustBytesToAddress([]byte{0x1}) - address2 := common.MustBytesToAddress([]byte{0x2}) - - accountCodes := map[Location][]byte{} - var logs []string - - signerAccount := address1 - - runtimeInterface := &testRuntimeInterface{ - getCode: func(location Location) (bytes []byte, err error) { - return accountCodes[location], nil - }, - storage: newTestLedger(nil, nil), - getSigningAccounts: func() ([]Address, error) { - return []Address{signerAccount}, nil - }, - resolveLocation: singleIdentifierLocationResolver(t), - getAccountContractCode: func(address Address, name string) (code []byte, err error) { - location := common.AddressLocation{ - Address: address, - Name: name, - } - return accountCodes[location], nil - }, - updateAccountContractCode: func(address Address, name string, code []byte) (err error) { - location := common.AddressLocation{ - Address: address, - Name: name, - } - accountCodes[location] = code - return nil - }, - log: func(message string) { - logs = append(logs, message) - }, - } - - nextTransactionLocation := newTransactionLocationGenerator() - - // Set up account - - setupTransaction := []byte(` - #allowAccountLinking - - transaction { - prepare(acct: AuthAccount) { - acct.linkAccount(/public/foo) - } - } - `) - - signerAccount = address1 - - err := runtime.ExecuteTransaction( - Script{ - Source: setupTransaction, - }, - Context{ - Interface: runtimeInterface, - Location: nextTransactionLocation(), - }, - ) - require.NoError(t, err) - - // Access - - accessTransaction := []byte(` - transaction { - prepare(acct: AuthAccount) { - let ref = getAccount(0x1) - .getCapability<&AuthAccount>(/public/foo) - .borrow()! - log(ref.address) - } - } - `) - - signerAccount = address2 - - err = runtime.ExecuteTransaction( - Script{ - Source: accessTransaction, - }, - Context{ - Interface: runtimeInterface, - Location: nextTransactionLocation(), - }, - ) - require.NoError(t, err) - - require.Equal(t, - []string{"0x0000000000000001"}, - logs, - ) - }) - t.Run("publish and claim", func(t *testing.T) { t.Parallel() diff --git a/runtime/sema/authaccount_type.go b/runtime/sema/authaccount_type.go index c8c03c80aa..df1ee6c736 100644 --- a/runtime/sema/authaccount_type.go +++ b/runtime/sema/authaccount_type.go @@ -80,7 +80,7 @@ var AuthAccountType = func() *CompositeType { { Label: ArgumentLabelNotRequired, Identifier: "newCapabilityPath", - TypeAnnotation: NewTypeAnnotation(CapabilityPathType), + TypeAnnotation: NewTypeAnnotation(PrivatePathType), }, }, ReturnTypeAnnotation: NewTypeAnnotation( diff --git a/runtime/tests/checker/account_test.go b/runtime/tests/checker/account_test.go index c67160ccb7..807dad6016 100644 --- a/runtime/tests/checker/account_test.go +++ b/runtime/tests/checker/account_test.go @@ -1163,27 +1163,29 @@ func TestCheckAccount_linkAccount(t *testing.T) { }, ) - if tc.enabled { - if tc.allowed { - switch tc.domain { - case common.PathDomainPrivate, common.PathDomainPublic: - require.NoError(t, err) - - default: - errs := RequireCheckerErrors(t, err, 1) + if !tc.enabled { + errs := RequireCheckerErrors(t, err, 1) - require.IsType(t, &sema.TypeMismatchError{}, errs[0]) - } - } else { - errs := RequireCheckerErrors(t, err, 1) + require.IsType(t, &sema.NotDeclaredMemberError{}, errs[0]) + return + } - require.IsType(t, &sema.NotDeclaredMemberError{}, errs[0]) - } - } else { + if !tc.allowed { errs := RequireCheckerErrors(t, err, 1) require.IsType(t, &sema.NotDeclaredMemberError{}, errs[0]) + + return } + + if tc.domain != common.PathDomainPrivate { + errs := RequireCheckerErrors(t, err, 1) + + require.IsType(t, &sema.TypeMismatchError{}, errs[0]) + return + } + + require.NoError(t, err) }) } diff --git a/runtime/tests/interpreter/capability_test.go b/runtime/tests/interpreter/capability_test.go index 2d6dd575cc..28517c2f1e 100644 --- a/runtime/tests/interpreter/capability_test.go +++ b/runtime/tests/interpreter/capability_test.go @@ -473,30 +473,26 @@ func TestInterpretCapability_borrow(t *testing.T) { ` #allowAccountLinking - fun link() { - account.linkAccount(/public/acct) + fun link(): Capability { + return account.linkAccount(/private/acct)! } - fun address(_ path: CapabilityPath): Address { - return account.getCapability(path).borrow<&AuthAccount>()!.address + fun address(_ cap: Capability): Address { + return cap.borrow<&AuthAccount>()!.address } - fun borrow(): Address { - return address(/public/acct) + fun borrow(_ cap: Capability): Address { + return address(cap) } - fun borrowAuth(): auth &AuthAccount? { - return account.getCapability(/public/acct).borrow() + fun borrowAuth(_ cap: Capability): auth &AuthAccount? { + return cap.borrow() } - fun nonExistent(): Address { - return address(/public/nonExistent) - } - - fun unlinkAfterBorrow(): Address { - let ref = account.getCapability(/public/acct).borrow<&AuthAccount>()! + fun unlinkAfterBorrow(_ cap: Capability): Address { + let ref = cap.borrow<&AuthAccount>()! - account.unlink(/public/acct) + account.unlink(/private/acct) return ref.address } @@ -508,12 +504,12 @@ func TestInterpretCapability_borrow(t *testing.T) { // link - _, err := inter.Invoke("link") + capability, err := inter.Invoke("link") require.NoError(t, err) t.Run("borrow", func(t *testing.T) { - value, err := inter.Invoke("borrow") + value, err := inter.Invoke("borrow", capability) require.NoError(t, err) RequireValuesEqual(t, @@ -525,23 +521,15 @@ func TestInterpretCapability_borrow(t *testing.T) { t.Run("borrowAuth", func(t *testing.T) { - value, err := inter.Invoke("borrowAuth") + value, err := inter.Invoke("borrowAuth", capability) require.NoError(t, err) require.Equal(t, interpreter.NilValue{}, value) }) - t.Run("nonExistent", func(t *testing.T) { - - _, err := inter.Invoke("nonExistent") - RequireError(t, err) - - require.ErrorAs(t, err, &interpreter.ForceNilError{}) - }) - t.Run("unlink after borrow", func(t *testing.T) { - _, err := inter.Invoke("unlinkAfterBorrow") + _, err := inter.Invoke("unlinkAfterBorrow", capability) RequireError(t, err) require.ErrorAs(t, err, &interpreter.DereferenceError{}) @@ -922,24 +910,16 @@ func TestInterpretCapability_check(t *testing.T) { ` #allowAccountLinking - fun link() { - account.linkAccount(/public/acct) - } - - fun checkPath(_ path: CapabilityPath): Bool { - return account.getCapability(path).check<&AuthAccount>() + fun link(): Capability { + return account.linkAccount(/private/acct)! } - fun check(): Bool { - return checkPath(/public/acct) + fun check(_ cap: Capability): Bool { + return cap.check<&AuthAccount>() } - fun checkAuth(): Bool { - return account.getCapability(/public/acct).check() - } - - fun nonExistent(): Bool { - return checkPath(/public/nonExistent) + fun checkAuth(_ cap: Capability): Bool { + return cap.check() } `, sema.Config{ @@ -949,12 +929,12 @@ func TestInterpretCapability_check(t *testing.T) { // link - _, err := inter.Invoke("link") + capability, err := inter.Invoke("link") require.NoError(t, err) t.Run("check", func(t *testing.T) { - value, err := inter.Invoke("check") + value, err := inter.Invoke("check", capability) require.NoError(t, err) require.Equal(t, interpreter.TrueValue, value) @@ -962,15 +942,7 @@ func TestInterpretCapability_check(t *testing.T) { t.Run("checkAuth", func(t *testing.T) { - value, err := inter.Invoke("checkAuth") - require.NoError(t, err) - - require.Equal(t, interpreter.FalseValue, value) - }) - - t.Run("nonExistent", func(t *testing.T) { - - value, err := inter.Invoke("nonExistent") + value, err := inter.Invoke("checkAuth", capability) require.NoError(t, err) require.Equal(t, interpreter.FalseValue, value) diff --git a/runtime/tests/interpreter/memory_metering_test.go b/runtime/tests/interpreter/memory_metering_test.go index 39762705cc..d6b09fc7ad 100644 --- a/runtime/tests/interpreter/memory_metering_test.go +++ b/runtime/tests/interpreter/memory_metering_test.go @@ -6976,7 +6976,7 @@ func TestInterpretStorageCapabilityValueMetering(t *testing.T) { pub fun main(account: AuthAccount) { let r <- create R() account.save(<-r, to: /storage/r) - let x = account.link<&R>(/public/capo, target: /storage/r) + let x = account.link<&R>(/public/cap, target: /storage/r) } ` meter := newTestMemoryGauge() @@ -7000,7 +7000,7 @@ func TestInterpretStorageCapabilityValueMetering(t *testing.T) { pub fun main(account: AuthAccount) { let r <- create R() account.save(<-r, to: /storage/r) - let x = account.link<&R>(/public/capo, target: /storage/r) + let x = account.link<&R>(/public/cap, target: /storage/r) let y = [x] } @@ -7026,7 +7026,7 @@ func TestInterpretPathLinkValueMetering(t *testing.T) { resource R {} pub fun main(account: AuthAccount) { - account.link<&R>(/public/capo, target: /private/p) + account.link<&R>(/public/cap, target: /private/p) } ` meter := newTestMemoryGauge() @@ -7052,7 +7052,7 @@ func TestInterpretAccountLinkValueMetering(t *testing.T) { #allowAccountLinking pub fun main(account: AuthAccount) { - account.linkAccount(/public/capo) + account.linkAccount(/private/cap) } ` @@ -8705,7 +8705,7 @@ func TestInterpretStorageMapMetering(t *testing.T) { pub fun main(account: AuthAccount) { let r <- create R() account.save(<-r, to: /storage/r) - account.link<&R>(/public/capo, target: /storage/r) + account.link<&R>(/public/cap, target: /storage/r) account.borrow<&R>(from: /storage/r) } `