Source code for torchdrug.layers.pool

import torch
from torch import nn
from torch.nn import functional as F
from torch_scatter import scatter_add, scatter_mean

from torchdrug import data


[docs]class DiffPool(nn.Module): """ Differentiable pooling operator from `Hierarchical Graph Representation Learning with Differentiable Pooling`_ .. _Hierarchical Graph Representation Learning with Differentiable Pooling: https://papers.nips.cc/paper/7729-hierarchical-graph-representation-learning-with-differentiable-pooling.pdf Parameter input_dim (int): input dimension output_node (int): number of nodes after pooling feature_layer (Module, optional): graph convolution layer for embedding pool_layer (Module, optional): graph convolution layer for pooling assignment loss_weight (float, optional): weight of entropy regularization zero_diagonal (bool, optional): remove self loops in the pooled graph or not sparse (bool, optional): use sparse assignment or not """ tau = 1 eps = 1e-10 def __init__(self, input_dim, output_node, feature_layer=None, pool_layer=None, loss_weight=1, zero_diagonal=False, sparse=False): super(DiffPool, self).__init__() self.input_dim = input_dim self.output_dim = feature_layer.output_dim self.output_node = output_node self.feature_layer = feature_layer self.pool_layer = pool_layer self.loss_weight = loss_weight self.zero_diagonal = zero_diagonal self.sparse = sparse if pool_layer is not None: self.linear = nn.Linear(pool_layer.output_dim, output_node) else: self.linear = nn.Linear(input_dim, output_node)
[docs] def forward(self, graph, input, all_loss=None, metric=None): """ Compute the node cluster assignment and pool the nodes. Parameters: graph (Graph): graph(s) input (Tensor): input node representations all_loss (Tensor, optional): if specified, add loss to this tensor metric (dict, optional): if specified, output metrics to this dict Returns: (PackedGraph, Tensor, Tensor): pooled graph, output node representations, node-to-cluster assignment """ feature = input if self.feature_layer: feature = self.feature_layer(graph, feature) x = input if self.pool_layer: x = self.pool_layer(graph, x) x = self.linear(x) if self.sparse: assignment = F.gumbel_softmax(x, hard=True, tau=self.tau, dim=-1) new_graph, output = self.sparse_pool(graph, feature, assignment) else: assignment = F.softmax(x, dim=-1) new_graph, output = self.dense_pool(graph, feature, assignment) if all_loss is not None: prob = scatter_mean(assignment, graph.node2graph, dim=0, dim_size=graph.batch_size) entropy = -(prob * (prob + self.eps).log()).sum(dim=-1) entropy = entropy.mean() metric["assignment entropy"] = entropy if self.loss_weight > 0: all_loss -= entropy * self.loss_weight if self.zero_diagonal: edge_list = new_graph.edge_list[:, :2] is_diagonal = edge_list[:, 0] == edge_list[:, 1] new_graph = new_graph.edge_mask(~is_diagonal) return new_graph, output, assignment
def dense_pool(self, graph, input, assignment): node_in, node_out = graph.edge_list.t()[:2] # S^T A S, O(|V|k^2 + |E|k) x = graph.edge_weight.unsqueeze(-1) * assignment[node_out] x = scatter_add(x, node_in, dim=0, dim_size=graph.num_node) x = torch.einsum("np, nq -> npq", assignment, x) adjacency = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size) # S^T X x = torch.einsum("na, nd -> nad", assignment, input) output = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size).flatten(0, 1) index = torch.arange(self.output_node, device=graph.device).expand(len(graph), self.output_node, -1) edge_list = torch.stack([index.transpose(-1, -2), index], dim=-1).flatten(0, -2) edge_weight = adjacency.flatten() if isinstance(graph, data.PackedGraph): num_nodes = torch.ones(len(graph), dtype=torch.long, device=input.device) * self.output_node num_edges = torch.ones(len(graph), dtype=torch.long, device=input.device) * self.output_node ** 2 graph = data.PackedGraph(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges) else: graph = data.Graph(edge_list, edge_weight=edge_weight, num_node=self.output_node) return graph, output def sparse_pool(self, graph, input, assignment): assignment = assignment.argmax(dim=-1) edge_list = graph.edge_list[:, :2] edge_list = assignment[edge_list] pooled_node = graph.node2graph * self.output_node + assignment output = scatter_add(input, pooled_node, dim=0, dim_size=graph.batch_size * self.output_node) edge_weight = graph.edge_weight if isinstance(graph, data.PackedGraph): num_nodes = torch.ones(len(graph), dtype=torch.long, device=input.device) * self.output_node num_edges = graph.num_edges graph = data.PackedGraph(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges) else: graph = data.Graph(edge_list, edge_weight=edge_weight, num_node=self.output_node) return graph, output
[docs]class MinCutPool(DiffPool): """ Min cut pooling operator from `Spectral Clustering with Graph Neural Networks for Graph Pooling`_ .. _Spectral Clustering with Graph Neural Networks for Graph Pooling: http://proceedings.mlr.press/v119/bianchi20a/bianchi20a.pdf Parameters: input_dim (int): input dimension output_node (int): number of nodes after pooling feature_layer (Module, optional): graph convolution layer for embedding pool_layer (Module, optional): graph convolution layer for pooling assignment loss_weight (float, optional): weight of entropy regularization zero_diagonal (bool, optional): remove self loops in the pooled graph or not sparse (bool, optional): use sparse assignment or not """ eps = 1e-10 def __init__(self, input_dim, output_node, feature_layer=None, pool_layer=None, loss_weight=1, zero_diagonal=True, sparse=False): super(MinCutPool, self).__init__(input_dim, output_node, feature_layer, pool_layer, loss_weight, zero_diagonal, sparse)
[docs] def forward(self, graph, input, all_loss=None, metric=None): """ Compute the node cluster assignment and pool the nodes. Parameters: graph (Graph): graph(s) input (Tensor): input node representations all_loss (Tensor, optional): if specified, add loss to this tensor metric (dict, optional): if specified, output metrics to this dict Returns: (PackedGraph, Tensor, Tensor): pooled graph, output node representations, node-to-cluster assignment """ feature = input if self.feature_layer: feature = self.feature_layer(graph, feature) x = input if self.pool_layer: x = self.pool_layer(graph, x) x = self.linear(x) if self.sparse: assignment = F.gumbel_softmax(x, hard=True, tau=self.tau, dim=-1) new_graph, output = self.sparse_pool(graph, feature, assignment) else: assignment = F.softmax(x, dim=-1) new_graph, output = self.dense_pool(graph, feature, assignment) if all_loss is not None: edge_list = new_graph.edge_list is_diagonal = edge_list[:, 0] == edge_list[:, 1] num_intra = scatter_add(new_graph.edge_weight[is_diagonal], new_graph.edge2graph[is_diagonal], dim=0, dim_size=new_graph.batch_size) x = torch.einsum("na, n, nc -> nac", assignment, graph.degree_in, assignment) x = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size) num_all = torch.einsum("baa -> b", x) cut_loss = (1 - num_intra / (num_all + self.eps)).mean() metric["normalized cut loss"] = cut_loss x = torch.einsum("na, nc -> nac", assignment, assignment) x = scatter_add(x, graph.node2graph, dim=0, dim_size=graph.batch_size) x = x / x.flatten(-2).norm(dim=-1, keepdim=True).unsqueeze(-1) x = x - torch.eye(self.output_node, device=x.device) / (self.output_node ** 0.5) regularization = x.flatten(-2).norm(dim=-1).mean() metric["orthogonal regularization"] = regularization if self.loss_weight > 0: all_loss += (cut_loss + regularization) * self.loss_weight if self.zero_diagonal: edge_list = new_graph.edge_list[:, :2] is_diagonal = edge_list[:, 0] == edge_list[:, 1] new_graph = new_graph.edge_mask(~is_diagonal) return new_graph, output, assignment