This post explores the theory and some follow-up papers of one of the most influential machine learning papers: Generative Adversarial Networks (GANs). Contrary to other deep learning models, I find that generative models are supported by more rigorous mathematics that are easily digestible.

As we understand the theory behind GANs (mostly through the paper by Arjovsky and Bottou, 2017), we will recognize its limitations and the reason behind its instability. This naturally leads us to Wasserstein Generative Adversarial Networks (WGANs), which apply useful concepts from Optimal Transport (OT).

Since these papers are quite “old” by machine learning standards, there are many blog posts that already discuss them. The following is my own understanding plus some parts I find missing in other posts. A Kaggle notebook is also provided to compare the algorithms learning on a toy example: Kaggle Notebook Link

GANs

Cool History

The story is that GAN was conceived at a bar in Montréal (Les 3 Brasseurs) by Ian Goodfellow during his Ph.D. studies. At the bar, he proposed GAN to his friends, but quickly faced skepticism.

The idea of jointly training a pair of networks against each other seemed too difficult when training just one network was difficult enough. Perhaps due to the influence of alcohol, he was still confident, so he headed home, coded up GAN, and produced amazing results on MNIST (a handwritten digit recognition dataset).

As of today, the seminal paper published in NeurIPS has over 75,000 citations, and has become a cornerstone in the field of unsupervised learning.

GAN Architecture (Goodfellow et al., 2014)

GANs learn the latent distribution of the unlabeled data through a pair of networks (the generator and the discriminator) that are competing against each other. An analogy is used in the original paper: the generator is a team of counterfeiters that produces fake currency, while the discriminator is the police trying to detect forgery. When both networks are optimal, the generator is able to produce counterfeit currency that is indistinguishable from the real ones.

GAN architecture
Figure 1: A GAN consists of two components: the discriminator $D$ outputs the probability that a given sample is real, and the generator $G$ produces synthetic samples given a latent variable $z$ sampled from the base distribution, e.g. $z \sim \mathcal{N}(0,1)$. The discriminator parameters $\theta_d$ are updated to assigns high probability to real samples and low probability to synthetic samples, while the generator parameters $\theta_g$ are updated to "fool" the discriminator into assigning high probabilities to the synthetic samples.

First, let’s define $p_z, p_g, p_r$ as the probability distributions of the latent variable $z$, the generator, and real samples respectively.

The discriminator $D(x,\theta_d)\to [0, 1]$ outputs a scalar that represents the probability that $x$ came from $p_r$ rather than $p_g$. The generator $G(z, \theta_g)$ then learns a mapping from $z$ to the data space. In other words, it is learning $p_g$ over data $x$, where $x\sim p_r(x)$.

$D$ is trained in a binary classification fashion. It maximizes the probability of assigning the correct label to both real samples $x$ and fake samples $x’ = G(z)$. This is equivalent to maximizing $$\max_D \mathbb{E}_{x\sim p_r(x)}\left[\log (D(x))\right] +\mathbb{E}_{z\sim p_z(z)}\left[\log(1- D(G(z)))\right].$$

The generator then minimizes the probability of the discriminator assigning the correct label to the fake samples: $$\min_G \mathbb{E}_{z\sim p_z(z)}\left[\log(1- D(G(z)))\right].$$ In other words, they are playing a MinMax game with the following loss function: $$ \begin{equation}\label{eq: gan_loss} \min_{G}\max_{D} L(G, D) = \mathbb{E}_{x\sim p_r(x)}\left[\log (D(x))\right] + \mathbb{E}_{z\sim p_z(z)}\left[\log(1- D(G(z)))\right] \end{equation} $$

How to train GANs

Since the discriminator and the generator are playing a two-player non-cooperative game, training them can be difficult, and similar to finding the Nash equilibrium (Salimans et al., 2016).

In practice, this is traditionally done with the following pipeline: in each training loop, fix the generator and update the discriminator for $k$ steps, then fix the discriminator and update the generator for 1 step.

