@@ -254,19 +254,16 @@ def _embedding_input_wrangler(
254
254
args : list [Any ], kwargs : dict [str , Any ]
255
255
) -> tuple [list [Any ], dict [str , Any ]]:
256
256
"""Remove arguments not present in the aten op signature."""
257
- if "max_norm" in kwargs :
258
- del kwargs ["max_norm" ]
259
- if "norm_type" in kwargs :
260
- del kwargs ["norm_type" ]
257
+ kwargs .pop ("max_norm" , None )
258
+ kwargs .pop ("norm_type" , None )
261
259
return args , kwargs
262
260
263
261
264
262
def _empty_input_wrangler (
265
263
args : list [Any ], kwargs : dict [str , Any ]
266
264
) -> tuple [list [Any ], dict [str , Any ]]:
267
265
"""Remove arguments not present in the aten op signature."""
268
- if "requires_grad" in kwargs :
269
- del kwargs ["requires_grad" ]
266
+ kwargs .pop ("requires_grad" , None )
270
267
return args , kwargs
271
268
272
269
@@ -325,8 +322,7 @@ def _max_pool_input_wrangler(
325
322
args : list [Any ], kwargs : dict [str , Any ]
326
323
) -> tuple [list [Any ], dict [str , Any ]]:
327
324
# Remove return_indices argument because this op doesn't accept it
328
- if "return_indices" in kwargs :
329
- del kwargs ["return_indices" ]
325
+ kwargs .pop ("return_indices" , None )
330
326
return args , kwargs
331
327
332
328
@@ -364,8 +360,7 @@ def _nll_loss_input_wrangler(
364
360
def _nonzero_input_wrangler (
365
361
args : list [Any ], kwargs : dict [str , Any ]
366
362
) -> tuple [list [Any ], dict [str , Any ]]:
367
- if "as_tuple" in kwargs :
368
- del kwargs ["as_tuple" ]
363
+ kwargs .pop ("as_tuple" , None )
369
364
return args , kwargs
370
365
371
366
@@ -421,8 +416,7 @@ def _roll_input_wrangler(
421
416
def _scalar_tensor_input_wrangler (
422
417
args : list [Any ], kwargs : dict [str , Any ]
423
418
) -> tuple [list [Any ], dict [str , Any ]]:
424
- if "requires_grad" in kwargs :
425
- del kwargs ["requires_grad" ]
419
+ kwargs .pop ("requires_grad" , None )
426
420
return args , kwargs
427
421
428
422
0 commit comments