IAM

ARTICLE

47.9% Robust Test Error on CIFAR10 with Adversarial Training and PyTorch

Knowing how to compute adversarial examples from this previous article, it would be ideal to train models for which such adversarial examples do not exist. This is the goal of developing adversarially robust training procedures. In this article, I want to describe a particularly popular approach called adversarial training. The idea is to train on adversarial examples computed during training on-the-fly. I will also discuss a PyTorch implementation that obtains 47.9% robust test error — 52.1% robust accuracy — on CIFAR10 using a WRN-28-10 architecture.

Introduction

Adversarial training [] has become the de-facto standard to train adversarially robust models. Thus, plenty of different variants are published at the top-tier conferences every year — see, for example, [][][]. On the other side, adversarial training also spurred research on improving adversarial attacks to evaluate robust models appropriately. Despite the wide variety of adversarial training methods, correctly implementing adversarial training can be tricky, especially for researchers new to adversarial robustness. This is partly due to several additional hyper-parameters compared to normal training.

Therefore, this article intends to provide both a detailed description of adversarial training as well as a PyTorch implementation that obtains good adversarial robustness in practice. The implementation follows the original implementation of [] but is based on the projected gradient descent (PGD) implementation of this previous article. This means that the code can, technically, be used for general $L_p$ adversarial training. However, results for $L_\infty$ and $L_2$ adversarial robustness will be best. Furthermore, the implementation allows the use of several popular architectures, including ResNets [] and Wide ResNets [].

This article is part of a series of articles:

This article combines the two previous ones showing how to train adversarially robust models on CIFAR10 using PGD. In future articles, I will show how to combine thie with the PGD attack presented in this article to perform adversarial training [].

The code for this article can be found on GitHub:

Code on GitHub

Adversarial Training

The intuition behind adversarial training is to continuously generate adversarial examples during training and learn to reduce cross-entropy on those adversarial examples. In the end, this will turn out to be an alternating optimization problem. Formally, however, it is generally written as a min-max optimization problem:

$\min_w \mathbb{E}_{(x,y)}[\max_{\|\delta\|_p \leq \epsilon} \mathcal{L}(f(x + \delta;w), y)]$(1)

where $\mathcal{L}$ denotes the cross-entropy loss, computed between the model's prediction $f(x + \delta;w)$ on the adversarial example $\tilde{x} = x + \delta$ and the target $y$. In practice, this problem is solved using mini-batch stochastic gradient descent. This means that the adversarial examples are actually computed on a mini-batch $B$ of training examples:

$\sum_{i \in B} \max_{\|\delta\|_p \leq \epsilon} \mathcal{L}(f(x + \delta;w), y)$

The corresponding gradient is then used to update the weights $w$. It is important to note that solving equation (1) is generally difficult. Which is why it is actually implemented in an alternating way, as suggested above. This means, we alternate between two steps:

  • Computing adversarial examples: solve the inner maximization problem, $\max_{\|\delta\|_p \leq \epsilon} \mathcal{L}(f(x + \delta;w),y)$, for each training example in the mini-batch individually, for fixed weights $w$.
  • Update weights: perform one gradient step as part of solving the outer minimization problem on the computed adversarial examples, $\min_w \sum_{i \in B} \mathcal{L}(f(x + \delta;w),y)$ for $\delta$ as computed in the first step.

Besides this high-level outline of adversarial training, there are plenty of details that need to be fixed in order to actually implement it. Some of them are still discussed in recent papers [][]. Most importantly, we need to decide how to actually solve the inner maximization problem. In [], the projected gradient descent (PGD) method for compute adversarial examples was proposed. The exact algorithm, however, is less important than the realization of using multiple iterations to solve the inner problem. Previous work [] proposed approaches similar to adversarial training but generally used only one iteration to compute adversarial examples. Unfortunately, this was shown not to result in adversarially robust models in practice.

When going through my implementation in the following section, we also have to decide how to handle batch normalization [] when computing adversarial examples and whether to compute adversarial examples for all or only part of the examples in a mini-batch.

PyTorch Implementation

My adversarial training implementation follows the training interface of this article and PGD attack from this article. As these articles include detailed descriptions of the code, I will just recap the corresponding interfaces:

The adversarial training routine will inherit and overwrite the train and test methods from the following NormalTraining class:

Listing 1: Interface for normal training procedure that our adversarial training implementation will inherit.

class NormalTraining:
        def __init__(self, model, trainset, testset, optimizer, scheduler, augmentation=None, loss=common.torch.classification_loss, writer=common.summary.SummaryWriter(), cuda=False):
        # ...

    def train(self, epoch):
        # perform one epoch of training ...

    def test(self, epoch):
        # test the current model ...

For computing adversarial examples, my PGD implementation implements the interface in Listing 2. Remember that the attack is run on a mini-batch of images given an objective to optimize. For adversarial training this will be to maximize cross-entropy loss.

Listing 2: Interface for computing adversarial examples on a given mini-batch of images.

class Attack:
    def run(self, model, images, objective, writer=common.summary.SummaryWriter(), prefix=''):
        # ...

