From be6b32c2cefd6a303f395ed66be0f475bfcaf1a7 Mon Sep 17 00:00:00 2001 From: N0I0C0K Date: Sun, 20 Oct 2024 17:18:57 +0800 Subject: [PATCH] Use specific enum type --- src/thriftpyi/entities.py | 5 +++-- src/thriftpyi/proxies.py | 16 ++++++++++++++-- src/thriftpyi/stubs.py | 2 +- src/thriftpyi/utils.py | 3 ++- tests/stubs/expected/async/todo.pyi | 6 +++--- tests/stubs/expected/optional/todo.pyi | 6 +++--- tests/stubs/expected/sync/todo.pyi | 6 +++--- 7 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/thriftpyi/entities.py b/src/thriftpyi/entities.py index 245f99b..4168271 100644 --- a/src/thriftpyi/entities.py +++ b/src/thriftpyi/entities.py @@ -2,7 +2,7 @@ import ast from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Sequence, Type, Union, Optional +from typing import TYPE_CHECKING, List, Optional, Sequence, Type, Union if TYPE_CHECKING: AnyFunctionDef = Union[ast.AsyncFunctionDef, ast.FunctionDef] @@ -161,7 +161,8 @@ def as_ast(self) -> Union[ast.AnnAssign, ast.Assign]: if self.type is None: return ast.Assign( targets=[ast.Name(id=self.name, ctx=ast.Store())], - value=ast.Constant(value=self.value), + value=ast.Constant(value=self.value, kind=None), + lineno=0, ) if not self.required: diff --git a/src/thriftpyi/proxies.py b/src/thriftpyi/proxies.py index 42ab567..466ebd1 100644 --- a/src/thriftpyi/proxies.py +++ b/src/thriftpyi/proxies.py @@ -168,11 +168,23 @@ def get_fields(self, *, ignore_type: bool = False) -> List[Field]: for item in self.thrift_spec ] - def _get_python_type(self, item: TSpecItemProxy) -> str: - pytype = get_python_type(item.ttype, meta=item.meta) + def _remove_self_module(self, pytype: str) -> str: + left_type, sep, right_type = pytype.partition(",") + # Due to complex type, such as Dict[some_module.TypeA, some_module.TypeB] + # recursively deal with the first and second parts + if right_type != "": + return ( + self._remove_self_module(left_type) + + sep + + self._remove_self_module(right_type) + ) start, _, end = pytype.rpartition(f"{self.module_name}.") return start + end + def _get_python_type(self, item: TSpecItemProxy) -> str: + pytype = get_python_type(item.ttype, meta=item.meta) + return self._remove_self_module(pytype) + def _get_default_value(self, item: TSpecItemProxy) -> FieldValue: default_value = self.default_spec.get(item.name) return cast(FieldValue, default_value) diff --git a/src/thriftpyi/stubs.py b/src/thriftpyi/stubs.py index 20156af..b6f67f6 100644 --- a/src/thriftpyi/stubs.py +++ b/src/thriftpyi/stubs.py @@ -61,7 +61,7 @@ def _make_relative_import(names: Iterable[str]) -> ast.ImportFrom: ) -def _make_consts(interface: TModuleProxy) -> List[ast.AnnAssign]: +def _make_consts(interface: TModuleProxy) -> List[ast.stmt]: return [item.as_ast() for item in interface.get_consts()] diff --git a/src/thriftpyi/utils.py b/src/thriftpyi/utils.py index d6520af..aa3a823 100644 --- a/src/thriftpyi/utils.py +++ b/src/thriftpyi/utils.py @@ -71,7 +71,8 @@ def _get_i16(meta: List) -> str: def _get_i32(meta: List) -> str: - del meta + if meta and meta[0] is not None: + return f"{meta[0].__module__}.{meta[0].__name__}" return "int" diff --git a/tests/stubs/expected/async/todo.pyi b/tests/stubs/expected/async/todo.pyi index 2f1883a..9881da9 100644 --- a/tests/stubs/expected/async/todo.pyi +++ b/tests/stubs/expected/async/todo.pyi @@ -13,19 +13,19 @@ class TodoType(IntEnum): class TodoItem: id: int text: str - type: int + type: TodoType created: dates.DateTime is_deleted: bool picture: Optional[bytes] = None is_favorite: bool = False class Todo: - async def create(self, text: str, type: int) -> int: ... + async def create(self, text: str, type: TodoType) -> int: ... async def update(self, item: TodoItem) -> None: ... async def get(self, id: int) -> TodoItem: ... async def all(self, pager: shared.LimitOffset) -> List[TodoItem]: ... async def filter(self, ids: List[int]) -> List[TodoItem]: ... async def stats(self) -> Dict[int, float]: ... async def types(self) -> Set[int]: ... - async def groupby(self) -> Dict[int, List[TodoItem]]: ... + async def groupby(self) -> Dict[TodoType, List[TodoItem]]: ... async def ping(self) -> str: ... diff --git a/tests/stubs/expected/optional/todo.pyi b/tests/stubs/expected/optional/todo.pyi index d887e19..b0c0346 100644 --- a/tests/stubs/expected/optional/todo.pyi +++ b/tests/stubs/expected/optional/todo.pyi @@ -13,19 +13,19 @@ class TodoType(IntEnum): class TodoItem: id: Optional[int] = None text: Optional[str] = None - type: Optional[int] = None + type: Optional[TodoType] = None created: Optional[dates.DateTime] = None is_deleted: Optional[bool] = None picture: Optional[bytes] = None is_favorite: Optional[bool] = False class Todo: - def create(self, text: str, type: int) -> int: ... + def create(self, text: str, type: TodoType) -> int: ... def update(self, item: TodoItem) -> None: ... def get(self, id: int) -> TodoItem: ... def all(self, pager: shared.LimitOffset) -> List[TodoItem]: ... def filter(self, ids: List[int]) -> List[TodoItem]: ... def stats(self) -> Dict[int, float]: ... def types(self) -> Set[int]: ... - def groupby(self) -> Dict[int, List[TodoItem]]: ... + def groupby(self) -> Dict[TodoType, List[TodoItem]]: ... def ping(self) -> str: ... diff --git a/tests/stubs/expected/sync/todo.pyi b/tests/stubs/expected/sync/todo.pyi index 0ec5810..b907bee 100644 --- a/tests/stubs/expected/sync/todo.pyi +++ b/tests/stubs/expected/sync/todo.pyi @@ -13,19 +13,19 @@ class TodoType(IntEnum): class TodoItem: id: int text: str - type: int + type: TodoType created: dates.DateTime is_deleted: bool picture: Optional[bytes] = None is_favorite: bool = False class Todo: - def create(self, text: str, type: int) -> int: ... + def create(self, text: str, type: TodoType) -> int: ... def update(self, item: TodoItem) -> None: ... def get(self, id: int) -> TodoItem: ... def all(self, pager: shared.LimitOffset) -> List[TodoItem]: ... def filter(self, ids: List[int]) -> List[TodoItem]: ... def stats(self) -> Dict[int, float]: ... def types(self) -> Set[int]: ... - def groupby(self) -> Dict[int, List[TodoItem]]: ... + def groupby(self) -> Dict[TodoType, List[TodoItem]]: ... def ping(self) -> str: ...