Skip to content

Commit 1fa8b61

Browse files
authored
Merge pull request #27 from jymchng/feat/mypy-plugin
mypy plugin seems to be working well
2 parents 8f72d48 + c1e4f02 commit 1fa8b61

File tree

9 files changed

+263
-63
lines changed

9 files changed

+263
-63
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ repos:
8888
files: '(newtype)/.*\.py$' # Single quote critical due to escape character '\' used in RegEx search string (see YAML - 7.3 Flow Scalar Styles)
8989
args: [--config-file=./pyproject.toml, --ignore-missing-imports, --scripts-are-modules]
9090
exclude: '(docs|tests)/.*\.py$'
91+
additional_dependencies:
92+
- "mypy>=1.0.0"
93+
- "." # Install the current package
9194

9295
- repo: local
9396
hooks:

docs/development/building.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,17 @@ poetry self add poetry-dynamic-versioning
197197
poetry version patch
198198
```
199199

200-
2. Commit your changes and push to the repository.
200+
2. Set the tag to the new version:
201+
```
202+
git tag v0.1.3
203+
```
204+
205+
3. Commit your changes and push to the repository.
201206

202207
```bash
203208
git commit -m "Release version 0.1.3"
204209
```
205210

206-
3. Set the tag to the new version:
207-
```
208-
git tag v0.1.3
209-
```
210-
211211
4. Push the changes to the repository:
212212
```
213213
git push --tags

examples/email_str.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from newtype import NewType, newtype_exclude
88

99

