A Variational Autoencoder (VAE) doesn’t learn a direct mapping from input to output, but rather learns a probabilistic mapping to a distribution in the latent space, from which the output can be reconstructed.

Let’s see this in action. Imagine we have a dataset of handwritten digits (MNIST). A standard autoencoder would try to compress an image into a single point in latent space and then reconstruct it. A VAE, however, compresses the image into the parameters of a probability distribution (typically a Gaussian, defined by its mean and variance) in that latent space. Then, it samples a point from this distribution and decodes that sample to reconstruct the original image.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

# Define the VAE model
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstruction = self.decoder(z)
        return reconstruction

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                keras.losses.binary_crossentropy(data, reconstruction)
            )
            kl_loss = -0.5 * tf.reduce_sum(
                1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1
            )
            total_loss = reconstruction_loss + tf.reduce_mean(kl_loss)
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

# Load and preprocess MNIST data
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

# Define encoder
latent_dim = 2
encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var], name="encoder")

# Define sampling layer
class Sampling(layers.Layer):
    def __init__(self, **kwargs):
        super(Sampling, self).__init__(**kwargs)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    def get_config(self):
        config = super(Sampling, self).get_config()
        return config

# Define decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")

# Instantiate VAE
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())

# Train VAE
vae.fit(mnist_digits, epochs=10, batch_size=128)

# Plotting latent space and reconstructions
encoder_output_mean = encoder.predict(mnist_digits[:1000])[0]
decoder_output = decoder.predict(np.random.normal(size=(100, latent_dim)))

plt.figure(figsize=(10, 10))
plt.scatter(encoder_output_mean[:, 0], encoder_output_mean[:, 1], alpha=0.5)
plt.title("Latent Space Visualization")
plt.xlabel("z_mean_dim_1")
plt.ylabel("z_mean_dim_2")
plt.show()

n = 10  # Digits to generate
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-4, 4, n)
grid_y = np.linspace(-4, 4, n)

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        x_decoded = decoder.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='gray')
plt.title("Generated Digits from Latent Space Walk")
plt.show()

The core problem a VAE solves is learning a meaningful, continuous, and structured latent representation of data. Unlike a standard autoencoder that might produce a very jagged and discontinuous latent space (where small changes in the latent code lead to large, nonsensical changes in the output), VAEs enforce a prior distribution (usually a standard normal distribution) on the latent space. This regularization encourages similar inputs to map to nearby points in the latent space and allows for smooth interpolation between different data points.

Internally, the VAE has two main parts: the encoder and the decoder. The encoder takes an input (like an image) and outputs two vectors: z_mean (the mean of the distribution) and z_log_var (the log variance of the distribution). The Sampling layer then uses these parameters to sample a latent vector z. This sampling step is crucial for backpropagation because it introduces stochasticity. The decoder then takes this sampled z and attempts to reconstruct the original input.

The VAE is trained by minimizing a combined loss function:

  1. Reconstruction Loss: This measures how well the decoder reconstructs the input from the sampled latent vector. It’s typically a pixel-wise loss like Binary Cross-Entropy (for binary images) or Mean Squared Error (for continuous-valued images).
  2. KL Divergence Loss: This is the regularization term. It measures the difference between the distribution learned by the encoder (z_mean, z_log_var) and the prior distribution (usually a standard normal distribution N(0, 1)). By minimizing this, we force the encoder to produce distributions that are close to the prior, effectively "filling" the latent space and making it continuous. The formula for KL divergence between a Gaussian N(μ, σ^2) and N(0, 1) is -0.5 * sum(1 + log(σ^2) - μ^2 - σ^2).

You control the VAE’s behavior through several levers:

  • latent_dim: The dimensionality of the latent space. Higher dimensions can capture more complex variations but require more data and can be harder to interpret.
  • Network Architecture: The complexity of the encoder and decoder networks (number of layers, filter sizes, activation functions) determines how much information can be compressed and generated.
  • Loss Weights: While not explicitly shown in this basic example, you can sometimes add weights to the reconstruction loss and KL divergence loss to prioritize one over the other.
  • Prior Distribution: While usually a standard normal, other priors can be used for specific tasks.

The "reparameterization trick" is how VAEs achieve differentiability through the sampling process. Instead of sampling z directly from N(z_mean, exp(z_log_var)), we sample epsilon from a standard normal distribution N(0, 1) and then compute z = z_mean + exp(0.5 * z_log_var) * epsilon. This way, the randomness is externalized, and the gradient can flow through z_mean and z_log_var back to the encoder.

The next challenge is understanding how to generate entirely new data that resembles the training set, beyond simple interpolation.

Want structured learning?

Take the full Tensorflow course →