diff --git a/bbq/commons/types.go b/bbq/commons/types.go index 757a5299c7..b3faab4055 100644 --- a/bbq/commons/types.go +++ b/bbq/commons/types.go @@ -20,6 +20,7 @@ package commons import ( "github.com/onflow/cadence/common" + "github.com/onflow/cadence/interpreter" "github.com/onflow/cadence/sema" ) @@ -45,6 +46,15 @@ func TypeQualifiedName(typ sema.Type, functionName string) string { return typeQualifier + "." + functionName } +func StaticTypeQualifiedName(typ interpreter.StaticType, functionName string) string { + if typ == nil { + return functionName + } + + typeQualifier := StaticTypeQualifier(typ) + return typeQualifier + "." + functionName +} + func QualifiedName(typeName, functionName string) string { if typeName == "" { return functionName @@ -61,6 +71,8 @@ func QualifiedName(typeName, functionName string) string { // TODO: Add other types // TODO: Maybe make this a method on the type func TypeQualifier(typ sema.Type) string { + // IMPORTANT: Ensure this is in sync with `StaticTypeQualifier` method below. + switch typ := typ.(type) { case *sema.ConstantSizedType: return TypeQualifierArrayConstantSized @@ -94,6 +106,51 @@ func TypeQualifier(typ sema.Type) string { } } +func StaticTypeQualifier(typ interpreter.StaticType) string { + // IMPORTANT: Ensure this is in sync with `TypeQualifier` method above. + // TODO: Try to unify. Maybe generate the two functions from a single definition. + + switch typ := typ.(type) { + case *interpreter.ConstantSizedStaticType: + return TypeQualifierArrayConstantSized + case *interpreter.VariableSizedStaticType: + return TypeQualifierArrayVariableSized + case *interpreter.DictionaryStaticType: + return TypeQualifierDictionary + case interpreter.FunctionStaticType: + // This is only applicable for types that also has a constructor with the same name. + // e.g: `String` type has the `String()` constructor as well as the type on which + // functions can be called (`String.join()`). + // Thus, if a constructor function is used as a type-qualifier, + // then used the actual type associated with it (i.e: the return type). + if typ.TypeFunctionType != nil { + return TypeQualifier(typ.TypeFunctionType) + } + return TypeQualifierFunction + case *interpreter.OptionalStaticType: + return TypeQualifierOptional + case *interpreter.ReferenceStaticType: + return StaticTypeQualifier(typ.ReferencedType) + case *interpreter.IntersectionStaticType: + // TODO: Revisit. Probably this is not needed here? + return StaticTypeQualifier(typ.Types[0]) + case *interpreter.CapabilityStaticType: + return TypeQualifierCapability + case *interpreter.InclusiveRangeStaticType: + return TypeQualifierInclusiveRange + + // In addition to the `TypeQualifier` method above, + // following are needed. + case *interpreter.CompositeStaticType: + return typ.QualifiedIdentifier + case *interpreter.InterfaceStaticType: + return typ.QualifiedIdentifier + + default: + return typ.String() + } +} + func LocationQualifier(typ sema.Type) string { switch typ := typ.(type) { case *sema.ReferenceType: diff --git a/bbq/vm/context.go b/bbq/vm/context.go index 9053802e58..dc72f10695 100644 --- a/bbq/vm/context.go +++ b/bbq/vm/context.go @@ -332,14 +332,16 @@ func (c *Context) GetMethod( ) interpreter.FunctionValue { staticType := value.StaticType(c) - semaType := c.SemaTypeFromStaticType(staticType) - var location common.Location - if locatedType, ok := semaType.(sema.LocatedType); ok { - location = locatedType.GetLocation() + + switch staticType := staticType.(type) { + case *interpreter.CompositeStaticType: + location = staticType.Location + case *interpreter.InterfaceStaticType: + location = staticType.Location } - qualifiedFuncName := commons.TypeQualifiedName(semaType, name) + qualifiedFuncName := commons.StaticTypeQualifiedName(staticType, name) method := c.GetFunction(location, qualifiedFuncName) if method == nil { @@ -423,22 +425,48 @@ func (c *Context) DefaultDestroyEvents(resourceValue *interpreter.CompositeValue return eventValues } -func (c *Context) SemaTypeFromStaticType(staticType interpreter.StaticType) sema.Type { - typeID := staticType.ID() - semaType, ok := c.semaTypeCache[typeID] - if ok { - return semaType +func (c *Context) SemaTypeFromStaticType(staticType interpreter.StaticType) (semaType sema.Type) { + _, isPrimitiveType := staticType.(interpreter.PrimitiveStaticType) + + // For primitive types, conversion is just a switch-case and returning a constant. + // It is efficient than a map lookup/update. + // So don't bother using the cache for primitive static types. + if !isPrimitiveType { + typeID := staticType.ID() + cachedSemaType, ok := c.semaTypeCache[typeID] + if ok { + return cachedSemaType + } + + defer func() { + if c.semaTypeCache == nil { + c.semaTypeCache = make(map[sema.TypeID]sema.Type) + } + c.semaTypeCache[typeID] = semaType + }() } // TODO: avoid the sema-type conversion - semaType = interpreter.MustConvertStaticToSemaType(staticType, c) + return interpreter.MustConvertStaticToSemaType(staticType, c) +} - if c.semaTypeCache == nil { - c.semaTypeCache = make(map[sema.TypeID]sema.Type) +func (c *Context) SemaAccessFromStaticAuthorization(auth interpreter.Authorization) (sema.Access, error) { + semaAccess, ok := c.semaAccessCache[auth] + if ok { + return semaAccess, nil + } + + semaAccess, err := interpreter.ConvertStaticAuthorizationToSemaAccess(auth, c) + if err != nil { + return nil, err + } + + if c.semaAccessCache == nil { + c.semaAccessCache = make(map[interpreter.Authorization]sema.Access) } - c.semaTypeCache[typeID] = semaType + c.semaAccessCache[auth] = semaAccess - return semaType + return semaAccess, nil } func (c *Context) GetContractValue(contractLocation common.AddressLocation) *interpreter.CompositeValue { @@ -532,22 +560,3 @@ func (c *Context) GetEntitlementMapType( func (c *Context) LocationRange() interpreter.LocationRange { return c.getLocationRange() } - -func (c *Context) SemaAccessFromStaticAuthorization(auth interpreter.Authorization) (sema.Access, error) { - semaAccess, ok := c.semaAccessCache[auth] - if ok { - return semaAccess, nil - } - - semaAccess, err := interpreter.ConvertStaticAuthorizationToSemaAccess(auth, c) - if err != nil { - return nil, err - } - - if c.semaAccessCache == nil { - c.semaAccessCache = make(map[interpreter.Authorization]sema.Access) - } - c.semaAccessCache[auth] = semaAccess - - return semaAccess, nil -} diff --git a/bbq/vm/linker.go b/bbq/vm/linker.go index 7ca6b2d8c5..aba0ac8264 100644 --- a/bbq/vm/linker.go +++ b/bbq/vm/linker.go @@ -201,16 +201,12 @@ func linkImportedGlobal( contractValue := global.GetValue(context) staticType := contractValue.StaticType(context) - semaType, err := interpreter.ConvertStaticToSemaType(context, staticType) - if err != nil { - panic(err) - } return interpreter.NewEphemeralReferenceValue( context, interpreter.UnauthorizedAccess, contractValue, - semaType, + staticType, ) }, ) diff --git a/bbq/vm/test/ft_test.go b/bbq/vm/test/ft_test.go index 4527259cb0..7813446b94 100644 --- a/bbq/vm/test/ft_test.go +++ b/bbq/vm/test/ft_test.go @@ -168,8 +168,8 @@ func compiledFTTransfer(tb testing.TB) { context interpreter.BorrowCapabilityControllerContext, address interpreter.AddressValue, capabilityID interpreter.UInt64Value, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, ) interpreter.ReferenceValue { return stdlib.BorrowCapabilityController( context, diff --git a/bbq/vm/test/interpreter_test.go b/bbq/vm/test/interpreter_test.go index 319e9ab38c..c415fbc1aa 100644 --- a/bbq/vm/test/interpreter_test.go +++ b/bbq/vm/test/interpreter_test.go @@ -331,8 +331,8 @@ func interpreterFTTransfer(tb testing.TB) { context interpreter.BorrowCapabilityControllerContext, address interpreter.AddressValue, capabilityID interpreter.UInt64Value, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, ) interpreter.ReferenceValue { return stdlib.BorrowCapabilityController( context, diff --git a/bbq/vm/test/vm_test.go b/bbq/vm/test/vm_test.go index 2337d04bb3..ce4002f876 100644 --- a/bbq/vm/test/vm_test.go +++ b/bbq/vm/test/vm_test.go @@ -5266,8 +5266,8 @@ func TestCasting(t *testing.T) { assert.Equal( t, &interpreter.ForceCastTypeMismatchError{ - ExpectedType: sema.IntType, - ActualType: sema.BoolType, + ExpectedType: interpreter.PrimitiveStaticTypeInt, + ActualType: interpreter.PrimitiveStaticTypeBool, LocationRange: interpreter.LocationRange{ Location: TestLocation, HasPosition: bbq.Position{ @@ -11729,7 +11729,11 @@ func TestBorrowContractLinksGlobals(t *testing.T) { context, contractAddress, interpreter.NewUnmeteredStringValue(contractName), - sema.NewReferenceType(nil, sema.UnauthorizedAccess, sema.AnyStructType), + interpreter.NewReferenceStaticType( + nil, + interpreter.UnauthorizedAccess, + interpreter.PrimitiveStaticTypeAnyStruct, + ), accountHandler, ) diff --git a/bbq/vm/types.go b/bbq/vm/types.go deleted file mode 100644 index 593e4078e4..0000000000 --- a/bbq/vm/types.go +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Cadence - The resource-oriented smart contract programming language - * - * Copyright Flow Foundation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package vm - -import ( - "github.com/onflow/cadence/bbq" - "github.com/onflow/cadence/interpreter" -) - -// UnwrapOptionalType returns the type if it is not an optional type, -// or the inner-most type if it is (optional types are repeatedly unwrapped) -func UnwrapOptionalType(ty bbq.StaticType) bbq.StaticType { - for { - optionalType, ok := ty.(*interpreter.OptionalStaticType) - if !ok { - return ty - } - ty = optionalType.Type - } -} diff --git a/bbq/vm/value.go b/bbq/vm/value.go index fb278d7ef7..1f85b5e46b 100644 --- a/bbq/vm/value.go +++ b/bbq/vm/value.go @@ -31,13 +31,10 @@ func ConvertAndBox( value Value, valueType, targetType bbq.StaticType, ) Value { - valueSemaType := context.SemaTypeFromStaticType(valueType) - targetSemaType := context.SemaTypeFromStaticType(targetType) - - return interpreter.ConvertAndBox( + return interpreter.ConvertAndBoxToStaticType( context, value, - valueSemaType, - targetSemaType, + valueType, + targetType, ) } diff --git a/bbq/vm/value_implicit_reference.go b/bbq/vm/value_implicit_reference.go index 334a7bb1e3..1e072f6446 100644 --- a/bbq/vm/value_implicit_reference.go +++ b/bbq/vm/value_implicit_reference.go @@ -47,7 +47,7 @@ func NewImplicitReferenceValue(context interpreter.ReferenceCreationContext, val } } - semaType := interpreter.MustSemaTypeOfValue(value, context) + staticType := value.StaticType(context) // Create an explicit reference to represent the implicit reference behavior of 'self' value. // Authorization doesn't matter, we just need a reference to add to tracking. @@ -55,7 +55,7 @@ func NewImplicitReferenceValue(context interpreter.ReferenceCreationContext, val context, interpreter.UnauthorizedAccess, value, - semaType, + staticType, ) return ImplicitReferenceValue{ diff --git a/bbq/vm/vm.go b/bbq/vm/vm.go index ac45ec162d..9b7bc221a5 100644 --- a/bbq/vm/vm.go +++ b/bbq/vm/vm.go @@ -1085,13 +1085,10 @@ func checkMemberAccessTargetType( context := vm.context - // TODO: Avoid sema type conversion. - accessedSemaType := context.SemaTypeFromStaticType(accessedType) - interpreter.CheckMemberAccessTargetType( context, accessedValue, - accessedSemaType, + accessedType, ) } @@ -1127,11 +1124,11 @@ func opTransferAndConvert(vm *VM, ins opcode.InstructionTransferAndConvert) { value := vm.peek() valueType := value.StaticType(context) - transferredValue := interpreter.TransferAndConvert( + transferredValue := interpreter.TransferAndConvertToStaticType( context, value, - context.SemaTypeFromStaticType(valueType), - context.SemaTypeFromStaticType(targetType), + valueType, + targetType, ) vm.replaceTop(transferredValue) @@ -1163,11 +1160,11 @@ func opConvert(vm *VM, ins opcode.InstructionConvert) { value := vm.peek() valueType := value.StaticType(context) - transferredValue := interpreter.ConvertAndBoxWithValidation( + transferredValue := interpreter.ConvertAndBoxToStaticTypeWithValidation( context, value, - context.SemaTypeFromStaticType(valueType), - context.SemaTypeFromStaticType(targetType), + valueType, + targetType, ) vm.replaceTop(transferredValue) @@ -1242,12 +1239,9 @@ func opForceCast(vm *VM, ins opcode.InstructionForceCast) { var result Value if !isSubType { - targetSemaType := context.SemaTypeFromStaticType(targetType) - valueSemaType := context.SemaTypeFromStaticType(valueType) - panic(&interpreter.ForceCastTypeMismatchError{ - ExpectedType: targetSemaType, - ActualType: valueSemaType, + ExpectedType: targetType, + ActualType: valueType, }) } @@ -1268,7 +1262,7 @@ func castValueAndValueType(context *Context, targetType bbq.StaticType, value Va // so we don't substitute them. // If the target is `AnyStruct` or `AnyResource` we want to preserve optionals - unboxedExpectedType := UnwrapOptionalType(targetType) + unboxedExpectedType := interpreter.UnwrapOptionalType(targetType) if !(unboxedExpectedType == interpreter.PrimitiveStaticTypeAnyStruct || unboxedExpectedType == interpreter.PrimitiveStaticTypeAnyResource) { // otherwise dynamic cast now always unboxes optionals @@ -1372,11 +1366,9 @@ func opNewRef(vm *VM, ins opcode.InstructionNewRef) { context := vm.context - semaBorrowedType := context.SemaTypeFromStaticType(borrowedType) - - ref := interpreter.CreateReferenceValue( + ref := interpreter.CreateReferenceValueFromStaticType( context, - semaBorrowedType, + borrowedType, value, ins.IsImplicit, ) diff --git a/interpreter/account_test.go b/interpreter/account_test.go index f9577ad83d..4db0ad5e1b 100644 --- a/interpreter/account_test.go +++ b/interpreter/account_test.go @@ -554,7 +554,7 @@ func testAccountWithErrorHandlerWithCompiler( NoOpReferenceCreationContext{}, interpreter.FullyEntitledAccountAccess, account, - sema.AccountType, + interpreter.PrimitiveStaticTypeAccount, ), Kind: common.DeclarationKindConstant, } @@ -567,7 +567,7 @@ func testAccountWithErrorHandlerWithCompiler( NoOpReferenceCreationContext{}, interpreter.UnauthorizedAccess, account, - sema.AccountType, + interpreter.PrimitiveStaticTypeAccount, ), Kind: common.DeclarationKindConstant, } diff --git a/interpreter/container_mutation_test.go b/interpreter/container_mutation_test.go index f74d125e20..8fbac944a4 100644 --- a/interpreter/container_mutation_test.go +++ b/interpreter/container_mutation_test.go @@ -123,8 +123,8 @@ func TestInterpretArrayMutation(t *testing.T) { var mutationError *interpreter.ContainerMutationError require.ErrorAs(t, err, &mutationError) - assert.Equal(t, sema.StringType, mutationError.ExpectedType) - assert.Equal(t, sema.IntType, mutationError.ActualType) + assert.Equal(t, interpreter.PrimitiveStaticTypeString, mutationError.ExpectedType) + assert.Equal(t, interpreter.PrimitiveStaticTypeInt, mutationError.ActualType) }) t.Run("nested array invalid", func(t *testing.T) { @@ -143,8 +143,8 @@ func TestInterpretArrayMutation(t *testing.T) { var mutationError *interpreter.ContainerMutationError require.ErrorAs(t, err, &mutationError) - assert.Equal(t, sema.StringType, mutationError.ExpectedType) - assert.Equal(t, sema.IntType, mutationError.ActualType) + assert.Equal(t, interpreter.PrimitiveStaticTypeString, mutationError.ExpectedType) + assert.Equal(t, interpreter.PrimitiveStaticTypeInt, mutationError.ActualType) }) t.Run("array append valid", func(t *testing.T) { @@ -197,8 +197,8 @@ func TestInterpretArrayMutation(t *testing.T) { var mutationError *interpreter.ContainerMutationError require.ErrorAs(t, err, &mutationError) - assert.Equal(t, sema.StringType, mutationError.ExpectedType) - assert.Equal(t, sema.IntType, mutationError.ActualType) + assert.Equal(t, interpreter.PrimitiveStaticTypeString, mutationError.ExpectedType) + assert.Equal(t, interpreter.PrimitiveStaticTypeInt, mutationError.ActualType) }) t.Run("array appendAll invalid", func(t *testing.T) { @@ -217,8 +217,8 @@ func TestInterpretArrayMutation(t *testing.T) { var mutationError *interpreter.ContainerMutationError require.ErrorAs(t, err, &mutationError) - assert.Equal(t, sema.StringType, mutationError.ExpectedType) - assert.Equal(t, sema.IntType, mutationError.ActualType) + assert.Equal(t, interpreter.PrimitiveStaticTypeString, mutationError.ExpectedType) + assert.Equal(t, interpreter.PrimitiveStaticTypeInt, mutationError.ActualType) }) t.Run("array insert valid", func(t *testing.T) { @@ -271,8 +271,8 @@ func TestInterpretArrayMutation(t *testing.T) { var mutationError *interpreter.ContainerMutationError require.ErrorAs(t, err, &mutationError) - assert.Equal(t, sema.StringType, mutationError.ExpectedType) - assert.Equal(t, sema.IntType, mutationError.ActualType) + assert.Equal(t, interpreter.PrimitiveStaticTypeString, mutationError.ExpectedType) + assert.Equal(t, interpreter.PrimitiveStaticTypeInt, mutationError.ActualType) }) t.Run("array concat mismatching values", func(t *testing.T) { @@ -332,8 +332,8 @@ func TestInterpretArrayMutation(t *testing.T) { var mutationError *interpreter.ContainerMutationError require.ErrorAs(t, err, &mutationError) - assert.Equal(t, sema.StringType, mutationError.ExpectedType) - assert.Equal(t, sema.IntType, mutationError.ActualType) + assert.Equal(t, interpreter.PrimitiveStaticTypeString, mutationError.ExpectedType) + assert.Equal(t, interpreter.PrimitiveStaticTypeInt, mutationError.ActualType) }) t.Run("host function mutation", func(t *testing.T) { @@ -520,18 +520,18 @@ func TestInterpretArrayMutation(t *testing.T) { require.ErrorAs(t, err, &mutationError) // Expected type - require.IsType(t, &sema.OptionalType{}, mutationError.ExpectedType) - optionalType := mutationError.ExpectedType.(*sema.OptionalType) + require.IsType(t, &interpreter.OptionalStaticType{}, mutationError.ExpectedType) + optionalType := mutationError.ExpectedType.(*interpreter.OptionalStaticType) - require.IsType(t, &sema.FunctionType{}, optionalType.Type) - funcType := optionalType.Type.(*sema.FunctionType) + require.IsType(t, interpreter.FunctionStaticType{}, optionalType.Type) + funcType := optionalType.Type.(interpreter.FunctionStaticType) assert.Equal(t, sema.VoidType, funcType.ReturnTypeAnnotation.Type) assert.Empty(t, funcType.Parameters) // Actual type - assert.IsType(t, &sema.FunctionType{}, mutationError.ActualType) - actualFuncType := mutationError.ActualType.(*sema.FunctionType) + assert.IsType(t, interpreter.FunctionStaticType{}, mutationError.ActualType) + actualFuncType := mutationError.ActualType.(interpreter.FunctionStaticType) assert.Equal(t, sema.VoidType, actualFuncType.ReturnTypeAnnotation.Type) assert.Len(t, actualFuncType.Parameters, 1) @@ -586,15 +586,15 @@ func TestInterpretDictionaryMutation(t *testing.T) { require.ErrorAs(t, err, &mutationError) assert.Equal(t, - &sema.OptionalType{ - Type: sema.StringType, + &interpreter.OptionalStaticType{ + Type: interpreter.PrimitiveStaticTypeString, }, mutationError.ExpectedType, ) assert.Equal(t, - &sema.OptionalType{ - Type: sema.IntType, + &interpreter.OptionalStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, }, mutationError.ActualType, ) @@ -663,8 +663,8 @@ func TestInterpretDictionaryMutation(t *testing.T) { var mutationError *interpreter.ContainerMutationError require.ErrorAs(t, err, &mutationError) - assert.Equal(t, sema.StringType, mutationError.ExpectedType) - assert.Equal(t, sema.IntType, mutationError.ActualType) + assert.Equal(t, interpreter.PrimitiveStaticTypeString, mutationError.ExpectedType) + assert.Equal(t, interpreter.PrimitiveStaticTypeInt, mutationError.ActualType) }) t.Run("dictionary insert invalid key", func(t *testing.T) { @@ -683,8 +683,8 @@ func TestInterpretDictionaryMutation(t *testing.T) { var mutationError *interpreter.ContainerMutationError require.ErrorAs(t, err, &mutationError) - assert.Equal(t, sema.PublicPathType, mutationError.ExpectedType) - assert.Equal(t, sema.PrivatePathType, mutationError.ActualType) + assert.Equal(t, interpreter.PrimitiveStaticTypePublicPath, mutationError.ExpectedType) + assert.Equal(t, interpreter.PrimitiveStaticTypePrivatePath, mutationError.ActualType) }) t.Run("invalid update through reference", func(t *testing.T) { @@ -705,14 +705,14 @@ func TestInterpretDictionaryMutation(t *testing.T) { require.ErrorAs(t, err, &mutationError) assert.Equal(t, - &sema.OptionalType{ - Type: sema.StringType, + &interpreter.OptionalStaticType{ + Type: interpreter.PrimitiveStaticTypeString, }, mutationError.ExpectedType, ) assert.Equal(t, - &sema.OptionalType{ - Type: sema.IntType, + &interpreter.OptionalStaticType{ + Type: interpreter.PrimitiveStaticTypeInt, }, mutationError.ActualType, ) @@ -903,21 +903,21 @@ func TestInterpretDictionaryMutation(t *testing.T) { require.ErrorAs(t, err, &mutationError) // Expected type - require.IsType(t, &sema.OptionalType{}, mutationError.ExpectedType) - optionalType := mutationError.ExpectedType.(*sema.OptionalType) + require.IsType(t, &interpreter.OptionalStaticType{}, mutationError.ExpectedType) + optionalType := mutationError.ExpectedType.(*interpreter.OptionalStaticType) - require.IsType(t, &sema.FunctionType{}, optionalType.Type) - funcType := optionalType.Type.(*sema.FunctionType) + require.IsType(t, interpreter.FunctionStaticType{}, optionalType.Type) + funcType := optionalType.Type.(interpreter.FunctionStaticType) assert.Equal(t, sema.VoidType, funcType.ReturnTypeAnnotation.Type) assert.Empty(t, funcType.Parameters) // Actual type - require.IsType(t, &sema.OptionalType{}, mutationError.ActualType) - actualOptionalType := mutationError.ActualType.(*sema.OptionalType) + require.IsType(t, &interpreter.OptionalStaticType{}, mutationError.ActualType) + actualOptionalType := mutationError.ActualType.(*interpreter.OptionalStaticType) - require.IsType(t, &sema.FunctionType{}, actualOptionalType.Type) - actualFuncType := actualOptionalType.Type.(*sema.FunctionType) + require.IsType(t, interpreter.FunctionStaticType{}, actualOptionalType.Type) + actualFuncType := actualOptionalType.Type.(interpreter.FunctionStaticType) assert.Equal(t, sema.VoidType, actualFuncType.ReturnTypeAnnotation.Type) assert.Len(t, actualFuncType.Parameters, 1) diff --git a/interpreter/errors.go b/interpreter/errors.go index 1e2063a30a..de8aa4bb15 100644 --- a/interpreter/errors.go +++ b/interpreter/errors.go @@ -251,8 +251,8 @@ func (e RedeclarationError) Error() string { type DereferenceError struct { Cause string - ExpectedType sema.Type - ActualType sema.Type + ExpectedType StaticType + ActualType StaticType LocationRange } @@ -270,7 +270,7 @@ func (e *DereferenceError) SecondaryError() string { if e.Cause != "" { return e.Cause } - expected, actual := sema.ErrorMessageExpectedActualTypes( + expected, actual := ErrorMessageExpectedActualTypes( e.ExpectedType, e.ActualType, ) @@ -422,8 +422,8 @@ func (e *ForceNilError) SetLocationRange(locationRange LocationRange) { // ForceCastTypeMismatchError type ForceCastTypeMismatchError struct { - ExpectedType sema.Type - ActualType sema.Type + ExpectedType StaticType + ActualType StaticType LocationRange } @@ -433,7 +433,7 @@ var _ HasLocationRange = &ForceCastTypeMismatchError{} func (*ForceCastTypeMismatchError) IsUserError() {} func (e *ForceCastTypeMismatchError) Error() string { - expected, actual := sema.ErrorMessageExpectedActualTypes( + expected, actual := ErrorMessageExpectedActualTypes( e.ExpectedType, e.ActualType, ) @@ -451,8 +451,8 @@ func (e *ForceCastTypeMismatchError) SetLocationRange(locationRange LocationRang // TypeMismatchError type TypeMismatchError struct { - ExpectedType sema.Type - ActualType sema.Type + ExpectedType StaticType + ActualType StaticType LocationRange } @@ -462,7 +462,7 @@ var _ HasLocationRange = &TypeMismatchError{} func (*TypeMismatchError) IsUserError() {} func (e *TypeMismatchError) Error() string { - expected, actual := sema.ErrorMessageExpectedActualTypes( + expected, actual := ErrorMessageExpectedActualTypes( e.ExpectedType, e.ActualType, ) @@ -480,8 +480,8 @@ func (e *TypeMismatchError) SetLocationRange(locationRange LocationRange) { // InvalidMemberReferenceError type InvalidMemberReferenceError struct { - ExpectedType sema.Type - ActualType sema.Type + ExpectedType StaticType + ActualType StaticType LocationRange } @@ -491,7 +491,7 @@ var _ HasLocationRange = &InvalidMemberReferenceError{} func (*InvalidMemberReferenceError) IsUserError() {} func (e *InvalidMemberReferenceError) Error() string { - expected, actual := sema.ErrorMessageExpectedActualTypes( + expected, actual := ErrorMessageExpectedActualTypes( e.ExpectedType, e.ActualType, ) @@ -754,8 +754,8 @@ func (e *UseBeforeInitializationError) SetLocationRange(locationRange LocationRa // MemberAccessTypeError type MemberAccessTypeError struct { - ExpectedType sema.Type - ActualType sema.Type + ExpectedType StaticType + ActualType StaticType LocationRange } @@ -765,11 +765,12 @@ var _ HasLocationRange = &MemberAccessTypeError{} func (*MemberAccessTypeError) IsInternalError() {} func (e *MemberAccessTypeError) Error() string { + expected, actual := ErrorMessageExpectedActualTypes(e.ExpectedType, e.ActualType) return fmt.Sprintf( "%s invalid member access: expected `%s`, got `%s`", errors.InternalErrorMessagePrefix, - e.ExpectedType.QualifiedString(), - e.ActualType.QualifiedString(), + expected, + actual, ) } @@ -779,8 +780,8 @@ func (e *MemberAccessTypeError) SetLocationRange(locationRange LocationRange) { // ValueTransferTypeError type ValueTransferTypeError struct { - ExpectedType sema.Type - ActualType sema.Type + ExpectedType StaticType + ActualType StaticType LocationRange } @@ -790,7 +791,7 @@ var _ HasLocationRange = &ValueTransferTypeError{} func (*ValueTransferTypeError) IsInternalError() {} func (e *ValueTransferTypeError) Error() string { - expected, actual := sema.ErrorMessageExpectedActualTypes( + expected, actual := ErrorMessageExpectedActualTypes( e.ExpectedType, e.ActualType, ) @@ -809,7 +810,7 @@ func (e *ValueTransferTypeError) SetLocationRange(locationRange LocationRange) { // UnexpectedMappedEntitlementError type UnexpectedMappedEntitlementError struct { - Type sema.Type + Type StaticType LocationRange } @@ -822,7 +823,7 @@ func (e *UnexpectedMappedEntitlementError) Error() string { return fmt.Sprintf( "%s invalid transfer of value: found an unexpected runtime mapped entitlement `%s`", errors.InternalErrorMessagePrefix, - e.Type.QualifiedString(), + e.Type.ID(), ) } @@ -856,8 +857,8 @@ func (e *ResourceConstructionError) SetLocationRange(locationRange LocationRange // ContainerMutationError type ContainerMutationError struct { - ExpectedType sema.Type - ActualType sema.Type + ExpectedType StaticType + ActualType StaticType LocationRange } @@ -867,10 +868,15 @@ var _ HasLocationRange = &ContainerMutationError{} func (*ContainerMutationError) IsUserError() {} func (e *ContainerMutationError) Error() string { + expected, actual := ErrorMessageExpectedActualTypes( + e.ExpectedType, + e.ActualType, + ) + return fmt.Sprintf( "invalid container update: expected a subtype of `%s`, found `%s`", - e.ExpectedType.QualifiedString(), - e.ActualType.QualifiedString(), + expected, + actual, ) } @@ -1307,7 +1313,7 @@ func (e *NestedReferenceError) SetLocationRange(locationRange LocationRange) { // NonOptionalReferenceToNilError type NonOptionalReferenceToNilError struct { - ReferenceType sema.Type + ReferenceType StaticType LocationRange } @@ -1354,7 +1360,7 @@ func (e *InclusiveRangeConstructionError) SetLocationRange(locationRange Locatio // InvalidCapabilityIssueTypeError type InvalidCapabilityIssueTypeError struct { ExpectedTypeDescription string - ActualType sema.Type + ActualType StaticType LocationRange } @@ -1367,7 +1373,7 @@ func (e *InvalidCapabilityIssueTypeError) Error() string { return fmt.Sprintf( "invalid type: expected %s, got `%s`", e.ExpectedTypeDescription, - e.ActualType.QualifiedString(), + e.ActualType.String(), ) } @@ -1496,8 +1502,8 @@ func (e *CallStackLimitExceededError) SetLocationRange(locationRange LocationRan // StoredValueTypeMismatchError type StoredValueTypeMismatchError struct { - ExpectedType sema.Type - ActualType sema.Type + ExpectedType StaticType + ActualType StaticType LocationRange } @@ -1507,7 +1513,7 @@ var _ HasLocationRange = &StoredValueTypeMismatchError{} func (*StoredValueTypeMismatchError) IsUserError() {} func (e *StoredValueTypeMismatchError) Error() string { - expected, actual := sema.ErrorMessageExpectedActualTypes( + expected, actual := ErrorMessageExpectedActualTypes( e.ExpectedType, e.ActualType, ) @@ -1522,3 +1528,21 @@ func (e *StoredValueTypeMismatchError) Error() string { func (e *StoredValueTypeMismatchError) SetLocationRange(locationRange LocationRange) { e.LocationRange = locationRange } + +func ErrorMessageExpectedActualTypes( + expectedType StaticType, + actualType StaticType, +) ( + expected string, + actual string, +) { + expected = expectedType.String() + actual = actualType.String() + + if expected == actual { + expected = string(expectedType.ID()) + actual = string(actualType.ID()) + } + + return +} diff --git a/interpreter/idcapability_test.go b/interpreter/idcapability_test.go index 3e28c0f28e..60c9776675 100644 --- a/interpreter/idcapability_test.go +++ b/interpreter/idcapability_test.go @@ -140,7 +140,11 @@ func TestInterpretIDCapability(t *testing.T) { noopReferenceTracker{}, interpreter.UnauthorizedAccess, interpreter.NewUnmeteredStringValue("mock"), - sema.NewReferenceType(nil, sema.UnauthorizedAccess, sema.StringType), + interpreter.NewReferenceStaticType( + nil, + interpreter.UnauthorizedAccess, + interpreter.PrimitiveStaticTypeString, + ), ) inter, err := test(t, @@ -154,8 +158,8 @@ func TestInterpretIDCapability(t *testing.T) { _ interpreter.BorrowCapabilityControllerContext, address interpreter.AddressValue, capabilityID interpreter.UInt64Value, - _ *sema.ReferenceType, - _ *sema.ReferenceType, + _ *interpreter.ReferenceStaticType, + _ *interpreter.ReferenceStaticType, ) interpreter.ReferenceValue { assert.Equal(t, interpreter.AddressValue{0x42}, address) assert.Equal(t, interpreter.UInt64Value(id), capabilityID) @@ -188,8 +192,8 @@ func TestInterpretIDCapability(t *testing.T) { _ interpreter.CheckCapabilityControllerContext, address interpreter.AddressValue, capabilityID interpreter.UInt64Value, - _ *sema.ReferenceType, - _ *sema.ReferenceType, + _ *interpreter.ReferenceStaticType, + _ *interpreter.ReferenceStaticType, ) interpreter.BoolValue { assert.Equal(t, interpreter.AddressValue{0x42}, address) assert.Equal(t, interpreter.UInt64Value(id), capabilityID) diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 6b6621d9b4..76c9bd36c1 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -119,8 +119,8 @@ type CapabilityBorrowHandlerFunc func( context BorrowCapabilityControllerContext, address AddressValue, capabilityID UInt64Value, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *ReferenceStaticType, + capabilityBorrowType *ReferenceStaticType, ) ReferenceValue // CapabilityCheckHandlerFunc is a function that is used to check ID capabilities. @@ -128,8 +128,8 @@ type CapabilityCheckHandlerFunc func( context CheckCapabilityControllerContext, address AddressValue, capabilityID UInt64Value, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *ReferenceStaticType, + capabilityBorrowType *ReferenceStaticType, ) BoolValue // InjectedCompositeFieldsHandlerFunc is a function that handles storage reads. @@ -165,8 +165,8 @@ type ValidateAccountCapabilitiesGetHandlerFunc func( context AccountCapabilityGetValidationContext, address AddressValue, path PathValue, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *ReferenceStaticType, + capabilityBorrowType *ReferenceStaticType, ) (bool, error) // ValidateAccountCapabilitiesPublishHandlerFunc is a function that is used to handle when a capability of an account is got. @@ -935,7 +935,9 @@ func (interpreter *Interpreter) resultValue(returnValue Value, returnType sema.T return auth } - if optionalType, ok := returnType.(*sema.OptionalType); ok { + returnStaticType := ConvertSemaToStaticType(interpreter, returnType) + + if optionalType, ok := returnStaticType.(*OptionalStaticType); ok { switch returnValue := returnValue.(type) { // If this value is an optional value (T?), then transform it into an optional reference (&T)?. case *SomeValue: @@ -958,7 +960,7 @@ func (interpreter *Interpreter) resultValue(returnValue Value, returnType sema.T interpreter, resultAuth(returnType), returnValue, - returnType, + returnStaticType, ) } @@ -1511,14 +1513,15 @@ func (interpreter *Interpreter) declareNonEnumCompositeValue( var self Value = value if declaration.Kind() == common.CompositeKindAttachment { - attachmentType := MustSemaTypeOfValue(value, invocationContext).(*sema.CompositeType) + attachmentStaticType := value.StaticType(invocationContext) + attachmentType := invocationContext.SemaTypeFromStaticType(attachmentStaticType).(*sema.CompositeType) // Self's type in the constructor is fully entitled, since // the constructor can only be called when in possession of the base resource access := attachmentType.SupportedEntitlements().Access() auth := ConvertSemaAccessToStaticAuthorization(invocationContext, access) - self = NewEphemeralReferenceValue(invocationContext, auth, value, attachmentType) + self = NewEphemeralReferenceValue(invocationContext, auth, value, attachmentStaticType) // set the base to the implicitly provided value, and remove this implicit argument from the list implicitArgumentPos := len(invocation.Arguments) - 1 @@ -1604,7 +1607,8 @@ func (interpreter *Interpreter) declareEnumLookupFunction( location := interpreter.Location - intType := sema.IntType + intType := PrimitiveStaticTypeInt + enumRawStaticType := ConvertSemaToStaticType(interpreter, compositeType.EnumRawType) enumCases := declaration.Members.EnumCases() caseValues := make([]EnumCase, len(enumCases)) @@ -1618,7 +1622,7 @@ func (interpreter *Interpreter) declareEnumLookupFunction( interpreter, NewIntValueFromInt64(interpreter, int64(i)), intType, - compositeType.EnumRawType, + enumRawStaticType, ).(IntegerValue) caseValueFields := []CompositeField{ @@ -1899,6 +1903,22 @@ func TransferAndConvert( value Value, valueType, targetType sema.Type, ) Value { + valueStaticType := ConvertSemaToStaticType(context, valueType) + targetStaticType := ConvertSemaToStaticType(context, targetType) + + return TransferAndConvertToStaticType( + context, + value, + valueStaticType, + targetStaticType, + ) +} + +func TransferAndConvertToStaticType( + context ValueConversionContext, + value Value, + valueType, targetType StaticType, +) Value { transferredValue := value.Transfer( context, @@ -1909,7 +1929,7 @@ func TransferAndConvert( true, // value is standalone. ) - return ConvertAndBoxWithValidation( + return ConvertAndBoxToStaticTypeWithValidation( context, transferredValue, valueType, @@ -1923,7 +1943,24 @@ func ConvertAndBoxWithValidation( valueType sema.Type, targetType sema.Type, ) Value { - result := ConvertAndBox( + valueStaticType := ConvertSemaToStaticType(context, valueType) + targetStaticType := ConvertSemaToStaticType(context, targetType) + + return ConvertAndBoxToStaticTypeWithValidation( + context, + transferredValue, + valueStaticType, + targetStaticType, + ) +} + +func ConvertAndBoxToStaticTypeWithValidation( + context ValueConversionContext, + transferredValue Value, + valueType StaticType, + targetType StaticType, +) Value { + result := ConvertAndBoxToStaticType( context, transferredValue, valueType, @@ -1934,13 +1971,11 @@ func ConvertAndBoxWithValidation( resultStaticType := result.StaticType(context) if targetType != nil && - !IsSubTypeOfSemaType(context, resultStaticType, targetType) { - - resultSemaType := context.SemaTypeFromStaticType(resultStaticType) + !IsSubType(context, resultStaticType, targetType) { panic(&ValueTransferTypeError{ ExpectedType: targetType, - ActualType: resultSemaType, + ActualType: resultStaticType, }) } @@ -1977,6 +2012,20 @@ func ConvertAndBox( context ValueCreationContext, value Value, valueType, targetType sema.Type, +) Value { + valueStaticType := ConvertSemaToStaticType(context, valueType) + targetStaticType := ConvertSemaToStaticType(context, targetType) + + value = convert(context, value, valueStaticType, targetStaticType) + return BoxOptional(context, value, targetStaticType) +} + +// ConvertAndBoxToStaticType converts a value to a target static type, +// and boxes in optionals and any value, if necessary. +func ConvertAndBoxToStaticType( + context ValueCreationContext, + value Value, + valueType, targetType StaticType, ) Value { value = convert(context, value, valueType, targetType) return BoxOptional(context, value, targetType) @@ -1989,20 +2038,20 @@ func ConvertAndBox( func convertStaticType( gauge common.MemoryGauge, valueStaticType StaticType, - targetSemaType sema.Type, + targetSemaType StaticType, ) StaticType { switch valueStaticType := valueStaticType.(type) { case *ReferenceStaticType: - if targetReferenceType, isReferenceType := targetSemaType.(*sema.ReferenceType); isReferenceType { + if targetReferenceType, isReferenceType := targetSemaType.(*ReferenceStaticType); isReferenceType { return NewReferenceStaticType( gauge, - ConvertSemaAccessToStaticAuthorization(gauge, targetReferenceType.Authorization), + targetReferenceType.Authorization, valueStaticType.ReferencedType, ) } case *OptionalStaticType: - if targetOptionalType, isOptionalType := targetSemaType.(*sema.OptionalType); isOptionalType { + if targetOptionalType, isOptionalType := targetSemaType.(*OptionalStaticType); isOptionalType { return NewOptionalStaticType( gauge, convertStaticType( @@ -2014,7 +2063,7 @@ func convertStaticType( } case *DictionaryStaticType: - if targetDictionaryType, isDictionaryType := targetSemaType.(*sema.DictionaryType); isDictionaryType { + if targetDictionaryType, isDictionaryType := targetSemaType.(*DictionaryStaticType); isDictionaryType { return NewDictionaryStaticType( gauge, convertStaticType( @@ -2031,7 +2080,7 @@ func convertStaticType( } case *VariableSizedStaticType: - if targetArrayType, isArrayType := targetSemaType.(*sema.VariableSizedType); isArrayType { + if targetArrayType, isArrayType := targetSemaType.(*VariableSizedStaticType); isArrayType { return NewVariableSizedStaticType( gauge, convertStaticType( @@ -2043,7 +2092,7 @@ func convertStaticType( } case *ConstantSizedStaticType: - if targetArrayType, isArrayType := targetSemaType.(*sema.ConstantSizedType); isArrayType { + if targetArrayType, isArrayType := targetSemaType.(*ConstantSizedStaticType); isArrayType { return NewConstantSizedStaticType( gauge, convertStaticType( @@ -2056,7 +2105,7 @@ func convertStaticType( } case *CapabilityStaticType: - if targetCapabilityType, isCapabilityType := targetSemaType.(*sema.CapabilityType); isCapabilityType { + if targetCapabilityType, isCapabilityType := targetSemaType.(*CapabilityStaticType); isCapabilityType { return NewCapabilityStaticType( gauge, convertStaticType( @@ -2074,16 +2123,16 @@ func convert( context ValueCreationContext, value Value, valueType, - targetType sema.Type, + targetType StaticType, ) Value { if valueType == nil { return value } - unwrappedTargetType := sema.UnwrapOptionalType(targetType) + unwrappedTargetType := UnwrapOptionalType(targetType) // if the value is optional, convert the inner value to the unwrapped target type - if optionalValueType, valueIsOptional := valueType.(*sema.OptionalType); valueIsOptional { + if optionalValueType, valueIsOptional := valueType.(*OptionalStaticType); valueIsOptional { switch value := value.(type) { case NilValue: return value @@ -2103,129 +2152,130 @@ func convert( } switch unwrappedTargetType { - case sema.IntType: + case PrimitiveStaticTypeInt: if !valueType.Equal(unwrappedTargetType) { return ConvertInt(context, value) } - case sema.UIntType: + case PrimitiveStaticTypeUInt: if !valueType.Equal(unwrappedTargetType) { return ConvertUInt(context, value) } // Int* - case sema.Int8Type: + case PrimitiveStaticTypeInt8: if !valueType.Equal(unwrappedTargetType) { return ConvertInt8(context, value) } - case sema.Int16Type: + case PrimitiveStaticTypeInt16: if !valueType.Equal(unwrappedTargetType) { return ConvertInt16(context, value) } - case sema.Int32Type: + case PrimitiveStaticTypeInt32: if !valueType.Equal(unwrappedTargetType) { return ConvertInt32(context, value) } - case sema.Int64Type: + case PrimitiveStaticTypeInt64: if !valueType.Equal(unwrappedTargetType) { return ConvertInt64(context, value) } - case sema.Int128Type: + case PrimitiveStaticTypeInt128: if !valueType.Equal(unwrappedTargetType) { return ConvertInt128(context, value) } - case sema.Int256Type: + case PrimitiveStaticTypeInt256: if !valueType.Equal(unwrappedTargetType) { return ConvertInt256(context, value) } // UInt* - case sema.UInt8Type: + case PrimitiveStaticTypeUInt8: if !valueType.Equal(unwrappedTargetType) { return ConvertUInt8(context, value) } - case sema.UInt16Type: + case PrimitiveStaticTypeUInt16: if !valueType.Equal(unwrappedTargetType) { return ConvertUInt16(context, value) } - case sema.UInt32Type: + case PrimitiveStaticTypeUInt32: if !valueType.Equal(unwrappedTargetType) { return ConvertUInt32(context, value) } - case sema.UInt64Type: + case PrimitiveStaticTypeUInt64: if !valueType.Equal(unwrappedTargetType) { return ConvertUInt64(context, value) } - case sema.UInt128Type: + case PrimitiveStaticTypeUInt128: if !valueType.Equal(unwrappedTargetType) { return ConvertUInt128(context, value) } - case sema.UInt256Type: + case PrimitiveStaticTypeUInt256: if !valueType.Equal(unwrappedTargetType) { return ConvertUInt256(context, value) } // Word* - case sema.Word8Type: + case PrimitiveStaticTypeWord8: if !valueType.Equal(unwrappedTargetType) { return ConvertWord8(context, value) } - case sema.Word16Type: + case PrimitiveStaticTypeWord16: if !valueType.Equal(unwrappedTargetType) { return ConvertWord16(context, value) } - case sema.Word32Type: + case PrimitiveStaticTypeWord32: if !valueType.Equal(unwrappedTargetType) { return ConvertWord32(context, value) } - case sema.Word64Type: + case PrimitiveStaticTypeWord64: if !valueType.Equal(unwrappedTargetType) { return ConvertWord64(context, value) } - case sema.Word128Type: + case PrimitiveStaticTypeWord128: if !valueType.Equal(unwrappedTargetType) { return ConvertWord128(context, value) } - case sema.Word256Type: + case PrimitiveStaticTypeWord256: if !valueType.Equal(unwrappedTargetType) { return ConvertWord256(context, value) } // Fix* - case sema.Fix64Type: + case PrimitiveStaticTypeFix64: if !valueType.Equal(unwrappedTargetType) { return ConvertFix64(context, value) } - case sema.UFix64Type: + case PrimitiveStaticTypeUFix64: if !valueType.Equal(unwrappedTargetType) { return ConvertUFix64(context, value) } - } - switch unwrappedTargetType := unwrappedTargetType.(type) { - case *sema.AddressType: + case PrimitiveStaticTypeAddress: if !valueType.Equal(unwrappedTargetType) { return ConvertAddress(context, value) } + } + + switch unwrappedTargetType := unwrappedTargetType.(type) { - case sema.ArrayType: + case ArrayStaticType: if arrayValue, isArray := value.(*ArrayValue); isArray && !valueType.Equal(unwrappedTargetType) { oldArrayStaticType := arrayValue.StaticType(context) @@ -2235,7 +2285,7 @@ func convert( return value } - targetElementType := context.SemaTypeFromStaticType(arrayStaticType.ElementType()) + targetElementType := arrayStaticType.ElementType() array := arrayValue.array @@ -2259,13 +2309,13 @@ func convert( } value := MustConvertStoredValue(context, element) - valueType := context.SemaTypeFromStaticType(value.StaticType(context)) + valueType := value.StaticType(context) return convert(context, value, valueType, targetElementType) }, ) } - case *sema.DictionaryType: + case *DictionaryStaticType: if dictValue, isDict := value.(*DictionaryValue); isDict && !valueType.Equal(unwrappedTargetType) { oldDictStaticType := dictValue.StaticType(context) @@ -2275,8 +2325,8 @@ func convert( return value } - targetKeyType := context.SemaTypeFromStaticType(dictStaticType.KeyType) - targetValueType := context.SemaTypeFromStaticType(dictStaticType.ValueType) + targetKeyType := dictStaticType.KeyType + targetValueType := dictStaticType.ValueType dictionary := dictValue.dictionary @@ -2304,8 +2354,8 @@ func convert( key := MustConvertStoredValue(context, k) value := MustConvertStoredValue(context, v) - keyType := context.SemaTypeFromStaticType(key.StaticType(context)) - valueType := context.SemaTypeFromStaticType(value.StaticType(context)) + keyType := key.StaticType(context) + valueType := value.StaticType(context) convertedKey := convert(context, key, keyType, targetKeyType) convertedValue := convert(context, value, valueType, targetValueType) @@ -2315,9 +2365,9 @@ func convert( ) } - case *sema.CapabilityType: + case *CapabilityStaticType: if !valueType.Equal(unwrappedTargetType) && unwrappedTargetType.BorrowType != nil { - targetBorrowType := unwrappedTargetType.BorrowType.(*sema.ReferenceType) + targetBorrowType := unwrappedTargetType.BorrowType.(*ReferenceStaticType) switch capability := value.(type) { case *IDCapabilityValue: @@ -2338,8 +2388,8 @@ func convert( } } - case *sema.ReferenceType: - targetAuthorization := ConvertSemaAccessToStaticAuthorization(context, unwrappedTargetType.Authorization) + case *ReferenceStaticType: + targetAuthorization := unwrappedTargetType.Authorization switch ref := value.(type) { case *EphemeralReferenceValue: if shouldConvertReference(ref, valueType, unwrappedTargetType, targetAuthorization) { @@ -2348,7 +2398,7 @@ func convert( context, targetAuthorization, ref.Value, - unwrappedTargetType.Type, + unwrappedTargetType.ReferencedType, ) } @@ -2360,7 +2410,7 @@ func convert( targetAuthorization, ref.TargetStorageAddress, ref.TargetPath, - unwrappedTargetType.Type, + unwrappedTargetType.ReferencedType, ) } @@ -2372,23 +2422,35 @@ func convert( return value } +// UnwrapOptionalType returns the type if it is not an optional type, +// or the inner-most type if it is (optional types are repeatedly unwrapped) +func UnwrapOptionalType(ty StaticType) StaticType { + for { + optionalType, ok := ty.(*OptionalStaticType) + if !ok { + return ty + } + ty = optionalType.Type + } +} + func shouldConvertReference( ref ReferenceValue, - valueType sema.Type, - unwrappedTargetType *sema.ReferenceType, + valueType StaticType, + unwrappedTargetType *ReferenceStaticType, targetAuthorization Authorization, ) bool { if !valueType.Equal(unwrappedTargetType) { return true } - return !ref.BorrowType().Equal(unwrappedTargetType.Type) || + return !ref.BorrowType().Equal(unwrappedTargetType.ReferencedType) || !ref.GetAuthorization().Equal(targetAuthorization) } -func checkMappedEntitlements(unwrappedTargetType *sema.ReferenceType) { +func checkMappedEntitlements(unwrappedTargetType *ReferenceStaticType) { // check defensively that we never create a runtime mapped entitlement value - if _, isMappedAuth := unwrappedTargetType.Authorization.(*sema.EntitlementMapAccess); isMappedAuth { + if _, isMappedAuth := unwrappedTargetType.Authorization.(*EntitlementMapAuthorization); isMappedAuth { panic(&UnexpectedMappedEntitlementError{ Type: unwrappedTargetType, }) @@ -2396,12 +2458,12 @@ func checkMappedEntitlements(unwrappedTargetType *sema.ReferenceType) { } // BoxOptional boxes a value in optionals, if necessary -func BoxOptional(gauge common.MemoryGauge, value Value, targetType sema.Type) Value { +func BoxOptional(gauge common.MemoryGauge, value Value, targetType StaticType) Value { inner := value for { - optionalType, ok := targetType.(*sema.OptionalType) + optionalType, ok := targetType.(*OptionalStaticType) if !ok { break } @@ -3693,9 +3755,10 @@ func ConstructDictionaryTypeValue( // if the given key is not a valid dictionary key, it wouldn't make sense to create this type if keyType == nil || - !sema.IsSubType( - context.SemaTypeFromStaticType(keyType), - sema.HashableStructType, + !IsSubType( + context, + keyType, + PrimitiveStaticTypeHashableStruct, ) { return Nil } @@ -3936,8 +3999,7 @@ func ConstructInclusiveRangeTypeValue( ty := typeValue.Type // InclusiveRanges must hold integers - elemSemaTy := context.SemaTypeFromStaticType(ty) - if !sema.IsSameTypeKind(elemSemaTy, sema.IntegerType) { + if !IsSameTypeKind(context, ty, PrimitiveStaticTypeInteger) { return Nil } @@ -4393,9 +4455,7 @@ func IsSubType(typeConverter TypeConverter, subType StaticType, superType Static return true } - semaType := typeConverter.SemaTypeFromStaticType(superType) - - return IsSubTypeOfSemaType(typeConverter, subType, semaType) + return CheckSubTypeWithoutEquality_gen(typeConverter, subType, superType) } func IsSubTypeOfSemaType(typeConverter TypeConverter, staticSubType StaticType, superType sema.Type) bool { @@ -4668,6 +4728,12 @@ func checkValue( } }() + // For all values, try to load the type and see if it's not broken. + _, valueError = ConvertStaticToSemaType(context, staticType) + if valueError != nil { + return valueError + } + // Here, the value at the path could be either: // 1) The actual stored value (storage path) // 2) A capability to the value at the storage (private/public paths) @@ -4680,13 +4746,7 @@ func checkValue( // Capability values always have a `CapabilityStaticType` static type. borrowType := staticType.(*CapabilityStaticType).BorrowType - var borrowSemaType sema.Type - borrowSemaType, valueError = ConvertStaticToSemaType(context, borrowType) - if valueError != nil { - return valueError - } - - referenceType, ok := borrowSemaType.(*sema.ReferenceType) + referenceType, ok := borrowType.(*ReferenceStaticType) if !ok { panic(errors.NewUnreachableError()) } @@ -4700,11 +4760,6 @@ func checkValue( referenceType, referenceType, ) - - } else { - // For all other values, trying to load the type is sufficient. - // Here it is only interested in whether the type can be properly loaded. - _, valueError = ConvertStaticToSemaType(context, staticType) } return @@ -4898,12 +4953,12 @@ func NativeAccountStorageReadFunction( args []Value, ) Value { address := GetAddressValue(receiver, addressPointer).ToAddress() - semaBorrowType := typeArguments.NextSema() + borrowType := typeArguments.NextStatic() return AccountStorageRead( context, args, - semaBorrowType, + borrowType, address, clear, ) @@ -4929,7 +4984,7 @@ func authAccountReadFunction( func AccountStorageRead( invocationContext InvocationContext, arguments []Value, - typeParameter sema.Type, + typeParameter StaticType, address common.Address, clear bool, ) Value { @@ -4954,12 +5009,10 @@ func AccountStorageRead( valueStaticType := value.StaticType(invocationContext) - if !IsSubTypeOfSemaType(invocationContext, valueStaticType, typeParameter) { - valueSemaType := invocationContext.SemaTypeFromStaticType(valueStaticType) - + if !IsSubType(invocationContext, valueStaticType, typeParameter) { panic(&StoredValueTypeMismatchError{ ExpectedType: typeParameter, - ActualType: valueSemaType, + ActualType: valueStaticType, }) } @@ -4999,7 +5052,7 @@ func NativeAccountStorageBorrowFunction( args []Value, ) Value { address := GetAddressValue(receiver, addressPointer).ToAddress() - typeParameter := typeArguments.NextSema() + typeParameter := typeArguments.NextStatic() return AccountStorageBorrow( context, @@ -5027,7 +5080,7 @@ func authAccountStorageBorrowFunction( func AccountStorageBorrow( invocationContext InvocationContext, arguments []Value, - typeParameter sema.Type, + typeParameter StaticType, address common.Address, ) Value { path, ok := arguments[0].(PathValue) @@ -5035,17 +5088,17 @@ func AccountStorageBorrow( panic(errors.NewUnreachableError()) } - referenceType, ok := typeParameter.(*sema.ReferenceType) + referenceType, ok := typeParameter.(*ReferenceStaticType) if !ok { panic(errors.NewUnreachableError()) } reference := NewStorageReferenceValue( invocationContext, - ConvertSemaAccessToStaticAuthorization(invocationContext, referenceType.Authorization), + referenceType.Authorization, address, path, - referenceType.Type, + referenceType.ReferencedType, ) // Attempt to dereference, @@ -5572,16 +5625,14 @@ func setMember( func ExpectType( context ValueStaticTypeContext, value Value, - expectedType sema.Type, + expectedType StaticType, ) { valueStaticType := value.StaticType(context) - if !IsSubTypeOfSemaType(context, valueStaticType, expectedType) { - valueSemaType := context.SemaTypeFromStaticType(valueStaticType) - + if !IsSubType(context, valueStaticType, expectedType) { panic(&TypeMismatchError{ ExpectedType: expectedType, - ActualType: valueSemaType, + ActualType: valueStaticType, }) } } @@ -5595,8 +5646,8 @@ func checkContainerMutation( if !IsSubType(context, actualElementType, elementType) { panic(&ContainerMutationError{ - ExpectedType: context.SemaTypeFromStaticType(elementType), - ActualType: MustSemaTypeOfValue(element, context), + ExpectedType: elementType, + ActualType: actualElementType, }) } } @@ -5986,7 +6037,7 @@ func (interpreter *Interpreter) Storage() Storage { func NativeCapabilityBorrowFunction( addressValuePointer *AddressValue, capabilityIDPointer *UInt64Value, - capabilityBorrowTypePointer *sema.ReferenceType, + capabilityBorrowTypePointer *ReferenceStaticType, ) NativeFunction { return func( context NativeFunctionContext, @@ -5994,7 +6045,7 @@ func NativeCapabilityBorrowFunction( receiver Value, args []Value, ) Value { - var capabilityBorrowType *sema.ReferenceType + var capabilityBorrowType *ReferenceStaticType var capabilityID UInt64Value var addressValue AddressValue @@ -6019,7 +6070,7 @@ func NativeCapabilityBorrowFunction( return Nil } - capabilityBorrowType = context.SemaTypeFromStaticType(idCapabilityValue.BorrowType).(*sema.ReferenceType) + capabilityBorrowType = idCapabilityValue.BorrowType.(*ReferenceStaticType) addressValue = idCapabilityValue.Address() } else { capabilityBorrowType = capabilityBorrowTypePointer @@ -6027,7 +6078,7 @@ func NativeCapabilityBorrowFunction( addressValue = *addressValuePointer } - typeArgument := typeArguments.NextSema() + typeArgument := typeArguments.NextStatic() return CapabilityBorrow( context, @@ -6044,32 +6095,34 @@ func capabilityBorrowFunction( capabilityValue CapabilityValue, addressValue AddressValue, capabilityID UInt64Value, - capabilityBorrowType *sema.ReferenceType, + capabilityBorrowType *ReferenceStaticType, ) FunctionValue { + capabilityBorrowSemaType := context.SemaTypeFromStaticType(capabilityBorrowType) + return NewBoundHostFunctionValue( context, capabilityValue, - sema.CapabilityTypeBorrowFunctionType(capabilityBorrowType), + sema.CapabilityTypeBorrowFunctionType(capabilityBorrowSemaType), NativeCapabilityBorrowFunction(&addressValue, &capabilityID, capabilityBorrowType), ) } func CapabilityBorrow( invocationContext InvocationContext, - typeArgument sema.Type, + typeArgument StaticType, addressValue AddressValue, capabilityID UInt64Value, - capabilityBorrowType *sema.ReferenceType, + capabilityBorrowType *ReferenceStaticType, ) Value { if capabilityID == InvalidCapabilityID { return Nil } - var wantedBorrowType *sema.ReferenceType + var wantedBorrowType *ReferenceStaticType if typeArgument != nil { var ok bool - wantedBorrowType, ok = typeArgument.(*sema.ReferenceType) + wantedBorrowType, ok = typeArgument.(*ReferenceStaticType) if !ok { panic(errors.NewUnreachableError()) } @@ -6093,7 +6146,7 @@ func CapabilityBorrow( func NativeCapabilityCheckFunction( addressValuePointer *AddressValue, capabilityIDPointer *UInt64Value, - capabilityBorrowTypePointer *sema.ReferenceType, + capabilityBorrowTypePointer *ReferenceStaticType, ) NativeFunction { return func( context NativeFunctionContext, @@ -6101,7 +6154,7 @@ func NativeCapabilityCheckFunction( receiver Value, args []Value, ) Value { - var capabilityBorrowType *sema.ReferenceType + var capabilityBorrowType *ReferenceStaticType var capabilityID UInt64Value var addressValue AddressValue @@ -6127,7 +6180,7 @@ func NativeCapabilityCheckFunction( return FalseValue } - capabilityBorrowType = context.SemaTypeFromStaticType(idCapabilityValue.BorrowType).(*sema.ReferenceType) + capabilityBorrowType = idCapabilityValue.BorrowType.(*ReferenceStaticType) addressValue = idCapabilityValue.Address() } else { capabilityBorrowType = capabilityBorrowTypePointer @@ -6135,7 +6188,7 @@ func NativeCapabilityCheckFunction( addressValue = *addressValuePointer } - typeArgument := typeArguments.NextSema() + typeArgument := typeArguments.NextStatic() return CapabilityCheck( context, @@ -6152,33 +6205,35 @@ func capabilityCheckFunction( capabilityValue CapabilityValue, addressValue AddressValue, capabilityID UInt64Value, - capabilityBorrowType *sema.ReferenceType, + capabilityBorrowType *ReferenceStaticType, ) FunctionValue { + capabilityBorrowSemaType := context.SemaTypeFromStaticType(capabilityBorrowType) + return NewBoundHostFunctionValue( context, capabilityValue, - sema.CapabilityTypeCheckFunctionType(capabilityBorrowType), + sema.CapabilityTypeCheckFunctionType(capabilityBorrowSemaType), NativeCapabilityCheckFunction(&addressValue, &capabilityID, capabilityBorrowType), ) } func CapabilityCheck( invocationContext InvocationContext, - typeArgument sema.Type, + typeArgument StaticType, addressValue AddressValue, capabilityID UInt64Value, - capabilityBorrowType *sema.ReferenceType, + capabilityBorrowType *ReferenceStaticType, ) Value { if capabilityID == InvalidCapabilityID { return FalseValue } - var wantedBorrowType *sema.ReferenceType + var wantedBorrowType *ReferenceStaticType if typeArgument != nil { var ok bool - wantedBorrowType, ok = typeArgument.(*sema.ReferenceType) + wantedBorrowType, ok = typeArgument.(*ReferenceStaticType) if !ok { panic(errors.NewUnreachableError()) } @@ -6381,6 +6436,10 @@ func (interpreter *Interpreter) SemaTypeFromStaticType(staticType StaticType) se return MustConvertStaticToSemaType(staticType, interpreter) } +func (interpreter *Interpreter) SemaAccessFromStaticAuthorization(auth Authorization) (sema.Access, error) { + return ConvertStaticAuthorizationToSemaAccess(auth, interpreter) +} + func (interpreter *Interpreter) MaybeUpdateStorageReferenceMemberReceiver( storageReference *StorageReferenceValue, referencedValue Value, @@ -6398,10 +6457,6 @@ func (interpreter *Interpreter) MaybeUpdateStorageReferenceMemberReceiver( return member } -func (interpreter *Interpreter) SemaAccessFromStaticAuthorization(auth Authorization) (sema.Access, error) { - return ConvertStaticAuthorizationToSemaAccess(auth, interpreter) -} - func StorageReference( context ValueStaticTypeContext, storageReference *StorageReferenceValue, @@ -6428,6 +6483,6 @@ func StorageReference( storageReference.Authorization, storageReference.TargetStorageAddress, storageReference.TargetPath, - context.SemaTypeFromStaticType(referencedValueStaticType), + referencedValueStaticType, ) } diff --git a/interpreter/interpreter_expression.go b/interpreter/interpreter_expression.go index 316e6d730a..15ed1c5f0d 100644 --- a/interpreter/interpreter_expression.go +++ b/interpreter/interpreter_expression.go @@ -320,54 +320,51 @@ func (interpreter *Interpreter) checkMemberAccess( memberInfo, _ := interpreter.Program.Elaboration.MemberExpressionMemberAccessInfo(memberExpression) expectedType := memberInfo.AccessedType + expectedStaticType := ConvertSemaToStaticType(interpreter, expectedType) + CheckMemberAccessTargetType( interpreter, target, - expectedType, + expectedStaticType, ) } func CheckMemberAccessTargetType( context ValueStaticTypeContext, target Value, - expectedType sema.Type, + expectedType StaticType, ) { - switch expectedType := expectedType.(type) { - case *sema.TransactionType: - // TODO: maybe also check transactions. - // they are composites with a type ID which has an empty qualified ID, i.e. no type is available - - return + targetStaticType := target.StaticType(context) - case *sema.CompositeType: - // TODO: also check built-in values. - // blocked by standard library values (RLP, BLS, etc.), - // which are implemented as contracts, but currently do not have their type registered + switch expectedType := expectedType.(type) { + case *CompositeStaticType: + switch expectedType.Location.(type) { + case nil: + // `location == nil` means this is a built-in type. Skip them for now. + // TODO: also check built-in values. + // blocked by standard library values (RLP, BLS, etc.), + // which are implemented as contracts, but currently do not have their type registered + return - if expectedType.Location == nil { + case common.TransactionLocation: + // Also skip transactions for now. + // TODO: Check transactions. return } - } - targetStaticType := target.StaticType(context) - - if _, ok := expectedType.(*sema.OptionalType); ok { + case *OptionalStaticType: if _, ok := targetStaticType.(*OptionalStaticType); !ok { - targetSemaType := MustConvertStaticToSemaType(targetStaticType, context) - panic(&MemberAccessTypeError{ ExpectedType: expectedType, - ActualType: targetSemaType, + ActualType: targetStaticType, }) } } - if !IsSubTypeOfSemaType(context, targetStaticType, expectedType) { - targetSemaType := MustConvertStaticToSemaType(targetStaticType, context) - + if !IsSubType(context, targetStaticType, expectedType) { panic(&MemberAccessTypeError{ ExpectedType: expectedType, - ActualType: targetSemaType, + ActualType: targetStaticType, }) } } @@ -1248,6 +1245,7 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx castingExpressionTypes := interpreter.Program.Elaboration.CastingExpressionTypes(expression) expectedType := castingExpressionTypes.TargetType + expectedStaticType := ConvertSemaToStaticType(interpreter, expectedType) switch expression.Operation { case ast.OperationFailableCast, ast.OperationForceCast: @@ -1267,9 +1265,9 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx // otherwise dynamic cast now always unboxes optionals value = Unbox(value) } - valueSemaType := MustSemaTypeOfValue(value, interpreter) - valueStaticType := ConvertSemaToStaticType(interpreter, valueSemaType) - isSubType := IsSubTypeOfSemaType(interpreter, valueStaticType, expectedType) + + valueStaticType := value.StaticType(interpreter) + isSubType := IsSubType(interpreter, valueStaticType, expectedStaticType) switch expression.Operation { case ast.OperationFailableCast: @@ -1280,8 +1278,8 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx case ast.OperationForceCast: if !isSubType { panic(&ForceCastTypeMismatchError{ - ExpectedType: expectedType, - ActualType: valueSemaType, + ExpectedType: expectedStaticType, + ActualType: valueStaticType, }) } @@ -1290,7 +1288,7 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx } // The failable cast may upcast to an optional type, e.g. `1 as? Int?`, so box - value = ConvertAndBox(interpreter, value, valueSemaType, expectedType) + value = ConvertAndBoxToStaticType(interpreter, value, valueStaticType, expectedStaticType) if expression.Operation == ast.OperationFailableCast { // Failable casting is a resource invalidation @@ -1302,9 +1300,9 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx return value case ast.OperationCast: - staticValueType := castingExpressionTypes.StaticValueType + staticValueType := ConvertSemaToStaticType(interpreter, castingExpressionTypes.StaticValueType) // The cast may upcast to an optional type, e.g. `1 as Int?`, so box - return ConvertAndBox(interpreter, value, staticValueType, expectedType) + return ConvertAndBoxToStaticType(interpreter, value, staticValueType, expectedStaticType) default: panic(errors.NewUnreachableError()) @@ -1345,6 +1343,21 @@ func CreateReferenceValue( value Value, isImplicit bool, ) Value { + borrowStaticType := ConvertSemaToStaticType(context, borrowType) + return CreateReferenceValueFromStaticType( + context, + borrowStaticType, + value, + isImplicit, + ) +} + +func CreateReferenceValueFromStaticType( + context ReferenceCreationContext, + borrowType StaticType, + value Value, + isImplicit bool, +) Value { // There are four potential cases: // (1) Target type is optional, actual value is also optional @@ -1357,7 +1370,7 @@ func CreateReferenceValue( // (4) Target type is non-optional, actual value is non-optional switch typ := borrowType.(type) { - case *sema.OptionalType: + case *OptionalStaticType: innerType := typ.Type @@ -1371,7 +1384,7 @@ func CreateReferenceValue( innerValue := value.InnerValue() - referenceValue := CreateReferenceValue(context, innerType, innerValue, false) + referenceValue := CreateReferenceValueFromStaticType(context, innerType, innerValue, false) // Wrap the reference with an optional (since an optional is expected). return NewSomeValueNonCopying(context, referenceValue) @@ -1385,20 +1398,20 @@ func CreateReferenceValue( // Case (2): // If the referenced value is non-optional, // but the target type is optional. - referenceValue := CreateReferenceValue(context, innerType, value, false) + referenceValue := CreateReferenceValueFromStaticType(context, innerType, value, false) // Wrap the reference with an optional (since an optional is expected). return NewSomeValueNonCopying(context, referenceValue) } - case *sema.ReferenceType: + case *ReferenceStaticType: switch value := value.(type) { case *SomeValue: // Case (3.a): target type is non-optional, actual value is optional. innerValue := value.InnerValue() - return CreateReferenceValue(context, typ, innerValue, false) + return CreateReferenceValueFromStaticType(context, typ, innerValue, false) case NilValue: // Case (3.b) value is nil. @@ -1416,10 +1429,10 @@ func CreateReferenceValue( // Additionally, it is only safe to "compress" reference types like this when the desired // result reference type is unauthorized staticType := value.StaticType(context) - if typ.Authorization != sema.UnauthorizedAccess || !IsSubTypeOfSemaType(context, staticType, typ) { + if typ.Authorization != UnauthorizedAccess || !IsSubType(context, staticType, typ) { panic(&InvalidMemberReferenceError{ ExpectedType: typ, - ActualType: MustConvertStaticToSemaType(staticType, context), + ActualType: staticType, }) } @@ -1440,16 +1453,13 @@ func CreateReferenceValue( func newEphemeralReference( context ReferenceCreationContext, value Value, - typ *sema.ReferenceType, + typ *ReferenceStaticType, ) *EphemeralReferenceValue { - - auth := ConvertSemaAccessToStaticAuthorization(context, typ.Authorization) - return NewEphemeralReferenceValue( context, - auth, + typ.Authorization, value, - typ.Type, + typ.ReferencedType, ) } @@ -1511,11 +1521,13 @@ func (interpreter *Interpreter) VisitAttachExpression(attachExpression *ast.Atta attachmentType := interpreter.Program.Elaboration.AttachTypes(attachExpression) + baseStaticType := base.StaticType(interpreter) + baseValue := NewEphemeralReferenceValue( interpreter, auth, base, - MustSemaTypeOfValue(base, interpreter).(*sema.CompositeType), + baseStaticType, ) attachment, ok := interpreter.visitInvocationExpressionWithImplicitArgument( diff --git a/interpreter/interpreter_import.go b/interpreter/interpreter_import.go index 079fe37e21..47743a0184 100644 --- a/interpreter/interpreter_import.go +++ b/interpreter/interpreter_import.go @@ -106,16 +106,12 @@ func (interpreter *Interpreter) importResolvedLocation(resolvedLocation sema.Res } staticType := compositeValue.StaticType(interpreter) - semaType, err := ConvertStaticToSemaType(interpreter, staticType) - if err != nil { - panic(err) - } return NewEphemeralReferenceValue( interpreter, UnauthorizedAccess, compositeValue, - semaType, + staticType, ) } diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index db6f2813a1..b15772cec5 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -113,7 +113,9 @@ func TestInterpreterOptionalBoxing(t *testing.T) { value := BoxOptional( inter, TrueValue, - &sema.OptionalType{Type: sema.BoolType}, + &OptionalStaticType{ + Type: PrimitiveStaticTypeBool, + }, ) assert.Equal(t, NewUnmeteredSomeValueNonCopying(TrueValue), @@ -127,7 +129,9 @@ func TestInterpreterOptionalBoxing(t *testing.T) { value := BoxOptional( inter, NewUnmeteredSomeValueNonCopying(TrueValue), - &sema.OptionalType{Type: sema.BoolType}, + &OptionalStaticType{ + Type: PrimitiveStaticTypeBool, + }, ) assert.Equal(t, NewUnmeteredSomeValueNonCopying(TrueValue), @@ -141,9 +145,9 @@ func TestInterpreterOptionalBoxing(t *testing.T) { value := BoxOptional( inter, NewUnmeteredSomeValueNonCopying(TrueValue), - &sema.OptionalType{ - Type: &sema.OptionalType{ - Type: sema.BoolType, + &OptionalStaticType{ + Type: &OptionalStaticType{ + Type: PrimitiveStaticTypeBool, }, }, ) @@ -162,9 +166,9 @@ func TestInterpreterOptionalBoxing(t *testing.T) { value := BoxOptional( inter, Nil, - &sema.OptionalType{ - Type: &sema.OptionalType{ - Type: sema.BoolType, + &OptionalStaticType{ + Type: &OptionalStaticType{ + Type: PrimitiveStaticTypeBool, }, }, ) @@ -181,9 +185,9 @@ func TestInterpreterOptionalBoxing(t *testing.T) { value := BoxOptional( inter, NewUnmeteredSomeValueNonCopying(Nil), - &sema.OptionalType{ - Type: &sema.OptionalType{ - Type: sema.BoolType, + &OptionalStaticType{ + Type: &OptionalStaticType{ + Type: PrimitiveStaticTypeBool, }, }, ) diff --git a/interpreter/interpreter_transaction.go b/interpreter/interpreter_transaction.go index cd5752e411..96c9021b67 100644 --- a/interpreter/interpreter_transaction.go +++ b/interpreter/interpreter_transaction.go @@ -51,8 +51,13 @@ func (interpreter *Interpreter) declareTransactionEntryPoint(declaration *ast.Tr postConditionsRewrite := interpreter.Program.Elaboration.PostConditionsRewrite(declaration.PostConditions) - const qualifiedIdentifier = "" - staticType := NewCompositeStaticTypeComputeTypeID(interpreter, interpreter.Location, qualifiedIdentifier) + staticType := NewCompositeStaticTypeComputeTypeID( + interpreter, + interpreter.Location, + + // This is to be consistent with `sema.TransactionType` + sema.TransactionTypeName, + ) self := NewSimpleCompositeValue( interpreter, diff --git a/interpreter/member_test.go b/interpreter/member_test.go index 7c95e9f48d..dc603e3de3 100644 --- a/interpreter/member_test.go +++ b/interpreter/member_test.go @@ -422,12 +422,13 @@ func TestInterpretMemberAccessType(t *testing.T) { require.NoError(t, err) sType := RequireGlobalType(t, inter, "S") + sStaticType := interpreter.ConvertSemaToStaticType(nil, sType) ref := interpreter.NewUnmeteredEphemeralReferenceValue( inter, interpreter.UnauthorizedAccess, value, - sType, + sStaticType, ) _, err = inter.Invoke("get", ref) @@ -474,12 +475,13 @@ func TestInterpretMemberAccessType(t *testing.T) { require.NoError(t, err) sType := RequireGlobalType(t, inter, "S") + sStaticType := interpreter.ConvertSemaToStaticType(nil, sType) ref := interpreter.NewUnmeteredEphemeralReferenceValue( inter, interpreter.UnauthorizedAccess, value, - sType, + sStaticType, ) _, err = inter.Invoke("get", ref) @@ -521,12 +523,13 @@ func TestInterpretMemberAccessType(t *testing.T) { require.NoError(t, err) sType := RequireGlobalType(t, inter, "S") + sStaticType := interpreter.ConvertSemaToStaticType(nil, sType) ref := interpreter.NewUnmeteredEphemeralReferenceValue( inter, interpreter.UnauthorizedAccess, value, - sType, + sStaticType, ) _, err = inter.Invoke( @@ -571,12 +574,13 @@ func TestInterpretMemberAccessType(t *testing.T) { require.NoError(t, err) sType := RequireGlobalType(t, inter, "S") + sStaticType := interpreter.ConvertSemaToStaticType(nil, sType) ref := interpreter.NewUnmeteredEphemeralReferenceValue( inter, interpreter.UnauthorizedAccess, value, - sType, + sStaticType, ) _, err = inter.Invoke( diff --git a/interpreter/memory_metering_test.go b/interpreter/memory_metering_test.go index a256cd058d..ec4ea2fd06 100644 --- a/interpreter/memory_metering_test.go +++ b/interpreter/memory_metering_test.go @@ -176,8 +176,8 @@ func TestInterpretMemoryMeteringArray(t *testing.T) { // 1 Int8 for type // 2 String: 1 for type, 1 for value // 3 Bool: 1 for type, 2 for value - assert.Equal(t, uint64(6), meter.getMemory(common.MemoryKindPrimitiveStaticType)) - assert.Equal(t, uint64(10), meter.getMemory(common.MemoryKindVariableSizedStaticType)) + assert.Equal(t, uint64(18), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, uint64(30), meter.getMemory(common.MemoryKindVariableSizedStaticType)) } }) @@ -210,8 +210,8 @@ func TestInterpretMemoryMeteringArray(t *testing.T) { assert.Equal(t, uint64(8), meter.getMemory(common.MemoryKindVariable)) // 4 Int8: 1 for type, 3 for values - assert.Equal(t, uint64(4), meter.getMemory(common.MemoryKindPrimitiveStaticType)) - assert.Equal(t, uint64(5), meter.getMemory(common.MemoryKindVariableSizedStaticType)) + assert.Equal(t, uint64(18), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, uint64(21), meter.getMemory(common.MemoryKindVariableSizedStaticType)) } }) @@ -236,7 +236,7 @@ func TestInterpretMemoryMeteringArray(t *testing.T) { assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindAtreeArrayDataSlab)) assert.Equal(t, uint64(0), meter.getMemory(common.MemoryKindAtreeArrayMetaDataSlab)) assert.Equal(t, uint64(0), meter.getMemory(common.MemoryKindAtreeArrayElementOverhead)) - assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, ifCompile[uint64](2, 9), meter.getMemory(common.MemoryKindPrimitiveStaticType)) }) t.Run("append with packing", func(t *testing.T) { @@ -366,11 +366,11 @@ func TestInterpretMemoryMeteringArray(t *testing.T) { assert.Equal(t, uint64(0), meter.getMemory(common.MemoryKindAtreeArrayMetaDataSlab)) assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindAtreeArrayElementOverhead)) - assert.Equal(t, ifCompile[uint64](10, 7), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, ifCompile[uint64](10, 23), meter.getMemory(common.MemoryKindPrimitiveStaticType)) // TODO: assert equivalent for compiler/VM if !*compile { - assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindVariableSizedStaticType)) + assert.Equal(t, uint64(5), meter.getMemory(common.MemoryKindVariableSizedStaticType)) } }) @@ -451,7 +451,7 @@ func TestInterpretMemoryMeteringArray(t *testing.T) { // TODO: assert equivalent for compiler/VM if !*compile { - assert.Equal(t, uint64(12), meter.getMemory(common.MemoryKindConstantSizedStaticType)) + assert.Equal(t, uint64(36), meter.getMemory(common.MemoryKindConstantSizedStaticType)) } }) @@ -485,11 +485,11 @@ func TestInterpretMemoryMeteringArray(t *testing.T) { // 1 Int8 for `w` element // 2 Int8 for `r` elements // 2 Int8 for `q` elements - assert.Equal(t, ifCompile[uint64](30, 19), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, ifCompile[uint64](30, 63), meter.getMemory(common.MemoryKindPrimitiveStaticType)) // TODO: assert equivalent for compiler/VM if !*compile { - assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindVariableSizedStaticType)) + assert.Equal(t, uint64(9), meter.getMemory(common.MemoryKindVariableSizedStaticType)) } }) } @@ -519,12 +519,12 @@ func TestInterpretMemoryMeteringDictionary(t *testing.T) { assert.Equal(t, uint64(8), meter.getMemory(common.MemoryKindAtreeMapDataSlab)) assert.Equal(t, uint64(0), meter.getMemory(common.MemoryKindAtreeMapMetaDataSlab)) assert.Equal(t, uint64(159), meter.getMemory(common.MemoryKindAtreeMapPreAllocatedElement)) - assert.Equal(t, ifCompile[uint64](3, 9), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, ifCompile[uint64](3, 25), meter.getMemory(common.MemoryKindPrimitiveStaticType)) // TODO: assert equivalent for compiler/VM if !*compile { assert.Equal(t, uint64(3), meter.getMemory(common.MemoryKindVariable)) - assert.Equal(t, uint64(4), meter.getMemory(common.MemoryKindDictionaryStaticType)) + assert.Equal(t, uint64(12), meter.getMemory(common.MemoryKindDictionaryStaticType)) } }) @@ -556,8 +556,8 @@ func TestInterpretMemoryMeteringDictionary(t *testing.T) { // 4 Int8: 1 for type, 3 for values // 4 String: 1 for type, 3 for values - assert.Equal(t, uint64(8), meter.getMemory(common.MemoryKindPrimitiveStaticType)) - assert.Equal(t, uint64(4), meter.getMemory(common.MemoryKindDictionaryStaticType)) + assert.Equal(t, uint64(36), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, uint64(18), meter.getMemory(common.MemoryKindDictionaryStaticType)) } }) @@ -579,7 +579,7 @@ func TestInterpretMemoryMeteringDictionary(t *testing.T) { _, err = inter.Invoke("main") require.NoError(t, err) - assert.Equal(t, ifCompile[uint64](2, 3), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, ifCompile[uint64](2, 13), meter.getMemory(common.MemoryKindPrimitiveStaticType)) }) t.Run("insert", func(t *testing.T) { @@ -606,11 +606,11 @@ func TestInterpretMemoryMeteringDictionary(t *testing.T) { assert.Equal(t, uint64(0), meter.getMemory(common.MemoryKindAtreeMapMetaDataSlab)) assert.Equal(t, uint64(32), meter.getMemory(common.MemoryKindAtreeMapPreAllocatedElement)) - assert.Equal(t, ifCompile[uint64](12, 10), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, ifCompile[uint64](12, 30), meter.getMemory(common.MemoryKindPrimitiveStaticType)) // TODO: assert equivalent for compiler/VM if !*compile { - assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindDictionaryStaticType)) + assert.Equal(t, uint64(5), meter.getMemory(common.MemoryKindDictionaryStaticType)) } }) @@ -769,7 +769,7 @@ func TestInterpretMemoryMeteringComposite(t *testing.T) { assert.Equal(t, uint64(0), meter.getMemory(common.MemoryKindAtreeMapMetaDataSlab)) assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindAtreeMapElementOverhead)) assert.Equal(t, uint64(32), meter.getMemory(common.MemoryKindAtreeMapPreAllocatedElement)) - assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindCompositeStaticType)) + assert.Equal(t, ifCompile[uint64](2, 12), meter.getMemory(common.MemoryKindCompositeStaticType)) assert.Equal(t, uint64(4), meter.getMemory(common.MemoryKindCompositeTypeInfo)) // TODO: assert equivalent for compiler/VM @@ -806,7 +806,7 @@ func TestInterpretMemoryMeteringComposite(t *testing.T) { assert.Equal(t, uint64(0), meter.getMemory(common.MemoryKindAtreeMapElementOverhead)) assert.Equal(t, uint64(480), meter.getMemory(common.MemoryKindAtreeMapPreAllocatedElement)) - assert.Equal(t, ifCompile[uint64](6, 7), meter.getMemory(common.MemoryKindCompositeStaticType)) + assert.Equal(t, ifCompile[uint64](6, 27), meter.getMemory(common.MemoryKindCompositeStaticType)) assert.Equal(t, uint64(18), meter.getMemory(common.MemoryKindCompositeTypeInfo)) assert.Equal(t, uint64(0), meter.getMemory(common.MemoryKindCompositeField)) @@ -1485,11 +1485,11 @@ func TestInterpretMemoryMeteringOptionalValue(t *testing.T) { // 2 for `z` assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindOptionalValue)) - assert.Equal(t, ifCompile[uint64](20, 14), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, ifCompile[uint64](20, 34), meter.getMemory(common.MemoryKindPrimitiveStaticType)) // TODO: assert equivalent for compiler/VM if !*compile { - assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindDictionaryStaticType)) + assert.Equal(t, uint64(3), meter.getMemory(common.MemoryKindDictionaryStaticType)) } }) @@ -9048,7 +9048,7 @@ func TestInterpretMemoryMeteringIdentifier(t *testing.T) { _, err = inter.Invoke("main") require.NoError(t, err) assert.Equal(t, uint64(14), meter.getMemory(common.MemoryKindIdentifier)) - assert.Equal(t, ifCompile[uint64](4, 3), meter.getMemory(common.MemoryKindPrimitiveStaticType)) + assert.Equal(t, ifCompile[uint64](4, 17), meter.getMemory(common.MemoryKindPrimitiveStaticType)) }) } @@ -9127,7 +9127,7 @@ func TestInterpretFunctionStaticType(t *testing.T) { // TODO: assert equivalent for compiler/VM if !*compile { - assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindFunctionStaticType)) + assert.Equal(t, uint64(6), meter.getMemory(common.MemoryKindFunctionStaticType)) } }) @@ -9152,7 +9152,7 @@ func TestInterpretFunctionStaticType(t *testing.T) { _, err = inter.Invoke("main") require.NoError(t, err) - assert.Equal(t, ifCompile[uint64](2, 1), meter.getMemory(common.MemoryKindFunctionStaticType)) + assert.Equal(t, ifCompile[uint64](2, 3), meter.getMemory(common.MemoryKindFunctionStaticType)) }) t.Run("isInstance", func(t *testing.T) { @@ -9178,7 +9178,7 @@ func TestInterpretFunctionStaticType(t *testing.T) { assert.Equal( t, - ifCompile[uint64](2, 3), + ifCompile[uint64](2, 4), meter.getMemory(common.MemoryKindFunctionStaticType), ) }) @@ -9287,13 +9287,13 @@ func TestInterpretMemoryMeteringStaticTypeConversion(t *testing.T) { _, err = inter.Invoke("main") require.NoError(t, err) - assert.Equal(t, ifCompile[uint64](1, 2), meter.getMemory(common.MemoryKindDictionarySemaType)) - assert.Equal(t, ifCompile[uint64](2, 4), meter.getMemory(common.MemoryKindVariableSizedSemaType)) - assert.Equal(t, ifCompile[uint64](1, 2), meter.getMemory(common.MemoryKindConstantSizedSemaType)) - assert.Equal(t, ifCompile[uint64](2, 3), meter.getMemory(common.MemoryKindIntersectionSemaType)) - assert.Equal(t, ifCompile[uint64](2, 4), meter.getMemory(common.MemoryKindReferenceSemaType)) - assert.Equal(t, ifCompile[uint64](1, 2), meter.getMemory(common.MemoryKindCapabilitySemaType)) - assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindOptionalSemaType)) + assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindDictionarySemaType)) + assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindVariableSizedSemaType)) + assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindConstantSizedSemaType)) + assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindIntersectionSemaType)) + assert.Equal(t, uint64(2), meter.getMemory(common.MemoryKindReferenceSemaType)) + assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindCapabilitySemaType)) + assert.Equal(t, uint64(1), meter.getMemory(common.MemoryKindOptionalSemaType)) }) } diff --git a/interpreter/misc_test.go b/interpreter/misc_test.go index 5017c853ed..c7c4bdc635 100644 --- a/interpreter/misc_test.go +++ b/interpreter/misc_test.go @@ -4965,15 +4965,16 @@ func TestInterpretReferenceFailableDowncasting(t *testing.T) { ) } - riType := getType("RI").(*sema.InterfaceType) + riType := getType("RI") + riStaticType := interpreter.ConvertSemaToStaticType(nil, riType).(*interpreter.InterfaceStaticType) return &interpreter.StorageReferenceValue{ Authorization: auth, TargetStorageAddress: storageAddress, TargetPath: storagePath, - BorrowedType: &sema.IntersectionType{ - Types: []*sema.InterfaceType{ - riType, + BorrowedType: &interpreter.IntersectionStaticType{ + Types: []*interpreter.InterfaceStaticType{ + riStaticType, }, }, } @@ -7438,7 +7439,7 @@ func TestInterpretReferenceEventParameter(t *testing.T) { valueCreationContext, interpreter.UnauthorizedAccess, arrayValue, - interpreter.MustConvertStaticToSemaType(arrayStaticType, inter), + arrayStaticType, ) _, err = inter.Invoke("test", ref) @@ -11876,7 +11877,7 @@ func TestInterpretNilCoalesceReference(t *testing.T) { t, &interpreter.EphemeralReferenceValue{ Value: interpreter.NewUnmeteredIntValueFromInt64(2), - BorrowedType: sema.IntType, + BorrowedType: interpreter.PrimitiveStaticTypeInt, Authorization: interpreter.UnauthorizedAccess, }, variable, diff --git a/interpreter/reference_test.go b/interpreter/reference_test.go index c8dd4fb4b8..4802cbf8d7 100644 --- a/interpreter/reference_test.go +++ b/interpreter/reference_test.go @@ -637,6 +637,7 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { address := common.Address{0x1} rType := RequireGlobalType(t, inter, "R").(*sema.CompositeType) + rStaticType := interpreter.ConvertSemaToStaticType(nil, rType) array := interpreter.NewArrayValue( inter, @@ -655,8 +656,8 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { sema.Conjunction, ), array, - &sema.VariableSizedType{ - Type: rType, + &interpreter.VariableSizedStaticType{ + Type: rStaticType, }, ) @@ -742,13 +743,14 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { `) rType := RequireGlobalType(t, inter, "R").(*sema.CompositeType) + rStaticType := interpreter.ConvertSemaToStaticType(nil, rType) // Resource array in account 0x01 array1 := interpreter.NewArrayValue( inter, &interpreter.VariableSizedStaticType{ - Type: interpreter.ConvertSemaToStaticType(nil, rType), + Type: rStaticType, }, common.Address{0x1}, ) @@ -762,8 +764,8 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { sema.Conjunction, ), array1, - &sema.VariableSizedType{ - Type: rType, + &interpreter.VariableSizedStaticType{ + Type: rStaticType, }, ) @@ -772,7 +774,7 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { array2 := interpreter.NewArrayValue( inter, &interpreter.VariableSizedStaticType{ - Type: interpreter.ConvertSemaToStaticType(nil, rType), + Type: rStaticType, }, common.Address{0x2}, ) @@ -786,8 +788,8 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { sema.Conjunction, ), array2, - &sema.VariableSizedType{ - Type: rType, + &interpreter.VariableSizedStaticType{ + Type: rStaticType, }, ) @@ -840,6 +842,7 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { address := common.Address{0x1} rType := RequireGlobalType(t, inter, "R").(*sema.CompositeType) + rStaticType := interpreter.ConvertSemaToStaticType(nil, rType) array := interpreter.NewArrayValue( inter, @@ -858,8 +861,8 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { sema.Conjunction, ), array, - &sema.VariableSizedType{ - Type: rType, + &interpreter.VariableSizedStaticType{ + Type: rStaticType, }, ) @@ -969,6 +972,7 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { address := common.Address{0x1} rType := RequireGlobalType(t, inter, "R").(*sema.CompositeType) + rStaticType := interpreter.ConvertSemaToStaticType(nil, rType) array := interpreter.NewArrayValue( inter, @@ -987,8 +991,8 @@ func TestInterpretResourceReferenceInvalidationOnMove(t *testing.T) { sema.Conjunction, ), array, - &sema.VariableSizedType{ - Type: rType, + &interpreter.VariableSizedStaticType{ + Type: rStaticType, }, ) @@ -3144,7 +3148,7 @@ func TestInterpretOptionalReference(t *testing.T) { t, &interpreter.EphemeralReferenceValue{ Value: interpreter.NewUnmeteredIntValueFromInt64(1), - BorrowedType: sema.IntType, + BorrowedType: interpreter.PrimitiveStaticTypeInt, Authorization: interpreter.UnauthorizedAccess, }, value, diff --git a/interpreter/simplecompositevalue.go b/interpreter/simplecompositevalue.go index 2b7a915713..5ba6abead1 100644 --- a/interpreter/simplecompositevalue.go +++ b/interpreter/simplecompositevalue.go @@ -118,7 +118,7 @@ func (v *SimpleCompositeValue) StaticType(_ ValueStaticTypeContext) StaticType { func (v *SimpleCompositeValue) IsImportable(context ValueImportableContext) bool { // Check type is importable staticType := v.StaticType(context) - semaType := MustConvertStaticToSemaType(staticType, context) + semaType := context.SemaTypeFromStaticType(staticType) if !semaType.IsImportable(map[*sema.Member]bool{}) { return false } diff --git a/interpreter/statictype.go b/interpreter/statictype.go index 419da3ee33..23a47edc15 100644 --- a/interpreter/statictype.go +++ b/interpreter/statictype.go @@ -342,6 +342,10 @@ func (t InclusiveRangeStaticType) Equal(other StaticType) bool { return false } + if t.ElementType == nil { + return otherRangeType.ElementType == nil + } + return t.ElementType.Equal(otherRangeType.ElementType) } @@ -951,7 +955,7 @@ var _ ParameterizedStaticType = &CapabilityStaticType{} func NewCapabilityStaticType( memoryGauge common.MemoryGauge, borrowType StaticType, -) *CapabilityStaticType { +) StaticType { common.UseMemory(memoryGauge, common.CapabilityStaticTypeMemoryUsage) return &CapabilityStaticType{ @@ -1018,7 +1022,7 @@ func (t *CapabilityStaticType) BaseType() StaticType { return nil } - return PrimitiveStaticTypeCapability + return &CapabilityStaticType{} } func (t *CapabilityStaticType) TypeArguments() []StaticType { @@ -1258,7 +1262,7 @@ func ConvertStaticAuthorizationToSemaAccess( return sema.NewEntitlementMapAccess(entitlement), nil case EntitlementSetAuthorization: - var entitlements []*sema.EntitlementType + entitlements := make([]*sema.EntitlementType, 0, auth.Entitlements.Len()) err := auth.Entitlements.ForeachWithError(func(id common.TypeID, value struct{}) error { entitlement, err := handler.GetEntitlementType(id) if err != nil { diff --git a/interpreter/storage_test.go b/interpreter/storage_test.go index 316272a0a3..bd71a1dc37 100644 --- a/interpreter/storage_test.go +++ b/interpreter/storage_test.go @@ -553,6 +553,8 @@ func TestNestedContainerMutationAfterMove(t *testing.T) { Members: &sema.StringMemberOrderedMap{}, } + testResourcStaticType := ConvertSemaToStaticType(nil, testResourceType) + const fieldName = "test" for _, testCompositeType := range []*sema.CompositeType{ @@ -826,7 +828,7 @@ func TestNestedContainerMutationAfterMove(t *testing.T) { inter, UnauthorizedAccess, childValue1, - testResourceType, + testResourcStaticType, ) containerValue1.Append(inter, childValue1) diff --git a/interpreter/subtype_check.go b/interpreter/subtype_check.go index 3e35f8447f..970ac18559 100644 --- a/interpreter/subtype_check.go +++ b/interpreter/subtype_check.go @@ -27,6 +27,23 @@ import ( var FunctionPurityView = sema.FunctionPurityView +// IsSameTypeKind determines if the given subtype belongs to the +// same kind as the supertype. +// +// e.g: 'Never' type is a subtype of 'Integer', but not of the +// same kind as 'Integer'. Whereas, 'Int8' is both a subtype +// and also of same kind as 'Integer'. +// +// Note: Must be equivalent to `sema.IsSameTypeKind` method. +func IsSameTypeKind(context TypeConverter, subType StaticType, superType StaticType) bool { + + if subType == PrimitiveStaticTypeNever { + return false + } + + return IsSubType(context, subType, superType) +} + func isAttachmentType(typeConverter TypeConverter, typ StaticType) bool { switch typ { case PrimitiveStaticTypeAnyResourceAttachment, PrimitiveStaticTypeAnyStructAttachment: diff --git a/interpreter/value_accountcapabilitycontroller.go b/interpreter/value_accountcapabilitycontroller.go index d16ec5ea78..c2c333a150 100644 --- a/interpreter/value_accountcapabilitycontroller.go +++ b/interpreter/value_accountcapabilitycontroller.go @@ -296,7 +296,7 @@ func (v *AccountCapabilityControllerValue) ControllerCapabilityID() UInt64Value func (v *AccountCapabilityControllerValue) ReferenceValue( context ValueCapabilityControllerReferenceValueContext, capabilityAddress common.Address, - resultBorrowType *sema.ReferenceType, + resultBorrowType *ReferenceStaticType, ) ReferenceValue { accountHandler := context.GetAccountHandlerFunc() @@ -306,18 +306,14 @@ func (v *AccountCapabilityControllerValue) ReferenceValue( ExpectType( context, account, - sema.AccountType, + PrimitiveStaticTypeAccount, ) - authorization := ConvertSemaAccessToStaticAuthorization( - context, - resultBorrowType.Authorization, - ) return NewEphemeralReferenceValue( context, - authorization, + resultBorrowType.Authorization, account, - resultBorrowType.Type, + resultBorrowType.ReferencedType, ) } diff --git a/interpreter/value_capability.go b/interpreter/value_capability.go index 11658c2cf9..361abbee19 100644 --- a/interpreter/value_capability.go +++ b/interpreter/value_capability.go @@ -161,12 +161,12 @@ func (v *IDCapabilityValue) GetMethod(context MemberAccessibleContext, name stri switch name { case sema.CapabilityTypeBorrowFunctionName: // this function will panic already if this conversion fails - borrowType, _ := MustConvertStaticToSemaType(v.BorrowType, context).(*sema.ReferenceType) + borrowType, _ := v.BorrowType.(*ReferenceStaticType) return capabilityBorrowFunction(context, v, v.address, v.ID, borrowType) case sema.CapabilityTypeCheckFunctionName: // this function will panic already if this conversion fails - borrowType, _ := MustConvertStaticToSemaType(v.BorrowType, context).(*sema.ReferenceType) + borrowType, _ := v.BorrowType.(*ReferenceStaticType) return capabilityCheckFunction(context, v, v.address, v.ID, borrowType) } diff --git a/interpreter/value_composite.go b/interpreter/value_composite.go index 23e0e8a579..18f65973bd 100644 --- a/interpreter/value_composite.go +++ b/interpreter/value_composite.go @@ -662,14 +662,14 @@ func (v *CompositeValue) OwnerValue(context MemberAccessibleContext) OptionalVal ExpectType( context, ownerAccount, - sema.AccountType, + PrimitiveStaticTypeAccount, ) reference := NewEphemeralReferenceValue( context, UnauthorizedAccess, ownerAccount, - sema.AccountType, + PrimitiveStaticTypeAccount, ) return NewSomeValueNonCopying(context, reference) @@ -1788,7 +1788,8 @@ func (v *CompositeValue) getBaseValue( baseType = ty } - return NewEphemeralReferenceValue(context, functionAuthorization, v.base, baseType) + baseStaticType := ConvertSemaToStaticType(context, baseType) + return NewEphemeralReferenceValue(context, functionAuthorization, v.base, baseStaticType) } func (v *CompositeValue) SetBaseValue(base *CompositeValue) { @@ -1858,14 +1859,15 @@ func (v *CompositeValue) ForEachAttachment( returnType := functionValueType.ReturnTypeAnnotation.Type fn := func(attachment *CompositeValue) { - attachmentType := MustSemaTypeOfValue(attachment, context).(*sema.CompositeType) + attachmentStaticType := attachment.StaticType(context) + attachmentType := context.SemaTypeFromStaticType(attachmentStaticType).(*sema.CompositeType) attachmentReference := NewEphemeralReferenceValue( context, // attachments are unauthorized during iteration UnauthorizedAccess, attachment, - attachmentType, + attachmentStaticType, ) referenceType := sema.NewReferenceType( @@ -1897,12 +1899,14 @@ func AttachmentBaseAndSelfValues( attachmentReferenceAuth := ConvertSemaAccessToStaticAuthorization(context, fnAccess) base = v.getBaseValue(context, attachmentReferenceAuth) + valueStaticType := v.StaticType(context) + // in attachment functions, self is a reference value self = NewEphemeralReferenceValue( context, attachmentReferenceAuth, v, - MustSemaTypeOfValue(v, context), + valueStaticType, ) return @@ -1915,7 +1919,7 @@ func (v *CompositeValue) forEachAttachment( // The attachment iteration creates an implicit reference to the composite, and holds onto that referenced-value. // But the reference could get invalidated during the iteration, making that referenced-value invalid. // We create a reference here for the purposes of tracking it during iteration. - vType := MustSemaTypeOfValue(v, context) + vType := v.StaticType(context) compositeReference := NewEphemeralReferenceValue(context, UnauthorizedAccess, v, vType) forEachAttachment(context, compositeReference, f) } @@ -1981,12 +1985,14 @@ func (v *CompositeValue) getTypeKey( // dynamically set the attachment's base to this composite attachment.SetBaseValue(v) + attachmentStaticType := ConvertSemaToStaticType(context, attachmentType) + // The attachment reference has the same entitlements as the base access attachmentRef := NewEphemeralReferenceValue( context, ConvertSemaAccessToStaticAuthorization(context, baseAccess), attachment, - attachmentType, + attachmentStaticType, ) return NewSomeValueNonCopying(context, attachmentRef) diff --git a/interpreter/value_ephemeral_reference.go b/interpreter/value_ephemeral_reference.go index 50c30b5ba8..d37e6cd8e0 100644 --- a/interpreter/value_ephemeral_reference.go +++ b/interpreter/value_ephemeral_reference.go @@ -31,7 +31,7 @@ import ( type EphemeralReferenceValue struct { Value Value // BorrowedType is the T in &T - BorrowedType sema.Type + BorrowedType StaticType Authorization Authorization } @@ -48,7 +48,7 @@ func NewUnmeteredEphemeralReferenceValue( referenceTracker ReferenceTracker, authorization Authorization, value Value, - borrowedType sema.Type, + borrowedType StaticType, ) *EphemeralReferenceValue { if reference, isReference := value.(ReferenceValue); isReference { panic(&NestedReferenceError{ @@ -71,7 +71,7 @@ func NewEphemeralReferenceValue( context ReferenceCreationContext, authorization Authorization, value Value, - borrowedType sema.Type, + borrowedType StaticType, ) *EphemeralReferenceValue { common.UseMemory(context, common.EphemeralReferenceValueMemoryUsage) return NewUnmeteredEphemeralReferenceValue(context, authorization, value, borrowedType) @@ -186,7 +186,6 @@ func (v *EphemeralReferenceValue) GetTypeKey(context MemberAccessibleContext, ke self := v.Value if selfComposite, isComposite := self.(*CompositeValue); isComposite { - semaAccess, err := context.SemaAccessFromStaticAuthorization(v.Authorization) if err != nil { panic(err) @@ -237,7 +236,7 @@ func (v *EphemeralReferenceValue) ConformsToStaticType( staticType := v.Value.StaticType(context) - if !IsSubTypeOfSemaType(context, staticType, v.BorrowedType) { + if !IsSubType(context, staticType, v.BorrowedType) { return false } @@ -321,7 +320,7 @@ func (v *EphemeralReferenceValue) ForEach( ) } -func (v *EphemeralReferenceValue) BorrowType() sema.Type { +func (v *EphemeralReferenceValue) BorrowType() StaticType { return v.BorrowedType } diff --git a/interpreter/value_function.go b/interpreter/value_function.go index de9e11afab..452226ccb2 100644 --- a/interpreter/value_function.go +++ b/interpreter/value_function.go @@ -367,14 +367,14 @@ func ReceiverReference(context ReferenceCreationContext, receiver Value) (Refere selfRef, selfIsRef := receiver.(ReferenceValue) if !selfIsRef { - semaType := MustSemaTypeOfValue(receiver, context) + receiverType := receiver.StaticType(context) // Create an unauthorized reference. The purpose of it is only to track and invalidate resource moves, // it is not directly exposed to the users selfRef = NewEphemeralReferenceValue( context, UnauthorizedAccess, receiver, - semaType, + receiverType, ) } return selfRef, selfIsRef diff --git a/interpreter/value_reference.go b/interpreter/value_reference.go index 37b0a8ce57..db7a7373f3 100644 --- a/interpreter/value_reference.go +++ b/interpreter/value_reference.go @@ -22,7 +22,6 @@ import ( "github.com/onflow/atree" "github.com/onflow/cadence/errors" - "github.com/onflow/cadence/sema" ) type ReferenceValue interface { @@ -30,7 +29,7 @@ type ReferenceValue interface { AuthorizedValue isReference() ReferencedValue(context ValueStaticTypeContext, errorOnFailedDereference bool) *Value - BorrowType() sema.Type + BorrowType() StaticType } func DereferenceValue( diff --git a/interpreter/value_storage_reference.go b/interpreter/value_storage_reference.go index a2d33c943a..b92be3c84f 100644 --- a/interpreter/value_storage_reference.go +++ b/interpreter/value_storage_reference.go @@ -29,7 +29,7 @@ import ( // StorageReferenceValue type StorageReferenceValue struct { - BorrowedType sema.Type + BorrowedType StaticType TargetPath PathValue TargetStorageAddress common.Address Authorization Authorization @@ -48,7 +48,7 @@ func NewUnmeteredStorageReferenceValue( authorization Authorization, targetStorageAddress common.Address, targetPath PathValue, - borrowedType sema.Type, + borrowedType StaticType, ) *StorageReferenceValue { return &StorageReferenceValue{ Authorization: authorization, @@ -63,7 +63,7 @@ func NewStorageReferenceValue( authorization Authorization, targetStorageAddress common.Address, targetPath PathValue, - borrowedType sema.Type, + borrowedType StaticType, ) *StorageReferenceValue { common.UseMemory(memoryGauge, common.StorageReferenceValueMemoryUsage) return NewUnmeteredStorageReferenceValue( @@ -140,12 +140,14 @@ func (v *StorageReferenceValue) dereference(context ValueStaticTypeContext) (*Va if v.BorrowedType != nil { staticType := referenced.StaticType(context) - if !IsSubTypeOfSemaType(context, staticType, v.BorrowedType) { - semaType := context.SemaTypeFromStaticType(staticType) + // Try to convert the static-type to sema-type, to see if the type is broken. + // This is unfortunately needed only to maintain backward compatibility. + _ = context.SemaTypeFromStaticType(staticType) + if !IsSubType(context, staticType, v.BorrowedType) { return nil, &StoredValueTypeMismatchError{ ExpectedType: v.BorrowedType, - ActualType: semaType, + ActualType: staticType, } } } @@ -346,7 +348,7 @@ func (v *StorageReferenceValue) ConformsToStaticType( staticType := self.StaticType(context) - if !IsSubTypeOfSemaType(context, staticType, v.BorrowedType) { + if !IsSubType(context, staticType, v.BorrowedType) { return false } @@ -462,7 +464,7 @@ func forEachReference( ) } -func (v *StorageReferenceValue) BorrowType() sema.Type { +func (v *StorageReferenceValue) BorrowType() StaticType { return v.BorrowedType } diff --git a/interpreter/value_storagecapabilitycontroller.go b/interpreter/value_storagecapabilitycontroller.go index a3c76c6478..789fb46b61 100644 --- a/interpreter/value_storagecapabilitycontroller.go +++ b/interpreter/value_storagecapabilitycontroller.go @@ -35,7 +35,7 @@ type CapabilityControllerValue interface { ReferenceValue( context ValueCapabilityControllerReferenceValueContext, capabilityAddress common.Address, - resultBorrowType *sema.ReferenceType, + resultBorrowType *ReferenceStaticType, ) ReferenceValue ControllerCapabilityID() UInt64Value } @@ -331,17 +331,17 @@ func (v *StorageCapabilityControllerValue) ControllerCapabilityID() UInt64Value return v.CapabilityID } -func (v *StorageCapabilityControllerValue) ReferenceValue(context ValueCapabilityControllerReferenceValueContext, capabilityAddress common.Address, resultBorrowType *sema.ReferenceType) ReferenceValue { - authorization := ConvertSemaAccessToStaticAuthorization( - context, - resultBorrowType.Authorization, - ) +func (v *StorageCapabilityControllerValue) ReferenceValue( + context ValueCapabilityControllerReferenceValueContext, + capabilityAddress common.Address, + resultBorrowType *ReferenceStaticType, +) ReferenceValue { return NewStorageReferenceValue( context, - authorization, + resultBorrowType.Authorization, capabilityAddress, v.TargetPath, - resultBorrowType.Type, + resultBorrowType.ReferencedType, ) } diff --git a/interpreter/value_test.go b/interpreter/value_test.go index 3e6d01101b..7a0ef7d98c 100644 --- a/interpreter/value_test.go +++ b/interpreter/value_test.go @@ -999,8 +999,8 @@ func TestStringer(t *testing.T) { inter, UnauthorizedAccess, array, - &sema.VariableSizedType{ - Type: sema.AnyStructType, + &VariableSizedStaticType{ + Type: PrimitiveStaticTypeAnyStruct, }, ) @@ -1064,9 +1064,9 @@ func TestStringer(t *testing.T) { inter, UnauthorizedAccess, NewUnmeteredStringValue("hello"), - &sema.ReferenceType{ - Authorization: sema.UnauthorizedAccess, - Type: sema.StringType, + &ReferenceStaticType{ + Authorization: UnauthorizedAccess, + ReferencedType: PrimitiveStaticTypeString, }, ) }, @@ -3739,7 +3739,7 @@ func TestValue_ConformsToStaticType(t *testing.T) { inter, UnauthorizedAccess, TrueValue, - sema.BoolType, + PrimitiveStaticTypeBool, ) }, true, @@ -3751,7 +3751,7 @@ func TestValue_ConformsToStaticType(t *testing.T) { inter, UnauthorizedAccess, TrueValue, - sema.StringType, + PrimitiveStaticTypeString, ) }, false, @@ -3768,7 +3768,7 @@ func TestValue_ConformsToStaticType(t *testing.T) { UnauthorizedAccess, testAddress, NewUnmeteredPathValue(common.PathDomainStorage, "test"), - sema.BoolType, + PrimitiveStaticTypeBool, ) }, true, @@ -3780,7 +3780,7 @@ func TestValue_ConformsToStaticType(t *testing.T) { UnauthorizedAccess, testAddress, NewUnmeteredPathValue(common.PathDomainStorage, "test"), - sema.StringType, + PrimitiveStaticTypeString, ) }, false, diff --git a/interpreter/value_type.go b/interpreter/value_type.go index a4e0b8c196..f7dee14b7f 100644 --- a/interpreter/value_type.go +++ b/interpreter/value_type.go @@ -256,9 +256,10 @@ func MetaTypeIsSubType( return FalseValue } - result := sema.IsSubType( - invocationContext.SemaTypeFromStaticType(staticType), - invocationContext.SemaTypeFromStaticType(otherStaticType), + result := IsSubType( + invocationContext, + staticType, + otherStaticType, ) return BoolValue(result) } diff --git a/interpreter/variable.go b/interpreter/variable.go index 4f645d3971..4b3592130a 100644 --- a/interpreter/variable.go +++ b/interpreter/variable.go @@ -111,11 +111,11 @@ var _ Variable = &SelfVariable{} func NewSelfVariableWithValue(interpreter *Interpreter, value Value) Variable { common.UseMemory(interpreter, variableMemoryUsage) - semaType := MustSemaTypeOfValue(value, interpreter) + staticType := value.StaticType(interpreter) // Create an explicit reference to represent the implicit reference behavior of 'self' value. // Authorization doesn't matter, we just need a reference to add to tracking. - selfRef := NewEphemeralReferenceValue(interpreter, UnauthorizedAccess, value, semaType) + selfRef := NewEphemeralReferenceValue(interpreter, UnauthorizedAccess, value, staticType) return &SelfVariable{ value: value, diff --git a/runtime/authorizer.go b/runtime/authorizer.go index 6501188070..bc4eaf88f5 100644 --- a/runtime/authorizer.go +++ b/runtime/authorizer.go @@ -44,7 +44,7 @@ func newAccountReferenceValueFromAddress( context, staticAuthorization, accountValue, - sema.AccountType, + interpreter.PrimitiveStaticTypeAccount, ) } diff --git a/runtime/contract_function_executor.go b/runtime/contract_function_executor.go index ca61bd0935..6a8faafb56 100644 --- a/runtime/contract_function_executor.go +++ b/runtime/contract_function_executor.go @@ -310,8 +310,7 @@ func (executor *contractFunctionExecutor) executeWithVM( } staticType := contractValue.StaticType(context) - semaType := context.SemaTypeFromStaticType(staticType) - qualifiedFuncName := commons.TypeQualifiedName(semaType, executor.functionName) + qualifiedFuncName := commons.StaticTypeQualifiedName(staticType, executor.functionName) value, err := executor.vm.InvokeMethodExternally( qualifiedFuncName, diff --git a/runtime/contract_update_validation_test.go b/runtime/contract_update_validation_test.go index 1038627d99..01f7fede1a 100644 --- a/runtime/contract_update_validation_test.go +++ b/runtime/contract_update_validation_test.go @@ -3080,7 +3080,7 @@ func TestRuntimeContractUpdateProgramCaching(t *testing.T) { expectedGets := locationAccessCounts{} if *compile { - expectedGets[txLocation] = 2 + expectedGets[txLocation] = 1 } require.Equal(t, expectedGets, programGets1) @@ -3131,7 +3131,7 @@ func TestRuntimeContractUpdateProgramCaching(t *testing.T) { contractLocation: 1, } if *compile { - expectedGets[txLocation] = 2 + expectedGets[txLocation] = 1 } assert.Equal(t, @@ -3170,7 +3170,7 @@ func TestRuntimeContractUpdateProgramCaching(t *testing.T) { expectedGets1 := locationAccessCounts{} if *compile { - expectedGets1[txLocation1] = 2 + expectedGets1[txLocation1] = 1 } assert.Equal(t, @@ -3195,7 +3195,7 @@ func TestRuntimeContractUpdateProgramCaching(t *testing.T) { expectedGets2 := locationAccessCounts{} if *compile { - expectedGets2[txLocation2] = 2 + expectedGets2[txLocation2] = 1 } assert.Equal(t, expectedGets2, diff --git a/runtime/empty.go b/runtime/empty.go index aaf11f2b51..5207840118 100644 --- a/runtime/empty.go +++ b/runtime/empty.go @@ -28,7 +28,6 @@ import ( "github.com/onflow/cadence/ast" "github.com/onflow/cadence/common" "github.com/onflow/cadence/interpreter" - "github.com/onflow/cadence/sema" ) // EmptyRuntimeInterface is an empty implementation of runtime.Interface. @@ -214,8 +213,8 @@ func (EmptyRuntimeInterface) ValidateAccountCapabilitiesGet( _ interpreter.AccountCapabilityGetValidationContext, _ interpreter.AddressValue, _ interpreter.PathValue, - _ *sema.ReferenceType, - _ *sema.ReferenceType, + _ *interpreter.ReferenceStaticType, + _ *interpreter.ReferenceStaticType, ) (bool, error) { panic("unexpected call to ValidateAccountCapabilitiesGet") } diff --git a/runtime/external.go b/runtime/external.go index 5f7b3e9c62..53435fe2d3 100644 --- a/runtime/external.go +++ b/runtime/external.go @@ -29,7 +29,6 @@ import ( "github.com/onflow/cadence/common" "github.com/onflow/cadence/errors" "github.com/onflow/cadence/interpreter" - "github.com/onflow/cadence/sema" ) // ExternalInterface is an implementation of runtime.Interface which forwards all calls to the embedded Interface. @@ -484,8 +483,8 @@ func (e ExternalInterface) ValidateAccountCapabilitiesGet( context interpreter.AccountCapabilityGetValidationContext, address interpreter.AddressValue, path interpreter.PathValue, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, ) ( valid bool, err error, diff --git a/runtime/handlers.go b/runtime/handlers.go index df37b407da..301febdf27 100644 --- a/runtime/handlers.go +++ b/runtime/handlers.go @@ -46,8 +46,8 @@ func newCapabilityBorrowHandler(handler stdlib.CapabilityControllerHandler) inte context interpreter.BorrowCapabilityControllerContext, address interpreter.AddressValue, capabilityID interpreter.UInt64Value, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, ) interpreter.ReferenceValue { return stdlib.BorrowCapabilityController( context, @@ -65,8 +65,8 @@ func newCapabilityCheckHandler(handler stdlib.CapabilityControllerHandler) inter context interpreter.CheckCapabilityControllerContext, address interpreter.AddressValue, capabilityID interpreter.UInt64Value, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, ) interpreter.BoolValue { return stdlib.CheckCapabilityController( context, @@ -84,8 +84,8 @@ func newValidateAccountCapabilitiesGetHandler(i *Interface) interpreter.Validate context interpreter.AccountCapabilityGetValidationContext, address interpreter.AddressValue, path interpreter.PathValue, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, ) (bool, error) { return (*i).ValidateAccountCapabilitiesGet( diff --git a/runtime/imported_values_memory_metering_test.go b/runtime/imported_values_memory_metering_test.go index f9d6d50772..2bbd8958e0 100644 --- a/runtime/imported_values_memory_metering_test.go +++ b/runtime/imported_values_memory_metering_test.go @@ -115,7 +115,7 @@ func TestRuntimeImportedValueMemoryMetering(t *testing.T) { ) assert.Equal(t, uint64(1), meter[common.MemoryKindOptionalValue]) - assert.Equal(t, uint64(2), meter[common.MemoryKindOptionalStaticType]) + assert.Equal(t, uint64(3), meter[common.MemoryKindOptionalStaticType]) }) t.Run("UInt", func(t *testing.T) { @@ -474,7 +474,7 @@ func TestRuntimeImportedValueMemoryMetering(t *testing.T) { executeScript(t, script, meter, inclusiveRangeValue) assert.Equal(t, uint64(1), meter[common.MemoryKindCompositeValueBase]) - assert.Equal(t, uint64(1), meter[common.MemoryKindInclusiveRangeStaticType]) + assert.Equal(t, uint64(2), meter[common.MemoryKindInclusiveRangeStaticType]) assert.Equal(t, uint64(1), meter[common.MemoryKindCadenceInclusiveRangeValue]) }) } @@ -540,7 +540,7 @@ func TestRuntimeImportedValueMemoryMeteringForSimpleTypes(t *testing.T) { { TypeName: "String?", MemoryKind: common.MemoryKindOptionalStaticType, - Weight: 2, + Weight: 3, TypeInstance: cadence.NewOptional(cadence.String("hello")), }, { diff --git a/runtime/interface.go b/runtime/interface.go index bc4e1b9572..6248b41e8d 100644 --- a/runtime/interface.go +++ b/runtime/interface.go @@ -28,7 +28,6 @@ import ( "github.com/onflow/cadence/ast" "github.com/onflow/cadence/common" "github.com/onflow/cadence/interpreter" - "github.com/onflow/cadence/sema" ) type Interface interface { @@ -143,8 +142,8 @@ type Interface interface { context interpreter.AccountCapabilityGetValidationContext, address interpreter.AddressValue, path interpreter.PathValue, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, ) (bool, error) ValidateAccountCapabilitiesPublish( context interpreter.AccountCapabilityPublishValidationContext, diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 405b3d9a2e..061c69020f 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -6981,14 +6981,6 @@ func TestRuntimeOnGetOrLoadProgramHits(t *testing.T) { 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, }, - common.TransactionLocation{ - 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - }, - common.TransactionLocation{ - 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - }, common.TransactionLocation{ 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, @@ -7017,18 +7009,6 @@ func TestRuntimeOnGetOrLoadProgramHits(t *testing.T) { Address: Address{0x1}, Name: "HelloWorld", }, - common.TransactionLocation{ - 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - }, - common.AddressLocation{ - Address: Address{0x1}, - Name: "HelloWorld", - }, - common.AddressLocation{ - Address: Address{0x1}, - Name: "HelloWorld", - }, } } else { expectedHits = []common.Location{ @@ -11593,13 +11573,13 @@ func TestRuntimeForbidPublicEntitlementBorrow(t *testing.T) { _ interpreter.AccountCapabilityGetValidationContext, _ interpreter.AddressValue, path interpreter.PathValue, - wantedBorrowType *sema.ReferenceType, - _ *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + _ *interpreter.ReferenceStaticType, ) (bool, error) { validatedPaths = append(validatedPaths, path) - _, wantedHasEntitlements := wantedBorrowType.Authorization.(sema.EntitlementSetAccess) + _, wantedHasEntitlements := wantedBorrowType.Authorization.(interpreter.EntitlementSetAuthorization) return !wantedHasEntitlements, nil }, } @@ -11683,13 +11663,13 @@ func TestRuntimeForbidPublicEntitlementGet(t *testing.T) { _ interpreter.AccountCapabilityGetValidationContext, _ interpreter.AddressValue, path interpreter.PathValue, - wantedBorrowType *sema.ReferenceType, - _ *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + _ *interpreter.ReferenceStaticType, ) (bool, error) { validatedPaths = append(validatedPaths, path) - _, wantedHasEntitlements := wantedBorrowType.Authorization.(sema.EntitlementSetAccess) + _, wantedHasEntitlements := wantedBorrowType.Authorization.(interpreter.EntitlementSetAuthorization) return !wantedHasEntitlements, nil }, } diff --git a/stdlib/account.go b/stdlib/account.go index c3f0eb4bc5..93388e1ab5 100644 --- a/stdlib/account.go +++ b/stdlib/account.go @@ -145,6 +145,8 @@ func NewVMAccountConstructor(creator AccountCreator) StandardLibraryValue { ) } +var AccountReferenceStaticType = interpreter.ConvertSemaToStaticType(nil, sema.AccountReferenceType) + func NewAccount( context interpreter.MemberAccessibleContext, payer interpreter.MemberAccessibleValue, @@ -154,7 +156,7 @@ func NewAccount( interpreter.ExpectType( context, payer, - sema.AccountReferenceType, + AccountReferenceStaticType, ) payerValue := payer.GetMember(context, sema.AccountTypeAddressFieldName) @@ -289,7 +291,7 @@ func NewAccountReferenceValue( context, authorization, account, - sema.AccountType, + interpreter.PrimitiveStaticTypeAccount, ) } @@ -1149,7 +1151,7 @@ func nativeAccountInboxUnpublishFunction( args []interpreter.Value, ) interpreter.Value { nameValue := interpreter.AssertValueOfType[*interpreter.StringValue](args[0]) - borrowType := typeArguments.NextSema() + borrowType := typeArguments.NextStatic() providerValue := interpreter.GetAddressValue(receiver, providerPointer) @@ -1194,7 +1196,7 @@ func NewVMAccountInboxUnpublishFunction( func AccountInboxUnpublish( context interpreter.InvocationContext, providerValue interpreter.AddressValue, - borrowType sema.Type, + borrowType interpreter.StaticType, nameValue *interpreter.StringValue, handler EventEmitter, ) interpreter.Value { @@ -1215,12 +1217,12 @@ func AccountInboxUnpublish( panic(errors.NewUnreachableError()) } - capabilityType := sema.NewCapabilityType(context, borrowType) + capabilityType := interpreter.NewCapabilityStaticType(context, borrowType) publishedType := publishedValue.Value.StaticType(context) - if !interpreter.IsSubTypeOfSemaType(context, publishedType, capabilityType) { + if !interpreter.IsSubType(context, publishedType, capabilityType) { panic(&interpreter.ForceCastTypeMismatchError{ ExpectedType: capabilityType, - ActualType: context.SemaTypeFromStaticType(publishedType), + ActualType: publishedType, }) } @@ -1260,7 +1262,7 @@ func nativeAccountInboxClaimFunction( ) interpreter.Value { nameValue := interpreter.AssertValueOfType[*interpreter.StringValue](args[0]) providerValue := interpreter.AssertValueOfType[interpreter.AddressValue](args[1]) - borrowType := typeArguments.NextSema() + borrowType := typeArguments.NextStatic() recipientValue := interpreter.GetAddressValue(receiver, recipientPointer) @@ -1308,7 +1310,7 @@ func AccountInboxClaim( providerValue interpreter.AddressValue, recipientValue interpreter.AddressValue, nameValue *interpreter.StringValue, - borrowType sema.Type, + borrowType interpreter.StaticType, handler EventEmitter, ) interpreter.Value { providerAddress := providerValue.ToAddress() @@ -1333,12 +1335,12 @@ func AccountInboxClaim( return interpreter.Nil } - ty := sema.NewCapabilityType(context, borrowType) + ty := interpreter.NewCapabilityStaticType(context, borrowType) publishedType := publishedValue.Value.StaticType(context) - if !interpreter.IsSubTypeOfSemaType(context, publishedType, ty) { + if !interpreter.IsSubType(context, publishedType, ty) { panic(&interpreter.ForceCastTypeMismatchError{ ExpectedType: ty, - ActualType: context.SemaTypeFromStaticType(publishedType), + ActualType: publishedType, }) } @@ -1535,7 +1537,7 @@ func nativeAccountContractsBorrowFunction( args []interpreter.Value, ) interpreter.Value { nameValue := interpreter.AssertValueOfType[*interpreter.StringValue](args[0]) - borrowType := typeArguments.NextSema() + borrowType := typeArguments.NextStatic() address := interpreter.GetAddress(receiver, addressPointer) @@ -1583,18 +1585,18 @@ func AccountContractsBorrow( invocationContext interpreter.InvocationContext, address common.Address, nameValue *interpreter.StringValue, - borrowType sema.Type, + borrowType interpreter.StaticType, handler AccountContractsHandler, ) interpreter.Value { name := nameValue.Str location := common.NewAddressLocation(invocationContext, address, name) - referenceType, ok := borrowType.(*sema.ReferenceType) + referenceType, ok := borrowType.(*interpreter.ReferenceStaticType) if !ok { panic(errors.NewUnreachableError()) } - if referenceType.Authorization != sema.UnauthorizedAccess { + if referenceType.Authorization != interpreter.UnauthorizedAccess { panic(errors.NewDefaultUserError("cannot borrow a reference with an authorization")) } @@ -1622,7 +1624,7 @@ func AccountContractsBorrow( // Check the type staticType := contractValue.StaticType(invocationContext) - if !interpreter.IsSubTypeOfSemaType(invocationContext, staticType, referenceType.Type) { + if !interpreter.IsSubType(invocationContext, staticType, referenceType.ReferencedType) { return interpreter.Nil } @@ -1632,7 +1634,7 @@ func AccountContractsBorrow( invocationContext, interpreter.UnauthorizedAccess, contractValue, - referenceType.Type, + referenceType.ReferencedType, ) return interpreter.NewSomeValueNonCopying( @@ -3079,7 +3081,7 @@ func nativeAccountStorageCapabilitiesIssueFunction( receiver interpreter.Value, args []interpreter.Value, ) interpreter.Value { - borrowType := typeArguments.NextSema() + borrowType := typeArguments.NextStatic() address := interpreter.GetAddress(receiver, addressPointer) @@ -3127,7 +3129,7 @@ func AccountStorageCapabilitiesIssue( invocationContext interpreter.InvocationContext, handler CapabilityControllerIssueHandler, address common.Address, - typeParameter sema.Type, + typeParameter interpreter.StaticType, ) interpreter.Value { // Get path argument @@ -3213,11 +3215,6 @@ func AccountStorageCapabilitiesIssueWithType( panic(errors.NewUnreachableError()) } - ty, err := interpreter.ConvertStaticToSemaType(invocationContext, typeValue.Type) - if err != nil { - panic(errors.NewUnexpectedErrorFromCause(err)) - } - // Issue capability controller and return capability return checkAndIssueStorageCapabilityControllerWithType( @@ -3225,7 +3222,7 @@ func AccountStorageCapabilitiesIssueWithType( handler, address, targetPathValue, - ty, + typeValue.Type, ) } @@ -3234,10 +3231,10 @@ func checkAndIssueStorageCapabilityControllerWithType( handler CapabilityControllerIssueHandler, address common.Address, targetPathValue interpreter.PathValue, - ty sema.Type, + ty interpreter.StaticType, ) *interpreter.IDCapabilityValue { - borrowType, ok := ty.(*sema.ReferenceType) + borrowType, ok := ty.(*interpreter.ReferenceStaticType) if !ok { panic(&interpreter.InvalidCapabilityIssueTypeError{ ExpectedTypeDescription: "reference type", @@ -3247,13 +3244,11 @@ func checkAndIssueStorageCapabilityControllerWithType( // Issue capability controller - borrowStaticType := interpreter.ConvertSemaReferenceTypeToStaticReferenceType(context, borrowType) - capabilityIDValue := IssueStorageCapabilityController( context, handler, address, - borrowStaticType, + borrowType, targetPathValue, ) @@ -3267,7 +3262,7 @@ func checkAndIssueStorageCapabilityControllerWithType( context, capabilityIDValue, interpreter.NewAddressValue(context, address), - borrowStaticType, + borrowType, ) } @@ -3330,7 +3325,7 @@ func nativeAccountAccountCapabilitiesIssueFunction( receiver interpreter.Value, args []interpreter.Value, ) interpreter.Value { - borrowType := typeArguments.NextSema() + borrowType := typeArguments.NextStatic() address := interpreter.GetAddress(receiver, addressPointer) @@ -3383,18 +3378,13 @@ func nativeAccountAccountCapabilitiesIssueWithTypeFunction( args []interpreter.Value, ) interpreter.Value { typeValue := interpreter.AssertValueOfType[interpreter.TypeValue](args[0]) - ty, err := interpreter.ConvertStaticToSemaType(context, typeValue.Type) - if err != nil { - panic(errors.NewUnexpectedErrorFromCause(err)) - } - address := interpreter.GetAddress(receiver, addressPointer) return checkAndIssueAccountCapabilityControllerWithType( context, handler, address, - ty, + typeValue.Type, ) } } @@ -3432,34 +3422,32 @@ func checkAndIssueAccountCapabilityControllerWithType( context interpreter.CapabilityControllerContext, handler CapabilityControllerIssueHandler, address common.Address, - ty sema.Type, + ty interpreter.StaticType, ) *interpreter.IDCapabilityValue { // Get and check borrow type - typeBound := sema.AccountReferenceType - if !sema.IsSubType(ty, typeBound) { + typeBound := AccountReferenceStaticType + if !interpreter.IsSubType(context, ty, typeBound) { panic(&interpreter.InvalidCapabilityIssueTypeError{ - ExpectedTypeDescription: fmt.Sprintf("`%s`", typeBound.QualifiedString()), + ExpectedTypeDescription: fmt.Sprintf("`%s`", typeBound.String()), ActualType: ty, }) } - borrowType, ok := ty.(*sema.ReferenceType) + borrowType, ok := ty.(*interpreter.ReferenceStaticType) if !ok { panic(errors.NewUnreachableError()) } // Issue capability controller - borrowStaticType := interpreter.ConvertSemaReferenceTypeToStaticReferenceType(context, borrowType) - capabilityIDValue := IssueAccountCapabilityController( context, handler, address, - borrowStaticType, + borrowType, ) if capabilityIDValue == interpreter.InvalidCapabilityID { @@ -3472,7 +3460,7 @@ func checkAndIssueAccountCapabilityControllerWithType( context, capabilityIDValue, interpreter.NewAddressValue(context, address), - borrowStaticType, + borrowType, ) } @@ -3664,7 +3652,7 @@ func getStorageCapabilityControllerReference( context, interpreter.UnauthorizedAccess, storageCapabilityController, - sema.StorageCapabilityControllerType, + interpreter.PrimitiveStaticTypeStorageCapabilityController, ) } @@ -4086,12 +4074,8 @@ func AccountCapabilitiesPublish( domain := pathValue.Domain.StorageDomain() identifier := pathValue.Identifier - capabilityType, ok := capabilityValue.StaticType(invocationContext).(*interpreter.CapabilityStaticType) - if !ok { - panic(errors.NewUnreachableError()) - } - - borrowType := capabilityType.BorrowType + staticType := capabilityValue.StaticType(invocationContext) + borrowType := staticType.(*interpreter.CapabilityStaticType).BorrowType // It is possible to have legacy capabilities without borrow type. // So perform the validation only if the borrow type is present. @@ -4137,7 +4121,7 @@ func AccountCapabilitiesPublish( }) } - capabilityValue, ok = capabilityValue.Transfer( + capabilityValue, ok := capabilityValue.Transfer( invocationContext, atree.Address(accountAddress), true, @@ -4300,38 +4284,50 @@ func AccountCapabilitiesUnpublish( } func canBorrow( - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + typeConverter interpreter.TypeConverter, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, ) bool { // Ensure the wanted borrow type is not more permissive than the capability borrow type - if !wantedBorrowType.Authorization. - PermitsAccess(capabilityBorrowType.Authorization) { + if !interpreter.PermitsAccess( + typeConverter, + wantedBorrowType.Authorization, + capabilityBorrowType.Authorization, + ) { return false } // Ensure the wanted borrow type is a subtype or supertype of the capability borrow type - return sema.IsSubType(wantedBorrowType.Type, capabilityBorrowType.Type) || - sema.IsSubType(capabilityBorrowType.Type, wantedBorrowType.Type) + return interpreter.IsSubType( + typeConverter, + wantedBorrowType.ReferencedType, + capabilityBorrowType.ReferencedType, + ) || + interpreter.IsSubType( + typeConverter, + capabilityBorrowType.ReferencedType, + wantedBorrowType.ReferencedType, + ) } func getCheckedCapabilityController( context interpreter.GetCapabilityControllerContext, capabilityAddressValue interpreter.AddressValue, capabilityIDValue interpreter.UInt64Value, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, handler CapabilityControllerHandler, ) ( interpreter.CapabilityControllerValue, - *sema.ReferenceType, + *interpreter.ReferenceStaticType, ) { if wantedBorrowType == nil { wantedBorrowType = capabilityBorrowType - } else if !canBorrow(wantedBorrowType, capabilityBorrowType) { + } else if !canBorrow(context, wantedBorrowType, capabilityBorrowType) { return nil, nil } @@ -4350,13 +4346,7 @@ func getCheckedCapabilityController( controllerBorrowStaticType := controller.CapabilityControllerBorrowType() - controllerBorrowType, ok := - interpreter.MustConvertStaticToSemaType(controllerBorrowStaticType, context).(*sema.ReferenceType) - if !ok { - panic(errors.NewUnreachableError()) - } - - if !canBorrow(wantedBorrowType, controllerBorrowType) { + if !canBorrow(context, wantedBorrowType, controllerBorrowStaticType) { return nil, nil } @@ -4367,8 +4357,8 @@ func GetCheckedCapabilityControllerReference( context interpreter.GetCapabilityControllerReferenceContext, capabilityAddressValue interpreter.AddressValue, capabilityIDValue interpreter.UInt64Value, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, handler CapabilityControllerHandler, ) interpreter.ReferenceValue { controller, resultBorrowType := getCheckedCapabilityController( @@ -4392,8 +4382,8 @@ func BorrowCapabilityController( context interpreter.BorrowCapabilityControllerContext, capabilityAddress interpreter.AddressValue, capabilityID interpreter.UInt64Value, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, handler CapabilityControllerHandler, ) interpreter.ReferenceValue { referenceValue := GetCheckedCapabilityControllerReference( @@ -4424,8 +4414,8 @@ func CheckCapabilityController( context interpreter.CheckCapabilityControllerContext, capabilityAddress interpreter.AddressValue, capabilityID interpreter.UInt64Value, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, handler CapabilityControllerHandler, ) interpreter.BoolValue { @@ -4462,7 +4452,7 @@ func nativeAccountCapabilitiesGetFunction( args []interpreter.Value, ) interpreter.Value { pathValue := interpreter.AssertValueOfType[interpreter.PathValue](args[0]) - typeArgument := typeArguments.NextSema() + typeArgument := typeArguments.NextStatic() addressValue := interpreter.GetAddressValue(receiver, addressPointer) @@ -4531,7 +4521,7 @@ func AccountCapabilitiesGet( invocationContext interpreter.InvocationContext, controllerHandler CapabilityControllerHandler, pathValue interpreter.PathValue, - typeParameter sema.Type, + typeParameter interpreter.StaticType, borrow bool, addressValue interpreter.AddressValue, ) interpreter.Value { @@ -4545,7 +4535,7 @@ func AccountCapabilitiesGet( // Get borrow type type argument // `Never` is never a supertype of any stored value - if typeParameter.Equal(sema.NeverType) { + if typeParameter == interpreter.PrimitiveStaticTypeNever { if borrow { return interpreter.Nil } else { @@ -4557,7 +4547,7 @@ func AccountCapabilitiesGet( } } - wantedBorrowType, ok := typeParameter.(*sema.ReferenceType) + wantedBorrowType, ok := typeParameter.(*interpreter.ReferenceStaticType) if !ok { panic(errors.NewUnreachableError()) } @@ -4570,7 +4560,7 @@ func AccountCapabilitiesGet( interpreter.NewInvalidCapabilityValue( invocationContext, addressValue, - interpreter.ConvertSemaToStaticType(invocationContext, wantedBorrowType), + wantedBorrowType, ) } @@ -4618,8 +4608,7 @@ func AccountCapabilitiesGet( panic(errors.NewUnreachableError()) } - capabilityBorrowType, ok := - interpreter.MustConvertStaticToSemaType(capabilityStaticBorrowType, invocationContext).(*sema.ReferenceType) + capabilityBorrowType, ok := capabilityStaticBorrowType.(*interpreter.ReferenceStaticType) if !ok { panic(errors.NewUnreachableError()) } @@ -4669,17 +4658,11 @@ func AccountCapabilitiesGet( controllerHandler, ) if controller != nil { - resultBorrowStaticType := - interpreter.ConvertSemaReferenceTypeToStaticReferenceType(invocationContext, resultBorrowType) - if !ok { - panic(errors.NewUnreachableError()) - } - resultValue = interpreter.NewCapabilityValue( invocationContext, capabilityID, capabilityAddress, - resultBorrowStaticType, + resultBorrowType, ) } } @@ -4790,7 +4773,7 @@ func getAccountCapabilityControllerReference( context, interpreter.UnauthorizedAccess, accountCapabilityController, - sema.AccountCapabilityControllerType, + interpreter.PrimitiveStaticTypeAccountCapabilityController, ) } diff --git a/stdlib/account_test.go b/stdlib/account_test.go index 2218655c86..a72537182f 100644 --- a/stdlib/account_test.go +++ b/stdlib/account_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/require" "github.com/onflow/cadence/common" + "github.com/onflow/cadence/interpreter" "github.com/onflow/cadence/sema" . "github.com/onflow/cadence/test_utils/common_utils" ) @@ -114,10 +115,13 @@ func TestCanBorrow(t *testing.T) { for _, b := range types { a2, b2 := instantiate(a, b) - t.Run(fmt.Sprintf("%s / %s", a2, b2), func(t *testing.T) { + staticA2 := interpreter.ConvertSemaToStaticType(nil, a2).(*interpreter.ReferenceStaticType) + staticB2 := interpreter.ConvertSemaToStaticType(nil, b2).(*interpreter.ReferenceStaticType) + + t.Run(fmt.Sprintf("%s / %s", staticA2, staticB2), func(t *testing.T) { t.Parallel() - require.Equal(t, expected, canBorrow(a2, b2)) + require.Equal(t, expected, canBorrow(inter, staticA2, staticB2)) }) } } diff --git a/stdlib/bls.go b/stdlib/bls.go index e8cb7134a0..59d0064213 100644 --- a/stdlib/bls.go +++ b/stdlib/bls.go @@ -79,6 +79,8 @@ func NewVMBLSAggregatePublicKeysFunction( } } +var PublicKeyArrayStaticType = interpreter.ConvertSemaToStaticType(nil, sema.PublicKeyArrayType) + func BLSAggregatePublicKeys( context interpreter.InvocationContext, publicKeysValue *interpreter.ArrayValue, @@ -88,7 +90,7 @@ func BLSAggregatePublicKeys( interpreter.ExpectType( context, publicKeysValue, - sema.PublicKeyArrayType, + PublicKeyArrayStaticType, ) publicKeys := make([]*PublicKey, 0, publicKeysValue.Count()) @@ -180,6 +182,8 @@ func NewVMBLSAggregateSignaturesFunction( } } +var ByteArrayArrayStaticType = interpreter.ConvertSemaToStaticType(nil, sema.ByteArrayArrayType) + func BLSAggregateSignatures( context interpreter.InvocationContext, signaturesValue *interpreter.ArrayValue, @@ -189,7 +193,7 @@ func BLSAggregateSignatures( interpreter.ExpectType( context, signaturesValue, - sema.ByteArrayArrayType, + ByteArrayArrayStaticType, ) bytesArray := make([][]byte, 0, signaturesValue.Count()) diff --git a/stdlib/publickey.go b/stdlib/publickey.go index ba712cd8ad..6623da02f9 100644 --- a/stdlib/publickey.go +++ b/stdlib/publickey.go @@ -281,7 +281,7 @@ func PublicKeyVerifySignature( interpreter.ExpectType( context, publicKeyValue, - sema.PublicKeyType, + PublicKeyStaticType, ) signature, err := interpreter.ByteArrayValueToByteSlice(context, signatureValue) @@ -374,6 +374,8 @@ func NewVMPublicKeyVerifyPoPFunction(verifier BLSPoPVerifier) VMFunction { } } +var PublicKeyStaticType = interpreter.ConvertSemaToStaticType(nil, sema.PublicKeyType) + func PublicKeyVerifyPoP( context interpreter.InvocationContext, publicKeyValue *interpreter.CompositeValue, @@ -384,7 +386,7 @@ func PublicKeyVerifyPoP( interpreter.ExpectType( context, publicKeyValue, - sema.PublicKeyType, + PublicKeyStaticType, ) publicKey, err := NewPublicKeyFromValue(context, publicKeyValue) diff --git a/stdlib/test.go b/stdlib/test.go index fc403760ce..6ab575b07d 100644 --- a/stdlib/test.go +++ b/stdlib/test.go @@ -427,12 +427,12 @@ func newMatcherWithGenericTestFunction( for _, argument := range invocation.Arguments { argumentStaticType := argument.StaticType(invocationContext) - if !interpreter.IsSubTypeOfSemaType(invocationContext, argumentStaticType, parameterType) { - argumentSemaType := interpreter.MustConvertStaticToSemaType(argumentStaticType, invocationContext) + parameterStaticType := interpreter.ConvertSemaToStaticType(invocationContext, parameterType) + if !interpreter.IsSubType(invocationContext, argumentStaticType, parameterStaticType) { panic(&interpreter.TypeMismatchError{ - ExpectedType: parameterType, - ActualType: argumentSemaType, + ExpectedType: parameterStaticType, + ActualType: argumentStaticType, }) } } diff --git a/test_utils/runtime_utils/testinterface.go b/test_utils/runtime_utils/testinterface.go index 3f0361eee4..0953bc9c23 100644 --- a/test_utils/runtime_utils/testinterface.go +++ b/test_utils/runtime_utils/testinterface.go @@ -112,8 +112,8 @@ type TestRuntimeInterface struct { context interpreter.AccountCapabilityGetValidationContext, address interpreter.AddressValue, path interpreter.PathValue, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, ) (bool, error) OnValidateAccountCapabilitiesPublish func( context interpreter.AccountCapabilityPublishValidationContext, @@ -560,8 +560,8 @@ func (i *TestRuntimeInterface) ValidateAccountCapabilitiesGet( context interpreter.AccountCapabilityGetValidationContext, address interpreter.AddressValue, path interpreter.PathValue, - wantedBorrowType *sema.ReferenceType, - capabilityBorrowType *sema.ReferenceType, + wantedBorrowType *interpreter.ReferenceStaticType, + capabilityBorrowType *interpreter.ReferenceStaticType, ) (bool, error) { if i.OnValidateAccountCapabilitiesGet == nil { return true, nil