Creating a Message-Passing Layer
In this tutorial we go over the
MessagePassing interface for creating GNN layers.
This is a very flexible class that is based on three main functions: message, aggregate and update. By overriding these methods, you can define the behaviour of your own layers.
MessagePassing layer can be subclassed to create layers that work in single and disjoint mode using sparse adjacency matrices. This ensures that your layers will work for both node-level and graph-level learning while being very computationally efficient.
The functionality of these layers is defined by the
update methods, and is summarized as follows:
x_out[i] = update(x[i], aggregate([message(x[j]) for j in neighbours(i)]))
message function computes a transformation of the neighbours of each node. The
aggregate function aggregates the messages in a way that is independent of the order in which the messages are processed (like a sum, an average, the maximum, etc). The
update function takes the aggregated messages from the neighbours and decides how to transform each node.
This message-passing scheme is computed by calling the
propagate method of the class, which will return the updated node features (
In this example we will implement a graph convolutional network (Kipf & and Welling, 2016) using the MessagePassing interface.
First, let's add some trainable parameters when creating the layer:
class GCN(MessagePassing): def __init__(self, n_out, activation): super().__init__(activation=activation) self.n_out = n_out def build(self, input_shape): n_in = input_shape[-1] self.weights = self.add_weight(shape=(n_in, self.n_out))
Note that the Keras keyword
activation was passed to the constructor of the superclass. This can be done with any Keras keyword (like regularizers, constraints, etc) and the layer will process them automatically.
By default, the
call method of MessagePassing layers will only call
propagate. We modify it so that it also transforms the node features before starting the propagation:
def call(self, inputs): x, a = inputs # Update node features x = tf.matmul(x, self.weights) return self.propagate(x=x, a=a)
Then, we implement the
get_targets built-in methods can be used to automatically retrieve the node attributes of nodes that are sending (sources) or receiving (targets) a message.
For instance, we can use
get_targets to access the node features
x[j] of all neighbors
If you need direct access to the edge indices, you can use the
In this case, we only need to get the neighbors' features and return them:
def message(self, x): # Get the node features of all neighbors return self.get_sources(x)
Then, we define an aggregation function for the messages. We can use a simple average of the nodes:
from spektral.layers.ops import scatter_mean def aggregate(self, messages): return scatter_mean(messages, self.index_targets, self.n_nodes)
n_nodes is computed dynamically at the start of propagation, exactly like
Since there are a few common aggregation functions that are often used in the literature, you can also skip the implementation of this method and simply pass a special keyword to the
__init__() method of the superclass:
def __init__(self): # Equivalent to the above implementation of aggregate super().__init__(aggregate='mean')
Finally, we can use the
update method to apply the activation function:
def update(self, embeddings): return self.activation(embeddings)
This is enough to get started with building your own layers in Spektral.
An important feature of the MessagePassing class is that any extra keyword argument given to
propagate, will be compared to the signatures of
update and forwarded to those functions if a match is found.
For example, we can call:
propagate(x=x, a=a, extra_tensor=extra_tensor)
and define the message function as:
def message(self, x, extra_tensor=None): ... # Do something with extra_tensor
Finally, we already noted that MessagePassing layers only support single and disjoint mode, and they also require that the adjacency matrix is a SparseTensor.
If you need more control on your layers, you can have a look at
spektral.layers.Conv for a stripped-down class that performs no checks on the inputs and only implements some essential features like keyword parsing.
spektral.layers.GCNConv implements the same GCN layer that we just saw, using the
Conv class so that it can provide support for batch and mixed mode, as well as dense adjacency matrices.