Customize Models & Tasks#

TorchDrug provides many popular model architectures for graph representation learning. However, you may still find yourself in need of some more customized architectures.

Here we illustrate the steps for writing customized models based on the example of variational graph auto encoder (VGAE). VGAE learns latent node representations with a graph convolutional network (GCN) encoder and an inner product decoder. They are jointly trained with a reconstruction loss and evaluated on the link prediction task.

As a convention, we separate representation models and task-specific designs for better reusability.

Node Representation Model#

In VGAE, the node representation model is a variational graph convolutional network (VGCN). This can be implemented via standard graph convolution layers, plus a variational regularization loss. We define our model as a subclass of nn.Module and core.Configurable.

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data as torch_data

from torchdrug import core, layers, datasets, metrics
from torchdrug.core import Registry as R

class VariationalGraphConvolutionalNetwork(nn.Module, core.Configurable):

    def __init__(self, input_dim, hidden_dims, beta=0, batch_norm=False,
        super(VariationalGraphConvolutionalNetwork, self).__init__()
        self.input_dim = input_dim
        self.output_dim = hidden_dims[-1]
        self.dims = [input_dim] + list(hidden_dims)
        self.beta = beta

        self.layers = nn.ModuleList()
        for i in range(len(self.dims) - 2):
                layers.GraphConv(self.dims[i], self.dims[i + 1], None,
                                 batch_norm, activation)
            layers.GraphConv(self.dims[-2], self.dims[-1] * 2, None, False, None)

The definition is similar to most other torch models, except two points. First, the decoration line @R.register("models.VGCN") registers the model in the library with the name models.VGCN. This enables the model to be dumped into string format and reconstructed later. Second, self.input_dim and self.output_dim are set to inform other models that connect to it.

Then we implement the forward function. The forward function takes 4 arguments, graph(s), node input feature(s), the global loss and the global metric. The advanatage of these global variables is that they enable implementation of losses in a distributed, module-centric manner.

We compute the variational regularization loss, and add it to the global loss and the global metric.

def reparameterize(self, mu, log_sigma):
        z = mu + torch.rand_like(mu) * log_sigma.exp()
        z = mu
    return z

def forward(self, graph, input, all_loss=None, metric=None):
    x = input
    for layer in self.layers:
        x = layer(graph, x)
    mu, log_sigma = x.chunk(2, dim=-1)
    node_feature = self.reparameterize(mu, log_sigma)

    if all_loss is not None and self.beta > 0:
        loss = 0.5 * (mu ** 2 + log_sigma.exp() ** 2 - 2 * log_sigma - 1)
        loss = loss.sum(dim=-1).mean()
        all_loss += loss * self.beta
        metric["variational regularization loss"] = loss

    return {
        "node_feature": node_feature

Here we explicitly return a dict to indicate the type of our representations. The dict may also contain other representations, such edge representations or graph representations.

Put Them Together#

Let’s put all the ingredients together. Since the original Cora is a node classification dataset, we apply a wrapper to make it compatible with link prediction.

class CoraLinkPrediction(datasets.Cora):

    def __getitem__(self, index):
        return self.graph.edge_list[index]

    def __len__(self):
        return self.graph.num_edge

dataset = CoraLinkPrediction("~/node-datasets/")
lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = torch_data.random_split(dataset, lengths)

model = VariationalGraphConvolutionalNetwork(dataset.node_feature_dim, [128, 16],
                                             beta=1e-3, batch_norm=True)
task = LinkPrediction(model)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-2)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer, gpus=[0],

The result may look like

AUROC: 0.898589