-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AdvProp implementation #68
Comments
@michaelklachko That's my understanding of how it'd work (with a simplified model like that). You'd then always be using As you say though, it's less than ideal to pass that through in more complex networks. I had two approaches to try.
|
I did a first pass implementing option 2. It was minimal code, seems to be working, measured at somewhere between 15-20% throughput hit on my older Pascal (2 x 1080ti) machine. Think that'll be much better than two separate forward passes with the need to either maintain two models or recursively change a state variable per batch. |
Nice! Looks like a framework limitation forced the more efficient solution, because if I could propagate the flag I'd have stopped there :) |
@michaelklachko SplitBatchNorm is in master now. I've done limited experiments with it at this stage. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/split_batchnorm.py |
I have some questions if you can answer me please,
Using your code do we have to generate by ourselves the adversarial examples by ading noise to the clean images ? Because it seems that in the paper adversarial examples are generated by the CNN.
Thank you for your time |
@sab148 The only part of the AdvProp I have so far is the SplitBatchNorm. So yes, you must generate adversarial examples with a CNN, possibly using a variety of attacks, but I believe the AdvProp paper used one type of PGD attack. You might use a snapshot of the same model that you are training, or a variety of other pretrained models, or a combination for generating the adv examples. It will likely require another set of GPUs separate from the training ones. I have not integrated any of this but hacked together something similar in the past (out of date code, old pytorch, https://github.com/rwightman/pytorch-nips2017-adversarial/blob/master/python/train_adversarial_defense.py). I believe the two losses is an artifact of the way the paper network was implemented. I believe they had to do two separate passes for how they implemented their BN. I replace all BN layers with one that splits and passes samples through different BNs based on the batch index ranges, so everything is in one pass and the same batch, you just need to make sure you build the batch properly, and if you did need to weight different examples differently in the loss you'd need to keep track of that. |
@rwightman Thank you so much for your help :) |
Hey @rwightman - so if I am to write a tutorial or train a model using AdvProp (https://arxiv.org/abs/1911.09665) on ImageNet and your SplitBatchNorm2D then is below the correct way? Pseudo Code:
And now, the standard training steps- |
@amaarora likely some missing details (as there always are), but that is the gist of it yes. Similar to the AugMix training in terms of batch construction and possible splitbn use, BUT the generation of adversarial samples on the fly isn't trivial. For PGD and even basic iterative adv example generation you likely need to dedicate 1:1 or 1:2 ratio of GPUs and may want some queuing to allow those examples to be generated in separate processes from training processes. |
Thank you for your help :) |
https://arxiv.org/abs/1911.09665
In the paper, they propose calculating two losses: one for the forward pass with "clean" BN params, and another for the forward pass with adversarial BN params. Then they combine these two losses, and backprop through both BN paths at the same time (joint optimization).
Does the following look correct to you:
If so, how would you propagate
clean
argument to all the blocks, especially the ones that use nn.Sequential lists?Is there some existing AdvProp code to look at?
The text was updated successfully, but these errors were encountered: