Skip to content

Commit 6c2a8b6

Browse files
Skylion007pytorchmergebot
authored andcommitted
[Ez][BE]: Enable new stable ruff rules (#129825)
Applies a bunch of new ruff lint rules that are now stable. Some of these improve efficiency or readability. Since I already did passes on the codebase for these when they were in preview, there should be relatively few changes to the codebase. This is just more for future hardening of it. Pull Request resolved: pytorch/pytorch#129825 Approved by: https://github.com/XuehaiPan, https://github.com/jansel, https://github.com/malfet
1 parent 2926655 commit 6c2a8b6

File tree

16 files changed

+45
-38
lines changed

16 files changed

+45
-38
lines changed

benchmarks/distributed/ddp/benchmark.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -226,17 +226,17 @@ def main():
226226
print("-----------------------------------")
227227
print("PyTorch distributed benchmark suite")
228228
print("-----------------------------------")
229-
print("")
229+
print()
230230
print(f"* PyTorch version: {torch.__version__}")
231231
print(f"* CUDA version: {torch.version.cuda}")
232232
print(f"* Distributed backend: {args.distributed_backend}")
233233
print(f"* Maximum bucket size: {args.bucket_size}MB")
234-
print("")
234+
print()
235235
print("--- nvidia-smi topo -m ---")
236-
print("")
236+
print()
237237
print(output[0])
238238
print("--------------------------")
239-
print("")
239+
print()
240240

241241
torch.cuda.set_device(dist.get_rank() % 8)
242242
device = torch.device("cuda:%d" % (dist.get_rank() % 8))

benchmarks/distributed/ddp/diff.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def main():
3232
va = str(ja.get(key, "-"))
3333
vb = str(jb.get(key, "-"))
3434
print(f"{key + ':':20s} {va:>20s} vs {vb:>20s}")
35-
print("")
35+
print()
3636

3737
ba = ja["benchmark_results"]
3838
bb = jb["benchmark_results"]
@@ -48,13 +48,13 @@ def main():
4848
print(f"Benchmark: {name}")
4949

5050
# Print header
51-
print("")
51+
print()
5252
print(f"{'':>10s}", end="") # noqa: E999
5353
for _ in [75, 95]:
5454
print(
5555
f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end=""
5656
) # noqa: E999
57-
print("")
57+
print()
5858

5959
# Print measurements
6060
for i, (xa, xb) in enumerate(zip(ra["result"], rb["result"])):
@@ -78,8 +78,8 @@ def main():
7878
f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%",
7979
end="",
8080
) # noqa: E999
81-
print("")
82-
print("")
81+
print()
82+
print()
8383

8484

8585
if __name__ == "__main__":

benchmarks/dynamo/timm_models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ def iter_model_names(self, args):
318318
if index < start or index >= end:
319319
continue
320320
if (
321-
not re.search("|".join(args.filter), model_name, re.I)
322-
or re.search("|".join(args.exclude), model_name, re.I)
321+
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
322+
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
323323
or model_name in args.exclude_exact
324324
or model_name in self.skip_models
325325
):

benchmarks/dynamo/torchbench.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -399,8 +399,8 @@ def iter_model_names(self, args):
399399

400400
model_name = os.path.basename(model_path)
401401
if (
402-
not re.search("|".join(args.filter), model_name, re.I)
403-
or re.search("|".join(args.exclude), model_name, re.I)
402+
not re.search("|".join(args.filter), model_name, re.IGNORECASE)
403+
or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
404404
or model_name in args.exclude_exact
405405
or model_name in self.skip_models
406406
):

benchmarks/fastrnns/test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_rnns(
7979

8080
if verbose:
8181
print(experim.forward.graph_for(*experim.inputs))
82-
print("")
82+
print()
8383

8484

8585
def test_vl_py(**test_args):
@@ -141,7 +141,7 @@ def test_vl_py(**test_args):
141141

142142
if test_args["verbose"]:
143143
print(experim.forward.graph_for(*experim.inputs))
144-
print("")
144+
print()
145145

146146

147147
if __name__ == "__main__":

pyproject.toml

+6
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ select = [
9797
"SIM1",
9898
"W",
9999
# Not included in flake8
100+
"FURB",
100101
"LOG",
101102
"NPY",
102103
"PERF",
@@ -113,10 +114,13 @@ select = [
113114
"PLR0133", # constant comparison
114115
"PLR0206", # property with params
115116
"PLR1722", # use sys exit
117+
"PLR1736", # unnecessary list index
116118
"PLW0129", # assert on string literal
119+
"PLW0133", # useless exception statement
117120
"PLW0406", # import self
118121
"PLW0711", # binary op exception
119122
"PLW1509", # preexec_fn not safe with threads
123+
"PLW2101", # useless lock statement
120124
"PLW3301", # nested min max
121125
"PT006", # TODO: enable more PT rules
122126
"PT022",
@@ -133,6 +137,8 @@ select = [
133137
"RUF016", # type error non-integer index
134138
"RUF017",
135139
"RUF018", # no assignment in assert
140+
"RUF024", # from keys mutable
141+
"RUF026", # default factory kwarg
136142
"TCH",
137143
"TRY002", # ban vanilla raise (todo fix NOQAs)
138144
"TRY302",

test/error_messages/storage.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def check_error(desc, fn, *required_substrings):
1010
print(desc)
1111
print("-" * 80)
1212
print(error_message)
13-
print("")
13+
print()
1414
for sub in required_substrings:
1515
assert sub in error_message
1616
return

test/onnx/model_defs/srresnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def forward(self, x):
5151
class SRResNet(nn.Module):
5252
def __init__(self, rescale_factor, n_filters, n_blocks):
5353
super().__init__()
54-
self.rescale_levels = int(math.log(rescale_factor, 2))
54+
self.rescale_levels = int(math.log(rescale_factor, 2)) # noqa: FURB163
5555
self.n_filters = n_filters
5656
self.n_blocks = n_blocks
5757

test/quantization/core/experimental/quantization_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def evaluate(model, criterion, data_loader):
8080
acc1, acc5 = accuracy(output, target, topk=(1, 5))
8181
top1.update(acc1[0], image.size(0))
8282
top5.update(acc5[0], image.size(0))
83-
print('')
83+
print()
8484

8585
return top1, top5
8686

test/quantization/core/test_quantized_op.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1903,8 +1903,8 @@ def test_adaptive_avg_pool2d_nhwc(self):
19031903
X = np.array(X)
19041904
scale = 1
19051905
H, W = X.shape[-2:]
1906-
output_size_h = output_size_h if (output_size_h <= H) else H
1907-
output_size_w = output_size_w if (output_size_w <= W) else W
1906+
output_size_h = min(output_size_h, H)
1907+
output_size_w = min(output_size_w, W)
19081908
if output_size_h == output_size_w:
19091909
output_size = output_size_h
19101910
else:
@@ -1977,9 +1977,9 @@ def test_adaptive_avg_pool(self):
19771977
dim_to_check.append(3)
19781978

19791979
D, H, W = X.shape[-3:]
1980-
output_size_d = output_size_d if (output_size_d <= D) else D
1981-
output_size_h = output_size_h if (output_size_h <= H) else H
1982-
output_size_w = output_size_w if (output_size_w <= W) else W
1980+
output_size_d = min(output_size_d, D)
1981+
output_size_h = min(output_size_h, H)
1982+
output_size_w = min(output_size_w, W)
19831983

19841984
X = torch.from_numpy(X)
19851985
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
@@ -2049,9 +2049,9 @@ def test_adaptive_avg_pool3d_ndhwc(self):
20492049
X = np.array(X)
20502050
scale = 1
20512051
D, H, W = X.shape[-3:]
2052-
output_size_d = output_size_d if (output_size_d <= D) else D
2053-
output_size_h = output_size_h if (output_size_h <= H) else H
2054-
output_size_w = output_size_w if (output_size_w <= W) else W
2052+
output_size_d = min(output_size_d, D)
2053+
output_size_h = min(output_size_h, H)
2054+
output_size_w = min(output_size_w, W)
20552055
if output_size_d == output_size_h == output_size_w:
20562056
output_size = output_size_h
20572057
else:

test/test_autograd.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4661,7 +4661,7 @@ def test_profiler_function_event_avg(self):
46614661
self.assertEqual(avg.device_time_total, 0)
46624662

46634663
def test_profiler_shapes(self):
4664-
print("")
4664+
print()
46654665
layer1 = torch.nn.Linear(20, 30)
46664666
layer2 = torch.nn.Linear(30, 40)
46674667
input = torch.randn(128, 20)
@@ -4683,7 +4683,7 @@ def test_profiler_shapes(self):
46834683
self.assertEqual(len(found_indices), len(linear_expected_shapes))
46844684

46854685
def test_profiler_aggregation_lstm(self):
4686-
print("")
4686+
print()
46874687
rnn = torch.nn.LSTM(10, 20, 2)
46884688
total_time_s = 0
46894689
with profile(record_shapes=True, use_kineto=kineto_available()) as prof:

torch/_functorch/partitioners.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1632,7 +1632,7 @@ def get_saved_values_knapsack(memory_budget):
16321632
for i, txt in enumerate(x_values):
16331633
plt.annotate(
16341634
f"{txt:.2f}",
1635-
(x_values[i], y_values[i]),
1635+
(txt, y_values[i]),
16361636
textcoords="offset points",
16371637
xytext=(0, 10),
16381638
ha="center",

torch/_inductor/codegen/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,8 @@ def all_in_parens(string):
430430

431431
if (
432432
isinstance(string, CSEVariable)
433-
or re.match(r"^[a-z0-9_.]+$", string, re.I)
434-
or re.match(r"^\([^)]*\)$", string, re.I)
433+
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
434+
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
435435
or string == ""
436436
):
437437
return string

torch/_numpy/testing/utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1966,11 +1966,11 @@ def _filter(self, category=Warning, message="", module=None, record=False):
19661966
self._clear_registries()
19671967

19681968
self._tmp_suppressions.append(
1969-
(category, message, re.compile(message, re.I), module, record)
1969+
(category, message, re.compile(message, re.IGNORECASE), module, record)
19701970
)
19711971
else:
19721972
self._suppressions.append(
1973-
(category, message, re.compile(message, re.I), module, record)
1973+
(category, message, re.compile(message, re.IGNORECASE), module, record)
19741974
)
19751975

19761976
return record
@@ -2318,7 +2318,8 @@ def _parse_size(size_str):
23182318
}
23192319

23202320
size_re = re.compile(
2321-
r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())), re.I
2321+
r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())),
2322+
re.IGNORECASE,
23222323
)
23232324

23242325
m = size_re.match(size_str.lower())

torch/distributed/benchmarks/benchmark_ddp_rpc.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,12 @@ def run_worker(rank, world_size):
297297
print("-------------------------------------------")
298298
print(" Info ")
299299
print("-------------------------------------------")
300-
print("")
300+
print()
301301
print(f"* PyTorch version: {torch.__version__}")
302302
print(f"* CUDA version: {torch.version.cuda}")
303-
print("")
303+
print()
304304
print("------------ nvidia-smi topo -m -----------")
305-
print("")
305+
print()
306306
print(output[0])
307307
print("-------------------------------------------")
308308
print("PyTorch Distributed Benchmark (DDP and RPC)")

torch/testing/_internal/common_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5018,5 +5018,5 @@ def repl_frame(m):
50185018
s = re.sub(r"\n*You can suppress this exception.+", "", s, flags=re.DOTALL)
50195019
if suppress_prefix:
50205020
s = re.sub(r"Cannot export model.+\n\n", "", s)
5021-
s = re.sub(r" +$", "", s, flags=re.M)
5021+
s = re.sub(r" +$", "", s, flags=re.MULTILINE)
50225022
return s

0 commit comments

Comments
 (0)