Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove macOS stderr suppression #76

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
11 changes: 9 additions & 2 deletions .github/workflows/test-torchfix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ jobs:
test-torchfix:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
os: [ubuntu-latest, macos-latest, macos-latest-large, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Show CPU architecture
if: matrix.os == 'macos-latest-large' || matrix.os == 'macos-latest'
run: |
uname -m
- uses: actions/setup-python@v5
with:
python-version: '3.10'
Expand All @@ -23,9 +27,12 @@ jobs:
- name: Install TorchFix
run: |
pip3 install ".[dev]"
- name: Run torchfix CLI
run: |
torchfix --help
- name: Run pytest
run: |
pytest tests
pytest -vv tests
- name: Run flake8
run: |
flake8
Expand Down
36 changes: 36 additions & 0 deletions tests/test_torchfix.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import subprocess
from pathlib import Path
from torchfix.torchfix import (
TorchChecker,
Expand Down Expand Up @@ -92,3 +94,37 @@ def test_errorcodes_distinct():

def test_parse_error_code_str(case, expected):
assert process_error_code_str(case) == expected


def test_stderr_suppression(tmp_path):
data = f"import torchvision.datasets as datasets{os.linesep}"
data_path = tmp_path / "fixable.py"
data_path.write_text(data)
result = subprocess.run(
["torchfix", "--select", "TOR203", "--fix", str(data_path)],
stderr=subprocess.PIPE,
text=True,
check=False,
)
assert (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should also check the stdout here for the macOS issue.

result.stderr == "Finished checking 1 files.\n"
"Transformed 1 files successfully.\n"
)

data = f"import torchvision.datasets as datasets{os.linesep}"
data_path = tmp_path / "fixable.py"
data_path.write_text(data)
result = subprocess.run(
["torchfix", "--select", "TOR203", "--show-stderr", "--fix", str(data_path)],
stderr=subprocess.PIPE,
text=True,
check=False,
)
expected = result.stderr.replace("\\\\", "\\")
assert (
expected == f"Executing codemod...\n"
f"Failed to determine module name for {data_path}: '{data_path}' is not in the "
f"subpath of '' OR one path is relative and the other is absolute.\n"
f"Finished checking 1 files.\n"
f"Transformed 1 files successfully.\n"
)
31 changes: 6 additions & 25 deletions torchfix/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import libcst.codemod as codemod

import contextlib
import ctypes
import sys
import io

Expand All @@ -17,29 +16,6 @@
from .common import CYAN, ENDC


# Should get rid of this code eventually.
@contextlib.contextmanager
def StderrSilencer(redirect: bool = True):
if not redirect:
yield
elif sys.platform != "darwin":
with contextlib.redirect_stderr(io.StringIO()):
yield
else:
# redirect_stderr does not work for some reason
# Workaround it by using good old dup2 to redirect
# stderr to /dev/null
libc = ctypes.CDLL("libc.dylib")
orig_stderr = libc.dup(2)
with open("/dev/null", "w") as f:
libc.dup2(f.fileno(), 2)
try:
yield
finally:
libc.dup2(orig_stderr, 2)
libc.close(orig_stderr)


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -102,7 +78,12 @@ def main() -> None:
command_instance = TorchCodemod(codemod.CodemodContext(), config)
DIFF_CONTEXT = 5
try:
with StderrSilencer(not args.show_stderr):
supress_stderr = (
contextlib.redirect_stderr(io.StringIO())
if not args.show_stderr
else contextlib.nullcontext()
)
with supress_stderr:
result = codemod.parallel_exec_transform_with_prettyprint(
command_instance,
torch_files,
Expand Down