Creating a Message-Passing Layer
In this tutorial we go over the MessagePassing
interface for creating GNN layers.
This is a very flexible class that is based on three main functions: message, aggregate and update. By overriding these methods, you can define the behaviour of your own layers.
Essential information
The MessagePassing
layer can be subclassed to create layers that work in single and disjoint mode using sparse adjacency matrices. This ensures that your layers will work for both node-level and graph-level learning while being very computationally efficient.
The functionality of these layers is defined by the message
, aggregate
and update
methods, and is summarized as follows:
x_out[i] = update(x[i], aggregate([message(x[j]) for j in neighbours(i)]))
The message
function computes a transformation of the neighbours of each node. The aggregate
function aggregates the messages in a way that is independent of the order in which the messages are processed (like a sum, an average, the maximum, etc). The update
function takes the aggregated messages from the neighbours and decides how to transform each node.
This message-passing scheme is computed by calling the propagate
method of the class, which will return the updated node features (x_out
).
Example
In this example we will implement a graph convolutional network (Kipf & and Welling, 2016) using the MessagePassing interface.
First, let's add some trainable parameters when creating the layer:
class GCN(MessagePassing):
def __init__(self, n_out, activation):
super().__init__(activation=activation)
self.n_out = n_out
def build(self, input_shape):
n_in = input_shape[0][-1]
self.weights = self.add_weight(shape=(n_in, self.n_out))
Note that the Keras keyword activation
was passed to the constructor of the superclass. This can be done with any Keras keyword (like regularizers, constraints, etc) and the layer will process them automatically.
By default, the call
method of MessagePassing layers will only call propagate
. We modify it so that it also transforms the node features before starting the propagation:
def call(self, inputs):
x, a = inputs
# Update node features
x = tf.matmul(x, self.weights)
return self.propagate(x=x, a=a)
Then, we implement the message
function.
The get_sources
and get_targets
built-in methods can be used to automatically retrieve the node attributes of nodes that are sending (sources) or receiving (targets) a message.
For instance, we can use get_targets
to access the node features x[j]
of all neighbors j
.
If you need direct access to the edge indices, you can use the index_sources
and index_targets
attributes.
In this case, we only need to get the neighbors' features and return them:
def message(self, x):
# Get the node features of all neighbors
return self.get_sources(x)
Then, we define an aggregation function for the messages. We can use a simple average of the nodes:
from spektral.layers.ops import scatter_mean
def aggregate(self, messages):
return scatter_mean(messages, self.index_targets, self.n_nodes)
Note: n_nodes
is computed dynamically at the start of propagation, exactly like index_targets
.
Since there are a few common aggregation functions that are often used in the literature, you can also skip the implementation of this method and simply pass a special keyword to the __init__()
method of the superclass:
def __init__(self):
# Equivalent to the above implementation of aggregate
super().__init__(aggregate='mean')
Finally, we can use the update
method to apply the activation function:
def update(self, embeddings):
return self.activation(embeddings)
This is enough to get started with building your own layers in Spektral.
Notes
An important feature of the MessagePassing class is that any extra keyword argument given to propagate
, will be compared to the signatures of message
, aggregate
and update
and forwarded to those functions if a match is found.
For example, we can call:
propagate(x=x, a=a, extra_tensor=extra_tensor)
and define the message function as:
def message(self, x, extra_tensor=None):
... # Do something with extra_tensor
Finally, we already noted that MessagePassing layers only support single and disjoint mode, and they also require that the adjacency matrix is a SparseTensor.
If you need more control on your layers, you can have a look at spektral.layers.Conv
for a stripped-down class that performs no checks on the inputs and only implements some essential features like keyword parsing.
For example, spektral.layers.GCNConv
implements the same GCN layer that we just saw, using the Conv
class so that it can provide support for batch and mixed mode, as well as dense adjacency matrices.