Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
remove unused type: ignore directives (pytorch#60006)
Browse files Browse the repository at this point in the history
Summary:
During development it is common practice to put `type: ignore` comments on lines that are correct, but `mypy` doesn't recognize this. This often stems from the fact, that the used `mypy` version wasn't able to handle the used pattern.

With every new release `mypy` gets better at handling complex code. In addition to fix all the previously accepted but now failing patterns, we should also revisit all `type: ignore` comments to see if they are still needed or not. Fortunately, we don't need to do it manually: by adding `warn_unused_ignores = True` to the configuration, `mypy` will error out in case it encounters an `type: ignore` that is no longer needed.

Pull Request resolved: pytorch#60006

Reviewed By: jbschlosser, malfet

Differential Revision: D29133237

Pulled By: albanD

fbshipit-source-id: 41e82edc5cd5affa7ccedad044b59b94dad4425a
  • Loading branch information
pmeier authored and facebook-github-bot committed Jun 18, 2021
1 parent 7c29ca7 commit d5988c5
Show file tree
Hide file tree
Showing 37 changed files with 108 additions and 93 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/lint_native_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def fn(base: str) -> str:
with open(Path(__file__).parent.parent.parent / fn('.'), "r") as f:
contents = f.read()

yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
yaml = ruamel.yaml.YAML()
yaml.preserve_quotes = True
yaml.width = 1000
yaml.boolean_representation = ['False', 'True']
Expand Down
2 changes: 1 addition & 1 deletion caffe2/contrib/aten/gen_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
sys.path.insert(0, os.path.join(args.aten_root, '..'))
from tools.codegen.code_template import CodeTemplate as CT
else:
from tools.codegen.code_template import CodeTemplate as CT # type: ignore[import,no-redef]
from tools.codegen.code_template import CodeTemplate as CT

OP_TEMPLATE = CT.from_file(
os.path.join(args.template_dir, 'aten_op_template.h'))
Expand Down
2 changes: 1 addition & 1 deletion caffe2/distributed/file_store_handler_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tempfile
import shutil

from caffe2.distributed.python import StoreHandlerTimeoutError # type: ignore[import]
from caffe2.distributed.python import StoreHandlerTimeoutError
from caffe2.distributed.store_ops_test_util import StoreOpsTests
from caffe2.python import core, workspace, dyndep
from caffe2.python.test_util import TestCase
Expand Down
2 changes: 1 addition & 1 deletion caffe2/distributed/redis_store_handler_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import uuid

from caffe2.distributed.python import StoreHandlerTimeoutError # type: ignore[import]
from caffe2.distributed.python import StoreHandlerTimeoutError
from caffe2.distributed.store_ops_test_util import StoreOpsTests
from caffe2.python import core, workspace, dyndep
from caffe2.python.test_util import TestCase
Expand Down
1 change: 1 addition & 0 deletions mypy-strict.ini
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ disallow_any_unimported = True
# Across versions of mypy, the flags toggled by --strict vary. To ensure
# we have reproducible type check, we instead manually specify the flags
warn_unused_configs = True
warn_unused_ignores = True
disallow_any_generics = True
disallow_subclassing_any = True
disallow_untyped_calls = True
Expand Down
14 changes: 14 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ plugins = mypy_plugins/check_mypy_version.py

cache_dir = .mypy_cache/normal
warn_unused_configs = True
warn_unused_ignores = True
warn_redundant_casts = True
show_error_codes = True
show_column_numbers = True
Expand Down Expand Up @@ -95,6 +96,19 @@ ignore_errors = True
[mypy-torch.overrides]
ignore_errors = True

#
# Files with 'type: ignore' comments that are needed if checked with mypy-strict.ini
#

[mypy-tools.render_junit]
warn_unused_ignores = False

[mypy-tools.generate_torch_version]
warn_unused_ignores = False

[mypy-tools.stats_utils.s3_stat_parser]
warn_unused_ignores = False

#
# Adding type annotations to caffe2 is probably not worth the effort
# only work on this if you have a specific reason for it, otherwise
Expand Down
2 changes: 1 addition & 1 deletion test/test_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_set_exception(self) -> None:
f = Future()
f.set_exception(value_error)
with self.assertRaisesRegex(ValueError, "Intentional"):
f.value() # type: ignore[attr-defined]
f.value()

def cb(fut):
fut.value()
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def forward(self, x):
# data can be passed without errors
x = torch.randn(4, 4).fill_(1.0)
ms(x)
with self.assertRaisesRegex(torch.jit.Error, "foo"): # type: ignore[type-var]
with self.assertRaisesRegex(torch.jit.Error, "foo"):
ms(torch.tensor([False], dtype=torch.bool))


Expand Down
12 changes: 9 additions & 3 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool:
'iadd', 'iand', 'idiv', 'ilshift', 'imul',
'ior', 'irshift', 'isub', 'ixor', 'ifloordiv', 'imod', # inplace ops
)
comparison_ops = ('eq', 'ne', 'ge', 'gt', 'lt', 'le')
symmetric_comparison_ops = ('eq', 'ne')
asymmetric_comparison_ops = ('ge', 'gt', 'lt', 'le')
comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops

unary_ops = ('neg', 'abs', 'invert')
to_py_type_ops = ('bool', 'float', 'complex', 'long', 'index', 'int', 'nonzero')
all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops
Expand All @@ -145,8 +148,11 @@ def sig_for_ops(opname: str) -> List[str]:
if name in binary_ops:
return ['def {}(self, other: Any) -> Tensor: ...'.format(opname)]
elif name in comparison_ops:
# unsafe override https://github.com/python/mypy/issues/5704
return ['def {}(self, other: Any) -> Tensor: ... # type: ignore[override]'.format(opname)]
sig = 'def {}(self, other: Any) -> Tensor: ...'.format(opname)
if name in symmetric_comparison_ops:
# unsafe override https://github.com/python/mypy/issues/5704
sig += ' # type: ignore[override]'
return [sig]
elif name in unary_ops:
return ['def {}(self) -> Tensor: ...'.format(opname)]
elif name in to_py_type_ops:
Expand Down
2 changes: 1 addition & 1 deletion tools/render_junit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

try:
import rich # type: ignore[import]
import rich
except ImportError:
print("rich not found, for color output use 'pip install rich'")

Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_sharded_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _init_enumerable(

def _parse_and_validate_remote_device(self, device):

on, local_device = _parse_remote_device(device) # type: ignore[arg-type]
on, local_device = _parse_remote_device(device)

# Validate rank.
if isinstance(on, int) and (on < 0 or on >= dist.get_world_size(self._process_group)):
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,7 @@ def all_gather_object(object_list, obj, group=None):
all_gather(output_tensors, input_tensor, group=group)
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8) # type:ignore[call-overload]
tensor = tensor.type(torch.uint8)
if tensor.device != torch.device("cpu"):
tensor = tensor.cpu()
tensor_size = object_size_list[i]
Expand Down Expand Up @@ -1695,7 +1695,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
if my_rank != dst:
return
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8) # type: ignore[call-overload]
tensor = tensor.type(torch.uint8)
tensor_size = object_size_list[i]
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)

Expand Down Expand Up @@ -1790,7 +1790,7 @@ def broadcast_object_list(object_list, src=0, group=None):
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8) # type: ignore[call-overload]
obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
offset += obj_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
# see the explanation in the except clause below.
for is_server in [is_host, False]:
try:
store = TCPStore( # type: ignore[call-arg]
store = TCPStore(
host, port, is_master=is_server, timeout=timedelta(seconds=read_timeout)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, dataset, num_replicas=None, rank=None, start_index=0):

self.start_index = start_index
self.num_samples = int(
math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas) # type: ignore[arg-type]
math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas)
)
self.total_size = self.num_samples * self.num_replicas

