|
| 1 | +import logging |
| 2 | +import os |
| 3 | +from typing import Any, Callable, List, Optional, Type, Union |
| 4 | + |
| 5 | +from mypy.nodes import Argument, FuncDef, RefExpr, SymbolTableNode, TypeInfo, Var |
| 6 | +from mypy.plugin import ClassDefContext, Plugin |
| 7 | +from mypy.plugins.common import add_method |
| 8 | +from mypy.types import AnyType, CallableType, Instance, TypeOfAny, UnionType |
| 9 | +from mypy.types import Type as MypyType |
| 10 | + |
| 11 | + |
| 12 | +# Set up logging |
| 13 | +logger = logging.getLogger("newtype.mypy_plugin") |
| 14 | +# Remove any existing handlers to prevent duplicates |
| 15 | +for handler in logger.handlers[:]: |
| 16 | + logger.removeHandler(handler) |
| 17 | + |
| 18 | +# Only enable logging if __PYNT_DEBUG__ is set to "true" |
| 19 | +if os.environ.get("__PYNT_DEBUG__", "").lower() == "true": |
| 20 | + # Create a file handler |
| 21 | + file_handler = logging.FileHandler("mypy_plugin.log") |
| 22 | + file_handler.setFormatter( |
| 23 | + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| 24 | + ) |
| 25 | + logger.addHandler(file_handler) |
| 26 | + logger.setLevel(logging.DEBUG) |
| 27 | +else: |
| 28 | + logger.setLevel(logging.WARNING) |
| 29 | + |
| 30 | + |
| 31 | +def convert_union_type(typ: MypyType) -> MypyType: |
| 32 | + """Convert a type to use UnionType instead of | operator.""" |
| 33 | + if isinstance(typ, UnionType): |
| 34 | + # If it's already a UnionType, convert its items |
| 35 | + return UnionType([convert_union_type(t) for t in typ.items]) |
| 36 | + elif isinstance(typ, Instance) and typ.args: |
| 37 | + return typ.copy_modified(args=[convert_union_type(arg) for arg in typ.args]) |
| 38 | + return typ |
| 39 | + |
| 40 | + |
| 41 | +class NewTypePlugin(Plugin): |
| 42 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 43 | + super().__init__(*args, **kwargs) |
| 44 | + logger.info("Initializing NewTypePlugin") |
| 45 | + |
| 46 | + def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: |
| 47 | + logger.debug(f"get_base_class_hook called with fullname: {fullname}") |
| 48 | + if "newtype.NewType" in fullname: |
| 49 | + logger.info(f"Found NewType class: {fullname}") |
| 50 | + return handle_newtype_class |
| 51 | + logger.debug(f"No hook for {fullname}") |
| 52 | + return None |
| 53 | + |
| 54 | + |
| 55 | +def handle_newtype_class(ctx: ClassDefContext) -> None: # noqa: C901 |
| 56 | + logger.info(f"Processing NewType class: {ctx.cls.fullname}") |
| 57 | + |
| 58 | + if not hasattr(ctx.reason, "args") or not ctx.reason.args: |
| 59 | + logger.warning("No arguments provided to NewType") |
| 60 | + return |
| 61 | + |
| 62 | + # Get base type from NewType argument |
| 63 | + base_type_expr = ctx.reason.args[0] |
| 64 | + logger.debug(f"Base type expression: {base_type_expr}") |
| 65 | + |
| 66 | + if not isinstance(base_type_expr, RefExpr): |
| 67 | + logger.warning(f"Base type expression is not a RefExpr: {type(base_type_expr)}") |
| 68 | + return |
| 69 | + |
| 70 | + base_type: Optional[SymbolTableNode] |
| 71 | + |
| 72 | + # Handle built-in types specially |
| 73 | + if base_type_expr.fullname and base_type_expr.fullname.startswith("builtins."): |
| 74 | + logger.debug(f"Looking up built-in type: {base_type_expr.fullname}") |
| 75 | + base_type = ctx.api.lookup_fully_qualified(base_type_expr.fullname) |
| 76 | + else: |
| 77 | + logger.debug(f"Looking up qualified type: {base_type_expr.fullname}") |
| 78 | + base_type = ctx.api.lookup_qualified(base_type_expr.fullname, ctx.cls) |
| 79 | + |
| 80 | + if not base_type: |
| 81 | + logger.warning(f"Could not find base type: {base_type_expr.fullname}") |
| 82 | + return |
| 83 | + if not isinstance(base_type.node, TypeInfo): |
| 84 | + logger.warning(f"Base type node is not a TypeInfo: {type(base_type.node)}") |
| 85 | + return |
| 86 | + |
| 87 | + # Set up inheritance |
| 88 | + logger.info(f"Setting up inheritance for {ctx.cls.fullname} from {base_type.node.fullname}") |
| 89 | + base_instance = Instance(base_type.node, []) |
| 90 | + info = ctx.cls.info |
| 91 | + info.bases = [base_instance] |
| 92 | + info.mro = [info, base_type.node] + base_type.node.mro[1:] |
| 93 | + logger.debug(f"MRO: {[t.fullname for t in info.mro]}") |
| 94 | + |
| 95 | + # Copy all methods from base type |
| 96 | + logger.info(f"Processing methods from base type {base_type.node.fullname}") |
| 97 | + for name, node in base_type.node.names.items(): |
| 98 | + if isinstance(node.node, FuncDef) and isinstance(node.node.type, CallableType): |
| 99 | + logger.debug(f"Processing method: {name}") |
| 100 | + method_type = node.node.type |
| 101 | + |
| 102 | + # Convert return type to subtype if it matches base type |
| 103 | + ret_type = convert_union_type(method_type.ret_type) |
| 104 | + logger.debug(f"Original return type for {name}: {ret_type}") |
| 105 | + |
| 106 | + if isinstance(ret_type, Instance) and ret_type.type == base_type.node: |
| 107 | + logger.debug(f"Converting return type for {name} to {info.fullname}") |
| 108 | + ret_type = Instance(info, []) |
| 109 | + elif isinstance(ret_type, UnionType): |
| 110 | + logger.debug(f"Processing union return type for {name}: {ret_type}") |
| 111 | + items: List[Union[MypyType, Instance]] = [] |
| 112 | + for item in ret_type.items: |
| 113 | + if isinstance(item, Instance) and item.type == base_type.node: |
| 114 | + logger.debug(f"Converting union item from {item} to {info.fullname}") |
| 115 | + items.append(Instance(info, [])) |
| 116 | + else: |
| 117 | + items.append(item) |
| 118 | + ret_type = UnionType(items) |
| 119 | + logger.debug(f"Final union return type for {name}: {ret_type}") |
| 120 | + |
| 121 | + # Create arguments list, preserving original argument types |
| 122 | + arguments = [] |
| 123 | + if method_type.arg_types: |
| 124 | + logger.debug(f"Processing arguments for method {name}") |
| 125 | + # Skip first argument (self) |
| 126 | + for i, (arg_type, arg_kind, arg_name) in enumerate( |
| 127 | + zip( |
| 128 | + method_type.arg_types[1:], |
| 129 | + method_type.arg_kinds[1:], |
| 130 | + method_type.arg_names[1:] or [""] * len(method_type.arg_types[1:]), |
| 131 | + ), |
| 132 | + start=1, |
| 133 | + ): |
| 134 | + logger.debug( |
| 135 | + f"Processing argument {i} for {name}: \ |
| 136 | + {arg_name or f'arg{i}'} of type {arg_type}" |
| 137 | + ) |
| 138 | + |
| 139 | + # Special handling for __contains__ method |
| 140 | + if name == "__contains__" and i == 1: |
| 141 | + logger.debug( |
| 142 | + "Using Any type for __contains__ argument to satisfy Container protocol" |
| 143 | + ) |
| 144 | + arg_type = AnyType(TypeOfAny.special_form) |
| 145 | + else: |
| 146 | + # Convert any union types in arguments |
| 147 | + arg_type = convert_union_type(arg_type) |
| 148 | + |
| 149 | + # Create a new variable for the argument |
| 150 | + var = Var(arg_name or f"arg{i}", arg_type) |
| 151 | + var.is_ready = True |
| 152 | + |
| 153 | + # Create the argument |
| 154 | + arg = Argument( |
| 155 | + variable=var, |
| 156 | + type_annotation=arg_type, |
| 157 | + initializer=None, |
| 158 | + kind=arg_kind, |
| 159 | + ) |
| 160 | + arguments.append(arg) |
| 161 | + |
| 162 | + # Add method to class |
| 163 | + logger.info(f"Adding method {name} to {ctx.cls.fullname} with return type {ret_type}") |
| 164 | + add_method(ctx, name, arguments, ret_type) |
| 165 | + |
| 166 | + |
| 167 | +def plugin(version: str) -> Type[Plugin]: |
| 168 | + logger.info(f"Initializing plugin for mypy version: {version}") |
| 169 | + return NewTypePlugin |
0 commit comments