BERT, when fine-tuned, isn’t just learning a task; it’s adapting its entire internal world model to a new set of sensory inputs and desired outputs.
Let’s see this in action. Imagine we have a small dataset of customer reviews, and we want to classify them as positive or negative. We’ll start with a pre-trained BERT model (like bert-base-uncased) and add a classification layer on top.
import tensorflow as tf
from transformers import TFBertModel, BertTokenizer
# Load pre-trained model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = TFBertModel.from_pretrained(model_name)
# Define input shape for BERT
input_ids = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_ids')
attention_mask = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='attention_mask')
# Pass inputs through BERT model
outputs = bert_model(input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output # Use the pooled output for classification
# Add a classification layer
num_classes = 2
dropout = tf.keras.layers.Dropout(0.1)(pooled_output)
logits = tf.keras.layers.Dense(num_classes, activation='softmax', name='outputs')(dropout)
# Create the fine-tuned model
model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=logits)
# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
loss = tf.keras.losses.CategoricalCrossentropy()
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
model.summary()
This sets up the architecture: the TFBertModel handles the heavy lifting of understanding language, and we’ve attached a simple Dense layer to map BERT’s rich internal representations to our specific classification problem. The pooled_output is BERT’s representation of the entire input sequence, usually derived from the [CLS] token’s final hidden state, which is designed to be a good summary for classification tasks.
Now, let’s prepare our data. Suppose we have texts = ["This movie was amazing!", "Terrible acting, waste of time."] and labels = [[0, 1], [1, 0]] (one-hot encoded for positive/negative).
# Tokenize and encode the texts
max_len = 128 # Maximum sequence length
encoded_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=max_len, return_tensors='tf')
# Prepare TensorFlow dataset
dataset = tf.data.Dataset.from_tensor_slices((
{'input_ids': encoded_inputs['input_ids'], 'attention_mask': encoded_inputs['attention_mask']},
labels
))
dataset = dataset.shuffle(buffer_size=1024).batch(2) # Batch size of 2
# Train the model (example with dummy data)
# In a real scenario, you'd use your actual training data and epochs
# model.fit(dataset, epochs=3)
The tokenizer converts our raw text into numerical IDs that BERT understands, padding shorter sequences and truncating longer ones to a fixed max_len. The attention_mask is crucial; it tells BERT which tokens are actual words and which are padding, so it doesn’t attend to padding. We then create a tf.data.Dataset for efficient training.
The core of fine-tuning BERT lies in its pre-trained weights. BERT has already learned a sophisticated understanding of grammar, syntax, and semantics from massive text corpora like Wikipedia and BookCorpus. When you fine-tune, you’re not teaching it language from scratch. Instead, you’re gently nudging these learned representations to become more specialized for your particular task and domain. The gradients flowing back during training adjust the weights of the BERT layers (and the new classification layer) to better map the input text to the desired output labels. The learning_rate=5e-5 is typically small for fine-tuning because we want to preserve most of the pre-trained knowledge.
The most surprising thing about fine-tuning BERT is how little data you often need. Because BERT has already learned general language features, it can often achieve high performance on downstream tasks with just a few hundred or a few thousand labeled examples, whereas training a model from scratch would require orders of magnitude more. This is because the pre-trained weights act as an incredibly powerful initialization, providing a strong inductive bias for language understanding.
Consider the outputs.pooler_output versus outputs.last_hidden_state. While pooler_output is a common choice for classification due to its intended purpose as a sequence summary, last_hidden_state provides the hidden states for each token in the input sequence. For tasks like Named Entity Recognition (NER) or Question Answering, where you need token-level predictions, you would use last_hidden_state and often a different head layer (e.g., a token classification layer with a Dense layer applied to each token’s representation). The choice of what to take from BERT’s output is directly tied to the granularity of your task.
Once you’ve mastered fine-tuning for classification, you’ll likely explore adapting BERT for sequence labeling tasks, which involves different output heads and a slightly different approach to data preparation.