Adversarial training takes three additional arguments: the attack to be used during training, the objective to optimize for computing adversarial examples and the fraction of adversarial examples to compute on each mini-batch. Note that, originally, [] trains on 100% adversarial examples in each mini-batch. These additional hyper-parameters are included in the following constructor:

Listing 3: Constructor for the adversarial training routine with the additional hyper-parameters attack, objective and fraction.

class AdversarialTraining(NormalTraining):
    def __init__(self, model, trainset, testset, optimizer, scheduler, attack, objective, fraction=0.5, augmentation=None, loss=common.torch.classification_loss, writer=common.summary.SummaryWriter(), cuda=False):
        assert fraction > 0
        assert fraction <= 1
        assert isinstance(attack, attacks.Attack)
        assert isinstance(objective, attacks.objectives.Objective)
        assert getattr(attack, 'norm', None) is not None

        super(AdversarialTraining, self).__init__(model, trainset, testset, optimizer, scheduler, augmentation, loss, writer, cuda)

        # 1. Attack and objective to be used for adversarial training
        self.attack = attack
        self.objective = objective
        # 2. For simplicity, we want the fraction to be the fraction of clean examples.
        self.fraction = 1 - fraction
        # number of test batches to evaluate adversarial robustness on
        self.max_batches = 10

        self.writer.add_text('config/attack', self.attack.__class__.__name__)
        self.writer.add_text('config/objective', self.objective.__class__.__name__)
        self.writer.add_text('config/fraction', str(fraction))
        # some more logging of hyper-parameters using self.writer ...
  1. Compared to normal training, adversarial training requires an attack and an attack objective as described in detail in this article.
  2. fraction is supposed to represent the fraction of adversarial examples to use during training on each mini-batch. For easier implementation however, self.fraction = 1 - fraction is used to denote the fraction of clean examples in each mini-batch. For example, [] use 50%/50% clean/adversarial examples per mini-batch, while [] uses 100% adversarial examples.

The code for one training epoch is listed below:

Listing 4: One epoch of adversarial training.

def train(self, epoch):
    for b, (inputs, targets) in enumerate(self.trainset):
        inputs = common.torch.as_variable(inputs, self.cuda)
        targets = common.torch.as_variable(targets, self.cuda)

        # 1. Split the batch into clean and adversarial examples
        # according to self.fraction = fraction of clean examples.
        fraction = self.fraction
        split = int(fraction*inputs.size(0))
        # update fraction for correct loss computation
        fraction = split / float(inputs.size(0))

        clean_inputs = inputs[:split]
        adversarial_inputs = inputs[split:]
        clean_targets = targets[:split]
        adversarial_targets = targets[split:]

        # 2. Compute adversarial examples in evaluation model.
        # Note that the true labels are specified using the attack objective (e.g., maximize cross-entropy loss).
        self.model.eval()
        self.objective.set(adversarial_targets)
        adversarial_perturbations, adversarial_objectives = self.attack.run(self.model, adversarial_inputs, self.objective)
        adversarial_perturbations = common.torch.as_variable(adversarial_perturbations, self.cuda)
        adversarial_inputs = adversarial_inputs + adversarial_perturbations

        # 3. Clean and adversarial inputs are combined in one batch.
        # This is important when training with batch normalization.
        if adversarial_inputs.shape[0] < inputs.shape[0]: # fraction is not 1
            inputs = torch.cat((clean_inputs, adversarial_inputs), dim=0)
        else:
            inputs = adversarial_inputs
            # targets remain unchanged

        # 3. We switch back training mode and perform the forward pass.
        self.model.train()
        self.optimizer.zero_grad()

        logits = self.model(inputs)
        clean_logits = logits[:split]
        adversarial_logits = logits[split:]

        # 4. Loss computation is split into clean and adversarial examples.
        # This is done mainly for monitoring them separately.
        # For the backward pass this does not make a difference.
        adversarial_loss = self.loss(adversarial_logits, adversarial_targets)
        adversarial_error = common.torch.classification_error(adversarial_logits, adversarial_targets)

        if adversarial_inputs.shape[0] < inputs.shape[0]:
            clean_loss = self.loss(clean_logits, clean_targets)
            clean_error = common.torch.classification_error(clean_logits, clean_targets)
            loss = (1 - fraction) * clean_loss + fraction * adversarial_loss
        else:
            clean_loss = torch.zeros(1)
            clean_error = torch.zeros(1)
            loss = adversarial_loss

        loss.backward()
        self.optimizer.step()
        self.scheduler.step()

        # optional logging ...
  1. In each iteration, we split inputs and targets according to self.fraction.
  2. Adversarial examples are then computed on the second part of the inputs. To this end, the model is put into evaluation model — this is mainly important to note accumulate batch normalization statistics during the attack. Note that the targets are set using the attack's objective. Generally, this setup allows training with various different (untargeted and targeted) attacks.
  3. Clean and adversarial examples are put in a single batch. This is also important when training with batch normalization because when training with 50%/50% clean/adversarial examples, for example, the batch normalization statistics are updated based on both clean and adversarial examples in each iteration. Note that the targets remain unchanged.
  4. For the forward pass, the model is switched to training mode again. Afterwards, the predicted logits are split again into those on clean and adversarial examples for loss computation.
  5. Loss and error are computed on clean and adversarial examples separately. This is mainly done for proper monitoring since performance on clean and adversarial examples can be very different throughout training. For the backward pass, however, the overall loss on the whole mini-batch is considered.

