diff --git a/src/shed/_codemods.py b/src/shed/_codemods.py index 6e0219e..09cd198 100644 --- a/src/shed/_codemods.py +++ b/src/shed/_codemods.py @@ -128,12 +128,7 @@ def convert_optional_literal_to_literal_none(self, _, updated_node): expr = updated_node.slice[0].slice.value args = list(expr.slice) args[-1] = args[-1].with_changes(comma=cst.Comma()) - args.append( - cst.SubscriptElement( - slice=cst.Index(value=cst.Name(value="None")), - comma=cst.MaybeSentinel.DEFAULT, - ) - ) + args.append(cst.SubscriptElement(slice=cst.Index(value=cst.Name(value="None")))) expr = expr.with_changes(slice=tuple(args)) return expr @@ -305,9 +300,6 @@ def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node): def reorder_union_literal_contents_none_last(self, _, updated_node): subscript_slice = list(updated_node.slice) subscript_slice.sort(key=lambda elt: elt.slice.value.value == "None") - subscript_slice[-1] = subscript_slice[-1].with_changes( - comma=cst.MaybeSentinel.DEFAULT - ) return updated_node.with_changes(slice=subscript_slice) @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation())) @@ -332,3 +324,48 @@ def _has_none(node): return updated_node.with_changes(left=updated_node.right, right=node_left) else: return updated_node + + @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation())) + @m.leave( + m.BinaryOperation( + left=m.Subscript(value=m.Name(value="Literal")), + operator=m.BitOr(), + right=m.Name("None"), + ) + ) + def flatten_literal_op(self, _, updated_node): + literal = updated_node.left + args = list(literal.slice) + for item in args: + if m.matches(item, m.SubscriptElement(m.Index(m.Name("None")))): + return literal # Already has "None" + args.append( + cst.SubscriptElement( + slice=cst.Index(value=cst.Name(value="None")), + ) + ) + return literal.with_changes(slice=tuple(args)) + + @m.leave(m.Subscript(value=m.Name(value="Union") | m.Name(value="Literal"))) + def flatten_union_literal_subscript(self, _, updated_node): + new_slice = [] + has_none = False + for item in updated_node.slice: + if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))): + new_slice += item.slice.value.slice # peel off "Optional" + has_none = True + elif m.matches( + item.slice.value, m.Subscript(m.Name("Union") | m.Name("Literal")) + ) and m.matches(updated_node.value, item.slice.value.value): + new_slice += item.slice.value.slice # peel off "Union" or "Literal" + elif m.matches(item.slice.value, m.Name("None")): + has_none = True + else: + new_slice.append(item) + if has_none: + new_slice.append( + cst.SubscriptElement( + slice=cst.Index(value=cst.Name(value="None")), + ) + ) + return updated_node.with_changes(slice=new_slice) diff --git a/tests/recorded/flatten_literal.txt b/tests/recorded/flatten_literal.txt new file mode 100644 index 0000000..0a41277 --- /dev/null +++ b/tests/recorded/flatten_literal.txt @@ -0,0 +1,15 @@ +Literal[1, 2] | None # this should not change +var: Literal[1, 2] | None +var2: Literal[1, Literal[2, 3]] +var3: Literal[None, 1, 2] | None +Literal[1, 2, Union[bool, str]] # this should not change +Literal[1, 2, Union[bool, str], Optional[int]] + +================================================================================ + +Literal[1, 2] | None # this should not change +var: Literal[1, 2, None] +var2: Literal[1, 2, 3] +var3: Literal[1, 2, None] +Literal[1, 2, Union[bool, str]] # this should not change +Literal[1, 2, Union[bool, str], int, None] diff --git a/tests/recorded/flatten_union.txt b/tests/recorded/flatten_union.txt new file mode 100644 index 0000000..bd21a0e --- /dev/null +++ b/tests/recorded/flatten_union.txt @@ -0,0 +1,13 @@ +Union[int, Optional[str], bool] +Union[None, int, Optional[str]] +Union[Union[int, float], str] +Union[int, Literal[1, 2]] # this should not change +Union[int, Literal[1, 2], Optional[str]] + +================================================================================ + +Union[int, str, bool, None] +Union[int, str, None] +Union[int, float, str] +Union[int, Literal[1, 2]] # this should not change +Union[int, Literal[1, 2], str, None]