Generative Adversarial Networks (GANs) are a fascinating class of machine learning models that learn to generate new data that resembles a given training dataset. They operate through a clever game-theoretic approach involving two neural networks: a Generator and a Discriminator.
Let’s see this in action. Imagine we want to generate realistic-looking images of handwritten digits, like those in the MNIST dataset.
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
# --- Generator Network ---
def build_generator(latent_dim):
model = tf.keras.Sequential([
layers.Dense(7 * 7 * 128, input_shape=(latent_dim,), activation='relu'),
layers.Reshape((7, 7, 128)),
layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu'),
layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu'),
layers.Conv2D(1, kernel_size=4, padding='same', activation='tanh') # Output image is 28x28
])
return model
# --- Discriminator Network ---
def build_discriminator(img_shape):
model = tf.keras.Sequential([
layers.Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=img_shape, activation='relu'),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, kernel_size=3, strides=2, padding='same', activation='relu'),
layers.LeakyReLU(alpha=0.2),
layers.Flatten(),
layers.Dense(1, activation='sigmoid') # Output is a probability (real/fake)
])
return model
# --- Setup ---
latent_dim = 100
img_shape = (28, 28, 1)
generator = build_generator(latent_dim)
discriminator = build_discriminator(img_shape)
# Compile discriminator
discriminator.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
loss='binary_crossentropy',
metrics=['accuracy'])
# --- GAN Model (combines generator and discriminator) ---
# Freeze discriminator weights during generator training
discriminator.trainable = False
gan_input = tf.keras.Input(shape=(latent_dim,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)
gan = tf.keras.Model(gan_input, gan_output)
# Compile GAN
gan.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
loss='binary_crossentropy')
# --- Load MNIST Data ---
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 127.5 - 1.0 # Normalize to [-1, 1]
x_train = np.expand_dims(x_train, axis=-1)
# --- Training Loop (Simplified for illustration) ---
epochs = 10 # Very few epochs for quick demo
batch_size = 32
for epoch in range(epochs):
for _ in range(len(x_train) // batch_size):
# --- Train Discriminator ---
# Get a batch of real images
idx = np.random.randint(0, x_train.shape[0], batch_size)
real_images = x_train[idx]
real_labels = np.ones((batch_size, 1)) # Labels are 1 for real images
# Generate a batch of fake images
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_images = generator.predict(noise)
fake_labels = np.zeros((batch_size, 1)) # Labels are 0 for fake images
# Train discriminator on real and fake images
d_loss_real = discriminator.train_on_batch(real_images, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
discriminator_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# --- Train Generator ---
# Train generator to fool discriminator
noise = np.random.normal(0, 1, (batch_size, latent_dim))
misleading_labels = np.ones((batch_size, 1)) # Generator wants discriminator to think its output is real
generator_loss = gan.train_on_batch(noise, misleading_labels)
print(f"Epoch {epoch+1}/{epochs} - D Loss: {discriminator_loss[0]:.4f}, G Loss: {generator_loss:.4f}")
# --- Generate and Display Samples ---
print("\nGenerating sample images after training:")
noise = np.random.normal(0, 1, (10, latent_dim))
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5 # Denormalize to [0, 1]
plt.figure(figsize=(10, 4))
for i in range(10):
plt.subplot(2, 5, i+1)
plt.imshow(generated_images[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.show()
The core idea is a zero-sum game. The Generator takes random noise as input and tries to produce data (e.g., images) that look like they came from the real dataset. The Discriminator, on the other hand, acts as a detective. It’s trained to distinguish between real data from the training set and fake data produced by the Generator.
During training, these two networks are pitted against each other. The Generator gets better at fooling the Discriminator, and the Discriminator gets better at catching the Generator. This adversarial process continues until the Generator is producing data that is indistinguishable from the real data to the Discriminator.
The problem GANs solve is generating new, realistic data samples without explicit rule-based generation. Instead of defining "what makes a digit look like a 7," the GAN learns these characteristics implicitly from examples. This is incredibly powerful for tasks like image synthesis, style transfer, and data augmentation.
The training loop is where the magic happens. We alternate between training the Discriminator and the Generator. For the Discriminator, we feed it a mix of real images (labeled as "real") and fake images generated by the Generator (labeled as "fake"). It learns to output a high probability for real images and a low probability for fake ones. For the Generator, we feed it random noise and train it to produce outputs that the Discriminator classifies as "real." Crucially, during Generator training, the Discriminator’s weights are frozen; we only update the Generator’s weights to make its output more convincing.
The gan model itself is simply the Generator connected to the Discriminator. When we call gan.train_on_batch(noise, misleading_labels), the noise is passed through the Generator, then the output of the Generator is passed through the Discriminator. The misleading_labels are all 1s because the Generator’s objective is to make the Discriminator output 1 (i.e., classify the generated image as real).
One of the most subtle yet critical aspects of GAN training is the careful management of the optimizers and learning rates. Using Adam with beta_1=0.5 is a common heuristic that often leads to more stable training than the default beta_1=0.9. This is because the lower beta_1 reduces the momentum term, preventing oscillations and helping both networks converge more smoothly without overshooting.
The next hurdle you’ll likely encounter is mode collapse, where the Generator starts producing only a limited variety of outputs, failing to capture the full diversity of the training data.