IAM

ARTICLE

Generalizing Adversarial Robustness with Confidence-Calibrated Adversarial Training in PyTorch

Taking adversarial training from this previous article as baseline, this article introduces a new, confidence-calibrated variant of adversarial training that addresses two significant flaws: First, trained with L adversarial examples, adversarial training is not robust against L2 ones. Second, it incurs a significant increase in (clean) test error. Confidence-calibrated adversarial training addresses these problems by encouraging lower confidence on adversarial examples and subsequently rejecting them.

Introduction

Since adversarial training [] is the de-facto standard in obtaining adversarially robust models, there is an increased interest in addressing some of its limitations. I think the two most severe ones can be summarized as obtaining "narrow" adversarial robustness and inducing a robustness-accuracy trade-off. The former leads to robustness that does not generalize beyond the adversarial example type seen during training. Thus, $L_\infty$ adversarial training is commonly not very robust against other $L_p$ adversarial examples. The latter describes the observation that adversaria training based methods increase the (clean) test error quite significantly compared to normal training. For example, in this previous article, clean test error increases from roughly 2.5% to 9% for a WRN-18-10.

This article introduces a promising solution to both problems, termed confidence-calibrated adversarial training that I worked on for this ICML'20 paper []. It makes a simple adjustment to the adversarial training scheme: instead of minimizing cross-entropy loss on adversarial examples, encouraging high-confidence correct predictions, the model is encouraged to lower its confidence on adversarial examples. The idea is that low confidence is easier and far more meaningful to extrapolate beyond the adversarial examples seen during training. At test time, adversarial examples are then rejected (that is, detected) based on their confidence.

This article is part of a series of articles:

We build upon the implementation of adversarial training from the previous article to perform confidence-calibrated adversarial training. Evaluation of the detection scheme, however, will follow in a later article.

The code for this article can be found on GitHub:

Code on GitHub

Confidence-Calibrated Adversarial Training

Remember that standard adversarial training solves a complex min-max optimization problem,

$\min_w \mathbb{E}_{(x,y)}[\max_{\|\delta\|_p \leq \epsilon} \mathcal{L}(f(x + \delta;w), y)]$(1)

where adversarial examples are generated on-the-fly to train on. In practice, this is done by the using projected gradient descent (PGD) attack for the inner maximization. Note that Equation (1) trains only on adversarial examples. Alternatively, to balance adversarial robustness and clean performance, this can be replaced by:

$\min_w \mathbb{E}_{(x,y)}[\mathcal{L}(f(x;w), y)] + \mathbb{E}_{(x,y)}[\max_{\|\delta\|_p \leq \epsilon} \mathcal{L}(f(x + \delta;w), y)]$(2)

Essentially, this balances the clean and adversarial cross-entropy loss — in this case equally even though other weightings are possible. In practice this can be implemented by computing both losses or only computing adversarial examples for half of each mini-batch.

Figure 1: We plot confidence in the true (blue) and an adversarial (pink) class along an adversarial direction in input space. Left: Adversarial training enforces high-confidence predictions within the $L_\infty$ ball of size 0.03. Right beyond an $L_\infty$ norm of 0.03, however, adversarial examples can easily be found. Right: Confidence-calibrated adversarial training encourages low-confidence predictions on adversarial examples, which extrapolates beyond the $L_\infty$ ball of size 0.03 such that larger perturbations can be rejected, as well.

Equation (1) essentially minimizes the cross-entropy loss on adversarial examples, see Figure 1 above. This essentially means that the confidence on adversarial examples is intended to be 1 (that is, 100%). The question is whether high-confidence predictions on adversarial examples are meaningful in practice. Especially when thinking about larger $L_p$ perturbations, it seems unreasonable to expect the model to extrapolate high-confidence predictions to arbitrarily "far away" adversarial examples. This holds in particular for the extreme case of out-of-distribution examples.

Instead of high-confidence on adversarial examples (with correct prediction), the model could also be biased towards lower confidence (preserving the bias to correct prediction though). This is exactly what confidence-calibrated adversarial training does. Instead of minimizing cross-entropy loss with the true (one-hot) label, a new target distribution

$\tilde{y} = \lambda(\delta) \text{one_hot}(y) + (1 - \lambda(\delta))\frac{1}{K}$(3)

