IAM

ARTICLE

Simple Adversarial Transformations in PyTorch

Another alternative to the regular Lp-constrained adversarial examples that is additionally less visible than adversarial patches or frames are adversarial transformations such as small crops, rotations and translations. Similar to Lp adversarial examples, adversarial transformations are often less visible unless the original image is available for direct comparison. In this article, I will include a PyTorch implementation and some results against adversarial training.

Introduction

Adversarial transformations are maliciously transformed images, often employing basic affine transformations such as rotations, translations, shear or scale, causing mis-classification. Compared to standard $L_p$-constrained adversarial examples (as obtained by PGD []), they offer less degrees of freedom and are considered a more realistic threat model. This is because basic transformations such as rotations or zoom are readily available in most image processing tools. In contrast to adversarial patches, stickers or frames [][], they are less visible — especially if the original image is not available for comparison. This article presents a simple PGD-based PyTorch implementation of adversarial transformations, limited to affine transformations, that is shown to be extremely effective even against adversarially trained models.

This article is part of a series:

The code for this article can be found on GitHub:

Code on GitHub

Adversarial Transformations

Standard adversarial examples, using PGD, optimize cross-entropy loss between an input example $x$ and its true label $y$ allowing for a small additive perturbation $x + \delta$:

$\max_{\|\delta\|_p \leq \epsilon} \mathcal{L}(f(x + \delta), y)$

Here, $f$ is the classifier to be fooled by the adversarial example $x + \delta$ and $\delta$ is supposed to be small in its $L_p$ norm so that the change remains imperceptible for the human eye. Maximizing the cross-entropy loss will eventuall cause mis-classification. Instead of an additive perturbations, we can also use a general transformation $T$:

