Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Annotated types in plugins #8665

Open
wants to merge 38 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
869603d
Add support for native union types in plugins
arnaudsjs Jan 17, 2025
15129ab
Add changelog entry
arnaudsjs Jan 17, 2025
e2c4c32
Fix typing
arnaudsjs Jan 17, 2025
719fbb3
Fix typing
arnaudsjs Jan 17, 2025
001f50c
Small improvements
arnaudsjs Jan 17, 2025
c1c7df8
Use callable from abc
arnaudsjs Jan 17, 2025
53a0a5c
Improve error message
arnaudsjs Jan 17, 2025
505c41e
Fix formatting
arnaudsjs Jan 17, 2025
fdeea80
Initial commit
jptrindade Jan 20, 2025
23cd1e7
small adjustments
jptrindade Jan 21, 2025
2bc3eaa
fixed mypy
jptrindade Jan 21, 2025
8c30ca3
test with type alias
jptrindade Jan 21, 2025
bdf64cb
Update mypy-baseline.txt
jptrindade Jan 21, 2025
d1f7288
fix mypy
jptrindade Jan 21, 2025
5ce1296
revert mypy
jptrindade Jan 21, 2025
3bbe231
renamed exception
sanderr Jan 22, 2025
ec784bf
dedicated plugin type exception
sanderr Jan 22, 2025
c16d8b4
added test scenario
sanderr Jan 22, 2025
dac146d
import fix
sanderr Jan 22, 2025
f64c3a9
addressed comments
jptrindade Jan 22, 2025
db3c27e
added to_dsl_type_simple
jptrindade Jan 22, 2025
378ca58
fix mypy
jptrindade Jan 22, 2025
0127969
some test fixes
sanderr Jan 22, 2025
092261b
test fixes
sanderr Jan 22, 2025
6a6d6ec
pep8
sanderr Jan 22, 2025
28e62aa
rejection test first try
jptrindade Jan 22, 2025
69f3697
renamed change entry
sanderr Jan 22, 2025
1fdd946
include iso7
sanderr Jan 22, 2025
e57db25
Merge branch 'refs/heads/issue/8574-add-support-union-types' into iss…
jptrindade Jan 23, 2025
b07e091
merge and refactor
jptrindade Jan 23, 2025
2a1a57d
mypy
jptrindade Jan 23, 2025
3484f4a
Merge remote-tracking branch 'refs/remotes/origin/master' into issue/…
jptrindade Jan 23, 2025
0b21043
Revert "merge and refactor"
jptrindade Jan 23, 2025
001456a
Merge branch 'refs/heads/master' into issue/8573-add-support-for-anno…
jptrindade Jan 27, 2025
9e5679e
attempted merge
jptrindade Jan 27, 2025
2fc91a9
fixed mypy
jptrindade Jan 27, 2025
48a1a97
fixed test_plugin_types.py
jptrindade Jan 27, 2025
a5135da
WIP, still problems with pylance
jptrindade Jan 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
description: Add support for annotated types to plugins.
issue-nr: 8573
change-type: minor
destination-branches: [master, iso8, iso7]
75 changes: 48 additions & 27 deletions src/inmanta/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import warnings
from collections import abc
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Type, TypeVar
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Self, Type, TypeVar

import typing_inspect

Expand Down Expand Up @@ -247,14 +248,39 @@ def __eq__(self, other: object) -> bool:
}


def to_dsl_type(python_type: type[object]) -> inmanta_type.Type:
@dataclass(frozen=True)
class ModelType:
model_type: str

def __class_getitem__(cls: type[Self], key: str) -> Self:
return cls(key)


Entity: typing.TypeAlias = typing.Annotated[Any, ModelType["std::Entity"]]
"""
Alias used to treat std::Entity as an object in Python for type verification
"""


def parse_dsl_type(dsl_type: str, location: Range, resolver: Namespace) -> inmanta_type.Type:
locatable_type: LocatableString = LocatableString(dsl_type, location, 0, resolver)
return inmanta_type.resolve_type(locatable_type, resolver)


def to_dsl_type(python_type: type[object], location: Range, resolver: Namespace) -> inmanta_type.Type:
"""
Convert a python type annotation to an Inmanta DSL type annotation.

:param python_type: The evaluated python type as provided in the Python type annotation.
:param location: The location of this evaluation on the model
:param resolver: The namespace that can be used to resolve the type annotation of this argument.
"""
# Resolve aliases
if isinstance(python_type, typing.TypeAliasType):
return to_dsl_type(python_type.__value__, location, resolver)

# Any to any
if python_type is typing.Any:
if python_type is Any:
return inmanta_type.Type()

# None to None
Expand All @@ -266,16 +292,16 @@ def to_dsl_type(python_type: type[object]) -> inmanta_type.Type:
# Optional type
bases: Sequence[inmanta_type.Type]
if typing_inspect.is_optional_type(python_type):
other_types = [tt for tt in typing.get_args(python_type) if not typing_inspect.is_optional_type(tt)]
other_types = [tt for tt in typing_inspect.get_args(python_type) if not typing_inspect.is_optional_type(tt)]
if len(other_types) == 0:
# Probably not possible
return Null()
if len(other_types) == 1:
return inmanta_type.NullableType(to_dsl_type(other_types[0]))
bases = [to_dsl_type(arg) for arg in other_types]
return inmanta_type.NullableType(to_dsl_type(other_types[0], location, resolver))
bases = [to_dsl_type(arg, location, resolver) for arg in other_types]
return inmanta_type.NullableType(inmanta_type.Union(bases))
else:
bases = [to_dsl_type(arg) for arg in typing.get_args(python_type)]
bases = [to_dsl_type(arg, location, resolver) for arg in typing.get_args(python_type)]
return inmanta_type.Union(bases)

# Lists and dicts
Expand All @@ -284,7 +310,7 @@ def to_dsl_type(python_type: type[object]) -> inmanta_type.Type:

# dict
if issubclass(origin, Mapping):
if origin in [collections.abc.Mapping, dict, typing.Mapping]:
if origin in [collections.abc.Mapping, dict, Mapping]:
args = typing_inspect.get_args(python_type)
if not args:
return inmanta_type.TypedDict(inmanta_type.Type())
Expand All @@ -297,37 +323,32 @@ def to_dsl_type(python_type: type[object]) -> inmanta_type.Type:
if len(args) == 1:
return inmanta_type.TypedDict(inmanta_type.Type())

return inmanta_type.TypedDict(to_dsl_type(args[1]))
return inmanta_type.TypedDict(to_dsl_type(args[1], location, resolver))
else:
raise TypingException(None, f"invalid type {python_type}, dictionary types should be Mapping or dict")

# List
if issubclass(origin, Sequence):
if origin in [collections.abc.Sequence, list, typing.Sequence]:
if origin in [collections.abc.Sequence, list, Sequence]:
args = typing.get_args(python_type)
if not args:
return inmanta_type.List()
return inmanta_type.TypedList(to_dsl_type(args[0]))
return inmanta_type.TypedList(to_dsl_type(args[0], location, resolver))
else:
raise TypingException(None, f"invalid type {python_type}, list types should be Sequence or list")

# Set
if issubclass(origin, collections.abc.Set):
raise TypingException(None, f"invalid type {python_type}, set is not supported on the plugin boundary")

# TODO annotated types
# if typing.get_origin(t) is typing.Annotated:
# args: Sequence[object] = typing.get_args(python_type)
# inmanta_types: Sequence[plugin_typing.InmantaType] =
# [arg if isinstance(arg, plugin_typing.InmantaType) for arg in args]
# if inmanta_types:
# if len(inmanta_types) > 1:
# # TODO
# raise Exception()
# # TODO
# return parse_dsl_type(inmanta_types[0].dsl_type)
# # the annotation doesn't concern us => use base type
# return to_dsl_type(args[0])
# Annotated
if origin is typing.Annotated:
for meta in python_type.__metadata__: # type: ignore
if isinstance(meta, ModelType):
return parse_dsl_type(meta.model_type, location, resolver)
# the annotation doesn't concern us => use base type
return to_dsl_type(typing.get_args(python_type)[0], location, resolver)

