Graph Neural Networks (GNNs) in TensorFlow can model relationships between discrete entities, not just grid-like data.

Imagine you’re trying to predict which users on a social network are likely to churn. You have a graph where nodes are users and edges represent friendships. A standard neural network struggles with this because the number of friends (the "neighborhood" of a node) varies wildly, and the order of friends shouldn’t matter. GNNs excel here by explicitly operating on graph structures.

Let’s implement a simple GNN using TensorFlow. We’ll use the tensorflow_gnn library, which is built for this purpose. First, we need to represent our graph data.

import tensorflow as tf
import tensorflow_gnn as tgn

# Define a simple graph: 3 nodes, 2 edges
# Node 0 connected to Node 1
# Node 1 connected to Node 2

# Node features: each node has a 2-dimensional feature vector
node_features = tf.constant([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype=tf.float32)

# Edge indices: defining the connections. Format is [source_node_index, target_node_index]
# This means there's an edge from node 0 to node 1, and from node 1 to node 2.
edge_indices = tf.constant([[0, 1], [1, 2]], dtype=tf.int32)

# Create a GraphTensor. This is TensorFlow GNN's primary data structure.
# We specify the features for the 'nodes' and the connections for the 'edges'.
# 'nodes' and 'edges' are default feature names.
graph_tensor = tgn.GraphTensor.from_pieces(
    node_sets={
        "nodes": tgn.NodeSet.from_features(
            sizes=[3],  # Number of nodes in this set
            features={"features": node_features}
        )
    },
    edge_sets={
        "edges": tgn.EdgeSet.from_connections(
            sizes=[2],  # Number of edges in this set
            features={}, # No edge features for this example
            adjacency={"source": 0, "target": 1}, # Indices pointing to node sets
            # The indices in 'adjacency' refer to the *index* of the node set
            # in the 'node_sets' dictionary. In this case, node set 'nodes' is at index 0.
            # So, 'source' refers to node set index 0, and 'target' refers to node set index 1.
            # Since we only have one node set, 'nodes', both source and target point to it.
            # The actual edge data is then derived from the 'edge_indices' tensor.
            # This is a slightly more advanced way to specify edges directly.
            # A simpler way for a single node set would be:
            # adjacency={"source": "nodes", "target": "nodes"}
            # And then providing the edge_indices during GraphTensor construction if it wasn't implicit.
            # For simplicity and clarity, let's explicitly map to node indices.
            # The tgn.GraphTensor.from_pieces expects the *indices* of the node sets.
            # So, 'nodes' is at index 0.
            indices=edge_indices
        )
    }
)

print(graph_tensor)

This GraphTensor is our fundamental building block. It encapsulates nodes, edges, and their features. Now, let’s build a simple GNN layer. The core idea of a GNN layer is to aggregate information from a node’s neighbors and combine it with its own features.

# Define a simple GNN message passing layer
class SimpleGNN(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        # A dense layer to transform node features
        self.dense = tf.keras.layers.Dense(units, activation=tf.nn.relu)

    def call(self, graph_tensor):
        # Get the node features from the 'nodes' set
        node_features = graph_tensor.node_sets["nodes"]["features"]

        # Use GraphTensor's built-in message passing capabilities
        # This will send messages from neighbors to each node, aggregate them,
        # and then combine with the node's own features.
        # The exact aggregation and combination logic is handled by the underlying
        # message passing infrastructure.
        # For a simple GNN, we want to aggregate neighbor features.
        # The 'graph_tensor.distribute_values' method is key here.
        # It distributes node features to their respective edges, and then
        # allows for aggregation back to nodes.

        # Let's perform a manual aggregation for clarity first, then show the shortcut.
        # Manual aggregation:
        # 1. Get edge origins (source nodes) and destinations (target nodes)
        source_nodes = graph_tensor.edge_sets["edges"]["source"]
        target_nodes = graph_tensor.edge_sets["edges"]["target"]

        # 2. Gather features of the source nodes for each edge
        # This is like "sending messages" from source nodes to target nodes.
        # tf.gather requires indices to be of the same type as the tensor being gathered from.
        source_node_features = tf.gather(node_features, source_nodes)

        # 3. Aggregate messages at the target nodes.
        # We can use tf.math.unsorted_segment_sum.
        # The number of segments is the number of nodes.
        # The segment_ids are the target node indices.
        # The data is the gathered source node features.
        # We need to add the node's own features to the aggregated neighbor features.

        # This manual approach is complex. Let's leverage tgn's built-in aggregators.
        # The `GraphTensor.node_sets` object has methods for aggregation.
        # For a single edge set "edges" connecting "nodes" to "nodes", we can do:
        aggregated_neighbor_features = graph_tensor.node_sets["nodes"].merge_features(
            # Specify the edge set to use for aggregation.
            edge_set_name="edges",
            # Specify how to get the features from the *source* nodes of the edges.
            # This means for each edge, we'll take the features of its source node.
            feature_name="features",
            # We want to aggregate features arriving at the target node.
            # The aggregation function is sum by default for GraphTensor.
            # We are aggregating features *from* the source nodes *to* the target nodes.
            # So, we are interested in the features of the source nodes.
            # The operation is: for each target node, sum the features of its incoming neighbors' sources.
            # In our simple case, edges go from 0->1, 1->2.
            # Node 0 has no incoming edges.
            # Node 1 has an incoming edge from node 0. Aggregated feature for node 1 will be features of node 0.
            # Node 2 has an incoming edge from node 1. Aggregated feature for node 2 will be features of node 1.
            # This is a "message passing" step where messages are neighbor features.
            receiver_feature_name="features", # The feature name on the receiving node set
            sender_feature_name="features",   # The feature name on the sending node set (source of edge)
            # The `merge_features` method is designed for this.
            # It implicitly handles the message passing and aggregation.
            # It takes features from the *sender* nodes of the specified `edge_set_name`
            # and aggregates them to the *receiver* nodes.
            # Since our edges are "nodes" -> "nodes", the sender and receiver are both "nodes".
            # `receiver_feature_name` specifies the feature on the *target* node set.
            # `sender_feature_name` specifies the feature on the *source* node set.
            # Here, we want to aggregate the *source* node features to the *target* nodes.
            # So, `sender_feature_name` should be the feature on the 'nodes' set that we want to send.
            # `receiver_feature_name` is also 'features' because we are updating the 'nodes' set's features.
            # The aggregation is done using `tf.math.unsorted_segment_sum` by default.
        )

        # Concatenate aggregated neighbor features with original node features
        combined_features = tf.concat([node_features, aggregated_neighbor_features], axis=-1)

        # Apply the dense layer
        output_features = self.dense(combined_features)
        return output_features

# Instantiate the GNN layer
gnn_layer = SimpleGNN(units=8)

# Call the layer with our graph tensor
output_node_features = gnn_layer(graph_tensor)

print("\nOutput node features after one GNN layer:")
print(output_node_features)

In this SimpleGNN layer, graph_tensor.node_sets["nodes"].merge_features is the magic. It abstracts away the explicit gathering and segment sum operations. It understands that for the "edges" edge set, which connects the "nodes" node set to itself, it should take the "features" from the source nodes and aggregate them to the target nodes. By default, it uses tf.math.unsorted_segment_sum for aggregation.

The merge_features method is a powerful shortcut. It allows you to specify which node set is sending features (sender_feature_name), which node set is receiving them (receiver_feature_name), and which edge set connects them (edge_set_name). TensorFlow GNN handles the rest.

To build a deeper GNN, you would stack these layers. Each layer would refine the node representations by incorporating information from increasingly distant neighbors.

# Example of stacking GNN layers
class GNNModel(tf.keras.Model):
    def __init__(self, hidden_units, output_units):
        super().__init__()
        self.gnn1 = SimpleGNN(units=hidden_units)
        self.gnn2 = SimpleGNN(units=hidden_units)
        # A final layer to produce output predictions (e.g., for classification)
        self.output_layer = tf.keras.layers.Dense(output_units)

    def call(self, graph_tensor):
        x = self.gnn1(graph_tensor)
        # The output of a GNN layer is typically node features.
        # We need to update the graph_tensor with these new features for the next layer.
        # A common pattern is to create a new graph_tensor or update in place if supported.
        # For simplicity, let's assume the GNN layer returns updated node features.
        # In a real scenario, you'd pass these updated features back into a new GraphTensor or a modified one.
        # For this example, let's re-create the graph_tensor with updated features for clarity.
        # A more efficient way would be to use `graph_tensor.replace_features`.

        # Let's create a new GraphTensor for the next layer with updated node features.
        # This is verbose for demonstration; `replace_features` is preferred.
        updated_graph_tensor_1 = tgn.GraphTensor.from_pieces(
            node_sets={
                "nodes": tgn.NodeSet.from_features(
                    sizes=[tf.shape(x)[0]], # Size of the node set
                    features={"features": x}
                )
            },
            edge_sets=graph_tensor.edge_sets # Edges remain the same
        )

        x = self.gnn2(updated_graph_tensor_1)

        # Again, update GraphTensor for the next step.
        updated_graph_tensor_2 = tgn.GraphTensor.from_pieces(
            node_sets={
                "nodes": tgn.NodeSet.from_features(
                    sizes=[tf.shape(x)[0]],
                    features={"features": x}
                )
            },
            edge_sets=graph_tensor.edge_sets
        )

        # Finally, pool or aggregate node features for graph-level prediction,
        # or use node features directly for node-level prediction.
        # Here, we'll just take the final node features and pass them to a dense layer.
        final_node_predictions = self.output_layer(updated_graph_tensor_2.node_sets["nodes"]["features"])
        return final_node_predictions

# Instantiate and train the model (training loop omitted for brevity)
model = GNNModel(hidden_units=16, output_units=1) # Example: outputting a single value per node

# To actually train, you'd need labels and a loss function.
# For node classification:
# `model(graph_tensor)` would output predictions for each node.
# You'd then compute a loss against true labels for each node.

# To make predictions on the graph:
predictions = model(graph_tensor)
print("\nPredictions for each node:")
print(predictions)

The core mental model for GNNs is "message passing." Each node sends a "message" (its features, transformed) to its neighbors. These messages are then aggregated at the receiving nodes, and combined with the node’s own current features to produce its new representation. This process repeats across multiple layers, allowing information to propagate across the graph.

A crucial detail often overlooked is how tensorflow_gnn handles different types of graphs and features. You can have multiple node sets (e.g., "users" and "items") and multiple edge sets (e.g., "user_friends," "user_buys_item"). The GraphTensor structure is designed to accommodate this complexity, and the message passing mechanisms can be configured to operate between specific node and edge sets. For instance, you could have a GNN layer that only propagates messages from "user" nodes to "item" nodes via "user_buys_item" edges.

The next step is to explore how to handle different aggregation functions (mean, max) and how to perform graph-level predictions by pooling node features.

Want structured learning?

Take the full Tensorflow course →