To update the discriminator, we fix the generator and isolate the $\max_{D} L(G, D)$ part of Equation $\eqref{eq: gan_loss}$: $$ \begin{aligned} L(\theta_d) & = \mathbb{E}_{x\sim p_r(x)}\left[\log (D(x, \theta_d))\right] + \mathbb{E}_{z\sim p_z(z)}\left[\log(1- D(G^*(z), \theta_d))\right] \\ & \approx \frac{1}{m}\sum_{i=1}^m\left[\log (D(x_i, \theta_d))\right] + \frac{1}{n}\sum_{j=1}^n\left[ \log( 1 - D ( G^* (z_j), \theta_d)) \right], \end{aligned} $$ where $m$ represents the number of real examples, $n$ represents the number of fake examples, and $G^*(z)$ represents the fixed generator.
To update the generator, we use the following loss function: $$ \begin{aligned} L(\theta_g) & = \left(\mathbb{E}_{x\sim p_r(x)}\left[\log (D^*(x))\right] + \mathbb{E}_{z\sim p_z(z)}\left[\log(1- D^*(G(z, \theta_g)))\right]\right) \\ & \approx \frac{1}{n}\sum_{j=1}^n\left[ \log ( 1 - D^* (G(z_j, \theta_g)))\right] \end{aligned}. $$ Note that we remove the first term since it is not related to $G(z, \theta_g)$.

The Optimal Discriminator and Generator

It is intuitive that in theory, the perfect generator replicates the entire data space, in other words, $p_g = p_r$ (see proof below). It is natural then to think that the perfect discriminator can always differentiate between real and synthetic samples, however in theory, the best discriminator achieves random guessing.

Lil’Log provides a great breakdown on the math behind this. The general idea is that if we take the derivative of the discriminator’s loss function with a fixed generator, and set it to zero, $D^*(x) = \frac{p_r(x)}{p_r(x)+p_g(x)}$. And since the perfect generator enables $p_g = p_r$, $D^*(x) = \frac{1}{2}$.

Given $G^*(z, \theta_g)$ such that $p_r = p_g$, and $D^*(x, \theta_d)$ such that $D^*(x) = \frac{1}{2}$, we can derive the global minimum of the loss function: $$ \begin{aligned} L(G^*, D^*) &= \mathbb{E}_{x\sim p_r(x)}\left[-\log2\right] +\mathbb{E}_{x\sim p_g(x)}\left[-\log2\right]\\ &= -2\log2 \end{aligned} $$ Thus, the best possible value of $L(G, D)$ is $-2\log2$.

Tangent on KL & JS Divergence

Kullback-Leibler (KL) divergence is one of the most commonly used measures in machine learning for comparing two probability distributions. Given distributions $p, q$, it measures how different $q$ is from $p$ with the following formulation: $$ KL(p\Vert q) = \int_{x}p(x)\log\frac{p(x)}{q(x)}dx. $$ Usually, $p$ is the true distribution, and $q$ is the model. The interpretation from information theory is that KL divergence measure the extra information gain (relative entropy) when switching from $q$ to $p$.

There are many limitations to KL divergence, one of the most being its asymmetry. If we explore the Wikipedia page for KL divergence, we will find that the Jensen-Shannon (JS) divergence addresses this issue by taking the KL from both sides of the mixture distribution: $$ JSD(p\Vert q) = \frac{1}{2}KL(p\Vert\frac{p+q}{2}) + \frac{1}{2}KL(q\Vert\frac{p+q}{2}). $$

