Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Dec 27, 2024
1 parent 095bad4 commit 1f06050
Show file tree
Hide file tree
Showing 13 changed files with 37 additions and 42 deletions.
4 changes: 2 additions & 2 deletions onnxscript/_thirdparty/asciichartpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def plot(series, *, bin_edges=None, cfg=None):
height = cfg.get("height", interval)
ratio = height / interval if interval > 0 else 1

min2 = int(floor(minimum * ratio))
max2 = int(ceil(maximum * ratio))
min2 = floor(minimum * ratio)
max2 = ceil(maximum * ratio)

def clamp(n):
return min(max(n, minimum), maximum)
Expand Down
16 changes: 8 additions & 8 deletions onnxscript/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,14 +1239,14 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
if i != len(loop_stmt.body) - 1:
self.fail(s, "Instruction break must be the last one of the loop.")

_current_scope = self._current_scope()
if s.test.id not in _current_scope:
current_scope = self._current_scope()
if s.test.id not in current_scope:
self.fail(
loop_stmt,
f"Unable to find condition variable {s.test.id!r} in known "
f"variables {list(_current_scope)!r}.",
f"variables {list(current_scope)!r}.",
)
condition_name = _current_scope[s.test.id].value
condition_name = current_scope[s.test.id].value
operator_name = "Not"
continue
self._translate_stmt(s)
Expand All @@ -1255,14 +1255,14 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:

if cond_while is not None:
# Loop while
_current_scope = self._current_scope()
if cond_while not in _current_scope:
current_scope = self._current_scope()
if cond_while not in current_scope:
self.fail(
loop_stmt,
f"Unable to find condition variable {cond_while!r} in known "
f"variables {list(_current_scope)!r}.",
f"variables {list(current_scope)!r}.",
)
o_cond_var = _current_scope[cond_while].value
o_cond_var = current_scope[cond_while].value

