IAM

ARTICLE

Distal Adversarial Examples Against Neural Networks in PyTorch

Out-of-distribution examples are images that are cearly irrelevant to the task at hand. Unfortunately, deep neural networks frequently assign random labels with high confidence to such examples. In this article, I want to discuss an adversarial way of computing high-confidence out-of-distribution examples, so-called distal adversarial examples, and how confidence-calibrated adversarial training handles them.

Introduction

Confidence-calibrated adversarial training, as introduced in this previous article, not only generalizes robustness to unseen types of adversarial examples but also assignes lower confidence to out-of-distribution examples. In this article, I want to illustrate this on a particular type of out-of-distribution examples: distal adversarial examples. These are essentially random noise images which are optimized to maximize confidence within a small $L_\infty$ ball [].

This article is part of a series of articles:

The code for this article can be found on GitHub:

Code on GitHub

Distal Adversarial Examples

In [], distal adversarial examples where introduced as means to adversariall train model that are aware of arbitrary out-of-distribution examples. However, they can also be used to evaluate that models do not assign high confidence to random images, even if the models were to pick up on some (spurious) correlations. Given a noise image $x$, the goal is to find a perturbation $x + \delta$ maximizing confidence, that is, solving

$\max_{\|\delta\|_\infty \leq \epsilon} \max_k \log f_k(x + \delta) = \max_{\|\delta\|_\infty \leq \epsilon} \mathcal{L}(f(x + \delta))$(1)

where $f_k$ denotes the $k$-th predicted probability (softmax output). This can be implemented using projected gradient descent (PGD) as commonly used for adversarial exampels [] and described in this article. While [] often smoothen the random noise $x$ because this leads to high confidence, even simple uniform noise usually leads to high confidence for normally trained models.

PyTorch Implementation

Using a standard PGD adversarial example implementation, we only need to adapt initialization of the noise/image as well as the objective optimized to reflect Equation (1). Uniform random initialization, however, is straight-forward to implement. Thus, Listing 1 only highlights the smoothening applied. Essentially, a randomized Gaussian kernel is applied:

Listing 1: Smoothening applied by [] on top of uniform random initialization to find distal adversarial examples.

class SmoothInitialization(Initialization):
    """
    Gaussian smoothing as initialization; can be used after any random initialization.
    """

    def __call__(self, images, perturbations):
        sigma = numpy.random.uniform(1, 2)
        gamma = numpy.random.uniform(5, 30)
        gaussian_smoothing = common.torch.GaussianLayer(sigma=sigma, channels=perturbations.size()[1])
        if common.torch.is_cuda(perturbations):
            gaussian_smoothing = gaussian_smoothing.cuda()
        perturbations.data = 1 / (1 + torch.exp(-gamma * (gaussian_smoothing.forward(perturbations) - 0.5)))

The following listing shows an implementation of the objective in terms of a loss that is to be maximized to find distal adversarial examples. Similar to other loss implementations in PyTorch, it includes averaging or summing over the batch dimension:

Listing 2: PyTorch implementation of Equation (1) to be used during PGD.

def max_log_loss(logits, targets=None, reduction='mean'):
    max_log = torch.max(torch.nn.functional.log_softmax(logits, dim=1), dim=1)[0]
    if reduction == 'mean':
        return torch.mean(max_log)
    elif reduction == 'sum':
        return torch.sum(max_log)
    else:
        return max_log

Initialization and loss can then be used using the PGD implementation from this previous article. Usage is also similar, only that the attack is run on uniform random images rather than actual test images.

Results

I generated 1000 random images of size 32x32x3 to evaluate models that I adversarially trained (AT) on CIFAR10: one standard AT model [] and one trained using confidence-calibrated adversarial training (CCAT) []. Both models are WRN-18-10 [] architectures. For evaluation, I fixed a confidence threshold such that 99% of the correctly classified test example pass it (have higher confidence). This is similar to the confidence-thresholded robust test error evaluation of this article. As with normal training, we would not expect AT to assing particularly low confidence to random images — especially after explicitly maximizing confidence as with distal adversarial examples. As a result, AT yields a 100% false positive rate. Meaning all of the 1000 distal adversarial examples obtain confidence above the threshold. CCAT in constrast assigns low confidence to all of these distal adversarial examples (0% false positive rate), showing that assigning low confidence to adversarial examples seen during training extrapolates to these distal adversarial examples.

Conclusion

In this article, I experimented with distal adversarial examples as introduced in []. These are perturbations on top of random images with the intent to maximize confidence. Essentially distal adversarial examples correspond to hard or adversarial out-of-distribution examples that models should not assign high confidence to — however, most models do. I showed that this is the case for standard adversarial training, while confidence-calibrated adversarial training avoids this problem.

  • [] David Stutz, Matthias Hein, Bernt Schiele. Confidence-Calibrated Adversarial Training: Generalizing to Unseen Attacks. ICML 2020: 9155-9166
  • [] Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, Adrian Vladu. Towards Deep Learning Models Resistant to Adversarial Attacks. ICLR (Poster) 2018.
  • [] Matthias Hein, Maksym Andriushchenko, Julian Bitterwolf. Why ReLU Networks Yield High-Confidence Predictions Far Away From the Training Data and How to Mitigate the Problem. CVPR 2019: 41-50
  • [] Sergey Zagoruyko, Nikos Komodakis. Wide Residual Networks. BMVC 2016
What is your opinion on this article? Let me know your thoughts on Twitter @davidstutz92 or LinkedIn in/davidstutz92.