With this tool in hand, we can understand the loss function from a different perspective and prove that the perfect generator $p_g$ replicates the data space $p_r$. Recall that the optimal discriminator $D^*(x) = \frac{p_r(x)}{p_r(x)+p_g(x)}$, if we plug this into Equation $\eqref{eq: gan_loss}$, we get: $$ \begin{aligned} L(G, D^{*}) &= \mathbb{E}_{x\sim p_r(x)}\left[\log (D^*(x))\right] + \mathbb{E}_{x\sim p_g(x)}\left[\log(1- D^*(x)\right]\\ &= \int_{x}p_r(x)\log\left(\frac{2\cdot p_r(x)}{2\cdot(p_r(x)+p_g(x))}\right)dx\\ &\phantom{=} + \int_{x}p_g(x)\log\left(\frac{2\cdot p_g(x)}{2\cdot(p_r(x)+p_g(x))}\right)dx\\ &= -2\log2 + \int_{x}p_r(x)\log\left(\frac{p_r(x)}{\frac{p_r(x)+p_g(x)}{2}}\right)dx\\ &\phantom{=} + \int_{x}p_g(x)\log\left(\frac{p_g(x)}{\frac{p_r(x)+p_g(x)}{2}}\right)dx\\ &= -2\log2 + KL(p_r\Vert\frac{p_r+p_g}{2}) + KL(p_g\Vert\frac{p_r+p_g}{2})\\ &= -2\log2 + 2JSD(p_r\Vert p_g). \end{aligned} $$ Thus, when the discriminator is optimal, the loss function is equivalent to measuring the difference between $p_g$ and $p_r$ with JS divergence. Also, when the generator is optimal, $p_r = p_g$, and $2JSD(p_r, p_g)=0$, so $L(G^*, D^*) = -2\log2$, which is the global optimal.

The Real Cost Function (Arjovsky & Bottou, 2017)

manifold
Figure 2: Low dimensional manifolds in high dimension space can hardly have overlaps.

We have proved that the optimally trained discriminator should have a maximum cost of $-2\log2+2JSD(p_r|p_g)$, and with the optimal generator, this cost goes to $-2\log 2$. However, in practice, the training loss approaches $0$. This implies that $JSD(p_r|p_g)=\log2$, and the JS-divergence between the two distributions is maxed out.

This is because the supports for $p_r$ and $p_g$ lie on low dimensional manifolds. To provide some intuition, consider the data space $\mathcal{X}$ which contains all possible real-world images. Although the dimension of $\mathcal{X}$ can be artificially high (e.g. $256\times256\times3$), there are pre-existing restrictions placed by nature. For instance, a mammal face usually contains eyes, ears, a nose, a mouth, and a jaw. Thus, the support of $p_r$ is concentrated on a low dimensional manifold. Similarly, since the generator $G(z, \theta_g): \mathcal{Z} \to \mathcal{X}$ estimates the data space by sampling from a prior distribution $p_z$ with less dimension (e.g. 100) than $\mathcal{X}$, the support of $p_g$ is also on a low dimensional manifold.

Recall that if $\mathcal{X}: \Omega\to\mathbb{R}^n$ is a random variable, then the support of $\mathcal{X}$ is defined as the closure of the set $\mathrm{Supp}(x) = {x\in\mathbb{R}^n: p_r(x)>0}$. For instance, in our example of GANs, the support of $p_g$ has to be contained in the range of $G(z, \theta_g)$. In other words, the support of $p_g$ is the set of points in the data space that the generator can produce with non-zero probability.

Denote $\mathcal{M}$ as the manifold where the support of $p_r$ lies in, and $\mathcal{P}$ the manifold where the support of $p_g$ lies in. It can be shown that:

  1. If $\mathcal{M}$ and $\mathcal{P}$ don’t perfectly align and don’t have full dimensions, then there exists a perfect discriminator $D^{**}: \mathcal{X}\to[0,1]$ such that it takes the value $1$ on a set that contains the support of $p_r$ and value $0$ on a set that contains the support of $p_g$. Also, $\nabla_xD^*(x)=0$.
  2. If $\mathcal{M}$ and $\mathcal{P}$ don’t perfectly align and don’t have full dimensions, then \begin{aligned} JSD(p_r|p_g) &= \log2\\ KL(p_r|p_g) &= +\infty\\ KL(p_g|p_r) &= +\infty. \end{aligned}
  3. Let $J_{\theta_g} G(z,\theta_g)$ denote the Jacobian of $G$ with respect to $\theta_g$. If the perfect discriminator is satisfied, and $|D-D^{**}|<\epsilon$, and $\mathbb{E}_{z\sim p(z)}\left[|J_{\theta_g} G(z,\theta_g)|^2_2\right]\leq C^2$, then $$ |\nabla_{\theta_g}\mathbb{E}_{z\sim p(z)}\left[\log(1-D(G(z,\theta_g)))\right]|_2 < C\frac{\epsilon}{1-\epsilon}. $$

We will go through each item in the following paragraphs.
Note that the perfect discriminator $D^{**}$ is different from the optimal discriminator discussed above, which assigns equal probabilities to fake and real samples. In fact, this difference is precisely the reason why in practice, the cost of the discriminator approaches 0. In conjunction with the result in 2, we understand that although the generator is producing realistic looking results, the two underlying distributions $p_r, p_g$ are actually very different and easy for the discriminator to distinguish.

To explain the idea in 2 in simpler terms, consider the following four cases between $p_r$ and $p_g$ when calculating $JSD$: \begin{aligned} p_r(x) &= 0, p_g(x) = 0\\ p_r(x) &\neq 0, p_g(x) = 0 \\ p_r(x) &= 0, p_g(x) \neq 0\\ p_r(x) &\neq 0, p_g(x) \neq 0 \end{aligned} The first case does not contribute to $JSD$. The second case is $\frac{1}{2}\log\left(\frac{p_r}{\frac{p_r+0}{2}}\right)=\frac{1}{2}\log2$ in $JSD$ calculation. Similarly, the third case contributes $\frac{1}{2}\log2$.
For the fourth case, since $\mathcal{M}$ and $\mathcal{P}$ don’t perfectly align and don’t have full dimensions, for any parts where $p_r(x)$ and $p_g(x)$ overlap, the resulting integral is negligible, so it also does not contribute to $JSD$. To provide further intuition, if the data space is $\mathbb{R}^3$, and the supports are planes, then it is very unlikely for them to be perfectly aligned; their intersection is most likely a line which does not contribute to the measure (see Figure 2).

Finally, the result in 3 provides a bound on the generator’s loss when the difference between the discriminator and the perfect discriminator is also bounded. This is also known as the vanishing gradient problem, and explains the instability in GAN training. In fact, when the discriminator is perfect, the gradient provided to the generator approaches $0$: $$\lim_{|D-D^{**}|\to0}\nabla_{\theta_g}\mathbb{E}_{z\sim p(z)}\left[\log(1-D(G(z,\theta_g)))\right] = 0.$$ This is intuitive as the generator’s loss function depends on the discriminator.

Overall, this subsection shows that under the original loss function, there exists a perfect discriminator, which easily distinguishes $p_g,p_r$, and provides little signal for the generator to learn.

WGANs (Arjovsky, Chintala, & Bottou, 2017)

After providing a mathematical framework to dissect GANs, and formally understanding its limitations, Arjovsky and Bottou, this time with the addition of Chintala, borrowed ideas of optimal transport and proposed a new generative model that is easier to train.

As someone without a strong mathematical background, I used the following resources to understand the theory behind optimal transport and measure theory, you may find them helpful too:

The Wasserstein Distance

One important feature of optimal transport is that it can be used to define a continuous metric on a space of probability measures. If we compare distributions $p_r$ and $p_g$ using the minimum $L^p$ cost necessary for an optimal transport plan to shift from $p_r$ to $p_g$, this quantity is called the Wasserstein distance $W(p_r, p_g)$, and it is defined as: $$ W(p_r, p_g) = \inf_{\gamma \in \Pi(p_r, p_g)} \mathbb{E}_{(x, y) \sim \gamma} [\lVert x - y \rVert], $$ where $\Pi(p_r, p_g)$ is the set of all possible joint probability distributions between $p_r$ and $p_g$, or all possible transport plans (Alex Williams’ post contains a step-by-step example on calculating the transport plan costs).
By using the Kantorovich-Rubinstein duality, it can be reformulated as: $$ W(p_r, p_g) = \sup_{\lVert h \rVert_{L \leq 1}} \mathbb{E}_{x \sim p_r} [h(x)] - \mathbb{E}_{x \sim p_g} [h(x)], $$ where the supremum is taken over all $1$-Lipschitz functions. Intuitively, the duality expresses the Wasserstein distance between $p_r$ and $p_g$ as the largest difference in expectations of certain functions $h$ that satisfy the Lipschitz constraint. The dual formulation is often easier to compute since it only involves optimizing over a set of functions rather than a set of joint probability measures.

The idea of behind WGANs is very simple: optimize the parameters by solving the $W(p_r, p_g)$ instead. Doing so avoids the vanishing gradient we mentioned above, even when the supports for $p_r$ and $p_g$ lie on low dimensional manifolds.

WGAN Implementation

Concretely, the discriminator $D$ is replaced by a critic $C$ that estimates the Wasserstein distance between the real data distribution from the fake one: $$ \max_{C\in \mathcal{C}} \mathbb{E}_{x\sim p_r(x)}[C(x)] - \mathbb{E}_{x\sim p_g(x)}[C(x)], $$ where $\mathcal{C}$ is the set of $1$-Lipschitz functions.
The generator still tries to fool the critic by minimizing the estimated Wasserstein distance: $$ \min_G \mathbb{E}_{z\sim p_z(z)}[C(G(z))]. $$ The overall loss function is now: $$ \begin{equation}\label{eq: wgan_loss} \min_{G}\max_{C\in \mathcal{C}} \mathbb{E}_{x\sim p_r(x)}[C(x)] - \mathbb{E}_{z\sim p_z(z)}[C(G(z))]. \end{equation} $$

Compared to the original GAN loss (equation $\eqref{eq: gan_loss}$), this loss function seems simpler, and it actually is! To implement WGAN, we simply remove the final activation function from the discriminator, so that instead of measuring the probability, we output a scalar value (e.g. $[-1, 1]$) from the final linear layer directly which corresponds to the estimated distribution distance. In other words, the only architectural difference between GANs and WGANs is the lack of a sigmoid activation function in the critic.

Of course, we also need to enforce the Lipschitz constraint. The original WGAN paper proposes to address this through weight clipping: by simply clipping the weights of the critic to a compact set $[−c, c]$ during training. Another prominent work in this direction is gradient penalty by Gulrajani et al., 2017, which penalizes the norm of gradient with respect to its input.

Kaggle Notebook on Toy Examples (Link)

I wanted to play with GANs and WGANs on some toy examples. In this notebook, I train and compare each algorithm on 2D Gaussian and geometric shapes. The implementation is done in TensorFlow as a practice for me, so feel free to comment and leave suggestions.

The crux of the implementation is very simple:

import tensorflow as tf
from tensorflow import keras

import keras.backend as K
from keras import layers
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Activation
from keras.models import Sequential, Model
from keras.constraints import MaxNorm

cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)
# GAN losses
def discriminator_loss_fn(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def gan_generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# WGAN losses
def critic_loss_fn(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)  # equation 2

def wgan_generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)  # generator update does not depend on the first term in equation 2