if python_type in python_to_model:
return python_to_model[python_type]

Expand Down Expand Up @@ -385,24 +406,24 @@ def resolve_type(self, plugin: "Plugin", resolver: Namespace) -> inmanta_type.Ty
:param resolver: The namespace that can be used to resolve the type annotation of this
argument.
"""
if self.type_expression in PLUGIN_TYPES:
if isinstance(self.type_expression, collections.abc.Hashable) and self.type_expression in PLUGIN_TYPES:
self._resolved_type = PLUGIN_TYPES[self.type_expression]
return self._resolved_type

plugin_line: Range = Range(plugin.location.file, plugin.location.lnr, 1, plugin.location.lnr + 1, 1)
if not isinstance(self.type_expression, str):
if typing_inspect.is_union_type(self.type_expression) and not typing.get_args(self.type_expression):
# If typing.Union is not subscripted, isinstance(self.type_expression, type) evaluates to False.
raise InvalidTypeAnnotation(stmt=None, msg=f"Union type must be subscripted, got {self.type_expression}")
if isinstance(self.type_expression, type) or typing.get_origin(self.type_expression) is not None:
self._resolved_type = to_dsl_type(self.type_expression)
self._resolved_type = to_dsl_type(self.type_expression, plugin_line, resolver)
else:
raise InvalidTypeAnnotation(
stmt=None,
msg="Bad annotation in plugin %s for %s, expected str or python type but got %s (%s)"
% (plugin.get_full_name(), self.VALUE_NAME, type(self.type_expression).__name__, self.type_expression),
)
else:
plugin_line: Range = Range(plugin.location.file, plugin.location.lnr, 1, plugin.location.lnr + 1, 1)
locatable_type: LocatableString = LocatableString(self.type_expression, plugin_line, 0, resolver)
self._resolved_type = inmanta_type.resolve_type(locatable_type, resolver)
return self._resolved_type
Expand Down
96 changes: 64 additions & 32 deletions tests/compiler/test_plugin_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,64 +17,96 @@
"""

import collections.abc
from typing import Any, Mapping, Optional, Sequence, Union
from typing import Annotated, Any, Mapping, Optional, Sequence, Union

import pytest

import inmanta.ast.type as inmanta_type
from inmanta.ast import RuntimeException
from inmanta.plugins import Null, to_dsl_type
from inmanta.ast import Namespace, Range, RuntimeException
from inmanta.plugins import ModelType, Null, to_dsl_type


def test_conversion(caplog):
"""
Test behaviour of to_dsl_type function.
"""
assert inmanta_type.Integer() == to_dsl_type(int)
assert inmanta_type.Float() == to_dsl_type(float)
assert inmanta_type.NullableType(inmanta_type.Float()) == to_dsl_type(float | None)
assert inmanta_type.List() == to_dsl_type(list)
assert inmanta_type.TypedList(inmanta_type.String()) == to_dsl_type(list[str])
assert inmanta_type.TypedList(inmanta_type.String()) == to_dsl_type(Sequence[str])
assert inmanta_type.List() == to_dsl_type(Sequence)
assert inmanta_type.List() == to_dsl_type(collections.abc.Sequence)
assert inmanta_type.TypedList(inmanta_type.String()) == to_dsl_type(collections.abc.Sequence[str])
assert inmanta_type.TypedDict(inmanta_type.Type()) == to_dsl_type(dict)
assert inmanta_type.TypedDict(inmanta_type.Type()) == to_dsl_type(Mapping)
assert inmanta_type.TypedDict(inmanta_type.String()) == to_dsl_type(dict[str, str])
assert inmanta_type.TypedDict(inmanta_type.String()) == to_dsl_type(Mapping[str, str])

