Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AdvProp implementation #68

Open
michaelklachko opened this issue Jan 3, 2020 · 10 comments
Open

AdvProp implementation #68

michaelklachko opened this issue Jan 3, 2020 · 10 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@michaelklachko
Copy link

michaelklachko commented Jan 3, 2020

https://arxiv.org/abs/1911.09665

In the paper, they propose calculating two losses: one for the forward pass with "clean" BN params, and another for the forward pass with adversarial BN params. Then they combine these two losses, and backprop through both BN paths at the same time (joint optimization).

Does the following look correct to you:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, stride=2)
        self.bnC = nn.BatchNorm2d(32)
        self.bnA = nn.BatchNorm2d(32)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(32*14*14, 10)

    def forward(self, x, clean=True):
        x = self.conv(x)
        if clean:
            x = self.bnC(x)
        else:
            x = self.bnA(x)
        x = self.relu(x)
        x = self.linear(x)
        return x

model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

for i in range(1000):
    batchC, targetC = get_clean_batch()
    batchA, targetA = get_adv_batch()

    outputC = model(batchC, clean=True)
    outputA = model(batchA, clean=False)

    lossC = loss_fn(outputC, targetC)
    lossA = loss_fn(outputA, targetA)
    loss = lossC + lossA

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

If so, how would you propagate clean argument to all the blocks, especially the ones that use nn.Sequential lists?

Is there some existing AdvProp code to look at?

@rwightman
Copy link
Collaborator

rwightman commented Jan 3, 2020

@michaelklachko That's my understanding of how it'd work (with a simplified model like that). You'd then always be using clean=True in your example for validation and production.

As you say though, it's less than ideal to pass that through in more complex networks. I had two approaches to try.

  1. Make a duplicate 'aux' model that references the same conv/linear layers as the original but replaces the BN layers with new ones. I think the gradient tracking would still work but I'm not 100%. The model for validation and saving remains 100% as usual. Requires n passes for n BN sets.

  2. Make a custom BN layer that either splits along predefined batch boundary and maintains and routes to child BN layers for perturbed/advserial samples. This would be done by iterating over the model and replacing the BN layers like the SyncBN helpers (https://github.com/NVIDIA/apex/blob/e6cb749b52f44de76cac564b20e3a66b6a837424/apex/parallel/__init__.py#L21). The saving/validation code would have to be setup so it uses bn[0] and ends up appearing like an unmodified model. Requires one pass, just larger batch size.

@rwightman
Copy link
Collaborator

rwightman commented Jan 5, 2020

I did a first pass implementing option 2. It was minimal code, seems to be working, measured at somewhere between 15-20% throughput hit on my older Pascal (2 x 1080ti) machine. Think that'll be much better than two separate forward passes with the need to either maintain two models or recursively change a state variable per batch.

@michaelklachko
Copy link
Author

michaelklachko commented Jan 5, 2020

Nice! Looks like a framework limitation forced the more efficient solution, because if I could propagate the flag I'd have stopped there :)

@rwightman
Copy link
Collaborator

@michaelklachko SplitBatchNorm is in master now. I've done limited experiments with it at this stage. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/split_batchnorm.py

@sab148
Copy link

sab148 commented Apr 5, 2020

I have some questions if you can answer me please,

  1. In the paper of AdvProp section 4.3 they said

For each clean mini-batch, we first attack the network
using the auxiliary BNs to generate its adversarial counter-
part;

Using your code do we have to generate by ourselves the adversarial examples by ading noise to the clean images ? Because it seems that in the paper adversarial examples are generated by the CNN.

  1. Also about the calculation of the loss function, it seems like you calculate only one loss for both clean and adversarial examples and not calculate one for each and sum up like in the paper. Is these two ways are equivalent ?

Thank you for your time

@rwightman
Copy link
Collaborator

@sab148 The only part of the AdvProp I have so far is the SplitBatchNorm.

So yes, you must generate adversarial examples with a CNN, possibly using a variety of attacks, but I believe the AdvProp paper used one type of PGD attack. You might use a snapshot of the same model that you are training, or a variety of other pretrained models, or a combination for generating the adv examples. It will likely require another set of GPUs separate from the training ones. I have not integrated any of this but hacked together something similar in the past (out of date code, old pytorch, https://github.com/rwightman/pytorch-nips2017-adversarial/blob/master/python/train_adversarial_defense.py).

I believe the two losses is an artifact of the way the paper network was implemented. I believe they had to do two separate passes for how they implemented their BN. I replace all BN layers with one that splits and passes samples through different BNs based on the batch index ranges, so everything is in one pass and the same batch, you just need to make sure you build the batch properly, and if you did need to weight different examples differently in the loss you'd need to keep track of that.

@sab148
Copy link

sab148 commented Apr 7, 2020

@rwightman Thank you so much for your help :)

@rwightman rwightman added enhancement New feature or request help wanted Extra attention is needed labels Nov 2, 2020
@amaarora
Copy link
Contributor

amaarora commented Jan 25, 2021

. I replace all BN layers with one that splits and passes samples through different BNs based on the batch index ranges, so everything is in one pass and the same batch, you just need to make sure you build the batch properly, and if you did need to weight different examples differently in the loss you'd need to keep track of that.

Hey @rwightman - so if I am to write a tutorial or train a model using AdvProp (https://arxiv.org/abs/1911.09665) on ImageNet and your SplitBatchNorm2D then is below the correct way?

Pseudo Code:

  1. Get a batch of size 32 with imgs as X and labels as y from ImageNet, called batch(X, y)
  2. Generate a new auxilary batch for these 32 imgs, labels where e is the adverserial noise so we get a new batch called batch(X+e, y)
  3. Create a batch of size 64 where the first 32 images are clean and next 32 images are adverserial example
  4. Convert the model using split_batchnorm function to convert all BatchNorm2D layers to SplitBatchNorm2D layers where number of splits equals 2.

And now, the standard training steps-
5. Calculate the loss using nn.CrossEntropyLoss
6. Do loss.backward()
7. Do optimizer.step()
8. Do scheduler.step() if scheduler is present

@rwightman
Copy link
Collaborator

@amaarora likely some missing details (as there always are), but that is the gist of it yes. Similar to the AugMix training in terms of batch construction and possible splitbn use, BUT the generation of adversarial samples on the fly isn't trivial. For PGD and even basic iterative adv example generation you likely need to dedicate 1:1 or 1:2 ratio of GPUs and may want some queuing to allow those examples to be generated in separate processes from training processes.

@amaarora
Copy link
Contributor

Thank you for your help :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants