The Transformer’s self-attention mechanism allows it to weigh the importance of different input tokens regardless of their distance, fundamentally changing how sequence models process information compared to RNNs.
Let’s see this in action. Imagine we’re building a simple English-to-French translator. Our input sentence is "Hello world", and we want to translate it to "Bonjour le monde".
import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, Dense, MultiHeadAttention, LayerNormalization, Dropout
from tensorflow.keras.models import Model
# --- Model Parameters ---
vocab_size_en = 10000 # Size of English vocabulary
vocab_size_fr = 8000 # Size of French vocabulary
d_model = 128 # Embedding dimension and model dimension
num_heads = 8 # Number of attention heads
dff = 512 # Inner layer dimension in feed-forward network
max_seq_len = 100 # Maximum sequence length
dropout_rate = 0.1
# --- Positional Encoding ---
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# apply sin to even indices in the array; 2i
sines = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
cosines = np.cos(angle_rads[:, 1::2])
pos_encoding = np.concatenate([sines, cosines], axis=-1)
pos_encoding = pos_encoding[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
return pos * angle_rates
# --- Multi-Head Attention ---
class MultiHeadAttentionLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttentionLayer, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))
output = self.dense(concat_attention)
return output, attention_weights
def scaled_dot_product_attention(q, k, v, mask):
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
return output, attention_weights
# --- Feed Forward Network ---
def point_wise_feed_forward_network(d_model, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'),
tf.keras.layers.Dense(d_model)
])
# --- Encoder Layer ---
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttentionLayer(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
attn_output, _ = self.mha(x, x, x, mask)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output)
return out2
# --- Decoder Layer ---
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttentionLayer(d_model, num_heads) # Self-attention
self.mha2 = MultiHeadAttentionLayer(d_model, num_heads) # Encoder-decoder attention
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.dropout3 = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
# Masked Multi-Head Attention (Self-Attention)
attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layernorm1(attn1 + x)
# Multi-Head Attention (Encoder-Decoder Attention)
attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layernorm2(attn2 + out1)
# Feed Forward Network
ffn_output = self.ffn(out2)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layernorm3(ffn_output + out2)
return out3, attn_weights_block1, attn_weights_block2
# --- Encoder ---
class Encoder(tf.keras.layers.Layer):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
maximum_position_encoding, rate=0.1):
super(Encoder, self).__init__()
self.num_layers = num_layers
self.d_model = d_model
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
seq_len = tf.shape(x)[1]
x = self.embedding(x)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.enc_layers[i](x, training, mask)
return x
# --- Decoder ---
class Decoder(tf.keras.layers.Layer):
def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
maximum_position_encoding, rate=0.1):
super(Decoder, self).__init__()
self.num_layers = num_layers
self.d_model = d_model
self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
seq_len = tf.shape(x)[1]
x = self.embedding(x)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x, _, _ = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)
return x
# --- Transformer ---
class Transformer(tf.keras.Model):
def __init__(self, num_encoder_layers, num_decoder_layers, d_model, num_heads, dff,
input_vocab_size, target_vocab_size, pe_input, pe_target, rate=0.1):
super(Transformer, self).__init__()
self.encoder = Encoder(num_encoder_layers, d_model, num_heads, dff,
input_vocab_size, pe_input, rate)
self.decoder = Decoder(num_decoder_layers, d_model, num_heads, dff,
target_vocab_size, pe_target, rate)
self.final_layer = tf.keras.layers.Dense(target_vocab_size)
def call(self, inp, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask):
enc_output = self.encoder(inp, training, enc_padding_mask)
dec_output = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)
final_output = self.final_layer(dec_output)
return final_output
# --- Masking ---
def create_padding_mask(seq):
seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
return seq[:, tf.newaxis, tf.newaxis, :]
def create_look_ahead_mask(size):
mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
return mask
def create_masks(inp, tar):
# Encoder padding mask
enc_padding_mask = create_padding_mask(inp)
# Used in the 2nd attention block in the decoder.
# This padding mask is used to mask out padding tokens in the encoder output.
dec_padding_mask = create_padding_mask(inp)
# Used in the 1st attention block in the decoder.
# This mask prevents attention to future tokens.
look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
dec_target_padding_mask = create_padding_mask(tar)
combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
return enc_padding_mask, combined_mask, dec_padding_mask
# --- Example Usage ---
num_encoder_layers = 4
num_decoder_layers = 4
num_heads = 8
dff = 512
dropout_rate = 0.1
transformer = Transformer(
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
d_model=128,
num_heads=8,
dff=512,
input_vocab_size=10000,
target_vocab_size=8000,
pe_input=1000,
pe_target=1000,
rate=dropout_rate)
# Dummy data for demonstration
encoder_input = tf.constant([[1, 2, 3, 4, 0, 0, 0, 0]], dtype=tf.int64) # "Hello world" tokenized, padded
decoder_input = tf.constant([[10, 5, 0, 0, 0, 0, 0, 0]], dtype=tf.int64) # "<start> Bonjour" tokenized, padded
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input, decoder_input)
# Forward pass
predictions = transformer(encoder_input, decoder_input,
training=False,
enc_padding_mask=enc_padding_mask,
look_ahead_mask=combined_mask,
dec_padding_mask=dec_padding_mask)
print("Shape of predictions:", predictions.shape)
# Expected output shape: (batch_size, target_seq_len, target_vocab_size)
# e.g., (1, 8, 8000)
The core of the Transformer is its ability to perform "self-attention." Instead of processing a sequence word-by-word like an RNN, self-attention allows each word in the input sequence to "look at" every other word in the same sequence. It calculates an "attention score" between each pair of words, indicating how relevant they are to each other. This score is then used to create a weighted sum of the word representations, effectively enriching each word’s meaning with context from the entire sequence.
Consider the MultiHeadAttentionLayer. It’s not just one attention calculation; it’s multiple "heads" running in parallel. Each head learns to focus on different aspects of the relationships between words. For example, one head might focus on grammatical dependencies, while another might focus on semantic similarities. The outputs from all heads are concatenated and linearly transformed, giving the model a richer understanding of the input.
The Encoder stack processes the input sequence. Each EncoderLayer consists of a multi-head self-attention sub-layer followed by a position-wise feed-forward network. Crucially, positional encodings are added to the input embeddings. Since the self-attention mechanism itself doesn’t inherently know the order of words, these encodings inject information about the relative or absolute position of tokens in the sequence.
The Decoder stack is similar but has an extra multi-head attention sub-layer. This layer attends to the output of the encoder, allowing the decoder to "focus" on relevant parts of the input sequence when generating the output. The decoder also uses masked self-attention to ensure that during training, it can only attend to previously generated tokens, preventing it from "cheating" by looking at future words in the target sequence. The look_ahead_mask is critical here, setting attention scores to future positions to negative infinity before the softmax, effectively zeroing them out.
The Transformer model orchestrates these components. It takes tokenized input and target sequences, applies masks to handle padding and prevent look-ahead, and passes them through the encoder and decoder stacks. The final Dense layer projects the decoder’s output into the vocabulary space, producing probability distributions over the next possible token.
The most surprising thing is how effectively the positional encoding, combined with the attention mechanism, can capture long-range dependencies without any recurrence. Unlike RNNs that struggle to remember information from far back in a sequence due to vanishing gradients, the Transformer’s attention can directly link any two words, no matter how far apart. This is achieved by calculating attention weights between every pair of tokens, making the path length between any two tokens constant (effectively 1 through the attention calculation), regardless of their distance in the sequence.
The create_masks function is key to managing the different attention mechanisms. It generates masks for padding (to ignore <pad> tokens) and a look-ahead mask for the decoder’s self-attention (to prevent attending to future tokens). These masks are passed through the model to ensure correct behavior during training and inference.
The next concept to explore is the training process, specifically the loss function, optimizer, and how the model is trained autoregressively during inference.