assert inmanta_type.TypedDict(inmanta_type.String()) == to_dsl_type(collections.abc.Mapping[str, str])
namespace = Namespace("dummy-namespace")
namespace.primitives = inmanta_type.TYPES

location: Range = Range("test", 1, 1, 2, 1)

def to_dsl_type_simple(python_type: type[object]) -> inmanta_type.Type:
return to_dsl_type(python_type, location, namespace)

assert inmanta_type.NullableType(inmanta_type.Integer()) == to_dsl_type_simple(Annotated[int | None, "something"])

assert inmanta_type.TypedDict(inmanta_type.Type()) == to_dsl_type_simple(Annotated[dict[str, int], ModelType["dict"]])
assert inmanta_type.Integer() == to_dsl_type_simple(int)
assert inmanta_type.Float() == to_dsl_type_simple(float)
assert inmanta_type.NullableType(inmanta_type.Float()) == to_dsl_type_simple(float | None)
assert inmanta_type.List() == to_dsl_type_simple(list)
assert inmanta_type.TypedList(inmanta_type.String()) == to_dsl_type_simple(list[str])
assert inmanta_type.TypedList(inmanta_type.String()) == to_dsl_type_simple(Sequence[str])
assert inmanta_type.List() == to_dsl_type_simple(Sequence)
assert inmanta_type.List() == to_dsl_type_simple(collections.abc.Sequence)
assert inmanta_type.TypedList(inmanta_type.String()) == to_dsl_type_simple(collections.abc.Sequence[str])
assert inmanta_type.TypedDict(inmanta_type.Type()) == to_dsl_type_simple(dict)
assert inmanta_type.TypedDict(inmanta_type.Type()) == to_dsl_type_simple(Mapping)
assert inmanta_type.TypedDict(inmanta_type.String()) == to_dsl_type_simple(dict[str, str])
assert inmanta_type.TypedDict(inmanta_type.String()) == to_dsl_type_simple(Mapping[str, str])

# Union types
assert inmanta_type.Integer() == to_dsl_type(Union[int])
assert inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()]) == to_dsl_type(Union[int, str])
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type(
assert inmanta_type.Integer() == to_dsl_type_simple(Union[int])
assert inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()]) == to_dsl_type_simple(Union[int, str])
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type_simple(
Union[None, int, str]
)
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type(
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type_simple(
Optional[Union[int, str]]
)
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type(
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type_simple(
Union[int, str] | None
)
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type(
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type_simple(
None | Union[int, str]
)
# verify that nested unions are flattened and nested None values are considered for NullableType
assert inmanta_type.NullableType(
inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String(), inmanta_type.Float()])
) == to_dsl_type(Union[int, Union[str, Union[float, None]]])
) == to_dsl_type_simple(Union[int, Union[str, Union[float, None]]])

assert Null() == to_dsl_type(Union[None])
# Union types
assert inmanta_type.Integer() == to_dsl_type_simple(Union[int])
assert inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()]) == to_dsl_type_simple(Union[int, str])
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type_simple(
Union[None, int, str]
)
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type_simple(
Optional[Union[int, str]]
)
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type_simple(
Union[int, str] | None
)
assert inmanta_type.NullableType(inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String()])) == to_dsl_type_simple(
None | Union[int, str]
)
# verify that nested unions are flattened and nested None values are considered for NullableType
assert inmanta_type.NullableType(
inmanta_type.Union([inmanta_type.Integer(), inmanta_type.String(), inmanta_type.Float()])
) == to_dsl_type_simple(Union[int, Union[str, Union[float, None]]])

assert Null() == to_dsl_type_simple(Union[None])
assert inmanta_type.TypedDict(inmanta_type.String()) == to_dsl_type_simple(collections.abc.Mapping[str, str])

