Try out my SAGAN implementation

This article discusses self-attention generative adversarial networks, a new technique that improves the stability of generative adversarial networks and the quality of their samples. An orthogonal technique called spectral normalization is explained in this blog post.

Since generative adversarial networks were conceived four years ago, startling improvements in the quality and diversity of samples have been achieved. Much of this progress has come from stabilizing the adversarial game between the discriminator and the generator, which can collapse if the discriminator can differeniate real and fake examples too easily. Some notable papers about GAN stabilization include

Many of the innovations above are agnostic to the problem where GANs are typically applied: image generation. In this setting, the generator is a deconvolutional neural network and the discriminator is a convolutional neural network. Typically, the discriminator is a series of local convolutional filters followed by nonlinearities, which successively downsample the image. The downsampling process terminates in a global average pooling, which compresses the activations to a single numeric output.

This average pooling treats distant points in the image as independent, which is certainly not the case in most natural images. For instance, suppose there is a real image in our training set

real sample

and our generator produces an output

real sample

In this case, the generator matches the true data exactly except at a small spatial position where the generator mistakenly produces a red box. If the discriminator is using global average pooling, then the outputs of distant receptive fields—which cannot differentiate the two images—will be averaged with the outputs of few receptive fields that do contain the red box.

real sample

This averaging destroys the signal that the discriminator needs to differentiate the two images, making it difficult for the generator to improve further.

The independence assumption of global pooling (and direct linear layers as well) makes it difficult for the discriminator to detect errors in the global configuration of features—this is one reason why it has been difficult to generate high-quality ImageNet samples.

Non-local Neural Networks (Wang et al.) address the problem of CNNs processing information too locally by introducing a self-attention mechanism, where the output of each activation is modulated by a subset of other activations. This helps capture dependencies between distant parts of the images and allows the CNN to attend to smaller parts of the image if necessary.

Self-Attention Generative Adversarial Networks (SAGAN) apply this mechanism to the GAN discriminator, mitigating the problem above.

In a self-attention block, we temporarily ignore all spatial structure by flattening each channel into a single vector, after computing a local linear transformation on each pixel:

\begin{equation*} o_\text{self-attention} = \underbrace{\text{softmax}(\overbrace{x^T W_f^T}^{HW\times C}\,\,\overbrace{W_g x}^{C\times HW})}_{HW \times HW} \underbrace{W_h x}_{HW \times C} \end{equation*}

where \(H\) is the height, \(W\) is the width, and \(C\) is the number of channels in the self-attention layer.

The softmax attention output weighs the influence of distant outputs on the output at each position in the discriminator. This allows the discriminator to explicitly compare outputs at separated points in the image and, in addition, focus all of its attention on a specific location in the image.

The self-attention output is added to the original output and multiplied by a weight parameter. At the start of training, this weight is zero; as the generated samples become more difficult to distinguish from the real samples, this weight can be increased so that the discriminator can use self-attention.

Self-attention (taken from Zhang et al.)

The above self-attention visualizations from the SAGAN paper demonstrate how the attention units capture the object shape. In most cases, the attention maps capture local feature similarities (2nd from left, top) as a convolutional network would. However, in some cases (1st and second from right, top and bottom), the self-attention layer detects detects correlations between features in distant regions of the image. This explains why the global structure of SAGAN samples, seen above, looks so convincing. Below is an example of the self-attention maps from my implementation of SAGAN for the top-left-hand corner of the image.

self-attention visualization

Introducing the self-attention mechanism to GANs improves the quality of samples considerably: the authors of SAGAN report an Inception score of 52.52, an enormous improvement on the already impressive result of 36.8 achieved by spectral normalization a few months ago. This improvement in quality and diversity metrics is reflected directly in the quality of SAGAN ImageNet samples; many are indistinguishable from real images.

In my opinion, self-attention GANs reveal the potential to improve GANs by incorporating better priors into models (in this case, our prior is to explicitly build in modeling of non-local feature dependencies into some layers of the discriminator). This approach to improving GANs stands in contrast to more explored approach thus far: changing the GAN training objective or discriminator regularization. We have seen before with LapGAN and Progressive Growing GAN that carefully designed discriminator architectures (i.e., model priors) can lead to improved sample quality and diversity. With the breakthrough of self-attention GANs, we see this line of work confirmed yet again. It is only a matter of time that further improvements based on this strategy will come, and this already impressive state of the art will be beaten once again.