is computed, where $K$ denotes the number of classes. The idea is to design $\lambda(\delta)$ in a way that $\tilde{y}$ becomes more and more uniform for large perturbations $\delta$. So, $\lambda(\delta)$ essentially models the transition from one-hot to uniform distribution depending on the perturbation $\delta$ of the computed adversarial example. There are various options for this, but for simplicity I will only highlight a power-like transition:

$\lambda(\delta) = \left(1-\min\left(1, \frac{\|\delta\|_\infty}{\epsilon}\right)\right)^\rho$.(4)

Figure 2: Illustration of the transition function $\lambda(\delta)$ from Equation (4). The true class is indicated in blue while the probability of all other classes are shown in gray.

Illustrated in Figure 2, this transition has the advantage that the model is still forced to predict with high confidence if there is no perturbation, $\delta = 0$, while resorting to a completely uniform distribution $\tilde{y} = \frac{1}{K}$ if $\|\delta\|_\infty = \epsilon$ where $\epsilon$ denotes the maximum allowed adversarial perturbation size used during training.

Following Equation (2), the learning problem can then be written as

$\min_w \mathbb{E}_{(x,y)}[\mathcal{L}(f(x;w), y)] + \mathbb{E}_{(x,y)}[\mathcal{L}(f(x + \delta;w), \tilde{y})]$.(5)

Here, I explicitly got rid of the inner maximization problem. This is because maximizing the cross-entropy loss to find adversarial examples is just not meaningful anymore. Maximizing the cross-entropy does usually not result in the maximum confidence in the adversarial class — it just makes sure that the true class receives extremely low confidence. But because confidence-calibrated adversarial training encourages low confidence on adversarial examples, a natural attack against that would be to maximize (adversarial) confidence. A simple objective to do this can be written as

$\max_{\|\delta\|_\infty \leq \epsilon} \max_{j\neq y} f_y(x + \delta;w)$(6)

where $y$ denotes the true label of the example $x$. Essentially, this optimizes the maximum confidence in any other class.

Overall, confidence-calibrated adversarial training can be summarized as follows. In each iteration, let $x_b$ an example of the current mini-batch with corresponding label $y_b$, then:

  • For the first half of the batch:
    • Find adversarial examples $x_b + \delta_b$ by maximizing confidence using Equation (6).
    • Compute the transition function $\lambda(\delta_b)$ using Equation (5).
    • Determine the target distribution $\tilde{y}_b$ using Equation (4).
  • Compute the loss $\mathcal{L}(f(x_b + \delta_b; w), \tilde{y}_b)$ on the first half of the batch and $\mathcal{L}(f(x_b;w),y_b)$ on the other half.
  • Update the weights $w$ according to the gradients.

Figure 1 (right) shows how the transition function $\lambda(\delta)$ with $\rho = 12$ is approximated in practice. Note that the transition is not as sharp, but does ultimately resort to uniform prediction for large $\|\delta\|_\infty$.

PyTorch Implementation

The following implementation builds on the adversarial training code from this previous article.

Compared to vanilla adversarial training, confidence-calibrated adversarial training takes two additional arguments: First, it requires a slightly different loss. This is because PyTorch's default cross-entropy loss expects class labels as targets and cannot work with a target distribution like $\tilde{y}$ in Equation (3). Second, we have to define the transition function $\lambda(\delta)$. Note that we already assume an adversarial training implementation that allows a parameter fraction which determines the part of each mini-batch to compute adversarial examples for, as described for Equation (2) above.

These additional arguments are reflected in the corresponding constructor:

Listing 1: Constructor of confidence-calibrated adversarial training, taking two additional arguments: loss and transition.

class ConfidenceCalibratedAdversarialTraining(AdversarialTraining):
    def __init__(self, model, trainset, testset, optimizer, scheduler, attack, objective, loss, transition, fraction=0.5, augmentation=None, writer=common.summary.SummaryWriter(), cuda=False):
        super(ConfidenceCalibratedAdversarialTraining, self).__init__(model, trainset, testset, optimizer, scheduler, attack, objective, fraction, augmentation, loss, writer, cuda)

        # 1. Loss function that computes a loss between a predicted and a target distribution.
        # The target distribution can be one-hot encoded but does not have to.
        self.loss = loss
        """ (callable) Loss. """

        # The transition function lambda(delta).
        self.transition = transition
        """ (callable) Transition. """

        self.writer.add_text('config/loss', self.loss.__name__)
        self.writer.add_text('config/transition', self.transition.__name__)
  1. The loss will be a callable that can compute a loss between the predicted logits and a target distribution (which can be one-hot).
  2. The transition, e.g., the power one from Equation (4), is also a callable taking in adversarial perturbations and computing, per example, the weight $\lambda$.