self.emit(
[o_cond_out],
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,16 @@ def eval_function(
has_array = False
for arg, param_schema in tagged_args:
if param_schema.is_input:
adapted_arg, _has_array = _adapt_to_eager_mode(arg)
has_array = has_array or _has_array
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
has_array = has_array or has_array_
adapted_args.append(adapted_arg)
else:
adapted_args.append(arg)

for key, (arg, param_schema) in tagged_kwargs.items():
if param_schema.is_input:
adapted_arg, _has_array = _adapt_to_eager_mode(arg)
has_array = has_array or _has_array
adapted_arg, has_array_ = _adapt_to_eager_mode(arg)
has_array = has_array or has_array_

Check warning on line 302 in onnxscript/evaluator.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/evaluator.py#L301-L302

Added lines #L301 - L302 were not covered by tests
adapted_kwargs[key] = adapted_arg
else:
adapted_kwargs[key] = arg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,15 @@ def type_constraints(self, signature_only: bool = True) -> OnnxFunctionTypeConst
)

# Rename type constraints to T0, T1, T2, ...
_seen_type_constraints: Set[TypeConstraint] = set()
seen_type_constraints: Set[TypeConstraint] = set()
for type_constraint in (
*input_type_constraints.values(),
*output_type_constraints.values(),
*intermediate_type_constraints.values(),
):
if type_constraint is not None and type_constraint not in _seen_type_constraints:
type_constraint.name = f"T{len(_seen_type_constraints)}"
_seen_type_constraints.add(type_constraint)
if type_constraint is not None and type_constraint not in seen_type_constraints:
type_constraint.name = f"T{len(seen_type_constraints)}"
seen_type_constraints.add(type_constraint)

return OnnxFunctionTypeConstraints(
input_type_constraints, output_type_constraints, intermediate_type_constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def main(args: argparse.Namespace) -> None:
functions[module_name] = {}
op_name = get_op_name(func)
if op_name in functions[module_name]:
logging.warning(
logging.warning( # noqa: LOG015

Check warning on line 286 in onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py#L286

Added line #L286 was not covered by tests
"Duplicated function: %s, overload: %s", op_name, func.func.name.overload_name
)
continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _get_func_schema_in_namespace(namespaces: List[_OpNamespace]) -> Dict[str, F
# to "resize(Tensor a, SymInt[] shape) -> Tensor"
if "!" in op_overload_packet.schema:
op_overload_packet.schema = re.sub( # type: ignore[attr-defined]
"[(][A-Za-z]![)]", "", op_overload_packet.schema
r"[(][A-Za-z]![)]", "", op_overload_packet.schema
)

# FIXME: remove below code if the issue below is fixed.
Expand All @@ -283,7 +283,7 @@ def main(args: argparse.Namespace) -> None:
if module_name not in functions:
functions[module_name] = {}
if op_name in functions[module_name]:
logging.warning(
logging.warning( # noqa: LOG015

Check warning on line 286 in onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py#L286

Added line #L286 was not covered by tests
"Duplicated function: %s, overload: %s",
op_name,
func_schema.name.overload_name,
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,7 @@ def format_name(value_name: str) -> str:

for input in function.inputs:
if not input.name:
logging.warning(
logger.warning(

Check warning on line 1074 in onnxscript/ir/serde.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/serde.py#L1074

Added line #L1074 was not covered by tests
"Function '%s': Value name not set for function input: %s",
function_qualified_name,
input,
Expand All @@ -1084,7 +1084,7 @@ def format_name(value_name: str) -> str:
for node in function:
for node_output in node.outputs:
if not node_output.name:
logging.warning(
logger.warning(
"Function '%s': Value name not set for node output: %s",
function_qualified_name,
node_output,
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main_graph(
val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3])
slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3])
unsqueeze_6 = opset18.Unsqueeze(input2, 1)
_to_copy_1 = opset18.Cast(unsqueeze_6, to=1)
to_copy_1 = opset18.Cast(unsqueeze_6, to=1)

Check warning on line 74 in onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py#L74

Added line #L74 was not covered by tests
view_1 = opset18.Constant(
value=make_tensor(
"value",
Expand Down Expand Up @@ -113,7 +113,7 @@ def main_graph(
],
)
)
view_2 = opset18.Reshape(_to_copy_1, [1, 1, 10], allowzero=0)
view_2 = opset18.Reshape(to_copy_1, [1, 1, 10], allowzero=0)

Check warning on line 116 in onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py#L116

Added line #L116 was not covered by tests
bmm = view_1 @ view_2
view_3 = opset18.Reshape(bmm, [1, 32, 10], allowzero=0)
transpose = opset18.Transpose(view_3, perm=[0, 2, 1])
Expand Down Expand Up @@ -199,8 +199,8 @@ def main_graph(
mul_13 = model_norm_weight * mul_12
t_7 = opset18.Transpose(lm_head_weight, perm=[1, 0])
view_23 = mul_13 @ t_7
_to_copy_12 = opset18.Identity(view_23)
return _to_copy_12, add_3, transpose_3
to_copy_12 = opset18.Identity(view_23)
return to_copy_12, add_3, transpose_3

Check warning on line 203 in onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/_smollm_1layer.py#L202-L203

Added lines #L202 - L203 were not covered by tests

model = main_graph.to_model_proto()
return model
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def clone(self, node_map: dict[NodePattern, NodePattern]) -> ValuePattern:
def name(self) -> str | None:
return self._name

def producer(self) -> None | NodePattern:
def producer(self) -> NodePattern | None:
return None

def uses(self) -> Sequence[tuple[NodePattern, int]]:
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/tools/benchmark/benchmark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _cmd_line(script_name: str, **kwargs: dict[str, Any]) -> list[str]:


def _extract_metrics(text: str) -> dict[str, str]:
reg = re.compile(":(.*?),(.*.?);")
reg = re.compile(r":(.*?),(.*.?);")

Check warning on line 111 in onnxscript/tools/benchmark/benchmark_helpers.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_helpers.py#L111

Added line #L111 was not covered by tests
res = reg.findall(text)
if len(res) == 0:
return {}
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/tools/benchmark/benchmark_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _cmd_line(script_name: str, **kwargs: dict[str, str | int | float]) -> list[


def _extract_metrics(text: str) -> dict[str, str]:
reg = re.compile(":(.*?),(.*.?);")
reg = re.compile(r":(.*?),(.*.?);")

Check warning on line 48 in onnxscript/tools/benchmark/benchmark_run.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/tools/benchmark/benchmark_run.py#L48

Added line #L48 was not covered by tests
res = reg.findall(text)
if len(res) == 0:
return {}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ ignore = [
"PYI041", # int | float is more clear
"RUF022", # We don't need to sort __all__ for elements to be grouped
"RUF031", # Parentheses for tuple in subscripts is more readable
"RUF052", # Variables with `_` prefix may not be dummy variables in all cases
"SIM102", # Collapible if statements are not always more readable
"SIM108", # We don't always encourage ternary operators
"SIM114", # Don't always combine if branches for debugability
Expand Down
18 changes: 6 additions & 12 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,16 @@ def _embedding_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Remove arguments not present in the aten op signature."""
if "max_norm" in kwargs:
del kwargs["max_norm"]
if "norm_type" in kwargs:
del kwargs["norm_type"]
kwargs.pop("max_norm", None)
kwargs.pop("norm_type", None)
return args, kwargs


def _empty_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
"""Remove arguments not present in the aten op signature."""
if "requires_grad" in kwargs:
del kwargs["requires_grad"]
kwargs.pop("requires_grad", None)
return args, kwargs


Expand Down Expand Up @@ -325,8 +322,7 @@ def _max_pool_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Remove return_indices argument because this op doesn't accept it
if "return_indices" in kwargs:
del kwargs["return_indices"]
kwargs.pop("return_indices", None)
return args, kwargs


Expand Down Expand Up @@ -364,8 +360,7 @@ def _nll_loss_input_wrangler(
def _nonzero_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "as_tuple" in kwargs:
del kwargs["as_tuple"]
kwargs.pop("as_tuple", None)
return args, kwargs


Expand Down Expand Up @@ -421,8 +416,7 @@ def _roll_input_wrangler(
def _scalar_tensor_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
if "requires_grad" in kwargs:
del kwargs["requires_grad"]
kwargs.pop("requires_grad", None)
return args, kwargs


Expand Down

0 comments on commit 1f06050

Please sign in to comment.