Loaders

[source]

Loader

spektral.data.loaders.Loader(dataset, batch_size=1, epochs=None, shuffle=True)

Parent class for data loaders. The role of a Loader is to iterate over a Dataset and yield batches of graphs to feed your Keras Models.

This is achieved by having a generator object that produces lists of Graphs, which are then collated together and returned as Tensors.

The core of a Loader is the collate(batch) method. This takes as input a list of Graph objects and returns a list of Tensors, np.arrays, or SparseTensors.

For instance, if all graphs have the same number of nodes and size of the attributes, a simple collation function can be:

def collate(self, batch):
    x = np.array([g.x for g in batch])
    a = np.array([g.a for g in batch)]
    return x, a

The load() method of a Loader returns an object that can be passed to a Keras model when using the fit, predict and evaluate functions. You can use it as follows:

model.fit(loader.load(), steps_per_epoch=loader.steps_per_epoch)

The steps_per_epoch property represents the number of batches that are in an epoch, and is a required keyword when calling fit, predict or evaluate with a Loader.

If you are using a custom training function, you can specify the input signature of your batches with the tf.TypeSpec system to avoid unnecessary re-tracings. The signature is computed automatically by calling loader.tf_signature().

For example, a simple training step can be written as:

@tf.function(input_signature=loader.tf_signature())  # Specify signature here
def train_step(inputs, target):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(target, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

We can then train our model in a loop as follows:

for batch in loader:
    train_step(*batch)

Arguments

  • dataset: a spektral.data.Dataset object;
  • batch_size: int, size of the mini-batches;
  • epochs: int, number of epochs to iterate over the dataset. By default (None) iterates indefinitely;
  • shuffle: bool, whether to shuffle the dataset at the start of each epoch.

[source]

SingleLoader

spektral.data.loaders.SingleLoader(dataset, epochs=None, sample_weights=None)

A Loader for single mode.

This loader produces Tensors representing a single graph. As such, it can only be used with Datasets of length 1 and the batch_size cannot be set.

The loader supports sample weights through the sample_weights argument. If given, then each batch will be a tuple (inputs, labels, sample_weights).

Arguments

  • dataset: a spektral.data.Dataset object with only one graph;
  • epochs: int, number of epochs to iterate over the dataset. By default (None) iterates indefinitely;
  • shuffle: bool, whether to shuffle the data at the start of each epoch;
  • sample_weights: Numpy array, will be appended to the output automatically.

Output

Returns a tuple (inputs, labels) or (inputs, labels, sample_weights).

inputs is a tuple containing the data matrices of the graph, only if they are not None:

  • x: same as dataset[0].x;
  • a: same as dataset[0].a (scipy sparse matrices are converted to SparseTensors);
  • e: same as dataset[0].e;

labels is the same as dataset[0].y.

sample_weights is the same array passed when creating the loader.


[source]

DisjointLoader

spektral.data.loaders.DisjointLoader(dataset, node_level=False, batch_size=1, epochs=None, shuffle=True)

A Loader for disjoint mode.

This loader represents a batch of graphs via their disjoint union.

The loader automatically computes a batch index tensor, containing integer indices that map each node to its corresponding graph in the batch.

The adjacency matrix os returned as a SparseTensor, regardless of the input.

If node_level=False, the labels are interpreted as graph-level labels and are stacked along an additional dimension. If node_level=True, then the labels are stacked vertically.

Note: TensorFlow 2.4 or above is required to use this Loader's load() method in a Keras training loop.

Arguments

  • dataset: a graph Dataset;
  • node_level: bool, if True stack the labels vertically for node-level prediction;
  • batch_size: size of the mini-batches;
  • epochs: number of epochs to iterate over the dataset. By default (None) iterates indefinitely;
  • shuffle: whether to shuffle the data at the start of each epoch.

Output

For each batch, returns a tuple (inputs, labels).

inputs is a tuple containing:

  • x: node attributes of shape [n_nodes, n_node_features];
  • a: adjacency matrices of shape [n_nodes, n_nodes];
  • e: edge attributes of shape [n_edges, n_edge_features];
  • i: batch index of shape [n_nodes].

labels have shape [batch, n_labels] if node_level=False or [n_nodes, n_labels] otherwise.


[source]

BatchLoader

spektral.data.loaders.BatchLoader(dataset, mask=False, batch_size=1, epochs=None, shuffle=True, node_level=False)

A Loader for batch mode.

This loader returns batches of graphs stacked along an extra dimension, with all "node" dimensions padded to be equal among all graphs.

If n_max is the number of nodes of the biggest graph in the batch, then the padding consist of adding zeros to the node features, adjacency matrix, and edge attributes of each graph so that they have shapes [n_max, n_node_features], [n_max, n_max], and [n_max, n_max, n_edge_features] respectively.

The zero-padding is done batch-wise, which saves up memory at the cost of more computation. If latency is an issue but memory isn't, or if the dataset has graphs with a similar number of nodes, you can use the PackedBatchLoader that zero-pads all the dataset once and then iterates over it.

Note that the adjacency matrix and edge attributes are returned as dense arrays.

if mask=True, node attributes will be extended with a binary mask that indicates valid nodes (the last feature of each node will be 1 if the node was originally in the graph and 0 if it is a fake node added by zero-padding).

Use this flag in conjunction with layers.base.GraphMasking to start the propagation of masks in a model (necessary for node-level prediction and models that use a dense pooling layer like DiffPool or MinCutPool).

If node_level=False, the labels are interpreted as graph-level labels and are returned as an array of shape [batch, n_labels]. If node_level=True, then the labels are padded along the node dimension and are returned as an array of shape [batch, n_max, n_labels].

Arguments

  • dataset: a graph Dataset;
  • mask: bool, whether to add a mask to the node features;
  • batch_size: int, size of the mini-batches;
  • epochs: int, number of epochs to iterate over the dataset. By default (None) iterates indefinitely;
  • shuffle: bool, whether to shuffle the data at the start of each epoch;
  • node_level: bool, if True pad the labels along the node dimension;

Output

For each batch, returns a tuple (inputs, labels).

inputs is a tuple containing:

  • x: node attributes of shape [batch, n_max, n_node_features];
  • a: adjacency matrices of shape [batch, n_max, n_max];
  • e: edge attributes of shape [batch, n_max, n_max, n_edge_features].

labels have shape [batch, n_labels] if node_level=False or [batch, n_max, n_labels] otherwise.


[source]

PackedBatchLoader

spektral.data.loaders.PackedBatchLoader(dataset, mask=False, batch_size=1, epochs=None, shuffle=True, node_level=False)

A BatchLoader that zero-pads the graphs before iterating over the dataset. This means that n_max is computed over the whole dataset and not just a single batch.

While using more memory than BatchLoader, this loader should reduce the computational overhead of padding each batch independently.

Use this loader if:

  • memory usage isn't an issue and you want to produce the batches as fast as possible;
  • the graphs in the dataset have similar sizes and there are no outliers in the dataset (i.e., anomalous graphs with many more nodes than the dataset average).

Arguments

  • dataset: a graph Dataset;
  • mask: bool, whether to add a mask to the node features;
  • batch_size: int, size of the mini-batches;
  • epochs: int, number of epochs to iterate over the dataset. By default (None) iterates indefinitely;
  • shuffle: bool, whether to shuffle the data at the start of each epoch;
  • node_level: bool, if True pad the labels along the node dimension;

Output

For each batch, returns a tuple (inputs, labels).

inputs is a tuple containing:

  • x: node attributes of shape [batch, n_max, n_node_features];
  • a: adjacency matrices of shape [batch, n_max, n_max];
  • e: edge attributes of shape [batch, n_max, n_max, n_edge_features].

labels have shape [batch, n_labels] if node_level=False or [batch, n_max, n_labels] otherwise.


[source]

MixedLoader

spektral.data.loaders.MixedLoader(dataset, batch_size=1, epochs=None, shuffle=True)

A Loader for mixed mode.

This loader returns batches where the node and edge attributes are stacked along an extra dimension, but the adjacency matrix is shared by all graphs.

The loader expects all node and edge features to have the same number of nodes and edges. The dataset is pre-packed like in a PackedBatchLoader.

Arguments

  • dataset: a graph Dataset;
  • batch_size: int, size of the mini-batches;
  • epochs: int, number of epochs to iterate over the dataset. By default (None) iterates indefinitely;
  • shuffle: bool, whether to shuffle the data at the start of each epoch.

Output

For each batch, returns a tuple (inputs, labels).

inputs is a tuple containing:

  • x: node attributes of shape [batch, n_nodes, n_node_features];
  • a: adjacency matrix of shape [n_nodes, n_nodes];
  • e: edge attributes of shape [batch, n_edges, n_edge_features].

labels have shape [batch, ..., n_labels].