Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConvTranspose CROWN Bounds #61

Closed
cherrywoods opened this issue Dec 8, 2023 · 12 comments
Closed

ConvTranspose CROWN Bounds #61

cherrywoods opened this issue Dec 8, 2023 · 12 comments

Comments

@cherrywoods
Copy link

Describe the bug
I was delighted to see that auto_LiRPA can bound ConvTranspose layers out of the box, but, unfortunately, CROWN in batch mode doesn't seem to work.

To Reproduce
Code to reproduce with the attached network (zipped): mnist_conv_generator.zip

>>> import torch
>>> from auto_LiRPA import PerturbationLpNorm, BoundedModule, BoundedTensor
/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/torch/utils/cpp_extension.py:25: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
No CUDA runtime is found, using CUDA_HOME='/usr'
>>> net = torch.load("mnist_conv_generator.pyt")
>>> net = BoundedModule(net, torch.zeros(1, 4, 1, 1))
/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/torch/nn/functional.py:2403: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if size_prods == 1:
Warning: ONNX Preprocess - Removing mutation from node aten::add_ on block input: '6'. This changes graph semantics.
Warning: ONNX Preprocess - Removing mutation from node aten::add_ on block input: '12'. This changes graph semantics.
>>> ptb = PerturbationLpNorm(x_L=torch.zeros(1, 4, 1, 1), x_U=torch.ones(1, 4, 1, 1))
>>> tensor = BoundedTensor(torch.zeros(10, 4, 1, 1), ptb)
>>> net.compute_bounds(x=(tensor,), method="ibp")  # works fine, output omitted
>>> net.compute_bounds(x=(tensor,), method="crown")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 1206, in compute_bounds
    return self._compute_bounds_main(C=C,
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 1303, in _compute_bounds_main
    self.check_prior_bounds(final)
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  [Previous line repeated 2 more times]
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 804, in check_prior_bounds
    self.compute_intermediate_bounds(
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/bound_general.py", line 915, in compute_intermediate_bounds
    self.restore_sparse_bounds(
  File "/home/david/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.4.0-py3.10.egg/auto_LiRPA/backward_bound.py", line 575, in restore_sparse_bounds
    lower[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] = new_lower
RuntimeError: shape mismatch: value tensor of shape [10, 703] cannot be broadcast to indexing result of shape [1, 703]

System configuration:

  • OS: Ubuntu 22.04.3 LTS
  • Python version: Python 3.10
  • Pytorch Version: PyTorch 1.12.1
  • Hardware: CPU only (also verified on CUDA: GeForce GT 1030)
  • Have you tried to reproduce the problem in a cleanly created conda/virtualenv environment using official installation instructions and the latest code on the main branch?: Yes
@cherrywoods
Copy link
Author

CROWN with a single input (e.g. torch.zeros(1, 4, 1, 1) instead of torch.zeros(10, 4, 1, 1)) works fine.

@shizhouxing
Copy link
Member

Hi @cherrywoods , could you please share the code for the model definition?

@cherrywoods
Copy link
Author

Sure, sorry for not including it right away:

generator = nn.Sequential(
    nn.ConvTranspose2d(4, 49, kernel_size=4, stride=1, bias=False),  # 49 x 4 x 4
    nn.BatchNorm2d(49, affine=True),
    nn.LeakyReLU(negative_slope=0.2),
    nn.ConvTranspose2d(49, 12, kernel_size=4, stride=4, bias=False),  # 12 x 16 x 16
    nn.BatchNorm2d(12, affine=True),
    nn.LeakyReLU(negative_slope=0.2),
    nn.ConvTranspose2d(12, 1, kernel_size=13, stride=1, bias=False),  # 1 x 28 x 28
    nn.Sigmoid(),
)

@shizhouxing
Copy link
Member

Hi @cherrywoods , the issue is that you need to update ptb as well to use a batch size of 10.

@cherrywoods
Copy link
Author

cherrywoods commented Dec 9, 2023

Hi @shizhouxing, the incorrect batch dimension was indeed a problem in the code I posted, however a very similar error persists also with fixed batch dimensions:

import torch
from auto_LiRPA import PerturbationLpNorm, BoundedModule, BoundedTensor
net = torch.load("mnist_conv_generator.pyt")
net = BoundedModule(net, torch.zeros(1, 4, 1, 1))
ptb = PerturbationLpNorm(x_L=torch.zeros(10, 4, 1, 1), x_U=torch.ones(10, 4, 1, 1))
tensor = BoundedTensor(torch.zeros(10, 4, 1, 1), ptb)
net.compute_bounds(x=(tensor,), method="ibp")  # works fine, output omitted
net.compute_bounds(x=(tensor,), method="crown")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 1339, in compute_bounds
    self.check_prior_bounds(final)
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 883, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 883, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 883, in check_prior_bounds
    self.check_prior_bounds(n)
  [Previous line repeated 2 more times]
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 885, in check_prior_bounds
    self.compute_intermediate_bounds(
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/bound_general.py", line 983, in compute_intermediate_bounds
    node.lower, node.upper = self.backward_general(
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/backward_bound.py", line 212, in backward_general
    lb, ub = concretize(
  File "/home/dboetius/.miniconda3/envs/auto_LiRPA/lib/python3.10/site-packages/auto_LiRPA-0.3.1-py3.10.egg/auto_LiRPA/backward_bound.py", line 532, in concretize
    lb = lb + root[i].perturbation.concretize(
RuntimeError: The size of tensor a (4) must match the size of tensor b (784) at non-singleton dimension 3

@shizhouxing
Copy link
Member

Hi @cherrywoods ,

You'll need to modify both x_L and x_U:

ptb = PerturbationLpNorm(x_L=torch.zeros(10, 4, 1, 1), x_U=torch.ones(1, 4, 1, 1))

@cherrywoods
Copy link
Author

cherrywoods commented Dec 9, 2023

Hi @shizhouxing, this was only a typo. I updated the code above. The error remains the same.

@shizhouxing
Copy link
Member

Hi @cherrywoods , but I tried your code and it worked fine on my side.

I see your output contains auto_LiRPA-0.3.1. Are you using the latest version of auto_LiRPA? The latest version should have a version number of 0.4.

@cherrywoods
Copy link
Author

That indeed seemed to be the issue. I somehow messed up pulling the latest release from Github. Thanks for your patience and sorry for the inconvenience. I'm happy that I can now use ConvTranspose layers :)

@cherrywoods
Copy link
Author

I reopen this because I keep getting errors in the actual code I'm using, which obviously uses different bounds than 0.0 and 1.0. I debugged through this for the past hour and couldn't find anything like the errors that we discussed above. To be on the safe side this time, I made a docker container that reproduces the issue: conv_transpose_issue.zip

The container creates a conda environment, downloads and installs the latest auto_LiRPA commit and then runs the following script:

import torch
from torch import nn
import auto_LiRPA
from auto_LiRPA import PerturbationLpNorm, BoundedModule, BoundedTensor

print(auto_LiRPA.__version__)

torch.manual_seed(0)
net = nn.Sequential(
    nn.ConvTranspose2d(4, 49, kernel_size=4, stride=1, bias=False),  # 49 x 4 x 4
    nn.BatchNorm2d(49, affine=True),
    nn.LeakyReLU(negative_slope=0.2),
    nn.ConvTranspose2d(49, 12, kernel_size=4, stride=4, bias=False),  # 12 x 16 x 16
    nn.BatchNorm2d(12, affine=True),
    nn.LeakyReLU(negative_slope=0.2),
    nn.ConvTranspose2d(12, 1, kernel_size=13, stride=1, bias=False),  # 1 x 28 x 28
    nn.Sigmoid(),
)
net = BoundedModule(net, torch.empty(1, 4, 1, 1))

lb = torch.zeros(1, 4, 1, 1)
ub = torch.ones(1, 4, 1, 1)
ptb = PerturbationLpNorm(x_L=lb,x_U=ub)
tensor = BoundedTensor(lb, ptb)
print(lb.shape, ub.shape, tensor.shape)
print(lb, ub, tensor)
bounds = net.compute_bounds(x=(tensor,), method="crown")  # works fine
print(bounds)

lb = lb.clone() - 1.0
ptb = PerturbationLpNorm(x_L=lb,x_U=ub)
tensor = BoundedTensor(lb, ptb)
print(lb.shape, ub.shape, tensor.shape)
print(lb, ub, tensor)
bounds = net.compute_bounds(x=(tensor,), method="crown")  # fails
print(bounds)

When I run this using:

docker build . -t auto_lirpa
docker run -t auto_lirpa

I get this output:

/opt/conda/envs/auto_LiRPA/lib/python3.10/site-packages/torch/utils/cpp_extension.py:25: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  from pkg_resources import packaging  # type: ignore[attr-defined]
0.4.0
/opt/conda/envs/auto_LiRPA/lib/python3.10/site-packages/torch/nn/functional.py:2403: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if size_prods == 1:
Warning: ONNX Preprocess - Removing mutation from node aten::add_ on block input: '6'. This changes graph semantics.
Warning: ONNX Preprocess - Removing mutation from node aten::add_ on block input: '12'. This changes graph semantics.
torch.Size([1, 4, 1, 1]) torch.Size([1, 4, 1, 1]) (1, 4, 1, 1)
tensor([[[[0.]],

         [[0.]],

         [[0.]],

         [[0.]]]]) tensor([[[[1.]],

         [[1.]],

         [[1.]],

         [[1.]]]]) <BoundedTensor: BoundedTensor([[[[0.]],

                [[0.]],

                [[0.]],

                [[0.]]]]), PerturbationLpNorm(norm=inf, eps=0, x_L=tensor([[[[0.]],

         [[0.]],

         [[0.]],

         [[0.]]]]), x_U=tensor([[[[1.]],

         [[1.]],

         [[1.]],

         [[1.]]]]))>
(tensor([[[[0.4989, 0.4986, 0.4980, 0.4974, 0.4972, 0.4965, 0.4962, 0.4952,
           0.4960, 0.4955, 0.4947, 0.4937, 0.4937, 0.4942, 0.4927, 0.4931,
           0.4949, 0.4950, 0.4949, 0.4951, 0.4974, 0.4967, 0.4970, 0.4978,
           0.4982, 0.4984, 0.4992, 0.4991],
          [0.4988, 0.4978, 0.4970, 0.4960, 0.4946, 0.4945, 0.4925, 0.4929,
           0.4919, 0.4920, 0.4908, 0.4898, 0.4899, 0.4897, 0.4895, 0.4895,
           0.4908, 0.4920, 0.4924, 0.4935, 0.4945, 0.4948, 0.4960, 0.4963,
           0.4970, 0.4968, 0.4985, 0.4991],
          [0.4979, 0.4969, 0.4955, 0.4937, 0.4921, 0.4922, 0.4906, 0.4904,
           0.4886, 0.4863, 0.4868, 0.4859, 0.4847, 0.4852, 0.4834, 0.4849,
           0.4853, 0.4862, 0.4894, 0.4895, 0.4915, 0.4921, 0.4932, 0.4933,
           0.4948, 0.4965, 0.4969, 0.4981],
          [0.4980, 0.4958, 0.4950, 0.4936, 0.4930, 0.4902, 0.4890, 0.4883,
           0.4864, 0.4847, 0.4837, 0.4822, 0.4816, 0.4810, 0.4812, 0.4808,
           0.4838, 0.4839, 0.4856, 0.4872, 0.4884, 0.4903, 0.4919, 0.4926,
           0.4932, 0.4949, 0.4968, 0.4986],
          [0.4972, 0.4944, 0.4926, 0.4917, 0.4899, 0.4879, 0.4849, 0.4841,
           0.4831, 0.4800, 0.4791, 0.4771, 0.4773, 0.4755, 0.4749, 0.4753,
           0.4782, 0.4796, 0.4813, 0.4834, 0.4866, 0.4876, 0.4888, 0.4905,
           0.4929, 0.4935, 0.4953, 0.4978],
          [0.4966, 0.4946, 0.4924, 0.4890, 0.4878, 0.4853, 0.4827, 0.4808,
           0.4784, 0.4778, 0.4756, 0.4726, 0.4713, 0.4724, 0.4703, 0.4710,
           0.4733, 0.4761, 0.4784, 0.4807, 0.4831, 0.4860, 0.4872, 0.4885,
           0.4904, 0.4931, 0.4945, 0.4974],
          [0.4967, 0.4941, 0.4914, 0.4891, 0.4847, 0.4837, 0.4812, 0.4793,
           0.4758, 0.4712, 0.4710, 0.4680, 0.4634, 0.4664, 0.4658, 0.4656,
           0.4682, 0.4713, 0.4742, 0.4776, 0.4787, 0.4813, 0.4848, 0.4868,
           0.4888, 0.4919, 0.4946, 0.4970],
          [0.4957, 0.4937, 0.4898, 0.4873, 0.4840, 0.4801, 0.4780, 0.4752,
           0.4733, 0.4695, 0.4679, 0.4649, 0.4632, 0.4615, 0.4618, 0.4622,
           0.4654, 0.4679, 0.4705, 0.4751, 0.4779, 0.4804, 0.4826, 0.4848,
           0.4868, 0.4911, 0.4937, 0.4970],
          [0.4960, 0.4917, 0.4870, 0.4846, 0.4805, 0.4777, 0.4753, 0.4731,
           0.4690, 0.4654, 0.4628, 0.4602, 0.4559, 0.4551, 0.4559, 0.4562,
           0.4601, 0.4631, 0.4674, 0.4710, 0.4756, 0.4768, 0.4803, 0.4826,
           0.4871, 0.4893, 0.4928, 0.4963],
          [0.4952, 0.4918, 0.4874, 0.4831, 0.4779, 0.4769, 0.4736, 0.4697,
           0.4658, 0.4622, 0.4592, 0.4563, 0.4500, 0.4524, 0.4527, 0.4506,
           0.4557, 0.4603, 0.4641, 0.4678, 0.4725, 0.4752, 0.4782, 0.4809,
           0.4847, 0.4887, 0.4923, 0.4958],
          [0.4953, 0.4901, 0.4872, 0.4829, 0.4787, 0.4749, 0.4712, 0.4677,
           0.4632, 0.4582, 0.4555, 0.4524, 0.4484, 0.4472, 0.4479, 0.4479,
           0.4521, 0.4565, 0.4608, 0.4645, 0.4686, 0.4728, 0.4759, 0.4796,
           0.4830, 0.4863, 0.4917, 0.4955],
          [0.4949, 0.4898, 0.4857, 0.4820, 0.4769, 0.4712, 0.4692, 0.4641,
           0.4607, 0.4569, 0.4527, 0.4478, 0.4450, 0.4428, 0.4435, 0.4430,
           0.4489, 0.4526, 0.4569, 0.4615, 0.4675, 0.4697, 0.4744, 0.4778,
           0.4819, 0.4843, 0.4912, 0.4953],
          [0.4943, 0.4882, 0.4841, 0.4808, 0.4739, 0.4696, 0.4659, 0.4614,
           0.4552, 0.4526, 0.4486, 0.4449, 0.4390, 0.4368, 0.4373, 0.4373,
           0.4441, 0.4469, 0.4523, 0.4572, 0.4636, 0.4666, 0.4704, 0.4750,
           0.4804, 0.4850, 0.4884, 0.4928],
          [0.4936, 0.4884, 0.4845, 0.4803, 0.4742, 0.4681, 0.4656, 0.4614,
           0.4557, 0.4509, 0.4485, 0.4396, 0.4384, 0.4378, 0.4364, 0.4367,
           0.4424, 0.4479, 0.4525, 0.4561, 0.4633, 0.4677, 0.4707, 0.4752,
           0.4794, 0.4846, 0.4898, 0.4944],
          [0.4935, 0.4893, 0.4848, 0.4804, 0.4746, 0.4701, 0.4669, 0.4623,
           0.4571, 0.4509, 0.4480, 0.4437, 0.4388, 0.4381, 0.4387, 0.4384,
           0.4429, 0.4482, 0.4523, 0.4569, 0.4616, 0.4673, 0.4707, 0.4750,
           0.4791, 0.4847, 0.4894, 0.4941],
          [0.4942, 0.4888, 0.4855, 0.4810, 0.4756, 0.4702, 0.4668, 0.4617,
           0.4572, 0.4532, 0.4489, 0.4444, 0.4398, 0.4388, 0.4389, 0.4361,
           0.4439, 0.4489, 0.4535, 0.4575, 0.4627, 0.4667, 0.4706, 0.4753,
           0.4802, 0.4845, 0.4899, 0.4947],
          [0.4953, 0.4903, 0.4865, 0.4830, 0.4786, 0.4732, 0.4701, 0.4660,
           0.4613, 0.4570, 0.4531, 0.4504, 0.4453, 0.4426, 0.4439, 0.4434,
           0.4478, 0.4525, 0.4580, 0.4615, 0.4663, 0.4693, 0.4735, 0.4774,
           0.4829, 0.4848, 0.4893, 0.4947],
          [0.4947, 0.4915, 0.4872, 0.4836, 0.4781, 0.4752, 0.4727, 0.4680,
           0.4640, 0.4603, 0.4577, 0.4527, 0.4488, 0.4470, 0.4481, 0.4476,
           0.4520, 0.4570, 0.4598, 0.4624, 0.4687, 0.4718, 0.4731, 0.4789,
           0.4813, 0.4874, 0.4905, 0.4952],
          [0.4964, 0.4926, 0.4888, 0.4856, 0.4815, 0.4786, 0.4748, 0.4711,
           0.4684, 0.4646, 0.4617, 0.4579, 0.4537, 0.4531, 0.4533, 0.4528,
           0.4566, 0.4604, 0.4632, 0.4673, 0.4712, 0.4747, 0.4776, 0.4811,
           0.4844, 0.4883, 0.4918, 0.4959],
          [0.4962, 0.4923, 0.4897, 0.4862, 0.4825, 0.4803, 0.4776, 0.4715,
           0.4707, 0.4683, 0.4652, 0.4618, 0.4579, 0.4565, 0.4573, 0.4566,
           0.4620, 0.4640, 0.4677, 0.4699, 0.4749, 0.4756, 0.4799, 0.4827,
           0.4857, 0.4893, 0.4933, 0.4962],
          [0.4968, 0.4936, 0.4910, 0.4885, 0.4847, 0.4817, 0.4796, 0.4777,
           0.4731, 0.4713, 0.4695, 0.4670, 0.4624, 0.4620, 0.4623, 0.4621,
           0.4650, 0.4669, 0.4713, 0.4743, 0.4754, 0.4798, 0.4820, 0.4828,
           0.4884, 0.4905, 0.4932, 0.4962],
          [0.4958, 0.4934, 0.4917, 0.4896, 0.4863, 0.4839, 0.4809, 0.4795,
           0.4761, 0.4750, 0.4724, 0.4700, 0.4679, 0.4656, 0.4665, 0.4649,
           0.4689, 0.4707, 0.4747, 0.4770, 0.4801, 0.4806, 0.4831, 0.4871,
           0.4893, 0.4920, 0.4941, 0.4966],
          [0.4977, 0.4953, 0.4931, 0.4912, 0.4893, 0.4868, 0.4846, 0.4821,
           0.4806, 0.4787, 0.4760, 0.4750, 0.4719, 0.4721, 0.4710, 0.4709,
           0.4736, 0.4759, 0.4771, 0.4793, 0.4815, 0.4827, 0.4859, 0.4884,
           0.4891, 0.4919, 0.4948, 0.4975],
          [0.4976, 0.4958, 0.4939, 0.4926, 0.4904, 0.4893, 0.4868, 0.4849,
           0.4830, 0.4822, 0.4804, 0.4789, 0.4762, 0.4753, 0.4757, 0.4758,
           0.4772, 0.4802, 0.4815, 0.4820, 0.4846, 0.4864, 0.4883, 0.4892,
           0.4917, 0.4933, 0.4946, 0.4975],
          [0.4978, 0.4969, 0.4950, 0.4953, 0.4937, 0.4911, 0.4896, 0.4891,
           0.4872, 0.4861, 0.4852, 0.4842, 0.4821, 0.4798, 0.4817, 0.4815,
           0.4834, 0.4844, 0.4856, 0.4867, 0.4868, 0.4890, 0.4906, 0.4925,
           0.4933, 0.4941, 0.4954, 0.4980],
          [0.4985, 0.4967, 0.4963, 0.4960, 0.4948, 0.4932, 0.4925, 0.4916,
           0.4899, 0.4890, 0.4877, 0.4865, 0.4855, 0.4851, 0.4836, 0.4860,
           0.4866, 0.4872, 0.4881, 0.4892, 0.4910, 0.4901, 0.4919, 0.4943,
           0.4953, 0.4961, 0.4972, 0.4987],
          [0.4988, 0.4982, 0.4975, 0.4968, 0.4963, 0.4956, 0.4941, 0.4936,
           0.4926, 0.4925, 0.4916, 0.4900, 0.4899, 0.4899, 0.4894, 0.4887,
           0.4900, 0.4909, 0.4917, 0.4924, 0.4931, 0.4939, 0.4946, 0.4958,
           0.4970, 0.4976, 0.4977, 0.4990],
          [0.4992, 0.4991, 0.4987, 0.4983, 0.4977, 0.4972, 0.4969, 0.4960,
           0.4961, 0.4953, 0.4954, 0.4944, 0.4944, 0.4943, 0.4938, 0.4942,
           0.4943, 0.4951, 0.4955, 0.4958, 0.4961, 0.4968, 0.4970, 0.4980,
           0.4977, 0.4986, 0.4986, 0.4992]]]], grad_fn=<ViewBackward0>), tensor([[[[0.5008, 0.5009, 0.5015, 0.5019, 0.5022, 0.5023, 0.5029, 0.5033,
           0.5039, 0.5049, 0.5048, 0.5050, 0.5060, 0.5064, 0.5057, 0.5054,
           0.5054, 0.5060, 0.5048, 0.5048, 0.5041, 0.5039, 0.5032, 0.5027,
           0.5020, 0.5027, 0.5015, 0.5011],
          [0.5012, 0.5017, 0.5018, 0.5030, 0.5038, 0.5050, 0.5054, 0.5062,
           0.5070, 0.5081, 0.5082, 0.5088, 0.5104, 0.5102, 0.5110, 0.5100,
           0.5097, 0.5092, 0.5081, 0.5083, 0.5064, 0.5058, 0.5055, 0.5044,
           0.5032, 0.5038, 0.5023, 0.5016],
          [0.5014, 0.5027, 0.5038, 0.5048, 0.5065, 0.5078, 0.5094, 0.5097,
           0.5095, 0.5121, 0.5125, 0.5132, 0.5140, 0.5152, 0.5156, 0.5144,
           0.5125, 0.5133, 0.5128, 0.5106, 0.5088, 0.5090, 0.5088, 0.5059,
           0.5050, 0.5047, 0.5031, 0.5016],
          [0.5017, 0.5035, 0.5040, 0.5057, 0.5077, 0.5097, 0.5103, 0.5125,
           0.5140, 0.5142, 0.5167, 0.5178, 0.5195, 0.5193, 0.5199, 0.5190,
           0.5177, 0.5162, 0.5152, 0.5145, 0.5117, 0.5104, 0.5099, 0.5077,
           0.5078, 0.5060, 0.5047, 0.5025],
          [0.5022, 0.5048, 0.5058, 0.5077, 0.5104, 0.5123, 0.5137, 0.5166,
           0.5190, 0.5201, 0.5216, 0.5229, 0.5264, 0.5275, 0.5257, 0.5255,
           0.5246, 0.5216, 0.5203, 0.5198, 0.5159, 0.5136, 0.5132, 0.5105,
           0.5086, 0.5076, 0.5061, 0.5031],
          [0.5031, 0.5056, 0.5074, 0.5090, 0.5117, 0.5158, 0.5184, 0.5193,
           0.5217, 0.5267, 0.5264, 0.5268, 0.5311, 0.5323, 0.5299, 0.5301,
           0.5289, 0.5286, 0.5241, 0.5220, 0.5192, 0.5171, 0.5147, 0.5123,
           0.5096, 0.5081, 0.5060, 0.5034],
          [0.5037, 0.5062, 0.5086, 0.5097, 0.5139, 0.5167, 0.5196, 0.5223,
           0.5252, 0.5300, 0.5293, 0.5315, 0.5346, 0.5349, 0.5373, 0.5342,
           0.5324, 0.5310, 0.5278, 0.5252, 0.5217, 0.5184, 0.5178, 0.5136,
           0.5110, 0.5099, 0.5070, 0.5042],
          [0.5035, 0.5070, 0.5091, 0.5117, 0.5166, 0.5184, 0.5234, 0.5264,
           0.5279, 0.5297, 0.5338, 0.5364, 0.5399, 0.5396, 0.5395, 0.5409,
           0.5361, 0.5323, 0.5319, 0.5282, 0.5258, 0.5209, 0.5202, 0.5160,
           0.5129, 0.5113, 0.5084, 0.5040],
          [0.5048, 0.5074, 0.5100, 0.5134, 0.5172, 0.5210, 0.5242, 0.5290,
           0.5319, 0.5347, 0.5379, 0.5411, 0.5484, 0.5445, 0.5443, 0.5442,
           0.5407, 0.5373, 0.5344, 0.5312, 0.5274, 0.5259, 0.5216, 0.5177,
           0.5144, 0.5113, 0.5084, 0.5048],
          [0.5051, 0.5083, 0.5110, 0.5138, 0.5184, 0.5240, 0.5285, 0.5298,
           0.5342, 0.5385, 0.5417, 0.5433, 0.5484, 0.5491, 0.5492, 0.5486,
           0.5444, 0.5427, 0.5379, 0.5337, 0.5312, 0.5264, 0.5238, 0.5191,
           0.5152, 0.5119, 0.5092, 0.5052],
          [0.5049, 0.5102, 0.5127, 0.5153, 0.5213, 0.5271, 0.5303, 0.5329,
           0.5376, 0.5417, 0.5439, 0.5466, 0.5519, 0.5542, 0.5535, 0.5526,
           0.5479, 0.5453, 0.5408, 0.5373, 0.5316, 0.5288, 0.5266, 0.5206,
           0.5163, 0.5134, 0.5096, 0.5057],
          [0.5054, 0.5102, 0.5140, 0.5174, 0.5236, 0.5275, 0.5317, 0.5359,
           0.5408, 0.5448, 0.5471, 0.5524, 0.5584, 0.5583, 0.5577, 0.5592,
           0.5541, 0.5509, 0.5453, 0.5430, 0.5363, 0.5312, 0.5272, 0.5234,
           0.5203, 0.5151, 0.5116, 0.5063],
          [0.5067, 0.5103, 0.5155, 0.5196, 0.5271, 0.5296, 0.5346, 0.5396,
           0.5461, 0.5484, 0.5517, 0.5568, 0.5640, 0.5631, 0.5642, 0.5643,
           0.5596, 0.5536, 0.5501, 0.5447, 0.5400, 0.5331, 0.5298, 0.5257,
           0.5204, 0.5171, 0.5121, 0.5067],
          [0.5055, 0.5113, 0.5148, 0.5190, 0.5242, 0.5304, 0.5354, 0.5395,
           0.5440, 0.5492, 0.5527, 0.5561, 0.5630, 0.5643, 0.5652, 0.5627,
           0.5584, 0.5548, 0.5523, 0.5439, 0.5392, 0.5345, 0.5302, 0.5252,
           0.5214, 0.5163, 0.5128, 0.5067],
          [0.5056, 0.5112, 0.5151, 0.5184, 0.5250, 0.5296, 0.5356, 0.5383,
           0.5453, 0.5491, 0.5522, 0.5553, 0.5627, 0.5649, 0.5649, 0.5638,
           0.5582, 0.5543, 0.5506, 0.5447, 0.5394, 0.5363, 0.5302, 0.5254,
           0.5195, 0.5179, 0.5122, 0.5066],
          [0.5060, 0.5101, 0.5151, 0.5191, 0.5254, 0.5296, 0.5344, 0.5388,
           0.5439, 0.5480, 0.5514, 0.5564, 0.5628, 0.5630, 0.5642, 0.5639,
           0.5576, 0.5549, 0.5503, 0.5453, 0.5392, 0.5346, 0.5319, 0.5258,
           0.5210, 0.5188, 0.5124, 0.5068],
          [0.5052, 0.5097, 0.5139, 0.5177, 0.5237, 0.5268, 0.5314, 0.5356,
           0.5405, 0.5440, 0.5483, 0.5514, 0.5571, 0.5567, 0.5577, 0.5580,
           0.5549, 0.5527, 0.5456, 0.5406, 0.5379, 0.5310, 0.5297, 0.5229,
           0.5186, 0.5140, 0.5110, 0.5064],
          [0.5048, 0.5090, 0.5137, 0.5158, 0.5208, 0.5258, 0.5305, 0.5320,
           0.5367, 0.5403, 0.5451, 0.5461, 0.5528, 0.5526, 0.5539, 0.5523,
           0.5490, 0.5463, 0.5413, 0.5360, 0.5328, 0.5284, 0.5261, 0.5204,
           0.5178, 0.5127, 0.5097, 0.5052],
          [0.5051, 0.5088, 0.5117, 0.5147, 0.5196, 0.5231, 0.5262, 0.5285,
           0.5342, 0.5365, 0.5397, 0.5421, 0.5477, 0.5485, 0.5488, 0.5478,
           0.5439, 0.5416, 0.5382, 0.5338, 0.5292, 0.5264, 0.5230, 0.5191,
           0.5145, 0.5125, 0.5099, 0.5059],
          [0.5042, 0.5076, 0.5126, 0.5132, 0.5172, 0.5212, 0.5247, 0.5264,
           0.5315, 0.5334, 0.5369, 0.5389, 0.5433, 0.5455, 0.5435, 0.5436,
           0.5411, 0.5383, 0.5338, 0.5313, 0.5279, 0.5237, 0.5209, 0.5200,
           0.5139, 0.5118, 0.5086, 0.5043],
          [0.5037, 0.5067, 0.5093, 0.5124, 0.5148, 0.5172, 0.5201, 0.5238,
           0.5262, 0.5288, 0.5306, 0.5330, 0.5368, 0.5372, 0.5377, 0.5376,
           0.5347, 0.5320, 0.5293, 0.5260, 0.5234, 0.5203, 0.5174, 0.5148,
           0.5120, 0.5094, 0.5076, 0.5041],
          [0.5027, 0.5052, 0.5080, 0.5102, 0.5125, 0.5161, 0.5179, 0.5199,
           0.5224, 0.5248, 0.5270, 0.5292, 0.5329, 0.5347, 0.5345, 0.5344,
           0.5304, 0.5298, 0.5261, 0.5229, 0.5209, 0.5181, 0.5161, 0.5136,
           0.5106, 0.5088, 0.5087, 0.5038],
          [0.5031, 0.5046, 0.5079, 0.5091, 0.5125, 0.5146, 0.5158, 0.5173,
           0.5218, 0.5223, 0.5230, 0.5250, 0.5288, 0.5297, 0.5306, 0.5282,
           0.5266, 0.5265, 0.5221, 0.5196, 0.5180, 0.5169, 0.5140, 0.5119,
           0.5096, 0.5074, 0.5060, 0.5030],
          [0.5029, 0.5038, 0.5061, 0.5073, 0.5121, 0.5134, 0.5130, 0.5144,
           0.5175, 0.5185, 0.5191, 0.5212, 0.5246, 0.5248, 0.5243, 0.5244,
           0.5223, 0.5217, 0.5184, 0.5161, 0.5147, 0.5134, 0.5111, 0.5100,
           0.5084, 0.5060, 0.5060, 0.5032],
          [0.5019, 0.5037, 0.5051, 0.5063, 0.5087, 0.5087, 0.5109, 0.5114,
           0.5141, 0.5142, 0.5156, 0.5170, 0.5196, 0.5188, 0.5198, 0.5195,
           0.5183, 0.5175, 0.5159, 0.5144, 0.5119, 0.5104, 0.5086, 0.5080,
           0.5072, 0.5054, 0.5046, 0.5020],
          [0.5018, 0.5027, 0.5041, 0.5051, 0.5063, 0.5084, 0.5088, 0.5093,
           0.5109, 0.5128, 0.5127, 0.5130, 0.5145, 0.5147, 0.5154, 0.5151,
           0.5136, 0.5127, 0.5113, 0.5107, 0.5089, 0.5084, 0.5068, 0.5058,
           0.5056, 0.5044, 0.5033, 0.5021],
          [0.5008, 0.5020, 0.5032, 0.5035, 0.5045, 0.5055, 0.5059, 0.5066,
           0.5074, 0.5080, 0.5085, 0.5091, 0.5111, 0.5107, 0.5111, 0.5108,
           0.5108, 0.5089, 0.5081, 0.5079, 0.5073, 0.5067, 0.5051, 0.5046,
           0.5040, 0.5039, 0.5025, 0.5018],
          [0.5005, 0.5009, 0.5013, 0.5027, 0.5026, 0.5030, 0.5038, 0.5034,
           0.5044, 0.5042, 0.5045, 0.5048, 0.5058, 0.5061, 0.5065, 0.5062,
           0.5062, 0.5051, 0.5048, 0.5043, 0.5036, 0.5037, 0.5028, 0.5024,
           0.5022, 0.5019, 0.5013, 0.5014]]]], grad_fn=<ViewBackward0>))
torch.Size([1, 4, 1, 1]) torch.Size([1, 4, 1, 1]) (1, 4, 1, 1)
tensor([[[[-1.]],

         [[-1.]],

         [[-1.]],

         [[-1.]]]]) tensor([[[[1.]],

         [[1.]],

         [[1.]],

         [[1.]]]]) <BoundedTensor: BoundedTensor([[[[-1.]],

                [[-1.]],

                [[-1.]],

                [[-1.]]]]), PerturbationLpNorm(norm=inf, eps=0, x_L=tensor([[[[-1.]],

         [[-1.]],

         [[-1.]],

         [[-1.]]]]), x_U=tensor([[[[1.]],

         [[1.]],

         [[1.]],

         [[1.]]]]))>
Traceback (most recent call last):
  File "/auto_LiRPA/script.py", line 35, in <module>
    bounds = net.compute_bounds(x=(tensor,), method="crown")  # fails
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 1206, in compute_bounds
    return self._compute_bounds_main(C=C,
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 1303, in _compute_bounds_main
    self.check_prior_bounds(final)
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  [Previous line repeated 2 more times]
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 804, in check_prior_bounds
    self.compute_intermediate_bounds(
  File "/auto_LiRPA/auto_LiRPA/bound_general.py", line 910, in compute_intermediate_bounds
    node.lower, node.upper = self.backward_general(
  File "/auto_LiRPA/auto_LiRPA/backward_bound.py", line 324, in backward_general
    lb, ub = concretize(self, batch_size, output_dim, lb, ub,
  File "/auto_LiRPA/auto_LiRPA/backward_bound.py", line 684, in concretize
    lb = lb + roots[i].perturbation.concretize(
RuntimeError: The size of tensor a (4) must match the size of tensor b (784) at non-singleton dimension 3
ERROR conda.cli.main_run:execute(49): `conda run python script.py` failed. (See above for error)

I know this behaviour is extremely strange, but since I am only subtracting 1.0 from the lower bound for which CROWN works, I don't think it's a shape issue again.

@cherrywoods cherrywoods reopened this Dec 10, 2023
@cherrywoods
Copy link
Author

I also confirmed that the error persists when I use a batch size of 10 for lb and ub.

@shizhouxing
Copy link
Member

Thanks for reporting the issue and sorry for delayed response. We have fixed it internally and will push the fix in the upcoming release soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants