TensorFlow custom layers and training loops are often seen as advanced topics, but they’re really just about giving the framework explicit instructions for operations it doesn’t have built-in.
Let’s see a simple custom layer in action. Imagine we want a layer that just multiplies its input by a learned scalar weight.
import tensorflow as tf
class ScalarMultiply(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(ScalarMultiply, self).__init__(**kwargs)
def build(self, input_shape):
self.scalar = self.add_weight(
shape=(),
initializer='random_normal',
trainable=True,
name='scalar_weight'
)
super(ScalarMultiply, self).build(input_shape)
def call(self, inputs):
return inputs * self.scalar
def get_config(self):
config = super(ScalarMultiply, self).get_config()
return config
# Instantiate and test
layer = ScalarMultiply()
input_data = tf.constant([1.0, 2.0, 3.0])
output_data = layer(input_data)
print(output_data.numpy())
print(layer.get_weights())
This code defines ScalarMultiply, a layer that has one trainable weight, self.scalar. The build method is where weights are typically created. add_weight is the key function here, defining the shape (scalar, hence ()), initializer, and whether it’s trainable. The call method defines the forward pass: it takes the input and multiplies it by the learned scalar.
Now, let’s integrate this into a custom training loop. Standard model.fit abstracts away a lot of detail, but a custom loop gives us fine-grained control. We’ll need to define our model, optimizer, loss function, and then manually iterate through the data, compute gradients, and update weights.
# Assume ScalarMultiply layer is defined above
# 1. Define the model
inputs = tf.keras.Input(shape=(10,))
x = tf.keras.layers.Dense(20, activation='relu')(inputs)
outputs = ScalarMultiply()(x) # Use our custom layer
model = tf.keras.Model(inputs=inputs, outputs=outputs)
# 2. Compile the model (or set up manually)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_fn = tf.keras.losses.MeanSquaredError()
# Dummy data
X_train = tf.random.normal((100, 10))
y_train = tf.random.normal((100, 20)) # Target shape matches output of our layer
# 3. Custom training loop
epochs = 5
batch_size = 32
for epoch in range(epochs):
print(f"Epoch {epoch+1}/{epochs}")
epoch_loss_avg = tf.keras.metrics.Mean()
# Iterate over batches
for i in range(0, X_train.shape[0], batch_size):
x_batch = X_train[i:i+batch_size]
y_batch = y_train[i:i+batch_size]
with tf.GradientTape() as tape:
predictions = model(x_batch, training=True) # Forward pass
loss = loss_fn(y_batch, predictions) # Compute loss
# Compute gradients
gradients = tape.gradient(loss, model.trainable_variables)
# Apply gradients
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Track loss
epoch_loss_avg.update_state(loss)
print(f" Loss: {epoch_loss_avg.result().numpy():.4f}")
# Check the trained scalar weight
scalar_layer = model.layers[-1] # Our ScalarMultiply layer is the last one
print(f"Trained scalar weight: {scalar_layer.scalar.numpy()}")
The core of the custom training loop is the tf.GradientTape. It records operations performed during the forward pass. When tape.gradient() is called, it computes the gradients of the loss with respect to the model.trainable_variables. The optimizer.apply_gradients() then uses these gradients to update the model’s weights. This manual process is what model.fit automates.
The build method is called the first time the layer is used with an input. This is crucial because the layer might not know the shape of its inputs until then. add_weight is the standard way to define trainable parameters within a layer. If you have weights that shouldn’t be trained, you’d set trainable=False.
The get_config method is important for saving and loading custom models. It allows TensorFlow to reconstruct the layer’s configuration. For a simple layer like ScalarMultiply with no specific initialization parameters, it just returns the base config. For layers with arguments in their __init__ method, you’d add those to the config dictionary.
The most surprising thing about custom training loops is how much explicit control they offer over the optimization process, allowing you to implement things like custom learning rate schedules on a per-batch basis or even complex reinforcement learning updates that don’t fit the standard gradient descent paradigm.
What most people don’t realize is that model.compile is essentially a convenience wrapper that sets up an optimizer, a loss function, and metrics. When you use a custom loop, you’re just doing those setup steps yourself and then managing the actual gradient computation and application.
The next thing you’ll likely want to explore is how to handle custom metrics within this loop, or perhaps how to implement more complex layer interactions like recurrent connections or attention mechanisms.