Expand All @@ -53,7 +53,7 @@ def __iter__(self):
g = torch.Generator()
g.manual_seed(self.epoch)
indices = (
torch.randperm(len(self.dataset) - self.start_index, generator=g) # type: ignore[arg-type]
torch.randperm(len(self.dataset) - self.start_index, generator=g)
.add(self.start_index)
.tolist()
)
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/launcher/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic import events, metrics
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState # type: ignore[import]
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent # type: ignore[import]
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from torch.distributed.elastic.multiprocessing import Std
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError, record
from torch.distributed.elastic.rendezvous import RendezvousParameters
Expand Down
14 changes: 7 additions & 7 deletions torch/distributed/nn/api/remote_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,10 @@ def named_modules(
):
_raise_not_supported(self.named_modules.__name__)

def train(self: T, mode: bool = True) -> T: # type: ignore[return]
def train(self: T, mode: bool = True) -> T:
return self.module_rref.rpc_sync().train() # type: ignore[operator, union-attr]

def eval(self: T) -> T: # type: ignore[return]
def eval(self: T) -> T:
return self.module_rref.rpc_sync().eval() # type: ignore[operator, union-attr]

def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return]
Expand All @@ -413,7 +413,7 @@ def share_memory(self: T) -> T: # type: ignore[return]
def extra_repr(self) -> str: # type: ignore[return]
_raise_not_supported(self.extra_repr.__name__)

def _prepare_init(self, remote_device: str) -> bool: # type: ignore[return]
def _prepare_init(self, remote_device: str) -> bool:
"""
Prepares the initializaiton and returns whether to enable automatically moving CPU tensors to CUDA devices.
"""
Expand Down Expand Up @@ -639,7 +639,7 @@ def __init__(
args: Tuple = None,
kwargs: Dict[str, Any] = None,
):
super().__init__(remote_device, module_cls, args, kwargs) # type: ignore[arg-type]
super().__init__(remote_device, module_cls, args, kwargs)