Loss and error computation are summarized below for clarity. Note that both implementations consider multiclass and binary cases:

def classification_loss(logits, targets, reduction='mean'):
    if logits.size()[1] > 1:
        return torch.nn.functional.cross_entropy(logits, targets, reduction=reduction)
    else:
        # probability 1 is class 1
        # probability 0 is class 0
        return torch.nn.functional.binary_cross_entropy(torch.nn.functional.sigmoid(logits).view(-1), targets.float(), reduction=reduction)
def classification_error(logits, targets, reduction='mean'):
    if logits.size()[1] > 1:
        values, indices = torch.max(torch.nn.functional.softmax(logits, dim=1), dim=1)
    else:
        indices = torch.round(torch.nn.functional.sigmoid(logits)).view(-1)

    errors = torch.clamp(torch.abs(indices.long() - targets.long()), max=1)
    if reduction == 'mean':
        return torch.mean(errors.float())
    elif reduction == 'sum':
        return torch.sum(errors.float())
    else:
        return errors

Results

Table 1: Results for adversarial training and different architecture. I report clean test error and robust test error. We evaluate against PGD with 10 iterations as well as AutoAttack []. A later article will focus on proper robustness evaluation.

ModelClean Test ErrorPGD Robust Test ErrorAA Robust Test Error
WRN-28-10943.847.9
SimpleNet14.546.751.4
ResNet-1815.647.352.7

In Table 1, I summarize some results obtained with different architectures: a WRN-28-10, a ResNet-18, and the lesser known but easy-to-train SimpleNet. Besides the test error on the first 1000 clean test examples, I also report the robust test error against PGD with 10 iterations (as used during training) and AutoAttack. The latter has become the de-facto standard benchmark to evaluate robustness properly. However, I will discuss robustness evaluation in more detail in a future article.

First of all, clean test error reduces significantly: the WRN-28-10 obtains 9% clean test error — without adversarial training, a test error of below 3% is possible. The other two architectures are even worse, with 14.5 and 15.6%, respectively. On the other hand, robustness is considerably better compared to a normally trained model, which usually yields near to 100% robust test error. Against the PGD attack I trained with, the WRN-28-10 obtains a robust test error of 43.8%. Against AutoAttack, a more sophisticated benchmark for adversarial robustness, robust test error increases slightly to 47.9%. Looking at recent literature [][], these results look quite good given that no early stopping or additional regularizers are employed.

Conclusion

Overall, adversarial training is a simple but effective strategy to improve robustness against adversarial examples. With this article, I want to provide a simple implementation that obtains good results in practice because properly implementing and adversarially training deep networks can be rather tricky. In some follow-up articles, I will discuss a variant, confidence-calibrated adversarial training [], and how to properly evaluate such models.

  • [] Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, Adrian Vladu. Towards Deep Learning Models Resistant to Adversarial Attacks. ICLR (Poster) 2018.
  • [] Sven Gowal, Chongli Qin, Jonathan Uesato, Timothy A. Mann, Pushmeet Kohli. Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples. CoRR abs/2010.03593 (2020).
  • [] Tianyu Pang, Xiao Yang, Yinpeng Dong, Hang Su, Jun Zhu. Bag of Tricks for Adversarial Training. CoRR abs/2010.00467 (2020).
  • [] David Stutz, Matthias Hein, Bernt Schiele. Relating Adversarially Robust Generalization to Flat Minima. CoRR abs/2104.04448 (2021).
  • [] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. Deep Residual Learning for Image Recognition. CVPR 2016: 770-778.
  • [] Sergey Zagoruyko, Nikos Komodakis. Wide Residual Networks. BMVC 2016.
  • [] Sergey Ioffe, Christian Szegedy. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. ICML 2015: 448-456.
  • [] Ian J. Goodfellow, Jonathon Shlens, Christian Szegedy. Explaining and Harnessing Adversarial Examples. ICLR (Poster) 2015.
  • [] David Stutz, Matthias Hein, Bernt Schiele. Disentangling Adversarial Robustness and Generalization. CVPR 2019: 6976-6987.
  • [] Francesco Croce, Matthias Hein. Reliable evaluation of adversarial robustness with an ensemble of diverse parameter-free attacks. ICML 2020: 2206-2216
  • [] Leslie Rice, Eric Wong, J. Zico Kolter. Overfitting in adversarially robust deep learning. ICML 2020: 8093-8104.
  • [] Dongxian Wu, Shu-Tao Xia, Yisen Wang. Adversarial Weight Perturbation Helps Robust Generalization. NeurIPS 2020.
  • [] David Stutz, Matthias Hein, Bernt Schiele. Confidence-Calibrated Adversarial Training: Generalizing to Unseen Attacks. ICML 2020: 9155-9166.
What is your opinion on this article? Let me know your thoughts on Twitter @davidstutz92 or LinkedIn in/davidstutz92.