Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ repos:
files: '(newtype)/.*\.py$' # Single quote critical due to escape character '\' used in RegEx search string (see YAML - 7.3 Flow Scalar Styles)
args: [--config-file=./pyproject.toml, --ignore-missing-imports, --scripts-are-modules]
exclude: '(docs|tests)/.*\.py$'
additional_dependencies:
- "mypy>=1.0.0"
- "." # Install the current package

- repo: local
hooks:
Expand Down
12 changes: 6 additions & 6 deletions docs/development/building.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,17 @@ poetry self add poetry-dynamic-versioning
poetry version patch
```

2. Commit your changes and push to the repository.
2. Set the tag to the new version:
```
git tag v0.1.3
```

3. Commit your changes and push to the repository.

```bash
git commit -m "Release version 0.1.3"
```

3. Set the tag to the new version:
```
git tag v0.1.3
```

4. Push the changes to the repository:
```
git push --tags
Expand Down
2 changes: 1 addition & 1 deletion examples/email_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from newtype import NewType, newtype_exclude


class EmailStr(NewType(str)):
class EmailStr(NewType(str)): # type: ignore[misc]
# you can define `__slots__` to save space
__slots__ = (
'_local_part',
Expand Down
52 changes: 32 additions & 20 deletions examples/newtype_enums.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from enum import Enum
from typing import Optional, Type

import pytest
from newtype import NewType

from newtype import NewType, newtype_exclude


class ENV(NewType(str), Enum):
class ENV(NewType(str), Enum): # type: ignore[misc]

LOCAL = "LOCAL"
DEV = "DEV"
Expand All @@ -23,26 +23,33 @@ class RegularENV(str, Enum):
PREPROD = "PREPROD"
PROD = "PROD"

RollYourOwnNewTypeEnum: "Optional[Type[RollYourOwnNewTypeEnum]]" = None

class ENVVariant(str):

__VALID_MEMBERS__ = ["LOCAL", "DEV", "SIT", "UAT", "PREPROD", "PROD"]

def __new__(cls, value: str):
def __new__(cls, value: str) -> "ENVVariant":
members = ENVVariant.__VALID_MEMBERS__
# if isinstance(value, RollYourOwnNewTypeEnum):
# value_as_str = str(value.value)
# else:
value_as_str = str(value)
value_as_str = str(value.value if hasattr(value, "value") else value)
if value_as_str not in members:
raise ValueError(f"`value` = {value} must be one of `{members}`; `value_as_str` = {value_as_str}")
return super().__new__(cls, value_as_str)

# why not i write my own `.replace(..)`
# yes, you can but how?
def my_replace(self, old: "ENVVariant", new: "ENVVariant", count: int=-1):
return ENVVariant(str(self).replace(str(old), str(new), count))
def my_replace(self, old: "ENVVariant", new: "ENVVariant", count: int=-1) -> "ENVVariant":
# Convert both old and new to their string values
old_str = str(old.value if hasattr(old, "value") else old)
new_str = str(new.value if hasattr(new, "value") else new)
# Do the replacement on string values
result = str(self.value if hasattr(self, "value") else self).replace(old_str, new_str, count)
# For enums, we need to look up the enum member by value
if issubclass(type(self), Enum):
return type(self)(result) # This will find the enum member

# For non-enum types, create new instance directly
return type(self)(result)

class RollYourOwnNewTypeEnum(ENVVariant, Enum):
class RollYourOwnNewTypeEnum(ENVVariant, Enum): # type: ignore[no-redef]

LOCAL = "LOCAL"
DEV = "DEV"
Expand All @@ -51,8 +58,8 @@ class RollYourOwnNewTypeEnum(ENVVariant, Enum):
PREPROD = "PREPROD"
PROD = "PROD"


def test_nt_env_replace():
# mypy doesn't raise errors here
def test_nt_env_replace() -> None:

env = ENV.LOCAL

Expand All @@ -63,6 +70,7 @@ def test_nt_env_replace():
# let's say now we want to replace the environment
# nevermind about the reason why we want to do so
env = env.replace(ENV.LOCAL, ENV.DEV)
# reveal_type(env) # Revealed type is "newtype_enums.ENV"

# replacement is successful
assert env is ENV.DEV
Expand All @@ -76,11 +84,13 @@ def test_nt_env_replace():
# cannot replace with something that is not a `ENV`
env = env.replace(ENV.DEV, "NotAnEnv")

# reveal_type(env) # Revealed type is "newtype_enums.ENV"

with pytest.raises(ValueError):
# cannot even make 'DEV' -> 'dev'
env = env.lower()

def test_reg_env_replace():
def test_reg_env_replace() -> None:

env = RegularENV.LOCAL

Expand All @@ -90,15 +100,17 @@ def test_reg_env_replace():
assert isinstance(env, RegularENV) # pass

# now we try to replace
env = env.replace(RegularENV.LOCAL, RegularENV.DEV)
env = env.replace("LOCAL", "DEV")

# we are hoping that it will continue to be a `RegularENV.DEV` but it is not
assert env is not RegularENV.DEV # pass, no longer a `RegularENV`
assert env is not RegularENV.LOCAL # pass, no longer a `RegularENV`
assert not isinstance(env, RegularENV)
assert isinstance(env, str) # 'downcast' (?) to `str`

def test_ryont_env_replace():
def test_ryont_env_replace() -> None:

assert RollYourOwnNewTypeEnum is not None

env = RollYourOwnNewTypeEnum.LOCAL

Expand Down Expand Up @@ -130,8 +142,8 @@ def test_ryont_env_replace():

env = RollYourOwnNewTypeEnum.LOCAL

# env = env.my_replace(RollYourOwnNewTypeEnum.LOCAL, RollYourOwnNewTypeEnum.PREPROD)
env = env.my_replace(RollYourOwnNewTypeEnum.LOCAL, RollYourOwnNewTypeEnum.PREPROD)

assert isinstance(env, str)
assert env is not RollYourOwnNewTypeEnum.PREPROD
assert env is RollYourOwnNewTypeEnum.PREPROD
assert isinstance(env, RollYourOwnNewTypeEnum)
9 changes: 8 additions & 1 deletion newtype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,11 @@ def __init__(self, val: int) -> None:


__version__ = "0.0.0" # Don't manually change, let poetry-dynamic-versioning handle it
__all__ = ["NewType", "newtype_exclude", "func_is_excluded", "NewTypeInit", "NewTypeMethod"]
__all__ = [
"NewType",
"newtype_exclude",
"func_is_excluded",
"NewTypeInit",
"NewTypeMethod",
"mypy_plugin",
]
169 changes: 169 additions & 0 deletions newtype/mypy_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import logging
import os
from typing import Any, Callable, List, Optional, Type, Union

from mypy.nodes import Argument, FuncDef, RefExpr, SymbolTableNode, TypeInfo, Var
from mypy.plugin import ClassDefContext, Plugin
from mypy.plugins.common import add_method
from mypy.types import AnyType, CallableType, Instance, TypeOfAny, UnionType
from mypy.types import Type as MypyType


# Set up logging
logger = logging.getLogger("newtype.mypy_plugin")
# Remove any existing handlers to prevent duplicates
for handler in logger.handlers[:]:
logger.removeHandler(handler)

# Only enable logging if __PYNT_DEBUG__ is set to "true"
if os.environ.get("__PYNT_DEBUG__", "").lower() == "true":
# Create a file handler
file_handler = logging.FileHandler("mypy_plugin.log")
file_handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger.addHandler(file_handler)
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.WARNING)


def convert_union_type(typ: MypyType) -> MypyType:
"""Convert a type to use UnionType instead of | operator."""
if isinstance(typ, UnionType):
# If it's already a UnionType, convert its items
return UnionType([convert_union_type(t) for t in typ.items])
elif isinstance(typ, Instance) and typ.args:
return typ.copy_modified(args=[convert_union_type(arg) for arg in typ.args])
return typ


class NewTypePlugin(Plugin):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
logger.info("Initializing NewTypePlugin")

def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
logger.debug(f"get_base_class_hook called with fullname: {fullname}")
if "newtype.NewType" in fullname:
logger.info(f"Found NewType class: {fullname}")
return handle_newtype_class
logger.debug(f"No hook for {fullname}")
return None


def handle_newtype_class(ctx: ClassDefContext) -> None: # noqa: C901
logger.info(f"Processing NewType class: {ctx.cls.fullname}")

if not hasattr(ctx.reason, "args") or not ctx.reason.args:
logger.warning("No arguments provided to NewType")
return

# Get base type from NewType argument
base_type_expr = ctx.reason.args[0]
logger.debug(f"Base type expression: {base_type_expr}")

if not isinstance(base_type_expr, RefExpr):
logger.warning(f"Base type expression is not a RefExpr: {type(base_type_expr)}")
return

base_type: Optional[SymbolTableNode]

# Handle built-in types specially
if base_type_expr.fullname and base_type_expr.fullname.startswith("builtins."):
logger.debug(f"Looking up built-in type: {base_type_expr.fullname}")
base_type = ctx.api.lookup_fully_qualified(base_type_expr.fullname)
else:
logger.debug(f"Looking up qualified type: {base_type_expr.fullname}")
base_type = ctx.api.lookup_qualified(base_type_expr.fullname, ctx.cls)

if not base_type:
logger.warning(f"Could not find base type: {base_type_expr.fullname}")
return
if not isinstance(base_type.node, TypeInfo):
logger.warning(f"Base type node is not a TypeInfo: {type(base_type.node)}")
return

# Set up inheritance
logger.info(f"Setting up inheritance for {ctx.cls.fullname} from {base_type.node.fullname}")
base_instance = Instance(base_type.node, [])
info = ctx.cls.info
info.bases = [base_instance]
info.mro = [info, base_type.node] + base_type.node.mro[1:]
logger.debug(f"MRO: {[t.fullname for t in info.mro]}")

# Copy all methods from base type
logger.info(f"Processing methods from base type {base_type.node.fullname}")
for name, node in base_type.node.names.items():
if isinstance(node.node, FuncDef) and isinstance(node.node.type, CallableType):
logger.debug(f"Processing method: {name}")
method_type = node.node.type

# Convert return type to subtype if it matches base type
ret_type = convert_union_type(method_type.ret_type)
logger.debug(f"Original return type for {name}: {ret_type}")

if isinstance(ret_type, Instance) and ret_type.type == base_type.node:
logger.debug(f"Converting return type for {name} to {info.fullname}")
ret_type = Instance(info, [])
elif isinstance(ret_type, UnionType):
logger.debug(f"Processing union return type for {name}: {ret_type}")
items: List[Union[MypyType, Instance]] = []
for item in ret_type.items:
if isinstance(item, Instance) and item.type == base_type.node:
logger.debug(f"Converting union item from {item} to {info.fullname}")
items.append(Instance(info, []))
else:
items.append(item)
ret_type = UnionType(items)
logger.debug(f"Final union return type for {name}: {ret_type}")

# Create arguments list, preserving original argument types
arguments = []
if method_type.arg_types:
logger.debug(f"Processing arguments for method {name}")
# Skip first argument (self)
for i, (arg_type, arg_kind, arg_name) in enumerate(
zip(
method_type.arg_types[1:],
method_type.arg_kinds[1:],
method_type.arg_names[1:] or [""] * len(method_type.arg_types[1:]),
),
start=1,
):
logger.debug(
f"Processing argument {i} for {name}: \
{arg_name or f'arg{i}'} of type {arg_type}"
)

# Special handling for __contains__ method
if name == "__contains__" and i == 1:
logger.debug(
"Using Any type for __contains__ argument to satisfy Container protocol"
)
arg_type = AnyType(TypeOfAny.special_form)
else:
# Convert any union types in arguments
arg_type = convert_union_type(arg_type)

# Create a new variable for the argument
var = Var(arg_name or f"arg{i}", arg_type)
var.is_ready = True

# Create the argument
arg = Argument(
variable=var,
type_annotation=arg_type,
initializer=None,
kind=arg_kind,
)
arguments.append(arg)

# Add method to class
logger.info(f"Adding method {name} to {ctx.cls.fullname} with return type {ret_type}")
add_method(ctx, name, arguments, ret_type)


def plugin(version: str) -> Type[Plugin]:
logger.info(f"Initializing plugin for mypy version: {version}")
return NewTypePlugin
Loading
Loading