Models

This module implements ready-to-use models from recent literature.

[source]

GCN

spektral.models.gcn.GCN(n_labels, channels=16, activation='relu', output_activation='softmax', use_bias=False, dropout_rate=0.5, l2_reg=0.00025, n_input_channels=None)

This model, with its default hyperparameters, implements the architecture from the paper:

Semi-Supervised Classification with Graph Convolutional Networks
Thomas N. Kipf and Max Welling

Mode: single, disjoint, mixed, batch.

Input

  • Node features of shape ([batch], n_nodes, n_node_features)
  • Weighted adjacency matrix of shape ([batch], n_nodes, n_nodes)

Output

  • Softmax predictions with shape ([batch], n_nodes, n_labels).

Arguments

  • n_labels: number of channels in output;
  • channels: number of channels in first GCNConv layer;
  • activation: activation of the first GCNConv layer;
  • output_activation: activation of the second GCNConv layer;
  • use_bias: whether to add a learnable bias to the two GCNConv layers;
  • dropout_rate: rate used in Dropout layers;
  • l2_reg: l2 regularization strength;
  • n_input_channels: number of input channels, required for tf 2.1;
  • **kwargs: passed to Model.__init__.

[source]

GeneralGNN

spektral.models.general_gnn.GeneralGNN(output, activation=None, hidden=256, message_passing=4, pre_process=2, post_process=2, connectivity='cat', batch_norm=True, dropout=0.0, aggregate='sum', hidden_activation='prelu', pool='sum')

This model implements the GNN architecture from the paper

Design Space for Graph Neural Networks
Jiaxuan You, Rex Ying, Jure Leskovec

The default parameters of the model are selected according to the best results obtained in the paper, and should provide a good performance on many node-level and graph-level tasks, without modifications. The defaults are as follows:

  • 256 hidden channels
  • 4 message passing layers
  • 2 pre-processing layers
  • 2 post-processing layers
  • Skip connections with concatenation
  • Batch normalization
  • No dropout
  • PReLU activations
  • Sum aggregation in the message-passing layers
  • Global sum pooling (not from the paper)

The GNN uses the GeneralConv layer for message passing, and has a pre- and a post-processing MLP for the node features. Message-passing layers also have optional skip connections, which can be implemented as sum or concatenation.

The dense layers of the pre-processing and post-processing MLPs compute the following update of the node features:

Message-passing layers compute:

Arguments

  • output: int, the number of output units;
  • activation: the activation function of the output layer.
  • hidden: int, the number of hidden units for all layers except the output one;
  • message_passing: int, the nummber of message-passing layers;
  • pre_process: int, the number of layers in the pre-processing MLP;
  • post_process: int, the number of layers in the post-processing MLP;
  • connectivity: the type of skip connection. Can be: None, 'sum' or 'cat';
  • batch_norm: bool, whether to use batch normalization;
  • dropout: float, dropout rate;
  • aggregate: string or callable, an aggregation function. Supported aggregations: 'sum', 'mean', 'max', 'min', 'prod'.
  • hidden_activation: activation function in the hidden layers. The PReLU activation can be used by passing hidden_activation='prelu'.
  • pool: string or None, the global pooling function. If None, no global pooling is applied (e.g., for node-level learning). Supported pooling methods: 'sum', 'avg', 'max', 'attn', 'attn_sum', 'sort' (see spektral.layers.pooling.global_pool).