Contrastive learning aims to learn representations by pulling similar samples closer together and pushing dissimilar samples apart in an embedding space, but the "dissimilar" part is often handled implicitly by just picking other random samples.

Let’s build a SimCLR implementation from the ground up. We’ll use TensorFlow and imagine we’re training a model to distinguish between different augmented views of the same image versus views of different images.

First, the core idea: take an image, create two different random augmentations of it. These two augmented images form a "positive pair." Any other augmented image in the batch is a "negative sample" relative to our original pair. The goal is to make the model’s embeddings for the positive pair very similar and embeddings for the negative samples very different.

Here’s a simplified dataset setup. We’ll assume you have a directory of images.

import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

# Load a dataset (e.g., CIFAR-10)
(ds_train, ds_info) = tfds.load(
    'cifar10',
    split='train',
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

IMG_SIZE = 32
BATCH_SIZE = 128

def preprocess(image, label):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    return image, label

ds_train = ds_train.map(preprocess).cache().shuffle(buffer_size=10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

Now, the crucial part: augmentations. SimCLR uses a pipeline of multiple augmentations. A common set includes random cropping, random horizontal flipping, color jittering (brightness, contrast, saturation, hue), and Gaussian blur.

def augment(image, label):
    # Random crop and resize back to original size
    image = tf.image.random_crop(image, size=[IMG_SIZE, IMG_SIZE, 3])
    image = tf.image.random_flip_left_right(image)

    # Color jitter
    image = tf.image.random_brightness(image, max_delta=0.3)
    image = tf.image.random_contrast(image, lower=0.1, upper=0.9)
    image = tf.image.random_saturation(image, lower=0.1, upper=0.9)
    image = tf.image.random_hue(image, max_delta=0.1)
    image = tf.clip_by_value(image, 0.0, 1.0) # Ensure values stay in [0, 1]

    # Optional: Gaussian blur (can be complex to implement simply)
    # For simplicity, we'll omit explicit Gaussian blur here, but it's part of full SimCLR.

    return image, label

# We need to apply augmentations *twice* to get a positive pair
def create_augment_pair(image, label):
    aug1, _ = augment(image, label)
    aug2, _ = augment(image, label)
    return aug1, aug2

# Apply the pair creation to the dataset
ds_augment_pairs = ds_train.map(create_augment_pair, num_parallel_calls=tf.data.AUTOTUNE)

Next, the encoder network. This is a standard CNN (like ResNet or a simple ConvNet) that maps an image to an embedding vector. After the encoder, SimCLR adds a small MLP "projection head" which maps the encoder’s output to a space where the contrastive loss is applied. The loss is not computed on the encoder’s output directly; it’s computed on the projection head’s output.

def build_encoder(input_shape=(IMG_SIZE, IMG_SIZE, 3)):
    inputs = tf.keras.Input(shape=input_shape)
    x = inputs
    # Simple ConvNet example
    x = layers.Conv2D(32, 3, activation='relu', padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.Conv2D(64, 3, activation='relu', padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.Conv2D(128, 3, activation='relu', padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(256, activation='relu')(x)
    encoder_output = layers.Dense(128)(x) # This is the embedding we want to learn
    return tf.keras.Model(inputs, encoder_output)

def build_projection_head(encoder_output_dim=128, projection_dim=64):
    inputs = tf.keras.Input(shape=(encoder_output_dim,))
    x = inputs
    x = layers.Dense(128, activation='relu')(x)
    projection_output = layers.Dense(projection_dim)(x) # This is where contrastive loss is applied
    return tf.keras.Model(inputs, projection_output)

# Instantiate models
encoder = build_encoder()
projection_head = build_projection_head()

# Combine them for the forward pass
def create_embeddings(image):
    encoder_out = encoder(image)
    projection_out = projection_head(encoder_out)
    return projection_out

# We need a way to get embeddings for a batch of images
# Let's create a dummy batch to see shapes
dummy_images, _ = next(iter(ds_augment_pairs))
dummy_embeddings = create_embeddings(dummy_images)
print("Embedding shape:", dummy_embeddings.shape) # Should be (BATCH_SIZE, projection_dim)

The contrastive loss is the heart of SimCLR. It’s typically the NT-Xent (Normalized Temperature-scaled Cross-Entropy) loss. For a batch of $N$ images, we generate $2N$ augmented views. Each augmented view $z_i$ is compared against $2N-2$ negative samples (all other augmented views in the batch except its positive pair $z_j$). The loss for a pair $(z_i, z_j)$ is:

$L_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbb{1}_{k \neq i} \exp(\text{sim}(z_i, z_k) / \tau)}$

where $\text{sim}(u, v) = u \cdot v / (|u| |v|)$ is cosine similarity, and $\tau$ is a temperature hyperparameter.

class NTXentLoss(tf.keras.losses.Loss):
    def __init__(self, temperature=0.1, **kwargs):
        super().__init__(**kwargs)
        self.temperature = temperature

    def call(self, embeddings_1, embeddings_2):
        # embeddings_1 and embeddings_2 are batches of embeddings from two augmentations
        # Shape: (BATCH_SIZE, projection_dim)

        # Concatenate to form pairs for the batch
        embeddings = tf.concat([embeddings_1, embeddings_2], axis=0) # Shape: (2*BATCH_SIZE, projection_dim)

        # Normalize embeddings
        embeddings = tf.math.l2_normalize(embeddings, axis=1)

        # Calculate cosine similarity matrix
        # sim_matrix shape: (2*BATCH_SIZE, 2*BATCH_SIZE)
        sim_matrix = tf.matmul(embeddings, embeddings, transpose_b=True)

        # Create labels for positive pairs
        # For a batch of size B, we have B positive pairs (0,1), (2,3), ..., (2B-2, 2B-1)
        # In the concatenated embeddings (2B total), the positive pairs are at (i, i+B) and (i+B, i)
        # e.g., if embeddings_1 are 0..B-1 and embeddings_2 are B..2B-1
        # The positive pairs are (0, B), (1, B+1), ..., (B-1, 2B-1)
        # And symmetrically (B, 0), (B+1, 1), ..., (2B-1, B-1)

        # Mask out self-similarity (diagonal)
        N = tf.shape(embeddings)[0] # N = 2 * BATCH_SIZE
        mask = tf.ones_like(sim_matrix) - tf.eye(N, dtype=tf.bool)

        # Apply temperature scaling
        logits = (sim_matrix) / self.temperature

        # Mask out self-similarity from logits
        logits = tf.where(mask, logits, tf.float32.min) # Set diagonal to a very small number

        # Create labels for the loss.
        # For each row i, the positive pair is at index i + BATCH_SIZE (if i < BATCH_SIZE)
        # or i - BATCH_SIZE (if i >= BATCH_SIZE).
        # We want to predict the index of the positive pair.
        labels = tf.range(N)
        labels = tf.concat([labels[BATCH_SIZE:], labels[:BATCH_SIZE]], axis=0) # Shift labels to align positive pairs

        # Calculate cross-entropy loss
        loss = tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

        # Sum loss over all positive pairs (each pair is counted twice, so divide by N)
        return tf.reduce_sum(loss) * (1.0 / tf.cast(N, tf.float32))

# Instantiate the loss
contrastive_loss = NTXentLoss(temperature=0.1)

Training involves passing pairs of augmented images through the encoder and projection head, then calculating the contrastive loss. We’ll use an Adam optimizer.

# Need to create a custom training loop
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)

@tf.function
def train_step(images_1, images_2):
    with tf.GradientTape() as tape:
        # Get embeddings for both augmented views
        embeddings_1 = create_embeddings(images_1)
        embeddings_2 = create_embeddings(images_2)

        # Calculate the contrastive loss
        loss = contrastive_loss(embeddings_1, embeddings_2)

    # Compute gradients and update weights
    gradients = tape.gradient(loss, encoder.trainable_variables + projection_head.trainable_variables)
    optimizer.apply_gradients(zip(gradients, encoder.trainable_variables + projection_head.trainable_variables))
    return loss

epochs = 10 # Typically requires many more epochs and larger batch sizes
for epoch in range(epochs):
    total_loss = 0
    num_batches = 0
    for batch_images_1, batch_images_2 in ds_augment_pairs:
        loss = train_step(batch_images_1, batch_images_2)
        total_loss += loss
        num_batches += 1
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss.numpy():.4f}")

After training, the encoder model can be used to extract features. For downstream tasks (like classification), you’d typically freeze the encoder weights and train a new linear classifier on top of its embeddings. The projection head is usually discarded after pre-training.

The surprising thing is that even with a simple encoder and a relatively small projection head, this method can learn surprisingly rich representations. The power comes from the pretext task itself forcing the model to understand visual invariance to common transformations.

The "dissimilar" part of contrastive learning is crucial. If you only had positive pairs, the model would collapse to outputting the same embedding for everything. The act of pushing dissimilar samples apart, even if those dissimilar samples are just other random augmentations from the batch, forces the model to learn discriminative features.

The next step is often to fine-tune this pre-trained encoder on a specific downstream task, like image classification, using a small amount of labeled data.

Want structured learning?

Take the full Tensorflow course →