Models
This module implements ready-to-use models from recent literature.
GCN
spektral.models.gcn.GCN(n_labels, channels=16, activation='relu', output_activation='softmax', use_bias=False, dropout_rate=0.5, l2_reg=0.00025)
This model, with its default hyperparameters, implements the architecture from the paper:
Semi-Supervised Classification with Graph Convolutional Networks
Thomas N. Kipf and Max Welling
Mode: single, disjoint, mixed, batch.
Input
- Node features of shape
([batch], n_nodes, n_node_features)
- Weighted adjacency matrix of shape
([batch], n_nodes, n_nodes)
Output
- Softmax predictions with shape
([batch], n_nodes, n_labels)
.
Arguments
n_labels
: number of channels in output;channels
: number of channels in first GCNConv layer;activation
: activation of the first GCNConv layer;output_activation
: activation of the second GCNConv layer;use_bias
: whether to add a learnable bias to the two GCNConv layers;dropout_rate
:rate
used inDropout
layers;l2_reg
: l2 regularization strength;**kwargs
: passed toModel.__init__
.
GeneralGNN
spektral.models.general_gnn.GeneralGNN(output, activation=None, hidden=256, message_passing=4, pre_process=2, post_process=2, connectivity='cat', batch_norm=True, dropout=0.0, aggregate='sum', hidden_activation='prelu', pool='sum')
This model implements the GNN architecture from the paper
Design Space for Graph Neural Networks
Jiaxuan You, Rex Ying, Jure Leskovec
Mode: single, disjoint, mixed.
The default parameters of the model are selected according to the best results obtained in the paper, and should provide a good performance on many node-level and graph-level tasks, without modifications. The defaults are as follows:
- 256 hidden channels
- 4 message passing layers
- 2 pre-processing layers
- 2 post-processing layers
- Skip connections with concatenation
- Batch normalization
- No dropout
- PReLU activations
- Sum aggregation in the message-passing layers
- Global sum pooling (not from the paper)
The GNN uses the GeneralConv
layer
for message passing, and has a pre- and a post-processing MLP for the node
features.
Message-passing layers also have optional skip connections, which can be
implemented as sum or concatenation.
The dense layers of the pre-processing and post-processing MLPs compute the following update of the node features:
Message-passing layers compute:
Arguments
output
: int, the number of output units;activation
: the activation function of the output layer.hidden
: int, the number of hidden units for all layers except the output one;message_passing
: int, the nummber of message-passing layers;pre_process
: int, the number of layers in the pre-processing MLP;post_process
: int, the number of layers in the post-processing MLP;connectivity
: the type of skip connection. Can be: None, 'sum' or 'cat';batch_norm
: bool, whether to use batch normalization;dropout
: float, dropout rate;aggregate
: string or callable, an aggregation function. Supported aggregations: 'sum', 'mean', 'max', 'min', 'prod'.hidden_activation
: activation function in the hidden layers. The PReLU activation can be used by passinghidden_activation='prelu'
.pool
: string or None, the global pooling function. If None, no global pooling is applied (e.g., for node-level learning). Supported pooling methods: 'sum', 'avg', 'max', 'attn', 'attn_sum', 'sort' (seespektral.layers.pooling.global_pool
).
GNNExplainer
spektral.models.gnn_explainer.GNNExplainer(model, n_hops=None, preprocess=None, graph_level=False, verbose=False, learning_rate=0.01, a_size_coef=0.0005, x_size_coef=0.1, a_entropy_coef=0.1, x_entropy_coef=0.1, laplacian_coef=0.0)
The GNNExplainer model from the paper:
GNNExplainer: Generating Explanations for Graph Neural Networks
Rex Ying, Dylan Bourgeois, Jiaxuan You, Marinka Zitnik and Jure Leskovec.
The model can be used to explain the predictions for a single node or for an entire graph. In both cases, it returns the subgraph that mostly contributes to the prediction.
Arguments
model
: tf.keras.Model to explain;n_hops
: number of hops from which the GNN aggregates info. IfNone
, then the number is inferred from the Conv and MessagePassing layers in the model.preprocess
: a preprocessing function to transform the adjacency matrix before giving it as input to the GNN; this is usually the samepreprocess
function of the Conv or MessagePassing layers used in the GNN (e.g.,GCNConv.preprocess
).graph_level
: if True, the GNN is assumed to be for graph-level prediction and the explanation is computed for the whole graph (and not just a node).verbose
: if True, print info during training;learning_rate
: learning rate when training the model;a_size_coef
: coefficient to control the number of edges of the subgraph that contributes to the prediction;x_size_coef
: coefficient to control the number of features of the subgraph that contributes to the prediction;a_entropy_coef
: coefficient to control the discretization of the adjacency mask;x_entropy_coef
: coefficient to control the discretization of the features mask;laplacian_coef
: coefficient to control the graph Laplacian loss;