Skip to content

Commit

Permalink
Fix built-in contract imports (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
bartolomej authored Sep 18, 2024
1 parent dcc100e commit d3bdc58
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 15 deletions.
10 changes: 7 additions & 3 deletions internal/v1_1/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,14 @@ func (g Generator) processDependencies(ctx context.Context, program *ast.Program
// fill in dependence information
g.template.Data.Dependencies = make([]Dependency, 0)
for _, imp := range imports {
contractName, err := ExtractContractName(imp.String())
if err != nil {
return err

// Built-in contracts imports are represented with identifier location
_, isBuiltInContract := imp.Location.(common.IdentifierLocation)
if isBuiltInContract {
continue
}

contractName := imp.Location.String()
networks, err := g.generateDependenceInfo(ctx, contractName)
if err != nil {
return err
Expand Down
58 changes: 58 additions & 0 deletions internal/v1_1/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,64 @@ func TestHelloScript(t *testing.T) {

}

func TestValidImports(t *testing.T) {
contracts := []Contract{
{
Contract: "Alice",
Networks: []Network{
{
Network: "testnet",
Address: "0x0000000000000001",
},
{
Network: "mainnet",
Address: "0x0000000000000001",
},
{
Network: "emulator",
Address: "0x0000000000000001",
},
},
},
{
Contract: "Bob",
Networks: []Network{
{
Network: "testnet",
Address: "0x0000000000000002",
},
{
Network: "mainnet",
Address: "0x0000000000000002",
},
{
Network: "emulator",
Address: "0x0000000000000002",
},
},
},
}

generator := Generator{
deployedContracts: contracts,
}

assert := assert.New(t)
code := `
import "Alice"
import Bob from 0x0000000000000002
import Joe
access(all)
fun main(): Void {}
`
ctx := context.Background()
template, err := generator.CreateTemplate(ctx, code, "")
assert.NoError(err, "Generate should not return an error")
autogold.ExpectFile(t, template)

}

func TestTransactionValue(t *testing.T) {
contracts := []Contract{
{
Expand Down
71 changes: 71 additions & 0 deletions internal/v1_1/testdata/TestValidImports.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
`{
"f_type": "InteractionTemplate",
"f_version": "1.1.0",
"id": "585d2dd4fc3523ba140fbe0f798c07f1d4c3792aaed72a3b0d7c32c23a6457e7",
"data": {
"type": "script",
"interface": "",
"messages": null,
"cadence": {
"body": "\n\timport \"Alice\"\n\timport \"Bob\"\n\timport Joe\n\n\taccess(all)\n\tfun main(): Void {}\n",
"network_pins": []
},
"dependencies": [
{
"contracts": [
{
"contract": "Alice",
"networks": [
{
"network": "testnet",
"address": "0x0000000000000001",
"dependency_pin_block_height": 0
},
{
"network": "mainnet",
"address": "0x0000000000000001",
"dependency_pin_block_height": 0
},
{
"network": "emulator",
"address": "0x0000000000000001",
"dependency_pin_block_height": 0
}
]
}
]
},
{
"contracts": [
{
"contract": "Bob",
"networks": [
{
"network": "testnet",
"address": "0x0000000000000002",
"dependency_pin_block_height": 0
},
{
"network": "mainnet",
"address": "0x0000000000000002",
"dependency_pin_block_height": 0
},
{
"network": "emulator",
"address": "0x0000000000000002",
"dependency_pin_block_height": 0
}
]
}
]
}
],
"parameters": null,
"output": {
"label": "result",
"index": 0,
"type": "Void",
"messages": []
}
}
}`
12 changes: 0 additions & 12 deletions internal/v1_1/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,18 +588,6 @@ func convertToBytes(value interface{}) ([]byte, error) {
}
}

func ExtractContractName(importStr string) (string, error) {
// Create a regex pattern to find the contract name inside the quotes
pattern := regexp.MustCompile(`import "([^"]+)"`)
matches := pattern.FindStringSubmatch(importStr)

if len(matches) >= 2 {
return matches[1], nil
}

return "", fmt.Errorf("no contract name found in string")
}

func isItemInArray[T comparable](item T, slice []T) bool {
for _, s := range slice {
if s == item {
Expand Down

0 comments on commit d3bdc58

Please sign in to comment.