assert Null() == to_dsl_type_simple(Union[None])

assert isinstance(to_dsl_type(Any), inmanta_type.Type)
assert isinstance(to_dsl_type_simple(Any), inmanta_type.Type)

with pytest.raises(RuntimeException):
to_dsl_type(dict[int, int])
to_dsl_type_simple(dict[int, int])

with pytest.raises(RuntimeException):
to_dsl_type(set[str])
to_dsl_type_simple(set[str])

class CustomList[T](list[T]):
pass
Expand All @@ -83,14 +115,14 @@ class CustomDict[K, V](Mapping[K, V]):
pass

with pytest.raises(RuntimeException):
to_dsl_type(CustomList[str])
to_dsl_type_simple(CustomList[str])

with pytest.raises(RuntimeException):
to_dsl_type(CustomDict[str, str])
to_dsl_type_simple(CustomDict[str, str])

# Check that a warning is produced when implicit cast to 'Any'
caplog.clear()
to_dsl_type(complex)
to_dsl_type_simple(complex)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one sparked a laugh

warning_message = (
"InmantaWarning: Python type <class 'complex'> was implicitly cast to 'Any' because no matching type "
"was found in the Inmanta DSL. Please refer to the documentation for an overview of supported types at the "
Expand Down
22 changes: 18 additions & 4 deletions tests/compiler/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def test_native_types(snippetcompiler: "SnippetCompilationTest") -> None:
"""
import plugin_native_types

test_entity = plugin_native_types::TestEntity(value=2)
a = "b"
a = plugin_native_types::get_from_dict({"a":"b"}, "a")

Expand Down Expand Up @@ -470,7 +471,17 @@ def test_native_types(snippetcompiler: "SnippetCompilationTest") -> None:
plugin_native_types::union_return_optional_3(value=val) # type return value: Union[int, str] | None
plugin_native_types::union_return_optional_4(value=val) # type return value: None | Union[int, str]
end
"""

# Annotated types
plugin_native_types::annotated_arg_entity(test_entity) # type value: Annotated[MyEntity, InmantaType("TestEntity")]
plugin_native_types::annotated_return_entity(test_entity) # type return value: Annotated[MyEntity, InmantaType("TestEntity")]

for val in ["yes", "no"]:
plugin_native_types::annotated_arg_literal(val) # type value: Annotated[Literal["yes", "no"], InmantaType("response")
plugin_native_types::annotated_return_literal(val) # type value: Annotated[Literal["yes", "no"], InmantaType("response")
end
""",
ministd=True,
)
compiler.do_compile()

Expand Down Expand Up @@ -517,7 +528,8 @@ def test_native_types(snippetcompiler: "SnippetCompilationTest") -> None:
f"""
import plugin_native_types
plugin_native_types::{plugin_name}(value={plugin_value})
"""
""",
ministd=True,
)
with pytest.raises(RuntimeException) as exc_info:
compiler.do_compile()
Expand Down Expand Up @@ -573,7 +585,8 @@ def test_native_types(snippetcompiler: "SnippetCompilationTest") -> None:
f"""
import plugin_native_types
plugin_native_types::{plugin_name}(value={plugin_value})
"""
""",
ministd=True,
)
with pytest.raises(WrappingRuntimeException) as exc_info:
compiler.do_compile()
Expand All @@ -582,7 +595,8 @@ def test_native_types(snippetcompiler: "SnippetCompilationTest") -> None:
snippetcompiler.setup_for_snippet(
"""
import plugin_invalid_union_type
"""
""",
ministd=True,
)
with pytest.raises(InvalidTypeAnnotation) as exc_info:
compiler.do_compile()
Expand Down
7 changes: 7 additions & 0 deletions tests/data/modules/plugin_native_types/model/_init.cf
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
entity TestEntity:
int value
end

typedef response as string matching self in ["yes", "no"]

implement TestEntity using std::none
Loading