Customize Models & Tasks#
TorchDrug provides many popular model architectures for graph representation learning. However, you may still find yourself in need of some more customized architectures.
Here we illustrate the steps for writing customized models based on the example of variational graph auto encoder (VGAE). VGAE learns latent node representations with a graph convolutional network (GCN) encoder and an inner product decoder. They are jointly trained with a reconstruction loss and evaluated on the link prediction task.
As a convention, we separate representation models and task-specific designs for better reusability.
Node Representation Model#
In VGAE, the node representation model is a variational graph convolutional network
(VGCN). This can be implemented via standard graph convolution layers, plus a
variational regularization loss. We define our model as a subclass of nn.Module
and core.Configurable
.
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data as torch_data
from torchdrug import core, layers, datasets, metrics
from torchdrug.core import Registry as R
@R.register("models.VGCN")
class VariationalGraphConvolutionalNetwork(nn.Module, core.Configurable):
def __init__(self, input_dim, hidden_dims, beta=0, batch_norm=False,
activation="relu"):
super(VariationalGraphConvolutionalNetwork, self).__init__()
self.input_dim = input_dim
self.output_dim = hidden_dims[-1]
self.dims = [input_dim] + list(hidden_dims)
self.beta = beta
self.layers = nn.ModuleList()
for i in range(len(self.dims) - 2):
self.layers.append(
layers.GraphConv(self.dims[i], self.dims[i + 1], None,
batch_norm, activation)
)
self.layers.append(
layers.GraphConv(self.dims[-2], self.dims[-1] * 2, None, False, None)
)
The definition is similar to most other torch
models, except two points.
First, the decoration line @R.register("models.VGCN")
registers the model in
the library with the name models.VGCN
. This enables the model to be dumped
into string format and reconstructed later. Second, self.input_dim
and
self.output_dim
are set to inform other models that connect to it.
Then we implement the forward function. The forward function takes 4 arguments, graph(s), node input feature(s), the global loss and the global metric. The advanatage of these global variables is that they enable implementation of losses in a distributed, module-centric manner.
We compute the variational regularization loss, and add it to the global loss and the global metric.
def reparameterize(self, mu, log_sigma):
if self.training:
z = mu + torch.rand_like(mu) * log_sigma.exp()
else:
z = mu
return z
def forward(self, graph, input, all_loss=None, metric=None):
x = input
for layer in self.layers:
x = layer(graph, x)
mu, log_sigma = x.chunk(2, dim=-1)
node_feature = self.reparameterize(mu, log_sigma)
if all_loss is not None and self.beta > 0:
loss = 0.5 * (mu ** 2 + log_sigma.exp() ** 2 - 2 * log_sigma - 1)
loss = loss.sum(dim=-1).mean()
all_loss += loss * self.beta
metric["variational regularization loss"] = loss
return {
"node_feature": node_feature
}
Here we explicitly return a dict to indicate the type of our representations. The dict may also contain other representations, such edge representations or graph representations.
Link Prediction Task#
Here we show how to implement the link prediction task for VGAE.
Generally, a task in TorchDrug contains 4 functions, predict()
, target()
,
forward
and evaluate()
. Such interfaces empower us to seamlessly switch
between different devices, such as CPUs, GPUs or even the distributed setting.
Among the above functions, predict()
and target()
compute the prediction and
the ground truth for a batch respectively. forward()
compute the training loss,
while evaluate()
compute the evaluation metrics.
Optionally, one can also implement preprocess()
function, which performs
arbitrary operations based on the dataset.
In the case of VGAE, we first compute the undirected training graph in
preprocess()
. In predict()
, we perform negative sampling, and predict
the logits for both positive and negative edges. In target()
, we return
the ground truth label for edges. evaluate()
computes the area under ROC curve
for the predictions.
@R.register("tasks.LinkPrediction")
class LinkPrediction(tasks.Task, core.Configurable):
def __init__(self, model):
super(LinkPrediction, self).__init__()
self.model = model
def preprocess(self, train_set, valid_set, test_set):
dataset = train_set.dataset
graph = dataset.graph
train_graph = dataset.graph.edge_mask(train_set.indices)
# flip the edges to make the graph undirected
edge_list = train_graph.edge_list.repeat(2, 1)
edge_list[train_graph.num_edge:, :2] = edge_list[train_graph.num_edge:, :2] \
.flip(1)
index = torch.arange(train_graph.num_edge, device=self.device) \
.repeat(2, 1).t().flatten()
data_dict, meta_dict = train_graph.data_mask(edge_index=index)
train_graph = type(train_graph)(
edge_list, edge_weight=train_graph.edge_weight[index],
num_node=train_graph.num_node, num_edge=train_graph.num_edge * 2,
meta_dict=meta_dict, **data_dict
)
self.register_buffer("train_graph", train_graph)
self.num_node = dataset.num_node
def forward(self, batch):
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
pred = self.predict(batch, all_loss, metric)
target = self.target(batch)
metric.update(self.evaluate(pred, target))
loss = F.binary_cross_entropy_with_logits(pred, target)
metric["bce loss"] = loss
all_loss += loss
return all_loss, metric
def predict(self, batch, all_loss=None, metric=None):
neg_batch = torch.randint(self.num_node, batch.shape, device=self.device)
batch = torch.cat([batch, neg_batch])
node_in, node_out = batch.t()
output = self.model(self.train_graph, self.train_graph.node_feature.float(),
all_loss, metric)
node_feature = output["node_feature"]
pred = torch.einsum("bd, bd -> b",
node_feature[node_in], node_feature[node_out])
return pred
def target(self, batch):
batch_size = len(batch)
target = torch.zeros(batch_size * 2, device=self.device)
target[:batch_size] = 1
return target
def evaluate(self, pred, target):
roc = metrics.area_under_roc(pred, target)
return {
"AUROC": roc
}
Put Them Together#
Let’s put all the ingredients together. Since the original Cora is a node classification dataset, we apply a wrapper to make it compatible with link prediction.
class CoraLinkPrediction(datasets.Cora):
def __getitem__(self, index):
return self.graph.edge_list[index]
def __len__(self):
return self.graph.num_edge
dataset = CoraLinkPrediction("~/node-datasets/")
lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = torch_data.random_split(dataset, lengths)
model = VariationalGraphConvolutionalNetwork(dataset.node_feature_dim, [128, 16],
beta=1e-3, batch_norm=True)
task = LinkPrediction(model)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-2)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer, gpus=[0],
batch_size=len(train_set))
solver.train(num_epoch=200)
solver.evaluate("valid")
The result may look like
AUROC: 0.898589