Customize Models & Tasks ======================== TorchDrug provides many popular model architectures for graph representation learning. However, you may still find yourself in need of some more customized architectures. Here we illustrate the steps for writing customized models based on the example of `variational graph auto encoder`_ (VGAE). VGAE learns latent node representations with a graph convolutional network (GCN) encoder and an inner product decoder. They are jointly trained with a reconstruction loss and evaluated on the link prediction task. .. _variational graph auto encoder: https://arxiv.org/pdf/1611.07308.pdf As a convention, we separate representation models and task-specific designs for better reusability. Node Representation Model ------------------------- In VGAE, the node representation model is a variational graph convolutional network (VGCN). This can be implemented via standard graph convolution layers, plus a variational regularization loss. We define our model as a subclass of `nn.Module` and :class:`core.Configurable `. .. code:: python import torch from torch import nn from torch.nn import functional as F from torch.utils import data as torch_data from torchdrug import core, layers, datasets, metrics from torchdrug.core import Registry as R @R.register("models.VGCN") class VariationalGraphConvolutionalNetwork(nn.Module, core.Configurable): def __init__(self, input_dim, hidden_dims, beta=0, batch_norm=False, activation="relu"): super(VariationalGraphConvolutionalNetwork, self).__init__() self.input_dim = input_dim self.output_dim = hidden_dims[-1] self.dims = [input_dim] + list(hidden_dims) self.beta = beta self.layers = nn.ModuleList() for i in range(len(self.dims) - 2): self.layers.append( layers.GraphConv(self.dims[i], self.dims[i + 1], None, batch_norm, activation) ) self.layers.append( layers.GraphConv(self.dims[-2], self.dims[-1] * 2, None, False, None) ) The definition is similar to most other ``torch`` models, except two points. First, the decoration line ``@R.register("models.VGCN")`` registers the model in the library with the name ``models.VGCN``. This enables the model to be dumped into string format and reconstructed later. Second, ``self.input_dim`` and ``self.output_dim`` are set to inform other models that connect to it. Then we implement the forward function. The forward function takes 4 arguments, graph(s), node input feature(s), the global loss and the global metric. The advanatage of these global variables is that they enable implementation of losses in a distributed, module-centric manner. We compute the variational regularization loss, and add it to the global loss and the global metric. .. code:: def reparameterize(self, mu, log_sigma): if self.training: z = mu + torch.rand_like(mu) * log_sigma.exp() else: z = mu return z def forward(self, graph, input, all_loss=None, metric=None): x = input for layer in self.layers: x = layer(graph, x) mu, log_sigma = x.chunk(2, dim=-1) node_feature = self.reparameterize(mu, log_sigma) if all_loss is not None and self.beta > 0: loss = 0.5 * (mu ** 2 + log_sigma.exp() ** 2 - 2 * log_sigma - 1) loss = loss.sum(dim=-1).mean() all_loss += loss * self.beta metric["variational regularization loss"] = loss return { "node_feature": node_feature } Here we explicitly return a dict to indicate the type of our representations. The dict may also contain other representations, such edge representations or graph representations. Link Prediction Task -------------------- Here we show how to implement the link prediction task for VGAE. Generally, a task in TorchDrug contains 4 functions, ``predict()``, ``target()``, ``forward`` and ``evaluate()``. Such interfaces empower us to seamlessly switch between different devices, such as CPUs, GPUs or even the distributed setting. Among the above functions, ``predict()`` and ``target()`` compute the prediction and the ground truth for a batch respectively. ``forward()`` compute the training loss, while ``evaluate()`` compute the evaluation metrics. Optionally, one can also implement ``preprocess()`` function, which performs arbitrary operations based on the dataset. In the case of VGAE, we first compute the undirected training graph in ``preprocess()``. In ``predict()``, we perform negative sampling, and predict the logits for both positive and negative edges. In ``target()``, we return the ground truth label for edges. ``evaluate()`` computes the area under ROC curve for the predictions. .. code:: python @R.register("tasks.LinkPrediction") class LinkPrediction(tasks.Task, core.Configurable): def __init__(self, model): super(LinkPrediction, self).__init__() self.model = model def preprocess(self, train_set, valid_set, test_set): dataset = train_set.dataset graph = dataset.graph train_graph = dataset.graph.edge_mask(train_set.indices) # flip the edges to make the graph undirected edge_list = train_graph.edge_list.repeat(2, 1) edge_list[train_graph.num_edge:, :2] = edge_list[train_graph.num_edge:, :2] \ .flip(1) index = torch.arange(train_graph.num_edge, device=self.device) \ .repeat(2, 1).t().flatten() data_dict, meta_dict = train_graph.data_mask(edge_index=index) train_graph = type(train_graph)( edge_list, edge_weight=train_graph.edge_weight[index], num_node=train_graph.num_node, num_edge=train_graph.num_edge * 2, meta_dict=meta_dict, **data_dict ) self.register_buffer("train_graph", train_graph) self.num_node = dataset.num_node def forward(self, batch): all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred = self.predict(batch, all_loss, metric) target = self.target(batch) metric.update(self.evaluate(pred, target)) loss = F.binary_cross_entropy_with_logits(pred, target) metric["bce loss"] = loss all_loss += loss return all_loss, metric def predict(self, batch, all_loss=None, metric=None): neg_batch = torch.randint(self.num_node, batch.shape, device=self.device) batch = torch.cat([batch, neg_batch]) node_in, node_out = batch.t() output = self.model(self.train_graph, self.train_graph.node_feature.float(), all_loss, metric) node_feature = output["node_feature"] pred = torch.einsum("bd, bd -> b", node_feature[node_in], node_feature[node_out]) return pred def target(self, batch): batch_size = len(batch) target = torch.zeros(batch_size * 2, device=self.device) target[:batch_size] = 1 return target def evaluate(self, pred, target): roc = metrics.area_under_roc(pred, target) return { "AUROC": roc } Put Them Together ----------------- Let's put all the ingredients together. Since the original Cora is a node classification dataset, we apply a wrapper to make it compatible with link prediction. .. code:: python class CoraLinkPrediction(datasets.Cora): def __getitem__(self, index): return self.graph.edge_list[index] def __len__(self): return self.graph.num_edge dataset = CoraLinkPrediction("~/node-datasets/") lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))] lengths += [len(dataset) - sum(lengths)] train_set, valid_set, test_set = torch_data.random_split(dataset, lengths) model = VariationalGraphConvolutionalNetwork(dataset.node_feature_dim, [128, 16], beta=1e-3, batch_norm=True) task = LinkPrediction(model) optimizer = torch.optim.Adam(task.parameters(), lr=1e-2) solver = core.Engine(task, train_set, valid_set, test_set, optimizer, gpus=[0], batch_size=len(train_set)) solver.train(num_epoch=200) solver.evaluate("valid") The result may look like .. code:: bash AUROC: 0.898589