def _remote_module_receiver(
Expand All @@ -651,7 +651,7 @@ def _remote_module_receiver(
serialized_remote_module = _SerializedRemoteModule._make(
remote_module_pickled_attrs
)
m = object.__new__(RemoteModule) # type: ignore[attr-defined]
m = object.__new__(RemoteModule)
m.__dict__.update(serialized_remote_module._asdict())

# Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method.
Expand All @@ -675,10 +675,10 @@ def _remote_module_reducer(remote_module):
# Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method.
if k == "module_rref":
pickled_attrs[k] = v._serialize()
elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES: # type: ignore[attr-defined]
elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES:
pickled_attrs[k] = v
# Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING: # type: ignore[attr-defined]
elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING:
print(
"The new attribute ``{}`` of RemoteModule is ignored during RPC pickling. "
"To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. "
Expand Down
3 changes: 1 addition & 2 deletions torch/distributed/rendezvous.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
result = urlparse(url)
if rank != -1 or world_size != -1:
query_dict: Dict[str, Union[int, str]] = dict(
# mypy doesn't allow dict() to accept List of values (#257)
pair.split("=") for pair in filter(None, result.query.split("&")) # type: ignore[arg-type, misc]
pair.split("=") for pair in filter(None, result.query.split("&"))
)
assert (
"rank" not in query_dict and "world_size" not in query_dict
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class lazy_property(object):
"""
def __init__(self, wrapped):
self.wrapped = wrapped
update_wrapper(self, wrapped) # type: ignore[arg-type]
update_wrapper(self, wrapped)

def __get__(self, instance, obj_type=None):
if instance is None:
Expand Down
4 changes: 2 additions & 2 deletions torch/fx/experimental/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def call_function(
args, # type: ignore[arg-type]
kwargs,
arg_types, # type: ignore[arg-type]
kwarg_types, # type: ignore[arg-type]
kwarg_types,
self.normalize_to_only_use_kwargs,
)
if new_args_and_kwargs:
Expand All @@ -93,7 +93,7 @@ def call_module(
self.module,
target,
args, # type: ignore[arg-type]
kwargs, # type: ignore[arg-type]
kwargs,
self.normalize_to_only_use_kwargs,
)
if new_args_and_kwargs:
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/passes/net_min_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _tag_nodes(self, selected_nodes: NodeSet):
if node in selected_nodes:
node.tag = "minimize"
elif any(
n.tag in {"minimize", "main_1"} # type: ignore[attr-defined]
n.tag in {"minimize", "main_1"}
for n in node.all_input_nodes
if n.op in CALLABLE_NODE_OPS
):
Expand Down
34 changes: 17 additions & 17 deletions torch/fx/passes/split_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def __init__(self, name: str):
self.outputs: Dict[str, None] = {}
self.partitions_dependent_on: Dict[str, None] = {}
self.partition_dependents: Dict[str, None] = {}
self.graph : torch.fx.graph.Graph = torch.fx.graph.Graph() # type: ignore[attr-defined, name-defined]
self.environment : Dict[torch.fx.node.Node, torch.fx.node.Node] = {} # type: ignore[name-defined]
self.graph : torch.fx.graph.Graph = torch.fx.graph.Graph()
self.environment : Dict[torch.fx.node.Node, torch.fx.node.Node] = {}
self.targets : Dict[str, Any] = {}

def __repr__(self) -> str:
Expand All @@ -26,12 +26,12 @@ def __repr__(self) -> str:
def split_module(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int], # type: ignore[name-defined]
split_callback: Callable[[torch.fx.node.Node], int],
):
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {} # type: ignore[name-defined]
orig_nodes: Dict[str, torch.fx.node.Node] = {}

def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optional[torch.fx.node.Node]): # type: ignore[name-defined] # noqa: B950
def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optional[torch.fx.node.Node]): # noqa: B950
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name:
Expand All @@ -56,7 +56,7 @@ def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optiona
if node.op in ["placeholder", "get_attr"]:
continue
if node.op == 'output':
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None)) # type: ignore[attr-defined]
torch.fx.graph.map_arg(node.args[0], lambda n: record_cross_partition_use(n, None))
continue
partition_name = str(split_callback(node))

Expand All @@ -68,8 +68,8 @@ def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optiona
partition.node_names.append(node.name)
node._fx_partition = partition_name

torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node)) # type: ignore[attr-defined]
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # type: ignore[attr-defined] # noqa: B950
torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)) # noqa: B950

# find partitions with no dependencies
root_partitions : List[str] = []
Expand Down Expand Up @@ -104,8 +104,8 @@ def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optiona

# swap out old graph nodes in kw/args with references to new nodes in this submodule
environment = partition.environment
gathered_args = torch.fx.graph.map_arg(node.args, lambda n : environment[n]) # type: ignore[attr-defined]
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n : environment[n]) # type: ignore[attr-defined]
gathered_args = torch.fx.graph.map_arg(node.args, lambda n : environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n : environment[n])

if node.op not in ['call_module', 'get_attr']:
target = node.target
Expand All @@ -128,9 +128,9 @@ def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optiona
partition.environment[node] = new_node

# Set up values to construct base module
base_mod_env : Dict[str, torch.fx.node.Node] = {} # type: ignore[name-defined]
base_mod_graph : torch.fx.graph.Graph = torch.fx.graph.Graph() # type: ignore[attr-defined, name-defined]
base_mod_attrs : Dict[str, torch.fx.graph_module.GraphModule] = {} # type: ignore[name-defined]
base_mod_env : Dict[str, torch.fx.node.Node] = {}
base_mod_graph : torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs : Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
base_mod_env[node.name] = base_mod_graph.placeholder(node.name)
Expand Down Expand Up @@ -159,21 +159,21 @@ def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optiona

# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, partition.graph) # type: ignore[attr-defined] # noqa: B950
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, partition.graph) # noqa: B950

# Emit call in base graph to this submodule

output_val = base_mod_graph.call_module(submod_name, tuple(base_mod_env[name] for name in partition.inputs))
if len(partition.outputs) > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val) # type: ignore[attr-defined]
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
base_mod_env[list(partition.outputs)[0]] = output_val

for node in m.graph.nodes:
if node.op == 'output':
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n : base_mod_env[n.name])) # type: ignore[attr-defined] # noqa: B950
base_mod_graph.output(torch.fx.graph.map_arg(node.args[0], lambda n : base_mod_env[n.name])) # noqa: B950

return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) # type: ignore[attr-defined]
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)
Loading

0 comments on commit d5988c5

Please sign in to comment.