$\max \mathcal{L}(f(T(x), y)$

Of course, we still want the transformation $T$ to be rather small such that $T(x)$ remains hard to distinguish from the original $x$. While this formulation generally allows more degrees of freedom (we could set $T(x) = \gamma x + \delta$ with both $\gamma$ and $\delta$ to be chosen), it is common [] to use rather simple transformations. For example, rotations, translations, zoom, etc. are often easy to apply to images using common image processing tools. In this article, we assume that $T$ is an affine transformation. The advantage is that there exist differentiable components for applying affine transformations such as spatial transformer networks []. This allows to differentiate through $T(x)$ in order to maximize the cross-entropy loss using standard PGD. Here, we assemble translation $[t_1, t_2]$, shear $[\lambda_1, \lambda_2]$, scale $s$ and rotation $r$ into an affine transformation matrix

$\left(\begin{matrix}\cos(r) s - \sin(r) s \lambda_1 & -\sin(r) s + \cos(r) s \lambda_1 & t_1\\\cos(r) s \lambda_2 + \sin(r) s & -\sin(r) s \lambda_2 + \cos(r) s & t_2\end{matrix}\right)$(1)

which results in $6$ parameters to be optimized. To keep the transformation small, we can simply limit the degree of rotations, translations, etc. using intervals — for example, only allowing rotations between $-5$ and $5$ degrees.

PyTorch Implementation

Core ingredient for implementing adversarial transformations is the spatial transformer network, which can easily be implemented using PyTorch's nn.functional.affine_grid and nn.functional.grid_sample. For simplicity, we implement the full transformation matrix of Equation (1). However, limiting the transformation to specific parts, e.g., the translation or rotation is easy to achieve from this implementation:

Listing 1: Spatial transformer network in PyTorch, implementing an affine transformation using Equation (1).

class STNDecoder(torch.nn.Module):
    def __init__(self, interpolation_mode, padding_mode):
        # ...

    def set_images(self, images):
        self.images = images

    def forward(self, theta):
        # 1. This is easily adapted to allow more constrained transformations by only considering
        # translations, shear, scale or rotation individually. Here, we stick to all 6 degrees of freedom.
        assert theta.size(1) == 6
        translation_x = theta[:, 0]
        translation_y = theta[:, 1]
        shear_x = theta[:, 2]
        shear_y = theta[:, 3]
        scales = 1 + theta[:, 4]
        rotation = theta[:, 5]

        transformation = torch.autograd.Variable(torch.FloatTensor(theta.size()[0], 6).fill_(0))
        if common.torch.is_cuda(theta):
            transformation = transformation.cuda()

        # 2. Set up the transformation matrics as above.
        transformation[:, 0] = torch.cos(rotation) * scales - torch.sin(rotation) * scales * shear_x
        transformation[:, 1] = -torch.sin(rotation) * scales + torch.cos(rotation) * scales * shear_x
        transformation[:, 2] = translation_x
        transformation[:, 3] = torch.cos(rotation) * scales * shear_y + torch.sin(rotation) * scales
        transformation[:, 4] = -torch.sin(rotation) * scales * shear_y + torch.cos(rotation) * scales
        transformation[:, 5] = translation_y
        transformation = transformation.view(-1, 2, 3)

        # 3. Apply the spatial transformer network.
        grid = torch.nn.functional.affine_grid(transformation, self.images.size())
        output = torch.nn.functional.grid_sample(self.images, grid, mode=self.interpolation_mode, padding_mode=self.padding_mode)
        output = torch.clamp(torch.clamp(output, min=0), max=1)
        return output
  1. For a batch of images, we assume a parameter tensor of shape batch_size x num_transformation_parameters. For simplicity, we use the full transformation parameters with num_transformation_parameters = 6 parameters. For readability, these parameters are named individually.
  2. Given the parameters, we set up the 2 x 3 transformation matrix for each image. Note that the input to forward are the parameters while the images are set as attribute. This is because we intend to optimize over theta, not the images.
  3. The spatial transformer network is applied using the combination of nn.functional.affine_grid and nn.functional.grid_sample: The former computes a flow field corresponding to the affine transformation; the latter applies this flow field to the images. The images are then clipped to $[0, 1]$.

Again, we follow this previous article to implement adversarial transformations. However, instead of manually implemented PGD, this implementation allows to use any arbitrary PyTorch optimizer:

Listing 2: Adversarial transformation implementation.

class BatchAffine(Attack):
    def __init__(self, optimizer, **kwargs):
        super(BatchAffine, self).__init__()
        self.optimizer = optimizer
        self.kwargs = kwargs
        self.thetas = None
        # More hyper-parameters ...

    def run(self, model, images, objective, writer=common.summary.SummaryWriter(), prefix=''):
        super(BatchAffine, self).run(model, images, objective, writer, prefix)
        self.thetas = torch.from_numpy(numpy.zeros((images.size(0), 6), dtype=numpy.float32))
        # Optionally initialize thetas differently.
        if common.torch.is_cuda(model):
            self.thetas = self.thetas.cuda()

        batch_size = self.thetas.size()[0]
        success_errors = numpy.ones((batch_size), dtype=numpy.float32)*1e12
        success_perturbations = numpy.zeros(images.size(), dtype=numpy.float32)

        # 1. Set up
        self.thetas = torch.autograd.Variable(self.thetas, requires_grad=True)
        optimizer = self.optimizer([self.thetas], **self.kwargs)
        decoder = models.STNDecoder(interpolation_mode=self.interpolation_mode, padding_mode=self.padding_mode)
        decoder.set_images(images)

        for i in range(self.max_iterations + 1):
            optimizer.zero_grad()

            # 2. Projection to make sure transformation is as small as we want it.
            if self.projection is not None:
                self.projection(torch.zeros_like(self.thetas), self.thetas)

            # 3. Apply spatial transformer network to obtain transformed images and classifier on top to compute the error.
            perturbed_images = decoder(self.thetas)
            output_logits = model(perturbed_images)
            error = self.c*self.norm(self.thetas) + objective(output_logits, self.thetas)

            # 4. Check if the last iteration improved the error.
            for b in range(batch_size):
                if error[b].item() < success_errors[b]:
                    success_errors[b] = error[b].cpu().item()
                    success_perturbations[b] = (perturbed_images[b] - images[b]).detach().cpu()
            if i == self.max_iterations:
                break

            # 5. Backward pass and optimizer step.
            loss = torch.sum(error)
            loss.backward()
            optimizer.step()

        return success_perturbations, success_errors
  1. The run method for computing adversarial transformations first sets up the transformation parameters theta to optimize as well as the optimizer and spatial decoder network, see Listing 1.
  2. In each iteration, we first project the transformation parameters onto the set of allowed transformations — this could be a $L_p$ projection as in this previous article.
  3. Then, we apply the spatial transformer network (using the transformation parameters as arguments) and forward the transformed images through the model/classifier to compute the error. The error could simply be the cross-entropy loss, but for modularity this is encapsulated in objective.
  4. After computing the error, we check whether we found a new optimum; if that is the case we safe the corresponding perturbation. Note that, for simplicity for visualization and evaluating, we return the transformed images and not the transformation parameters.
  5. Finally, if this was not the last iteration, we perform a backward pass and optimizer step to update the transformation parameters.

Results

Figure 1: Qualitative examples of adversarial transformations on CIFAR10.

Table 1: Clean test error (TE) and robust test error (RTE) for different models against adversarial transformations, see text for details.

ModelTH
Clean
RTE
Transl.
RTE
T+Shear
RTE
Full
Normal2.5%6.8%24.9%69.1%
AT9.1%19.1%56.9%88%
CCAT4.5%12.7%41.5%77.9%

For evaluation, I consider three WRN-28-10 [] models trained using normal training, adversarial training (AT) and confidence-calibrated adversarial training (CCAT). As attack, I ran three different settings, translations only, translation + shear, and full affine transformations, on the first 1000 test examples. First of all, Figure 1 shows some qualitative examples for the latter setting on CIFAR10. Clearly, some of these adversarial transformations are rather strong, but the main object is often still recognizable (subject to CIFAR's low resolution). Quantitatively, in Table 1, I report robust test error (RTE) — the test error on adversarial transformations. As expected, more degrees of freedom make adversarial transformations stronger, leading to higher RTE (and thereby more mis-classifications). In contrast to $L_p$-constrained adversarial examples, however, RTE is not close to 100% against a normally trained model. More surprisingly, adversarially trained models (AT and CCAT) are not more robust — in contrast, adversarial transformations are more successful against these models. This might partly be due to the general clean test error, but also shows that adversarial transformations are somewhat orthogonal to $L_p$ adversarial examples. However, with 100 iterations per attack and 5 restarts per example, these results might be biased by not optimally tuning the attack hyper-parameters.

Conclusion

Overall, this article highlighted another variant of adversarial attacks, adversarial transformations — in addition to $L_p$ adversarial examples and adversarial examples as discussed in previous articles. In practice, adversarial transformations are easily implemented using spatial transformer networks which allow for transformations with varying degrees of freedom. I also showed that robustness against adversarial trainsformations may be somewhat orthogonal to "standard" adversarial robustness against $L_p$ adversarial examples.

  • [] Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, Dimitris Tsipras, Adrian Vladu: Towards Deep Learning Models Resistant to Adversarial Attacks. ICLR (Poster) 2018.
  • [] Brown, Tom B., et al. "Adversarial patch." arXiv preprint arXiv:1712.09665 (2017).
  • [] Zajac, MichaƂ, et al. "Adversarial framing for image and video classification." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 33. No. 01. 2019.
  • [] Engstrom, Logan, et al. "Exploring the landscape of spatial robustness." International conference on machine learning. PMLR, 2019.
  • [] Jaderberg, Max, Karen Simonyan, and Andrew Zisserman. "Spatial transformer networks." Advances in neural information processing systems 28 (2015).
  • [] Zagoruyko, Sergey, and Nikos Komodakis. "Wide residual networks." arXiv preprint arXiv:1605.07146 (2016).
What is your opinion on this article? Let me know your thoughts on Twitter @davidstutz92 or LinkedIn in/davidstutz92.