Source code for torchdrug.models.infograph

import copy

import torch
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R


[docs]@R.register("models.InfoGraph") class InfoGraph(nn.Module, core.Configurable): """ InfoGraph proposed in `InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization`_. .. _InfoGraph\: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization: https://arxiv.org/pdf/1908.01000.pdf Parameters: model (nn.Module): node & graph representation model num_mlp_layer (int, optional): number of MLP layers in mutual information estimators activation (str or function, optional): activation function loss_weight (float, optional): weight of both unsupervised & transfer losses separate_model (bool, optional): separate supervised and unsupervised encoders. If true, the unsupervised loss will be applied on a separate encoder, and a transfer loss is applied between the two encoders. """ def __init__(self, model, num_mlp_layer=2, activation="relu", loss_weight=1, separate_model=False): super(InfoGraph, self).__init__() self.model = model self.separate_model = separate_model self.loss_weight = loss_weight self.output_dim = self.model.output_dim if separate_model: self.unsupervised_model = copy.deepcopy(model) self.transfer_mi = layers.MutualInformation(model.output_dim, num_mlp_layer, activation) else: self.unsupervised_model = model self.unsupervised_mi = layers.MutualInformation(model.output_dim, num_mlp_layer, activation)
[docs] def forward(self, graph, input, all_loss=None, metric=None): """ Compute the node representations and the graph representation(s). Add the mutual information between graph and nodes to the loss. Parameters: graph (Graph): :math:`n` graph(s) input (Tensor): input node representations all_loss (Tensor, optional): if specified, add loss to this tensor metric (dict, optional): if specified, output metrics to this dict Returns: dict with ``node_feature`` and ``graph_feature`` fields: node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` """ output = self.model(graph, input) if all_loss is not None: if self.separate_model: unsupervised_output = self.unsupervised_model(graph, input) mutual_info = self.transfer_mi(output["graph_feature"], unsupervised_output["graph_feature"]) metric["distillation mutual information"] = mutual_info if self.loss_weight > 0: all_loss -= mutual_info * self.loss_weight else: unsupervised_output = output graph_index = graph.node2graph node_index = torch.arange(graph.num_node, device=graph.device) pair_index = torch.stack([graph_index, node_index], dim=-1) mutual_info = self.unsupervised_mi(unsupervised_output["graph_feature"], unsupervised_output["node_feature"], pair_index) metric["graph-node mutual information"] = mutual_info if self.loss_weight > 0: all_loss -= mutual_info * self.loss_weight return output