10-
class EmailStr(NewType(str)):
10+
class EmailStr(NewType(str)): # type: ignore[misc]
1111
# you can define `__slots__` to save space
1212
__slots__ = (
1313
'_local_part',

examples/newtype_enums.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from enum import Enum
2+
from typing import Optional, Type
23

34
import pytest
5+
from newtype import NewType
46

5-
from newtype import NewType, newtype_exclude
67

7-
8-
class ENV(NewType(str), Enum):
8+
class ENV(NewType(str), Enum): # type: ignore[misc]
99

1010
LOCAL = "LOCAL"
1111
DEV = "DEV"
@@ -23,26 +23,33 @@ class RegularENV(str, Enum):
2323
PREPROD = "PREPROD"
2424
PROD = "PROD"
2525

26+
RollYourOwnNewTypeEnum: "Optional[Type[RollYourOwnNewTypeEnum]]" = None
27+
2628
class ENVVariant(str):
2729

2830
__VALID_MEMBERS__ = ["LOCAL", "DEV", "SIT", "UAT", "PREPROD", "PROD"]
2931

30-
def __new__(cls, value: str):
32+
def __new__(cls, value: str) -> "ENVVariant":
3133
members = ENVVariant.__VALID_MEMBERS__
32-
# if isinstance(value, RollYourOwnNewTypeEnum):
33-
# value_as_str = str(value.value)
34-
# else:
35-
value_as_str = str(value)
34+
value_as_str = str(value.value if hasattr(value, "value") else value)
3635
if value_as_str not in members:
3736
raise ValueError(f"`value` = {value} must be one of `{members}`; `value_as_str` = {value_as_str}")
3837
return super().__new__(cls, value_as_str)
3938

40-
# why not i write my own `.replace(..)`
41-
# yes, you can but how?
42-
def my_replace(self, old: "ENVVariant", new: "ENVVariant", count: int=-1):
43-
return ENVVariant(str(self).replace(str(old), str(new), count))
39+
def my_replace(self, old: "ENVVariant", new: "ENVVariant", count: int=-1) -> "ENVVariant":
40+
# Convert both old and new to their string values
41+
old_str = str(old.value if hasattr(old, "value") else old)
42+
new_str = str(new.value if hasattr(new, "value") else new)
43+
# Do the replacement on string values
44+
result = str(self.value if hasattr(self, "value") else self).replace(old_str, new_str, count)
45+
# For enums, we need to look up the enum member by value
46+
if issubclass(type(self), Enum):
47+
return type(self)(result) # This will find the enum member
48+
49+
# For non-enum types, create new instance directly
50+
return type(self)(result)
4451

45-
class RollYourOwnNewTypeEnum(ENVVariant, Enum):
52+
class RollYourOwnNewTypeEnum(ENVVariant, Enum): # type: ignore[no-redef]
4653

4754
LOCAL = "LOCAL"
4855
DEV = "DEV"
@@ -51,8 +58,8 @@ class RollYourOwnNewTypeEnum(ENVVariant, Enum):
5158
PREPROD = "PREPROD"
5259
PROD = "PROD"
5360

54-
55-
def test_nt_env_replace():
61+
# mypy doesn't raise errors here
62+
def test_nt_env_replace() -> None:
5663

5764
env = ENV.LOCAL
5865

@@ -63,6 +70,7 @@ def test_nt_env_replace():
6370
# let's say now we want to replace the environment
6471
# nevermind about the reason why we want to do so
6572
env = env.replace(ENV.LOCAL, ENV.DEV)
73+
# reveal_type(env) # Revealed type is "newtype_enums.ENV"
6674

6775
# replacement is successful
6876
assert env is ENV.DEV
@@ -76,11 +84,13 @@ def test_nt_env_replace():
7684
# cannot replace with something that is not a `ENV`
7785
env = env.replace(ENV.DEV, "NotAnEnv")
7886

87+
# reveal_type(env) # Revealed type is "newtype_enums.ENV"
88+
7989
with pytest.raises(ValueError):
8090
# cannot even make 'DEV' -> 'dev'
8191
env = env.lower()
8292

83-
def test_reg_env_replace():
93+
def test_reg_env_replace() -> None:
8494

8595
env = RegularENV.LOCAL
8696

@@ -90,15 +100,17 @@ def test_reg_env_replace():
90100
assert isinstance(env, RegularENV) # pass
91101

92102
# now we try to replace
93-
env = env.replace(RegularENV.LOCAL, RegularENV.DEV)
103+
env = env.replace("LOCAL", "DEV")
94104

95105
# we are hoping that it will continue to be a `RegularENV.DEV` but it is not
96106
assert env is not RegularENV.DEV # pass, no longer a `RegularENV`
97107
assert env is not RegularENV.LOCAL # pass, no longer a `RegularENV`
98108
assert not isinstance(env, RegularENV)
99109
assert isinstance(env, str) # 'downcast' (?) to `str`
100110

101-
def test_ryont_env_replace():
111+
def test_ryont_env_replace() -> None:
112+
113+
assert RollYourOwnNewTypeEnum is not None
102114

103115
env = RollYourOwnNewTypeEnum.LOCAL
104116

@@ -130,8 +142,8 @@ def test_ryont_env_replace():
130142

131143
env = RollYourOwnNewTypeEnum.LOCAL
132144

133-
# env = env.my_replace(RollYourOwnNewTypeEnum.LOCAL, RollYourOwnNewTypeEnum.PREPROD)
145+
env = env.my_replace(RollYourOwnNewTypeEnum.LOCAL, RollYourOwnNewTypeEnum.PREPROD)
134146

135147
assert isinstance(env, str)
136-
assert env is not RollYourOwnNewTypeEnum.PREPROD
148+
assert env is RollYourOwnNewTypeEnum.PREPROD
137149
assert isinstance(env, RollYourOwnNewTypeEnum)

newtype/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,11 @@ def __init__(self, val: int) -> None:
4242

4343

4444
__version__ = "0.0.0" # Don't manually change, let poetry-dynamic-versioning handle it
45-
__all__ = ["NewType", "newtype_exclude", "func_is_excluded", "NewTypeInit", "NewTypeMethod"]
45+
__all__ = [
46+
"NewType",
47+
"newtype_exclude",
48+
"func_is_excluded",
49+
"NewTypeInit",
50+
"NewTypeMethod",
51+
"mypy_plugin",
52+
]

newtype/mypy_plugin.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import logging
2+
import os
3+
from typing import Any, Callable, List, Optional, Type, Union
4+
5+
from mypy.nodes import Argument, FuncDef, RefExpr, SymbolTableNode, TypeInfo, Var
6+
from mypy.plugin import ClassDefContext, Plugin
7+
from mypy.plugins.common import add_method
8+
from mypy.types import AnyType, CallableType, Instance, TypeOfAny, UnionType
9+
from mypy.types import Type as MypyType
10+
11+
12+
# Set up logging
13+
logger = logging.getLogger("newtype.mypy_plugin")
14+
# Remove any existing handlers to prevent duplicates
15+
for handler in logger.handlers[:]:
16+
logger.removeHandler(handler)
17+
18+
# Only enable logging if __PYNT_DEBUG__ is set to "true"
19+
if os.environ.get("__PYNT_DEBUG__", "").lower() == "true":
20+
# Create a file handler
21+
file_handler = logging.FileHandler("mypy_plugin.log")
22+
file_handler.setFormatter(
23+
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
24+
)
25+
logger.addHandler(file_handler)
26+
logger.setLevel(logging.DEBUG)
27+
else:
28+
logger.setLevel(logging.WARNING)
29+
30+
31+
def convert_union_type(typ: MypyType) -> MypyType:
32+
"""Convert a type to use UnionType instead of | operator."""
33+
if isinstance(typ, UnionType):
34+
# If it's already a UnionType, convert its items
35+
return UnionType([convert_union_type(t) for t in typ.items])
36+
elif isinstance(typ, Instance) and typ.args:
37+
return typ.copy_modified(args=[convert_union_type(arg) for arg in typ.args])
38+
return typ
39+
40+
41+
class NewTypePlugin(Plugin):
42+
def __init__(self, *args: Any, **kwargs: Any) -> None:
43+
super().__init__(*args, **kwargs)
44+
logger.info("Initializing NewTypePlugin")
45+
46+
def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
47+
logger.debug(f"get_base_class_hook called with fullname: {fullname}")
48+
if "newtype.NewType" in fullname:
49+
logger.info(f"Found NewType class: {fullname}")
50+
return handle_newtype_class
51+
logger.debug(f"No hook for {fullname}")
52+
return None
53+
54+
55+
def handle_newtype_class(ctx: ClassDefContext) -> None: # noqa: C901
56+
logger.info(f"Processing NewType class: {ctx.cls.fullname}")
57+
58+
if not hasattr(ctx.reason, "args") or not ctx.reason.args:
59+
logger.warning("No arguments provided to NewType")
60+
return
61+
62+
# Get base type from NewType argument
63+
base_type_expr = ctx.reason.args[0]
64+
logger.debug(f"Base type expression: {base_type_expr}")
65+
66+
if not isinstance(base_type_expr, RefExpr):
67+
logger.warning(f"Base type expression is not a RefExpr: {type(base_type_expr)}")
68+
return
69+
70+
base_type: Optional[SymbolTableNode]
71+
72+
# Handle built-in types specially
73+
if base_type_expr.fullname and base_type_expr.fullname.startswith("builtins."):
74+
logger.debug(f"Looking up built-in type: {base_type_expr.fullname}")
75+
base_type = ctx.api.lookup_fully_qualified(base_type_expr.fullname)
76+
else:
77+
logger.debug(f"Looking up qualified type: {base_type_expr.fullname}")
78+
base_type = ctx.api.lookup_qualified(base_type_expr.fullname, ctx.cls)
79+
80+
if not base_type:
81+
logger.warning(f"Could not find base type: {base_type_expr.fullname}")
82+
return
83+
if not isinstance(base_type.node, TypeInfo):
84+
logger.warning(f"Base type node is not a TypeInfo: {type(base_type.node)}")
85+
return
86+
87+
# Set up inheritance
88+
logger.info(f"Setting up inheritance for {ctx.cls.fullname} from {base_type.node.fullname}")
89+
base_instance = Instance(base_type.node, [])
90+
info = ctx.cls.info
91+
info.bases = [base_instance]
92+
info.mro = [info, base_type.node] + base_type.node.mro[1:]
93+
logger.debug(f"MRO: {[t.fullname for t in info.mro]}")
94+
95+
# Copy all methods from base type
96+
logger.info(f"Processing methods from base type {base_type.node.fullname}")
97+
for name, node in base_type.node.names.items():
98+
if isinstance(node.node, FuncDef) and isinstance(node.node.type, CallableType):
99+
logger.debug(f"Processing method: {name}")
100+
method_type = node.node.type
101+
102+
# Convert return type to subtype if it matches base type
103+
ret_type = convert_union_type(method_type.ret_type)
104+
logger.debug(f"Original return type for {name}: {ret_type}")
105+
106+
if isinstance(ret_type, Instance) and ret_type.type == base_type.node:
107+
logger.debug(f"Converting return type for {name} to {info.fullname}")
108+
ret_type = Instance(info, [])
109+
elif isinstance(ret_type, UnionType):
110+
logger.debug(f"Processing union return type for {name}: {ret_type}")
111+
items: List[Union[MypyType, Instance]] = []
112+
for item in ret_type.items:
113+
if isinstance(item, Instance) and item.type == base_type.node:
114+
logger.debug(f"Converting union item from {item} to {info.fullname}")
115+
items.append(Instance(info, []))
116+
else:
117+
items.append(item)
118+
ret_type = UnionType(items)
119+
logger.debug(f"Final union return type for {name}: {ret_type}")
120+
121+
# Create arguments list, preserving original argument types
122+
arguments = []
123+
if method_type.arg_types:
124+
logger.debug(f"Processing arguments for method {name}")
125+
# Skip first argument (self)
126+
for i, (arg_type, arg_kind, arg_name) in enumerate(
127+
zip(
128+
method_type.arg_types[1:],
129+
method_type.arg_kinds[1:],
130+
method_type.arg_names[1:] or [""] * len(method_type.arg_types[1:]),
131+
),
132+
start=1,
133+
):
134+
logger.debug(
135+
f"Processing argument {i} for {name}: \
136+
{arg_name or f'arg{i}'} of type {arg_type}"
137+
)
138+
139+
# Special handling for __contains__ method
140+
if name == "__contains__" and i == 1:
141+
logger.debug(
142+
"Using Any type for __contains__ argument to satisfy Container protocol"
143+
)
144+
arg_type = AnyType(TypeOfAny.special_form)
145+
else:
146+
# Convert any union types in arguments
147+
arg_type = convert_union_type(arg_type)
148+
149+
# Create a new variable for the argument
150+
var = Var(arg_name or f"arg{i}", arg_type)
151+
var.is_ready = True
152+
153+
# Create the argument
154+
arg = Argument(
155+
variable=var,
156+
type_annotation=arg_type,
157+
initializer=None,
158+
kind=arg_kind,
159+
)
160+
arguments.append(arg)
161+
162+
# Add method to class
163+
logger.info(f"Adding method {name} to {ctx.cls.fullname} with return type {ret_type}")
164+
add_method(ctx, name, arguments, ret_type)
165+
166+
167+
def plugin(version: str) -> Type[Plugin]:
168+
logger.info(f"Initializing plugin for mypy version: {version}")
169+
return NewTypePlugin

0 commit comments

Comments
 (0)