As loss, standard cross-entropy will be used. But, as mentioned above, it needs to be implemented separately to allow uniform target distributions. The following implementation mirrors the definition of PyTorch's implementation to keep it simple:

Listing 2: Cross-entropy divergence, which is used as loss between the predicted logits and the target distribution.

def cross_entropy_divergence(logits, targets, reduction='mean'):
    assert len(list(logits.size())) == len(list(targets.size()))
    assert logits.size()[0] == targets.size()[0]
    assert logits.size()[1] == targets.size()[1]
    assert logits.size()[1] > 1

    divergences = torch.sum(- targets * torch.nn.functional.log_softmax(logits, dim=1), dim=1)
    if reduction == 'mean':
        return torch.mean(divergences)
    elif reduction == 'sum':
        return torch.sum(divergences)
    else:
        return divergences

Note that the above implementation directly works on the predicted logits and uses PyTorch's log_softmax to avoid the problem of computing the log on values that are potentially zero — or close to it. It also allows to compute the per-example loss. Next, we implement the transition function from Equation (4):

Listing 3: The transition function $\lambda(\delta)$.

def power_transition(perturbations, norm, epsilon=0.3, rho=1):
    norms = norm(perturbations)
    return 1 - torch.pow(1 - torch.min(torch.ones_like(norms), norms / epsilon), rho), norms

Before, having a look at the actual training loop, we first need to implement a new objective for the standard PGD attack used by vanilla adversarial training. This essentially follows Equation (6):

Listing 4: Objective used to compute high-confidence adversarial examples.

class UntargetedF7PObjective(Objective):
    def __call__(self, logits, perturbations=None):
        assert self.true_classes is not None
        if logits.size(1) > 1:
            current_probabilities = torch.nn.functional.softmax(logits, dim=1)
            current_probabilities = current_probabilities * (1 - common.torch.one_hot(self.true_classes, current_probabilities.size(1)))
            return - torch.max(current_probabilities, dim=1)[0]
        else:
            return self.true_classes.float()*(-1 + torch.nn.functional.sigmoid(logits.view(-1))) + (1 - self.true_classes.float())*(-torch.nn.functional.sigmoid(logits.view(-1)))

Note that this already supports both binary and multiclass problems. In order to compute the maximum confidence in a class different from the true one, the softmax probabilities are first multiplied by $1 - \text{one_hot}(y)$. This makes sure the probability in the true class is not considered and the objective is differentiable. This objective can then be used with the PGD implementation from this previous article — note that this already includes momentum and backtracking which is quite useful for confidence-calibrated adversarial training.

Now, the training loop is pretty basic when abstracting away any logging:

Listing 5: The main training loop of confidence-calibrated adversarial training.

