Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

alpha-CROWN in multiprocessing #58

Closed
cherrywoods opened this issue Dec 1, 2023 · 6 comments
Closed

alpha-CROWN in multiprocessing #58

cherrywoods opened this issue Dec 1, 2023 · 6 comments

Comments

@cherrywoods
Copy link

I am trying to compute bounds on multiple models in parallel using the multiprocessing library. This works fine when using IBP or CROWN, but when using alpha-CROWN, I get very nondescript (fatal) errors.

Reproduce

The following python script runs through, but does not print the bounds, indicating that the subprocess computing the bounds crashed silently. When I replace "alpha-CROWN" with "IBP" or "CROWN" in line 17, the code runs fine printing bounds on the console.

import multiprocessing as mp

import torch
from torch import nn
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm


def compute_bounds_worker(network, lb, ub):
    bounded_network = BoundedModule(
        network,
        lb,
    )
    perturbation = PerturbationLpNorm(x_L=lb, x_U=ub)
    midpoint = (ub + lb) / 2
    input_bounded = BoundedTensor(midpoint, ptb=perturbation)
    print("Compute Bounds")
    lb, ub = bounded_network.compute_bounds(x=(input_bounded,), method="alpha-CROWN")  # "IBP"
    print("Computation Finished")
    print(lb)
    print(ub)


if __name__ == "__main__":
    network = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1))
    lb = torch.zeros(1, 10)
    ub = torch.ones(1, 10)

    worker = mp.Process(
        target=compute_bounds_worker,
        kwargs={"network": network, "lb": lb, "ub": ub},
    )
    worker.start()
    worker.join()

Output:

/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/torch/utils/cpp_extension.py:25: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
Compute Bounds

Output with IBP:

/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/torch/utils/cpp_extension.py:25: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
Compute Bounds
Computation Finished
tensor([[-0.1279]], grad_fn=<AddBackward0>)
tensor([[0.7965]], grad_fn=<AddBackward0>)

System configuration:

  • OS: Ubuntu 22.04.3
  • Python version: 3.10
  • Pytorch Version: 1.12.1
  • Hardware: 11th Gen Intel® Core™ i7
  • Have you tried to reproduce the problem in a cleanly created conda/virtualenv environment using official installation instructions and the latest code on the main branch?: Yes
    • The error is not present when installing the CPU only version of pytorch 1.12.1
    • It is present when installing PyTorch with CUDA 11.6 conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge

Additional Context

In my actual project, I get an error message on the console (included below). I could not reproduce this exact error message, but I suspect the underlying issue is the same. The error might appear in my actual project because there, pytest invokes the code, because the main process and the subprocesses communicate via an mp.SimpleQueue, or because the subprocess obtains the bounds from a generator.

Fatal Python error: Aborted

Current thread 0x00007f019ebdd640 (most recent call first):
  <no Python frame>

Thread 0x00007f02385eb740 (most recent call first):
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/operators/clampmult.py", line 107 in backward
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/torch/autograd/function.py", line 253 in apply
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/torch/autograd/__init__.py", line 173 in backward
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/torch/_tensor.py", line 396 in backward
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/optimized_bounds.py", line 843 in get_optimized_bounds
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 1188 in compute_bounds
  ...
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/home/dboetius/.miniconda3/envs/env/lib/python3.10/site-packages/_pytest/config/__init__.py", line 166 in main

Extension modules: mkl._mklinit, mkl._py_mkl_service, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, scipy._lib._ccallback_c, numpy.linalg.lapack_lite, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.sparse.linalg._isolve._iterative, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.linalg._flinalg, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial.transform._rotation, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize.__nnls, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, scipy.stats.beta_ufunc, scipy.stats._boost.beta_ufunc, scipy.stats.binom_ufunc, scipy.stats._boost.binom_ufunc, scipy.stats.nbinom_ufunc, scipy.stats._boost.nbinom_ufunc, scipy.stats.hypergeom_ufunc, scipy.stats._boost.hypergeom_ufunc, scipy.stats.ncf_ufunc, scipy.stats._boost.ncf_ufunc, scipy.stats.ncx2_ufunc, scipy.stats._boost.ncx2_ufunc, scipy.stats.nct_ufunc, scipy.stats._boost.nct_ufunc, scipy.stats.skewnorm_ufunc, scipy.stats._boost.skewnorm_ufunc, scipy.stats.invgauss_ufunc, scipy.stats._boost.invgauss_ufunc, scipy.interpolate._fitpack, scipy.interpolate.dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._statlib, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._mvn, scipy.stats._rcont.rcont, torch._C, torch._C._fft, torch._C._linalg, torch._C._nn, torch._C._sparse, torch._C._special (total: 123)
@cherrywoods
Copy link
Author

Using torch.multiprocessing instead of multiprocessing does not resolve the issue.

@shizhouxing
Copy link
Member

Hi @cherrywoods , I debugged a little and found it crashed when loss.backward() in alpha-CROWN is called. I guess it's probably an issue with the multiprocessing library itself when loss.backward() is called (IBP and CROWN doesn't have loss.backward()).

@cherrywoods
Copy link
Author

cherrywoods commented Dec 7, 2023

Hi, thanks for looking into this! I also investigated whether the issue is with .backward(), but training in a separate process works fine. Also, the following simple example does not crash the subprocess for me:

import multiprocessing as mp

import torch
from torch import nn


def worker(network, x):
    x.requires_grad = True
    print("Start")
    output = network(x)
    output.backward()
    print("Finished")
    print(x.grad)


if __name__ == "__main__":
    network = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1))
    x = torch.zeros(1, 10)

    worker = mp.Process(
        target=worker,
        kwargs={"network": network, "x": x},
    )
    worker.start()
    worker.join()

@cherrywoods
Copy link
Author

In the crash stack trace in the issue description that I didn't manage to reproduce yet, it confirms that the crash is during loss.backward, but it also (more concretely) references line 107 in auto_LiRPA/operators/clampmult.py which contains an assertion. Unfortunately, I don't really know how to debug further than the backward call, because it calls into a C++ backend which then (apparently) calls into auto_LiRPA/operators/clampmult.py...

@shizhouxing
Copy link
Member

shizhouxing commented Jan 16, 2025

It looks like the default start method for multiprocessing is fork for Linux (https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods) while torch with CUDA doesn't support that: https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing.

@shizhouxing
Copy link
Member

As said in pytorch's documentation, the start method has to be spawn or forkserver. I tried adding mp.set_start_method('spawn') and it seems to resolve the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants