TFRecord is a file format that TensorFlow uses to store data efficiently for large-scale machine learning training.

Let’s see TFRecord in action with a simple example. Imagine we have image data and corresponding labels we want to feed into a TensorFlow model.

First, we need to create a TFRecord file. This involves serializing our data into tf.train.Example protocol buffers, which are then written to a .tfrecord file.

import tensorflow as tf
import os

# Assume we have image data (as bytes) and labels (as integers)
# For demonstration, let's create dummy data
num_samples = 100
image_height, image_width, channels = 32, 32, 3
dummy_images = tf.random.uniform(shape=[num_samples, image_height, image_width, channels], minval=0, maxval=255, dtype=tf.float32)
dummy_labels = tf.random.uniform(shape=[num_samples], minval=0, maxval=10, dtype=tf.int32)

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # Bytes out of tf.string
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_list_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_list_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def serialize_example(image_data, label):
  """Creates a tf.train.Example message ready to be written to a tar.gz file."""
  # Convert image to bytes
  image_bytes = image_data.numpy().astype(np.uint8).tobytes()

  feature = {
      'image': _bytes_feature(image_bytes),
      'label': _int64_list_feature([label]),
  }

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

# Write to TFRecord file
output_dir = './tfrecord_data'
os.makedirs(output_dir, exist_ok=True)
output_filename = os.path.join(output_dir, 'my_dataset.tfrecord')

with tf.io.TFRecordWriter(output_filename) as writer:
  for i in range(num_samples):
    image_bytes = dummy_images[i].numpy().astype(np.uint8).tobytes() # Convert to bytes
    label = dummy_labels[i].numpy()
    
    feature = {
        'image': _bytes_feature(image_bytes),
        'label': _int64_list_feature([label]),
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    writer.write(example_proto.SerializeToString())

print(f"TFRecord file created at: {output_filename}")

The core problem TFRecord solves is efficient I/O for massive datasets. Reading individual files (like JPEGs, PNGs, or CSVs) for millions of training examples can become a significant bottleneck, especially when dealing with distributed training. TFRecord consolidates data into a single (or a few) large binary files, reducing the overhead of file system operations and allowing TensorFlow to read data in large, contiguous chunks. This makes it ideal for scenarios where your dataset size dwarfs your available RAM.

Once the data is in TFRecord format, we can read it back into a TensorFlow Dataset object. This is where the real power for large-scale training comes into play.

import tensorflow as tf
import numpy as np # Import numpy

def parse_tfrecord_fn(example_proto):
  """Parses a single tf.train.Example proto."""
  feature_description = {
      'image': tf.io.FixedLenFeature([], tf.string),
      'label': tf.io.FixedLenFeature([], tf.int64),
  }
  example = tf.io.parse_single_example(example_proto, feature_description)
  
  # Decode image
  image = tf.io.decode_raw(example['image'], tf.uint8)
  image = tf.reshape(image, [32, 32, 3]) # Reshape to original image dimensions
  image = tf.cast(image, tf.float32) # Cast to float for model input
  
  label = tf.cast(example['label'], tf.int32)
  
  return image, label

# Create a dataset from the TFRecord file
tfrecord_files = [os.path.join(output_dir, 'my_dataset.tfrecord')]
raw_dataset = tf.data.TFRecordDataset(tfrecord_files)

# Parse the dataset
parsed_dataset = raw_dataset.map(parse_tfrecord_fn)

# Now you can use this parsed_dataset for training
# For example, batching and prefetching
batch_size = 32
buffer_size = 1000 # For shuffling

final_dataset = parsed_dataset.cache() # Cache data if it fits in memory, or skip for very large datasets
final_dataset = final_dataset.shuffle(buffer_size)
final_dataset = final_dataset.batch(batch_size)
final_dataset = final_dataset.prefetch(tf.data.AUTOTUNE) # Prefetch for performance

# Iterate through a few batches to see it working
print("\nIterating through a few batches:")
for images, labels in final_dataset.take(2):
  print(f"  Image batch shape: {images.shape}")
  print(f"  Label batch shape: {labels.shape}")

The tf.data.TFRecordDataset is highly optimized. It can read from multiple files in parallel and supports features like sharding for distributed training. The .map() operation applies our parsing function to each example. Crucially, .batch(), .shuffle(), and .prefetch() are where the large-scale training setup shines. .prefetch(tf.data.AUTOTUNE) allows the data pipeline to prepare the next batch of data in the background while the current batch is being processed by the model, effectively hiding data loading latency.

The mental model for TFRecord in large-scale training is a tiered pipeline: raw data -> TFRecord files -> tf.data.TFRecordDataset -> parsed tf.data.Dataset -> preprocessed & batched tf.data.Dataset -> model training. Each step is designed to be as efficient as possible, with TFRecord being the foundational layer for efficient storage and retrieval of massive data.

When dealing with extremely large datasets that don’t fit into memory even for caching, you’ll often skip the .cache() operation. Instead, you rely on the TFRecordDataset’s ability to read directly from disk and the prefetch mechanism to keep the GPU fed. This direct disk-to-GPU path is the essence of handling terabytes of training data without loading it all at once.

The next step you’ll encounter is managing multiple TFRecord files for sharded datasets in distributed training.

Want structured learning?

Take the full Tensorflow course →