def train(self, epoch):
    # 1. In contrast to vanilla adversarial training, the following will not train
    # reasonable models when _only_ using adversarial examples, so we need self.fraction to be < 1.
    assert self.fraction < 1

    for b, (inputs, targets) in enumerate(self.trainset):
        if self.augmentation is not None:
            inputs = self.augmentation.augment_images(inputs.numpy())

        inputs = common.torch.as_variable(inputs, self.cuda)
        targets = common.torch.as_variable(targets, self.cuda)
        distributions = common.torch.one_hot(targets, self.model.N_class)

        split = int(self.fraction * inputs.size()[0])
        # update fraction for correct loss computation
        fraction = split / float(inputs.size(0))

        # We split the mini-batch into two parts for clean and adversarial loss computation.
        clean_inputs = inputs[:split]
        adversarial_inputs = inputs[split:]
        clean_targets = targets[:split]
        adversarial_targets = targets[split:]
        clean_distributions = distributions[:split]
        adversarial_distributions = distributions[split:]

        # 1. Computing adversarial examples is done the same way as for standard adversarial training.
        # The only difference will be the used objective.
        # Note that the model is set to evaluation mode if batch normalization is included.
        self.model.eval()
        self.objective.set(adversarial_targets)
        adversarial_perturbations, adversarial_objectives = self.attack.run(self.model, adversarial_inputs, self.objective)
        adversarial_perturbations = common.torch.as_variable(adversarial_perturbations, self.cuda)
        adversarial_inputs = adversarial_inputs + adversarial_perturbations

        # 2. This is the key element: computing transition, i.e., the trade-off between one-hot and uniform
        # target distribution.
        # Note that lambda = gamma here (as lambda is a reserved keyword).
        gamma, adversarial_norms = self.transition(adversarial_perturbations)
        gamma = common.torch.expand_as(gamma, adversarial_distributions)

        # 3. Compute the target distribution accordingly.
        adversarial_distributions = adversarial_distributions*(1 - gamma)
        adversarial_distributions += gamma*torch.ones_like(adversarial_distributions)/self.model.N_class

        inputs = torch.cat((clean_inputs, adversarial_inputs), dim=0)

        self.model.train()
        self.optimizer.zero_grad()
        logits = self.model(inputs)
        clean_logits = logits[:split]
        adversarial_logits = logits[split:]

        # 4. Loss computation also follows standard adversarial training only that a 
        # different loss is used.
        adversarial_loss = self.loss(adversarial_logits, adversarial_distributions)
        adversarial_error = common.torch.classification_error(adversarial_logits, adversarial_targets)

        clean_loss = self.loss(clean_logits, clean_distributions)
        clean_error = common.torch.classification_error(clean_logits, clean_targets)
        loss = (1 - fraction) * clean_loss + fraction * adversarial_loss

        loss.backward()
        self.optimizer.step()
        self.scheduler.step()
        
        # optional logging ...

Note that most of the loop closely follows standard adversarial training:

  1. Adversarial examples are computed as before, only that a different objective is used, see Listing 4 above, where adversarial confidence is maximized instead of cross-entropy loss.
  2. Given the adversarial perturbations $\delta$, we first compute the transition $\lambda(\delta)$ for each adversarial example individually. Note that, as lambda is reserved, this is called gamma above.
  3. Given the value of gamma, the target distribution $\tilde{y}$ is computed by interpolating between a one-hot and a uniform distribution depending on gamma.
  4. Finally, loss computation also follows vanilla adversarial training, only that a different loss is used as discussed above.

Results

Figure 3: Histogram plot of confidences for adversarial examples computed against the first 1000 test examples on CIFAR10.

Figure 3 shows what happens when using confidence-calibrated adversarial training: when computing adversarial examples, they predominantly achieve low confidence. Specifically, the above figure shows the confidence for 1000 attacked test examples on CIFAR10. As can be seen, most of them, roughly 800, obtain confidence significantly below 0.5. Nevertheless, several adversarial examples still obtain very high confidence, roughly 200 of them.

Note that this only considers the confidence, calculated as the maximum probability across all classes. These adversarial examples might still switch the label (that is, they are still "adversarial"). In fact, simply running standard PGD against a model trained using confidence-calibrated adversarial training usually obtains very high robust test errors. For example, a WRN-28-10 trained with 40 iterations of the attack during training yields a robust test error of 80.16%. So it seems that the model is robust at all. But Figure 3 shows that most of these adversarial examples can easily be rejected based on their confidence — this is what confidence-calibrated adversarial training and the next articles in this series are all about.

Besides assigning low confidence to adversarial examples, confidence-calibrated adversarial training also obtains very good clean performance: the model from above obtains a 4.56% clean test error. This is only a drop of roughly 2% compared to a normally trained model, but still 4.6% better than standard adversarial training.

Conclusion

Overall, confidence-calibrated adversarial training is a very simple and intuitive variant of adversarial training that aims to reduce a model's confidence on adversarial examples. While this does not prevent adversarial examples from being found, most of them could be rejected base don their confidence. In the next articles, we will see that this improves robustness against adversarial examples that we haven't used during training. Moreover, using a WRN-28-10, the clean performance also improved significantly over standard adversarial training.

  • [] Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, Adrian Vladu. Towards Deep Learning Models Resistant to Adversarial Attacks. ICLR (Poster) 2018.
  • [] David Stutz, Matthias Hein, Bernt Schiele. Confidence-Calibrated Adversarial Training: Generalizing to Unseen Attacks. ICML 2020: 9155-9166.
What is your opinion on this article? Let me know your thoughts on Twitter @davidstutz92 or LinkedIn in/davidstutz92.