Source code for torchdrug.layers.sampler

from torch import nn
from torch_scatter import scatter_add

from torchdrug.layers import functional


[docs]class NodeSampler(nn.Module): """ Node sampler from `GraphSAINT: Graph Sampling Based Inductive Learning Method`_. .. _GraphSAINT\: Graph Sampling Based Inductive Learning Method: https://arxiv.org/pdf/1907.04931.pdf Parameters: budget (int, optional): number of node to keep ratio (int, optional): ratio of node to keep """ def __init__(self, budget=None, ratio=None): super(NodeSampler, self).__init__() if budget is None and ratio is None: raise ValueError("At least one of `budget` and `ratio` should be provided") self.budget = budget self.ratio = ratio
[docs] def forward(self, graph): """ Sample a subgraph from the graph. Parameters: graph (Graph): graph(s) """ # this is exact for a single graph # but approximate for packed graphs num_sample = graph.num_node if self.budget: num_sample = min(num_sample, self.budget) if self.ratio: num_sample = min(num_sample, int(self.ratio * graph.num_node)) prob = scatter_add(graph.edge_weight ** 2, graph.edge_list[:, 1], dim_size=graph.num_node) prob /= prob.mean() index = functional.multinomial(prob, num_sample) new_graph = graph.node_mask(index) node_out = new_graph.edge_list[:, 1] new_graph._edge_weight /= num_sample * prob[node_out] / graph.num_node return new_graph
[docs]class EdgeSampler(nn.Module): """ Edge sampler from `GraphSAINT: Graph Sampling Based Inductive Learning Method`_. .. _GraphSAINT\: Graph Sampling Based Inductive Learning Method: https://arxiv.org/pdf/1907.04931.pdf Parameters: budget (int, optional): number of node to keep ratio (int, optional): ratio of node to keep """ def __init__(self, budget=None, ratio=None): super(EdgeSampler, self).__init__() if budget is None and ratio is None: raise ValueError("At least one of `budget` and `ratio` should be provided") self.budget = budget self.ratio = ratio
[docs] def forward(self, graph): """ Sample a subgraph from the graph. Parameters: graph (Graph): graph(s) """ # this is exact for a single graph # but approximate for packed graphs node_in, node_out = graph.edge_list.t()[:2] num_sample = graph.num_edge if self.budget: num_sample = min(num_sample, self.budget) if self.ratio: num_sample = min(num_sample, int(self.ratio * graph.num_edge)) prob = 1 / graph.degree_out[node_out] + 1 / graph.degree_in[node_in] prob = prob / prob.mean() index = functional.multinomial(prob, num_sample) new_graph = graph.edge_mask(index) new_graph._edge_weight /= num_sample * prob[index] / graph.num_edge return new_graph