Skip to content

Commit

Permalink
Use specific enum type
Browse files Browse the repository at this point in the history
  • Loading branch information
N0I0C0K authored and unmade committed Oct 20, 2024
1 parent 0747642 commit be6b32c
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 15 deletions.
5 changes: 3 additions & 2 deletions src/thriftpyi/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import ast
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Sequence, Type, Union, Optional
from typing import TYPE_CHECKING, List, Optional, Sequence, Type, Union

if TYPE_CHECKING:
AnyFunctionDef = Union[ast.AsyncFunctionDef, ast.FunctionDef]
Expand Down Expand Up @@ -161,7 +161,8 @@ def as_ast(self) -> Union[ast.AnnAssign, ast.Assign]:
if self.type is None:
return ast.Assign(
targets=[ast.Name(id=self.name, ctx=ast.Store())],
value=ast.Constant(value=self.value),
value=ast.Constant(value=self.value, kind=None),
lineno=0,
)

if not self.required:
Expand Down
16 changes: 14 additions & 2 deletions src/thriftpyi/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,23 @@ def get_fields(self, *, ignore_type: bool = False) -> List[Field]:
for item in self.thrift_spec
]

def _get_python_type(self, item: TSpecItemProxy) -> str:
pytype = get_python_type(item.ttype, meta=item.meta)
def _remove_self_module(self, pytype: str) -> str:
left_type, sep, right_type = pytype.partition(",")
# Due to complex type, such as Dict[some_module.TypeA, some_module.TypeB]
# recursively deal with the first and second parts
if right_type != "":
return (
self._remove_self_module(left_type)
+ sep
+ self._remove_self_module(right_type)
)
start, _, end = pytype.rpartition(f"{self.module_name}.")
return start + end

def _get_python_type(self, item: TSpecItemProxy) -> str:
pytype = get_python_type(item.ttype, meta=item.meta)
return self._remove_self_module(pytype)

def _get_default_value(self, item: TSpecItemProxy) -> FieldValue:
default_value = self.default_spec.get(item.name)
return cast(FieldValue, default_value)
2 changes: 1 addition & 1 deletion src/thriftpyi/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _make_relative_import(names: Iterable[str]) -> ast.ImportFrom:
)


def _make_consts(interface: TModuleProxy) -> List[ast.AnnAssign]:
def _make_consts(interface: TModuleProxy) -> List[ast.stmt]:
return [item.as_ast() for item in interface.get_consts()]


Expand Down
3 changes: 2 additions & 1 deletion src/thriftpyi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def _get_i16(meta: List) -> str:


def _get_i32(meta: List) -> str:
del meta
if meta and meta[0] is not None:
return f"{meta[0].__module__}.{meta[0].__name__}"
return "int"


Expand Down
6 changes: 3 additions & 3 deletions tests/stubs/expected/async/todo.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ class TodoType(IntEnum):
class TodoItem:
id: int
text: str
type: int
type: TodoType
created: dates.DateTime
is_deleted: bool
picture: Optional[bytes] = None
is_favorite: bool = False

class Todo:
async def create(self, text: str, type: int) -> int: ...
async def create(self, text: str, type: TodoType) -> int: ...
async def update(self, item: TodoItem) -> None: ...
async def get(self, id: int) -> TodoItem: ...
async def all(self, pager: shared.LimitOffset) -> List[TodoItem]: ...
async def filter(self, ids: List[int]) -> List[TodoItem]: ...
async def stats(self) -> Dict[int, float]: ...
async def types(self) -> Set[int]: ...
async def groupby(self) -> Dict[int, List[TodoItem]]: ...
async def groupby(self) -> Dict[TodoType, List[TodoItem]]: ...
async def ping(self) -> str: ...
6 changes: 3 additions & 3 deletions tests/stubs/expected/optional/todo.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ class TodoType(IntEnum):
class TodoItem:
id: Optional[int] = None
text: Optional[str] = None
type: Optional[int] = None
type: Optional[TodoType] = None
created: Optional[dates.DateTime] = None
is_deleted: Optional[bool] = None
picture: Optional[bytes] = None
is_favorite: Optional[bool] = False

class Todo:
def create(self, text: str, type: int) -> int: ...
def create(self, text: str, type: TodoType) -> int: ...
def update(self, item: TodoItem) -> None: ...
def get(self, id: int) -> TodoItem: ...
def all(self, pager: shared.LimitOffset) -> List[TodoItem]: ...
def filter(self, ids: List[int]) -> List[TodoItem]: ...
def stats(self) -> Dict[int, float]: ...
def types(self) -> Set[int]: ...
def groupby(self) -> Dict[int, List[TodoItem]]: ...
def groupby(self) -> Dict[TodoType, List[TodoItem]]: ...
def ping(self) -> str: ...
6 changes: 3 additions & 3 deletions tests/stubs/expected/sync/todo.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ class TodoType(IntEnum):
class TodoItem:
id: int
text: str
type: int
type: TodoType
created: dates.DateTime
is_deleted: bool
picture: Optional[bytes] = None
is_favorite: bool = False

class Todo:
def create(self, text: str, type: int) -> int: ...
def create(self, text: str, type: TodoType) -> int: ...
def update(self, item: TodoItem) -> None: ...
def get(self, id: int) -> TodoItem: ...
def all(self, pager: shared.LimitOffset) -> List[TodoItem]: ...
def filter(self, ids: List[int]) -> List[TodoItem]: ...
def stats(self) -> Dict[int, float]: ...
def types(self) -> Set[int]: ...
def groupby(self) -> Dict[int, List[TodoItem]]: ...
def groupby(self) -> Dict[TodoType, List[TodoItem]]: ...
def ping(self) -> str: ...

0 comments on commit be6b32c

Please sign in to comment.