Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chadrik committed Oct 13, 2023
1 parent 4fbc405 commit a714309
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
15 changes: 11 additions & 4 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,18 @@ def _get_func_args(self, o: FuncDef, ctx: FunctionContext) -> list[ArgSig]:
name = f"**{name}"

args.append(ArgSig(name, typename, default=bool(arg_.initializer)))
if o.name == "__init__" and is_dataclass_generated and "**" in args:

is_dataclass_generated = (
self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
)
if o.name == "__init__" and is_dataclass_generated and "**" in [a.name for a in args]:
# The dataclass plugin generates invalid nameless "*" and "**" arguments
new_name = "".join(a.split(":", 1)[0] for a in args).replace("*", "")
args[args.index("*")] = f"*{new_name}_" # this name is guaranteed to be unique
args[args.index("**")] = f"**{new_name}__" # same here
new_name = "".join(a.name.strip("*") for a in args)
for arg in args:
if arg.name == "*":
arg.name = f"*{new_name}_" # this name is guaranteed to be unique
elif arg.name == "**":
arg.name = f"**{new_name}__" # same here
return args

def _get_func_return(self, o: FuncDef, ctx: FunctionContext) -> str | None:
Expand Down
7 changes: 5 additions & 2 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,10 @@ def import_lines(self) -> list[str]:
# be imported from it. the names can also be alias in the form 'original as alias'
module_map: Mapping[str, list[str]] = defaultdict(list)

for name in sorted(self.required_names):
for name in sorted(
self.required_names,
key=lambda n: (self.reverse_alias[n], n) if n in self.reverse_alias else (n, ""),
):
# If we haven't seen this name in an import statement, ignore it
if name not in self.module_for:
continue
Expand All @@ -477,7 +480,7 @@ def import_lines(self) -> list[str]:
assert "." not in name # Because reexports only has nonqualified names
result.append(f"import {name} as {name}\n")
else:
result.append(f"import {self.direct_imports[name]}\n")
result.append(f"import {name}\n")

# Now generate all the from ... import ... lines collected in module_map
for module, names in sorted(module_map.items()):
Expand Down

0 comments on commit a714309

Please sign in to comment.