# Model Builders, simple feedforward networks with 2 hidden layers
def build_generator(hidden_dim, latent_dim, output_dim=2):
    model = Sequential([
        Input(shape=(latent_dim,)),
        Dense(hidden_dim, activation='relu'),
        Dense(hidden_dim, activation='relu'),
        Dense(output_dim, activation='linear')
    ])
    return model

def build_discriminator(hidden_dim, lr, input_dim=2):
    model = Sequential([
        Input(shape=(input_dim,)),
        Dense(hidden_dim, activation='relu'),
        Dense(hidden_dim, activation='relu'),
        Dense(1)  # using cross entropy from logits, no activation
    ])
    return model

# notice how the implementation is the same for discriminator and critic, except critic has added weight clipping
def build_critic(hidden_dim, lr, input_dim=2, clip_value=0.01):
    const = MaxNorm(clip_value)  # weight clipping
    model = Sequential([
        Input(shape=(input_dim,)),
        Dense(hidden_dim, activation='relu', kernel_constraint=const),
        Dense(hidden_dim, activation='relu', kernel_constraint=const),
        Dense(1, kernel_constraint=const)
    ])
    return model

# per-iteration/batch training function for both GAN and WGAN
@tf.function
def train_gan_step(
        model,
        X_train,
        dis_iter=5,
        batch_size=64,
        latent_dim=2,
):
    dis, generator = model['dis'], model['gen']
    dis_loss, generator_loss = model['dis_loss'], model['gen_loss']
    dis_opt, generator_opt = model['dis_opt'], model['gen_opt']
    
    # Loop over discriminator/critic training steps (dis_iter times)
    for _ in range(dis_iter):
        noise = tf.random.normal([batch_size, latent_dim])
        batch_idx = tf.random.shuffle(tf.range(X_train.shape[0]))[:batch_size]
        sample_batch_inner = tf.gather(X_train, batch_idx)

        with tf.GradientTape() as disc_tape:
            generated_samples = generator(noise, training=False)
            real_output = dis(sample_batch_inner, training=True)
            fake_output_gan = dis(generated_samples, training=True)
            disc_loss = dis_loss(real_output, fake_output_gan)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, dis.trainable_variables)
        dis_opt.apply_gradients(zip(gradients_of_discriminator, dis.trainable_variables))

    # Train generator
    noise = tf.random.normal([batch_size, latent_dim])
    with tf.GradientTape() as gen_tape:
        generated_samples = generator(noise, training=True)
        fake_output = dis(generated_samples, training=False)
        gen_loss = generator_loss(fake_output)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    generator_opt.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    return

During training, I expected WGANs to be less sensitive to hyperparameter choices than GANs, especially on these simple 2D examples. However, the weight clipping value $c$ had a significant impact on the learning process.
The original WGAN paper had $c=0.01$, but this made the model update steps way too small, and resulted in converge slower than expected. Instead, I found that $c=0.5$ worked best. I have included a hyperparameter study pipeline in the notebook for readers to explore.
It seems that having multiple discriminator updates per generator update is also beneficial for GAN training.

d64_lr5e-5
Figure 3: Learning on an 8-Gaussian dataset. It is cool how the model is warping the cloud of points into the circular shape.
moon
Figure 4: Learning on a crescent moon dataset. WGAN is visibly more stable during training
gan_loss
wgan_loss
gan_grad
wgan_grad
Figure 5: Losses and Gradients on the crescent moon dataset. WGAN gradient (bottom right) is much more stable compared to the oscillating GAN gradient (bottom left).

Thank you for reading! If you found this post helpful in your own writing or research, please cite it as:

@article{bai2025gan,
  title={The Story behind WGAN},
  author={Bai, Mark},
  journal={rdh1115.github.io},
  year={2025},
  url={https://rdh1115.github.io/posts/literature/gan-wgan/}
}