Source code for torchdrug.data.graph

import math
import warnings
from functools import reduce
from collections import defaultdict

import networkx as nx

from matplotlib import pyplot as plt
import torch
from torch_scatter import scatter_add, scatter_min

from torchdrug import core, utils
from torchdrug.data import Dictionary
from torchdrug.utils import pretty

plt.switch_backend("agg")


[docs]class Graph(core._MetaContainer): r""" Basic container for sparse graphs. To batch graphs with variadic sizes, use :meth:`data.Graph.pack <torchdrug.data.Graph.pack>`. This will return a PackedGraph object with the following block diagonal adjacency matrix. .. math:: \begin{bmatrix} A_1 & \cdots & 0 \\ \vdots & \ddots & \vdots \\ 0 & \cdots & A_n \end{bmatrix} where :math:`A_i` is the adjacency of :math:`i`-th graph. You may register dynamic attributes for each graph. The registered attributes will be automatically processed during packing. .. warning:: This class doesn't enforce any order on the edges. Example:: >>> graph = data.Graph(torch.randint(10, (30, 2))) >>> with graph.node(): >>> graph.my_node_attr = torch.rand(10, 5, 5) Parameters: edge_list (array_like, optional): list of edges of shape :math:`(|E|, 2)` or :math:`(|E|, 3)`. Each tuple is (node_in, node_out) or (node_in, node_out, relation). edge_weight (array_like, optional): edge weights of shape :math:`(|E|,)` num_node (int, optional): number of nodes. By default, it will be inferred from the largest id in `edge_list` num_relation (int, optional): number of relations node_feature (array_like, optional): node features of shape :math:`(|V|, ...)` edge_feature (array_like, optional): edge features of shape :math:`(|E|, ...)` graph_feature (array_like, optional): graph feature of any shape """ _meta_types = {"node", "edge", "graph", "node reference", "edge reference", "graph reference"} def __init__(self, edge_list=None, edge_weight=None, num_node=None, num_relation=None, node_feature=None, edge_feature=None, graph_feature=None, **kwargs): super(Graph, self).__init__(**kwargs) # edge_list: N * [h, t] or N * [h, t, r] edge_list, num_edge = self._standarize_edge_list(edge_list, num_relation) edge_weight = self._standarize_edge_weight(edge_weight, edge_list) num_node = self._standarize_num_node(num_node, edge_list) num_relation = self._standarize_num_relation(num_relation, edge_list) self._edge_list = edge_list self._edge_weight = edge_weight self.num_node = num_node self.num_edge = num_edge self.num_relation = num_relation if node_feature is not None: with self.node(): self.node_feature = torch.as_tensor(node_feature, device=self.device) if edge_feature is not None: with self.edge(): self.edge_feature = torch.as_tensor(edge_feature, device=self.device) if graph_feature is not None: with self.graph(): self.graph_feature = torch.as_tensor(graph_feature, device=self.device)
[docs] def node(self): """ Context manager for node attributes. """ return self.context("node")
[docs] def edge(self): """ Context manager for edge attributes. """ return self.context("edge")
[docs] def graph(self): """ Context manager for graph attributes. """ return self.context("graph")
[docs] def node_reference(self): """ Context manager for node references. """ return self.context("node reference")
[docs] def edge_reference(self): """ Context manager for edge references. """ return self.context("edge reference")
[docs] def graph_reference(self): """ Context manager for graph references. """ return self.context("graph reference")
def _check_attribute(self, key, value): for type in self._meta_contexts: if "reference" in type: if value.dtype != torch.long: raise TypeError("Tensors used as reference must be long tensors") if type == "node": if len(value) != self.num_node: raise ValueError("Expect node attribute `%s` to have shape (%d, *), but found %s" % (key, self.num_node, value.shape)) elif type == "edge": if len(value) != self.num_edge: raise ValueError("Expect edge attribute `%s` to have shape (%d, *), but found %s" % (key, self.num_edge, value.shape)) elif type == "node reference": is_valid = (value >= -1) & (value < self.num_node) if not is_valid.all(): error_value = value[~is_valid] raise ValueError("Expect node reference in [-1, %d), but found %d" % (self.num_node, error_value[0])) elif type == "edge reference": is_valid = (value >= -1) & (value < self.num_edge) if not is_valid.all(): error_value = value[~is_valid] raise ValueError("Expect edge reference in [-1, %d), but found %d" % (self.num_edge, error_value[0])) elif type == "graph reference": is_valid = (value >= -1) & (value < self.batch_size) if not is_valid.all(): error_value = value[~is_valid] raise ValueError("Expect graph reference in [-1, %d), but found %d" % (self.batch_size, error_value[0])) def __setattr__(self, key, value): if hasattr(self, "meta_dict"): self._check_attribute(key, value) super(Graph, self).__setattr__(key, value) def _standarize_edge_list(self, edge_list, num_relation): if edge_list is not None and len(edge_list): if isinstance(edge_list, torch.Tensor) and edge_list.dtype != torch.long: try: edge_list = torch.LongTensor(edge_list) except TypeError: raise TypeError("Can't convert `edge_list` to torch.long") else: edge_list = torch.as_tensor(edge_list, dtype=torch.long) else: num_element = 2 if num_relation is None else 3 if isinstance(edge_list, torch.Tensor): device = edge_list.device else: device = "cpu" edge_list = torch.zeros(0, num_element, dtype=torch.long, device=device) if (edge_list < 0).any(): raise ValueError("`edge_list` should only contain non-negative indexes") num_edge = torch.tensor(len(edge_list), device=edge_list.device) return edge_list, num_edge def _standarize_edge_weight(self, edge_weight, edge_list): if edge_weight is not None: edge_weight = torch.as_tensor(edge_weight, dtype=torch.float, device=edge_list.device) if len(edge_list) != len(edge_weight): raise ValueError("`edge_list` and `edge_weight` should be the same size, but found %d and %d" % (len(edge_list), len(edge_weight))) else: edge_weight = torch.ones(len(edge_list), device=edge_list.device) return edge_weight def _standarize_num_node(self, num_node, edge_list): if num_node is None: num_node = self._maybe_num_node(edge_list) num_node = torch.as_tensor(num_node, device=edge_list.device) if (edge_list[:, :2] >= num_node).any(): raise ValueError("`num_node` is %d, but found node %d in `edge_list`" % (num_node, edge_list[:, :2].max())) return num_node def _standarize_num_relation(self, num_relation, edge_list): if num_relation is None and edge_list.shape[1] > 2: num_relation = self._maybe_num_relation(edge_list) if num_relation is not None: num_relation = torch.as_tensor(num_relation, device=edge_list.device) if edge_list.shape[1] <= 2: raise ValueError("`num_relation` is provided, but the number of dims of `edge_list` is less than 3.") elif (edge_list[:, 2] >= num_relation).any(): raise ValueError("`num_relation` is %d, but found relation %d in `edge_list`" % (num_relation, edge_list[:, 2].max())) return num_relation def _maybe_num_node(self, edge_list): warnings.warn("_maybe_num_node() is used to determine the number of nodes. " "This may underestimate the count if there are isolated nodes.") if len(edge_list): return edge_list[:, :2].max().item() + 1 else: return 0 def _maybe_num_relation(self, edge_list): warnings.warn("_maybe_num_relation() is used to determine the number of relations. " "This may underestimate the count if there are unseen relations.") return edge_list[:, 2].max().item() + 1 def _standarize_index(self, index, count): if isinstance(index, slice): start = index.start or 0 if start < 0: start += count stop = index.stop or count if stop < 0: stop += count step = index.step or 1 index = torch.arange(start, stop, step, device=self.device) else: index = torch.as_tensor(index, device=self.device) if index.ndim == 0: index = index.unsqueeze(0) if index.dtype == torch.bool: if index.shape != (count,): raise IndexError("Invalid mask. Expect mask to have shape %s, but found %s" % ((int(count),), tuple(index.shape))) index = index.nonzero().squeeze(-1) else: index = index.long() max_index = -1 if len(index) == 0 else index.max().item() if max_index >= count: raise IndexError("Invalid index. Expect index smaller than %d, but found %d" % (count, max_index)) return index def _get_mapping(self, index, count): index = self._standarize_index(index, count) if (index.bincount() > 1).any(): raise ValueError("Can't create mapping for duplicate index") mapping = -torch.ones(count + 1, dtype=torch.long, device=self.device) mapping[index] = torch.arange(len(index), device=self.device) return mapping def _get_repeat_pack_offsets(self, num_xs, repeats): new_num_xs = num_xs.repeat_interleave(repeats) cum_repeats_shifted = repeats.cumsum(0) - repeats new_num_xs[cum_repeats_shifted] -= num_xs offsets = new_num_xs.cumsum(0) return offsets
[docs] @classmethod def from_dense(cls, adjacency, node_feature=None, edge_feature=None): """ Create a sparse graph from a dense adjacency matrix. For zero entries in the adjacency matrix, their edge features will be ignored. Parameters: adjacency (array_like): adjacency matrix of shape :math:`(|V|, |V|)` or :math:`(|V|, |V|, |R|)` node_feature (array_like): node features of shape :math:`(|V|, ...)` edge_feature (array_like): edge features of shape :math:`(|V|, |V|, ...)` or :math:`(|V|, |V|, |R|, ...)` """ adjacency = torch.as_tensor(adjacency) if adjacency.shape[0] != adjacency.shape[1]: raise ValueError("`adjacency` should be a square matrix, but found %d and %d" % adjacency.shape[:2]) edge_list = adjacency.nonzero() edge_weight = adjacency[tuple(edge_list.t())] num_node = adjacency.shape[0] num_relation = adjacency.shape[2] if adjacency.ndim > 2 else None if edge_feature is not None: edge_feature = torch.as_tensor(edge_feature) edge_feature = edge_feature[tuple(edge_list.t())] return cls(edge_list, edge_weight, num_node, num_relation, node_feature, edge_feature)
[docs] def connected_components(self): """ Split this graph into connected components. Returns: (PackedGraph, LongTensor): connected components, number of connected components per graph """ node_in, node_out = self.edge_list.t()[:2] range = torch.arange(self.num_node, device=self.device) node_in, node_out = torch.cat([node_in, node_out, range]), torch.cat([node_out, node_in, range]) # find connected component # O(|E|d), d is the diameter of the graph min_neighbor = torch.arange(self.num_node, device=self.device) last = torch.zeros_like(min_neighbor) while not torch.equal(min_neighbor, last): last = min_neighbor min_neighbor = scatter_min(min_neighbor[node_out], node_in, dim_size=self.num_node)[0] anchor = torch.unique(min_neighbor) num_cc = self.node2graph[anchor].bincount(minlength=self.batch_size) return self.split(min_neighbor), num_cc
[docs] def split(self, node2graph): """ Split a graph into multiple disconnected graphs. Parameters: node2graph (array_like): ID of the graph each node belongs to Returns: PackedGraph """ node2graph = torch.as_tensor(node2graph, dtype=torch.long, device=self.device) # coalesce arbitrary graph IDs to [0, n) _, node2graph = torch.unique(node2graph, return_inverse=True) num_graph = node2graph.max() + 1 index = node2graph.argsort() mapping = torch.zeros_like(index) mapping[index] = torch.arange(len(index), device=self.device) node_in, node_out = self.edge_list.t()[:2] edge_mask = node2graph[node_in] == node2graph[node_out] edge2graph = node2graph[node_in] edge_index = edge2graph.argsort() edge_index = edge_index[edge_mask[edge_index]] prepend = torch.tensor([-1], device=self.device) is_first_node = torch.diff(node2graph[index], prepend=prepend) > 0 graph_index = self.node2graph[index[is_first_node]] edge_list = self.edge_list.clone() edge_list[:, :2] = mapping[edge_list[:, :2]] num_nodes = node2graph.bincount(minlength=num_graph) num_edges = edge2graph[edge_index].bincount(minlength=num_graph) num_cum_nodes = num_nodes.cumsum(0) offsets = (num_cum_nodes - num_nodes)[edge2graph[edge_index]] data_dict, meta_dict = self.data_mask(index, edge_index, graph_index=graph_index, exclude="graph reference") return self.packed_type(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes, num_edges=num_edges, num_relation=self.num_relation, offsets=offsets, meta_dict=meta_dict, **data_dict)
[docs] @classmethod def pack(cls, graphs): """ Pack a list of graphs into a PackedGraph object. Parameters: graphs (list of Graph): list of graphs Returns: PackedGraph """ edge_list = [] edge_weight = [] num_nodes = [] num_edges = [] num_relation = -1 num_cum_node = 0 num_cum_edge = 0 num_graph = 0 data_dict = defaultdict(list) meta_dict = graphs[0].meta_dict for graph in graphs: edge_list.append(graph.edge_list) edge_weight.append(graph.edge_weight) num_nodes.append(graph.num_node) num_edges.append(graph.num_edge) for k, v in graph.data_dict.items(): for type in meta_dict[k]: if type == "graph": v = v.unsqueeze(0) elif type == "node reference": v = v + num_cum_node elif type == "edge reference": v = v + num_cum_edge elif type == "graph reference": v = v + num_graph data_dict[k].append(v) if num_relation == -1: num_relation = graph.num_relation elif num_relation != graph.num_relation: raise ValueError("Inconsistent `num_relation` in graphs. Expect %d but got %d." % (num_relation, graph.num_relation)) num_cum_node += graph.num_node num_cum_edge += graph.num_edge num_graph += 1 edge_list = torch.cat(edge_list) edge_weight = torch.cat(edge_weight) data_dict = {k: torch.cat(v) for k, v in data_dict.items()} return cls.packed_type(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation, meta_dict=meta_dict, **data_dict)
[docs] def repeat(self, count): """ Repeat this graph. Parameters: count (int): number of repetitions Returns: PackedGraph """ edge_list = self.edge_list.repeat(count, 1) edge_weight = self.edge_weight.repeat(count) num_nodes = [self.num_node] * count num_edges = [self.num_edge] * count num_relation = self.num_relation data_dict = {} for k, v in self.data_dict.items(): if "graph" in self.meta_dict[k]: v = v.unsqueeze(0) shape = [1] * v.ndim shape[0] = count length = len(v) v = v.repeat(shape) for type in self.meta_dict[k]: if type == "node reference": offsets = torch.arange(count, device=self.device) * self.num_node v = v + offsets.repeat_interleave(length) elif type == "edge reference": offsets = torch.arange(count, device=self.device) * self.num_edge v = v + offsets.repeat_interleave(length) elif type == "graph reference": offsets = torch.arange(count, device=self.device) v = v + offsets.repeat_interleave(length) data_dict[k] = v return self.packed_type(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation, meta_dict=self.meta_dict, **data_dict)
[docs] def get_edge(self, edge): """ Get the weight of of an edge. Parameters: edge (array_like): index of shape :math:`(2,)` or :math:`(3,)` Returns: Tensor: weight of the edge """ if len(edge) != self.edge_list.shape[1]: raise ValueError("Incorrect edge index. Expect %d axes but got %d axes" % (self.edge_list.shape[1], len(edge))) edge_index, num_match = self.match(edge) return self.edge_weight[edge_index].sum()
[docs] def match(self, pattern): """ Return all matched indexes for each pattern. Support patterns with ``-1`` as the wildcard. Parameters: pattern (array_like): index of shape :math:`(N, 2)` or :math:`(N, 3)` Returns: (LongTensor, LongTensor): matched indexes, number of matches per edge Examples:: >>> graph = data.Graph([[0, 1], [1, 0], [1, 2], [2, 1], [2, 0], [0, 2]]) >>> index, num_match = graph.match([[0, -1], [1, 2]]) >>> assert (index == torch.tensor([0, 5, 2])).all() >>> assert (num_match == torch.tensor([2, 1])).all() """ if len(pattern) == 0: index = num_match = torch.zeros(0, dtype=torch.long, device=self.device) return index, num_match if not hasattr(self, "edge_inverted_index"): self.edge_inverted_index = {} pattern = torch.as_tensor(pattern, dtype=torch.long, device=self.device) if pattern.ndim == 1: pattern = pattern.unsqueeze(0) mask = pattern != -1 scale = 2 ** torch.arange(pattern.shape[-1], device=self.device) query_type = (mask * scale).sum(dim=-1) query_index = query_type.argsort() num_query = query_type.unique(return_counts=True)[1] query_ends = num_query.cumsum(0) query_starts = query_ends - num_query mask_set = mask[query_index[query_starts]].tolist() type_ranges = [] type_orders = [] # get matched range for each query type for i, mask in enumerate(mask_set): query_type = tuple(mask) type_index = query_index[query_starts[i]: query_ends[i]] type_edge = pattern[type_index][:, mask] if query_type not in self.edge_inverted_index: self.edge_inverted_index[query_type] = self._build_edge_inverted_index(mask) inverted_range, order = self.edge_inverted_index[query_type] ranges = inverted_range.get(type_edge, default=0) type_ranges.append(ranges) type_orders.append(order) ranges = torch.cat(type_ranges) orders = torch.stack(type_orders) types = torch.arange(len(mask_set), device=self.device) types = types.repeat_interleave(num_query) # reorder matched ranges according to the query order ranges = scatter_add(ranges, query_index, dim=0, dim_size=len(pattern)) types = scatter_add(types, query_index, dim_size=len(pattern)) # convert range to indexes starts, ends = ranges.t() num_match = ends - starts offsets = num_match.cumsum(0) - num_match types = types.repeat_interleave(num_match) ranges = torch.arange(num_match.sum(), device=self.device) ranges = ranges + (starts - offsets).repeat_interleave(num_match) index = orders[types, ranges] return index, num_match
def _build_edge_inverted_index(self, mask): keys = self.edge_list[:, mask] base = torch.tensor(self.shape, device=self.device) base = base[mask] max = reduce(int.__mul__, base.tolist()) if max > torch.iinfo(torch.int64).max: raise ValueError("Fail to build an inverted index table based on sorting. " "The graph is too large.") scale = base.cumprod(0) scale = torch.div(scale[-1], scale, rounding_mode="floor") key = (keys * scale).sum(dim=-1) order = key.argsort() num_keys = key.unique(return_counts=True)[1] ends = num_keys.cumsum(0) starts = ends - num_keys ranges = torch.stack([starts, ends], dim=-1) keys_set = keys[order[starts]] inverted_range = Dictionary(keys_set, ranges) return inverted_range, order def __getitem__(self, index): # why do we check tuple? # case 1: x[0, 1] is parsed as (0, 1) # case 2: x[[0, 1]] is parsed as [0, 1] if not isinstance(index, tuple): index = (index,) index = list(index) while len(index) < 2: index.append(slice(None)) if len(index) > 2: raise ValueError("Graph has only 2 axis, but %d axis is indexed" % len(index)) if all([isinstance(axis_index, int) for axis_index in index]): return self.get_edge(index) edge_list = self.edge_list.clone() for i, axis_index in enumerate(index): axis_index = self._standarize_index(axis_index, self.num_node) mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) mapping[axis_index] = axis_index edge_list[:, i] = mapping[edge_list[:, i]] edge_index = (edge_list >= 0).all(dim=-1) return self.edge_mask(edge_index) def __len__(self): return 1 @property def batch_size(self): """Batch size.""" return 1
[docs] def subgraph(self, index): """ Return a subgraph based on the specified nodes. Equivalent to :meth:`node_mask(index, compact=True) <node_mask>`. Parameters: index (array_like): node index Returns: Graph See also: :meth:`Graph.node_mask` """ return self.node_mask(index, compact=True)
def data_mask(self, node_index=None, edge_index=None, graph_index=None, include=None, exclude=None): data_dict, meta_dict = self.data_by_meta(include, exclude) node_mapping = None edge_mapping = None graph_mapping = None for k, v in data_dict.items(): for type in meta_dict[k]: if type == "node" and node_index is not None: v = v[node_index] elif type == "edge" and edge_index is not None: v = v[edge_index] elif type == "graph" and graph_index is not None: v = v.unsqueeze(0)[graph_index] elif type == "node reference" and node_index is not None: if node_mapping is None: node_mapping = self._get_mapping(node_index, self.num_node) v = node_mapping[v] elif type == "edge reference" and edge_index is not None: if edge_mapping is None: edge_mapping = self._get_mapping(edge_index, self.num_edge) v = edge_mapping[v] elif type == "graph reference" and graph_index is not None: if graph_mapping is None: graph_mapping = self._get_mapping(graph_index, self.batch_size) v = graph_mapping[v] data_dict[k] = v return data_dict, meta_dict
[docs] def node_mask(self, index, compact=False): """ Return a masked graph based on the specified nodes. This function can also be used to re-order the nodes. Parameters: index (array_like): node index compact (bool, optional): compact node ids or not Returns: Graph Examples:: >>> graph = data.Graph.from_dense(torch.eye(3)) >>> assert graph.node_mask([1, 2]).adjacency.shape == (3, 3) >>> assert graph.node_mask([1, 2], compact=True).adjacency.shape == (2, 2) """ index = self._standarize_index(index, self.num_node) mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) if compact: mapping[index] = torch.arange(len(index), device=self.device) num_node = len(index) else: mapping[index] = index num_node = self.num_node edge_list = self.edge_list.clone() edge_list[:, :2] = mapping[edge_list[:, :2]] edge_index = (edge_list[:, :2] >= 0).all(dim=-1) if compact: data_dict, meta_dict = self.data_mask(index, edge_index) else: data_dict, meta_dict = self.data_mask(edge_index=edge_index) return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_node=num_node, num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)
[docs] def compact(self): """ Remove isolated nodes and compact node ids. Returns: Graph """ index = self.degree_out + self.degree_in > 0 return self.subgraph(index)
[docs] def edge_mask(self, index): """ Return a masked graph based on the specified edges. This function can also be used to re-order the edges. Parameters: index (array_like): edge index Returns: Graph """ index = self._standarize_index(index, self.num_edge) data_dict, meta_dict = self.data_mask(edge_index=index) return type(self)(self.edge_list[index], edge_weight=self.edge_weight[index], num_node=self.num_node, num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)
[docs] def line_graph(self): """ Construct a line graph of this graph. The node feature of the line graph is inherited from the edge feature of the original graph. In the line graph, each node corresponds to an edge in the original graph. For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph, there is a directed edge (a, b) -> (b, c) in the line graph. Returns: Graph """ node_in, node_out = self.edge_list.t()[:2] edge_index = torch.arange(self.num_edge, device=self.device) edge_in = edge_index[node_out.argsort()] edge_out = edge_index[node_in.argsort()] degree_in = node_in.bincount(minlength=self.num_node) degree_out = node_out.bincount(minlength=self.num_node) size = degree_out * degree_in starts = (size.cumsum(0) - size).repeat_interleave(size) range = torch.arange(size.sum(), device=self.device) # each node u has degree_out[u] * degree_in[u] local edges local_index = range - starts local_inner_size = degree_in.repeat_interleave(size) edge_in_offset = (degree_out.cumsum(0) - degree_out).repeat_interleave(size) edge_out_offset = (degree_in.cumsum(0) - degree_in).repeat_interleave(size) edge_in_index = torch.div(local_index, local_inner_size, rounding_mode="floor") + edge_in_offset edge_out_index = local_index % local_inner_size + edge_out_offset edge_in = edge_in[edge_in_index] edge_out = edge_out[edge_out_index] edge_list = torch.stack([edge_in, edge_out], dim=-1) node_feature = getattr(self, "edge_feature", None) num_node = self.num_edge num_edge = size.sum() return Graph(edge_list, num_node=num_node, num_edge=num_edge, node_feature=node_feature)
[docs] def full(self): """ Return a fully connected graph over the nodes. Returns: Graph """ index = torch.arange(self.num_node, device=self.device) if self.num_relation: edge_list = torch.meshgrid(index, index, torch.arange(self.num_relation, device=self.device)) else: edge_list = torch.meshgrid(index, index) edge_list = torch.stack(edge_list).flatten(1) edge_weight = torch.ones(len(edge_list)) data_dict, meta_dict = self.data_by_meta(exclude="edge") return type(self)(edge_list, edge_weight=edge_weight, num_node=self.num_node, num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)
[docs] def directed(self, order=None): """ Mask the edges to create a directed graph. Edges that go from a node index to a larger or equal node index will be kept. Parameters: order (Tensor, optional): topological order of the nodes """ node_in, node_out = self.edge_list.t()[:2] if order is not None: edge_index = order[node_in] <= order[node_out] else: edge_index = node_in <= node_out return self.edge_mask(edge_index)
[docs] def undirected(self, add_inverse=False): """ Flip all the edges to create an undirected graph. For knowledge graphs, the flipped edges can either have the original relation or an inverse relation. The inverse relation for relation :math:`r` is defined as :math:`|R| + r`. Parameters: add_inverse (bool, optional): whether to use inverse relations for flipped edges """ edge_list = self.edge_list.clone() edge_list[:, :2] = edge_list[:, :2].flip(1) num_relation = self.num_relation if num_relation and add_inverse: edge_list[:, 2] += num_relation num_relation = num_relation * 2 edge_list = torch.stack([self.edge_list, edge_list], dim=1).flatten(0, 1) index = torch.arange(self.num_edge, device=self.device).unsqueeze(-1).expand(-1, 2).flatten() data_dict, meta_dict = self.data_mask(edge_index=index) return type(self)(edge_list, edge_weight=self.edge_weight[index], num_node=self.num_node, num_relation=num_relation, meta_dict=meta_dict, **data_dict)
@utils.cached_property def adjacency(self): """ Adjacency matrix of this graph. If :attr:`num_relation` is specified, a sparse tensor of shape :math:`(|V|, |V|, num\_relation)` will be returned. Otherwise, a sparse tensor of shape :math:`(|V|, |V|)` will be returned. """ return utils.sparse_coo_tensor(self.edge_list.t(), self.edge_weight, self.shape) _tensor_names = ["edge_list", "edge_weight", "num_node", "num_relation", "edge_feature"] def to_tensors(self): edge_feature = getattr(self, "edge_feature", torch.tensor(0, device=self.device)) return self.edge_list, self.edge_weight, self.num_node, self.num_relation, edge_feature @classmethod def from_tensors(cls, tensors): edge_list, edge_weight, num_node, num_relation, edge_feature = tensors if edge_feature.ndim == 0: edge_feature = None return cls(edge_list, edge_weight, num_node, num_relation, edge_feature=edge_feature) @property def node2graph(self): """Node id to graph id mapping.""" return torch.zeros(self.num_node, dtype=torch.long, device=self.device) @property def edge2graph(self): """Edge id to graph id mapping.""" return torch.zeros(self.num_edge, dtype=torch.long, device=self.device) @utils.cached_property def degree_out(self): """ Weighted number of edges containing each node as output. Note this is the **in-degree** in graph theory. """ return scatter_add(self.edge_weight, self.edge_list[:, 1], dim_size=self.num_node) @utils.cached_property def degree_in(self): """ Weighted number of edges containing each node as input. Note this is the **out-degree** in graph theory. """ return scatter_add(self.edge_weight, self.edge_list[:, 0], dim_size=self.num_node) @property def edge_list(self): """List of edges.""" return self._edge_list @property def edge_weight(self): """Edge weights.""" return self._edge_weight @property def device(self): """Device.""" return self.edge_list.device @property def requires_grad(self): return self.edge_weight.requires_grad @property def grad(self): return self.edge_weight.grad @property def data(self): return self def requires_grad_(self): self.edge_weight.requires_grad_() return self def size(self, dim=None): if self.num_relation: size = torch.Size((self.num_node, self.num_node, self.num_relation)) else: size = torch.Size((self.num_node, self.num_node)) if dim is None: return size return size[dim] @property def shape(self): return self.size()
[docs] def copy_(self, src): """ Copy data from ``src`` into ``self`` and return ``self``. The ``src`` graph must have the same set of attributes as ``self``. """ self.edge_list.copy_(src.edge_list) self.edge_weight.copy_(src.edge_weight) self.num_node.copy_(src.num_node) self.num_edge.copy_(src.num_edge) if self.num_relation is not None: self.num_relation.copy_(src.num_relation) keys = set(self.data_dict.keys()) src_keys = set(src.data_dict.keys()) if keys != src_keys: raise RuntimeError("Attributes mismatch. Trying to assign attributes %s, " "but current graph has attributes %s" % (src_keys, keys)) for k, v in self.data_dict.items(): v.copy_(src.data_dict[k]) return self
[docs] def detach(self): """ Detach this graph. """ return type(self)(self.edge_list.detach(), edge_weight=self.edge_weight.detach(), num_node=self.num_node, num_relation=self.num_relation, meta_dict=self.meta_dict, **utils.detach(self.data_dict))
[docs] def clone(self): """ Clone this graph. """ return type(self)(self.edge_list.clone(), edge_weight=self.edge_weight.clone(), num_node=self.num_node, num_relation=self.num_relation, meta_dict=self.meta_dict, **utils.clone(self.data_dict))
[docs] def cuda(self, *args, **kwargs): """ Return a copy of this graph in CUDA memory. This is a non-op if the graph is already on the correct device. """ edge_list = self.edge_list.cuda(*args, **kwargs) if edge_list is self.edge_list: return self else: return type(self)(edge_list, edge_weight=self.edge_weight, num_node=self.num_node, num_relation=self.num_relation, meta_dict=self.meta_dict, **utils.cuda(self.data_dict, *args, **kwargs))
[docs] def cpu(self): """ Return a copy of this graph in CPU memory. This is a non-op if the graph is already in CPU memory. """ edge_list = self.edge_list.cpu() if edge_list is self.edge_list: return self else: return type(self)(edge_list, edge_weight=self.edge_weight, num_node=self.num_node, num_relation=self.num_relation, meta_dict=self.meta_dict, **utils.cpu(self.data_dict))
[docs] def to(self, device, *args, **kwargs): """ Return a copy of this graph on the given device. """ device = torch.device(device) if device.type == "cpu": return self.cpu(*args, **kwargs) else: return self.cuda(device, *args, **kwargs)
def __repr__(self): fields = ["num_node=%d" % self.num_node, "num_edge=%d" % self.num_edge] if self.num_relation is not None: fields.append("num_relation=%d" % self.num_relation) if self.device.type != "cpu": fields.append("device='%s'" % self.device) return "%s(%s)" % (self.__class__.__name__, ", ".join(fields))
[docs] def visualize(self, title=None, save_file=None, figure_size=(3, 3), ax=None, layout="spring"): """ Visualize this graph with matplotlib. Parameters: title (str, optional): title for this graph save_file (str, optional): ``png`` or ``pdf`` file to save visualization. If not provided, show the figure in window. figure_size (tuple of int, optional): width and height of the figure ax (matplotlib.axes.Axes, optional): axis to plot the figure layout (str, optional): graph layout See also: `NetworkX graph layout`_ .. _NetworkX graph layout: https://networkx.github.io/documentation/stable/reference/drawing.html#module-networkx.drawing.layout """ is_root = ax is None if ax is None: fig = plt.figure(figsize=figure_size) if title is not None: ax = plt.gca() else: ax = fig.add_axes([0, 0, 1, 1]) if title is not None: ax.set_title(title) edge_list = self.edge_list[:, :2].tolist() G = nx.DiGraph(edge_list) G.add_nodes_from(range(self.num_node)) if hasattr(nx, "%s_layout" % layout): func = getattr(nx, "%s_layout" % layout) else: raise ValueError("Unknown networkx layout `%s`" % layout) if layout == "spring" or layout == "random": pos = func(G, seed=0) else: pos = func(G) nx.draw_networkx(G, pos, ax=ax) if self.num_relation: edge_labels = self.edge_list[:, 2].tolist() edge_labels = {tuple(e): l for e, l in zip(edge_list, edge_labels)} nx.draw_networkx_edge_labels(G, pos, edge_labels, ax=ax) ax.set_frame_on(False) if is_root: if save_file: fig.savefig(save_file) else: fig.show()
@classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): return NotImplemented def __getstate__(self): state = {} cls = self.__class__ for k, v in self.__dict__.items(): # do not pickle property / cached property if hasattr(cls, k) and isinstance(getattr(cls, k), property): continue state[k] = v return state
[docs]class PackedGraph(Graph): """ Container for sparse graphs with variadic sizes. To create a PackedGraph from Graph objects >>> batch = data.Graph.pack(graphs) To retrieve Graph objects from a PackedGraph >>> graphs = batch.unpack() .. warning:: Edges of the same graph are guaranteed to be consecutive in the edge list. However, this class doesn't enforce any order on the edges. Parameters: edge_list (array_like, optional): list of edges of shape :math:`(|E|, 2)` or :math:`(|E|, 3)`. Each tuple is (node_in, node_out) or (node_in, node_out, relation). edge_weight (array_like, optional): edge weights of shape :math:`(|E|,)` num_nodes (array_like, optional): number of nodes in each graph By default, it will be inferred from the largest id in `edge_list` num_edges (array_like, optional): number of edges in each graph num_relation (int, optional): number of relations node_feature (array_like, optional): node features of shape :math:`(|V|, ...)` edge_feature (array_like, optional): edge features of shape :math:`(|E|, ...)` offsets (array_like, optional): node id offsets of shape :math:`(|E|,)`. If not provided, nodes in `edge_list` should be relative index, i.e., the index in each graph. If provided, nodes in `edge_list` should be absolute index, i.e., the index in the packed graph. """ unpacked_type = Graph def __init__(self, edge_list=None, edge_weight=None, num_nodes=None, num_edges=None, num_relation=None, offsets=None, **kwargs): edge_list, num_nodes, num_edges, num_cum_nodes, num_cum_edges, offsets = \ self._get_cumulative(edge_list, num_nodes, num_edges, offsets) if offsets is None: offsets = self._get_offsets(num_nodes, num_edges, num_cum_nodes) edge_list = edge_list.clone() edge_list[:, :2] += offsets.unsqueeze(-1) num_node = num_nodes.sum() if (edge_list[:, :2] >= num_node).any(): raise ValueError("Sum of `num_nodes` is %d, but found %d in `edge_list`" % (num_node, edge_list[:, :2].max())) self._offsets = offsets self.num_nodes = num_nodes self.num_edges = num_edges self.num_cum_nodes = num_cum_nodes self.num_cum_edges = num_cum_edges super(PackedGraph, self).__init__(edge_list, edge_weight=edge_weight, num_node=num_node, num_relation=num_relation, **kwargs) def _get_offsets(self, num_nodes=None, num_edges=None, num_cum_nodes=None, num_cum_edges=None): if num_nodes is None: prepend = torch.tensor([0], device=self.device) num_nodes = torch.diff(num_cum_nodes, prepend=prepend) if num_edges is None: prepend = torch.tensor([0], device=self.device) num_edges = torch.diff(num_cum_edges, prepend=prepend) if num_cum_nodes is None: num_cum_nodes = num_nodes.cumsum(0) return (num_cum_nodes - num_nodes).repeat_interleave(num_edges)
[docs] def merge(self, graph2graph): """ Merge multiple graphs into a single graph. Parameters: graph2graph (array_like): ID of the new graph each graph belongs to """ graph2graph = torch.as_tensor(graph2graph, dtype=torch.long, device=self.device) # coalesce arbitrary graph IDs to [0, n) _, graph2graph = torch.unique(graph2graph, return_inverse=True) graph_key = graph2graph * self.batch_size + torch.arange(self.batch_size, device=self.device) graph_index = graph_key.argsort() graph = self.subbatch(graph_index) graph2graph = graph2graph[graph_index] num_graph = graph2graph[-1] + 1 num_nodes = scatter_add(graph.num_nodes, graph2graph, dim_size=num_graph) num_edges = scatter_add(graph.num_edges, graph2graph, dim_size=num_graph) offsets = self._get_offsets(num_nodes, num_edges) data_dict, meta_dict = graph.data_mask(exclude="graph") return type(self)(graph.edge_list, edge_weight=graph.edge_weight, num_nodes=num_nodes, num_edges=num_edges, num_relation=graph.num_relation, offsets=offsets, meta_dict=meta_dict, **data_dict)
[docs] def unpack(self): """ Unpack this packed graph into a list of graphs. Returns: list of Graph """ graphs = [] for i in range(self.batch_size): graphs.append(self.get_item(i)) return graphs
def __iter__(self): self._iter_index = 0 return self def __next__(self): if self._iter_index < self.batch_size: item = self[self._iter_index] self._iter_index += 1 return item raise StopIteration def _check_attribute(self, key, value): for type in self._meta_contexts: if "reference" in type: if value.dtype != torch.long: raise TypeError("Tensors used as reference must be long tensors") if type == "node": if len(value) != self.num_node: raise ValueError("Expect node attribute `%s` to have shape (%d, *), but found %s" % (key, self.num_node, value.shape)) elif type == "edge": if len(value) != self.num_edge: raise ValueError("Expect edge attribute `%s` to have shape (%d, *), but found %s" % (key, self.num_edge, value.shape)) elif type == "graph": if len(value) != self.batch_size: raise ValueError("Expect graph attribute `%s` to have shape (%d, *), but found %s" % (key, self.batch_size, value.shape)) elif type == "node reference": is_valid = (value >= -1) & (value < self.num_node) if not is_valid.all(): error_value = value[~is_valid] raise ValueError("Expect node reference in [-1, %d), but found %d" % (self.num_node, error_value[0])) elif type == "edge reference": is_valid = (value >= -1) & (value < self.num_edge) if not is_valid.all(): error_value = value[~is_valid] raise ValueError("Expect edge reference in [-1, %d), but found %d" % (self.num_edge, error_value[0])) elif type == "graph reference": is_valid = (value >= -1) & (value < self.batch_size) if not is_valid.all(): error_value = value[~is_valid] raise ValueError("Expect graph reference in [-1, %d), but found %d" % (self.batch_size, error_value[0]))
[docs] def unpack_data(self, data, type="auto"): """ Unpack node or edge data according to the packed graph. Parameters: data (Tensor): data to unpack type (str, optional): data type. Can be ``auto``, ``node``, or ``edge``. Returns: list of Tensor """ if type == "auto": if self.num_node == self.num_edge: raise ValueError("Ambiguous type. Please specify either `node` or `edge`") if len(data) == self.num_node: type = "node" elif len(data) == self.num_edge: type = "edge" else: raise ValueError("Graph has %d nodes and %d edges, but data has %d entries" % (self.num_node, self.num_edge, len(data))) data_list = [] if type == "node": for i in range(self.batch_size): data_list.append(data[self.num_cum_nodes[i] - self.num_nodes[i]: self.num_cum_nodes[i]]) elif type == "edge": for i in range(self.batch_size): data_list.append(data[self.num_cum_edges[i] - self.num_edges[i]: self.num_cum_edges[i]]) return data_list
[docs] def repeat(self, count): """ Repeat this packed graph. This function behaves similarly to `torch.Tensor.repeat`_. .. _torch.Tensor.repeat: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html Parameters: count (int): number of repetitions Returns: PackedGraph """ num_nodes = self.num_nodes.repeat(count) num_edges = self.num_edges.repeat(count) offsets = self._get_offsets(num_nodes, num_edges) edge_list = self.edge_list.repeat(count, 1) edge_list[:, :2] += (offsets - self._offsets.repeat(count)).unsqueeze(-1) data_dict = {} for k, v in self.data_dict.items(): shape = [1] * v.ndim shape[0] = count length = len(v) v = v.repeat(shape) for _type in self.meta_dict[k]: if _type == "node reference": pack_offsets = torch.arange(count, device=self.device) * self.num_node v = v + pack_offsets.repeat_interleave(length) elif _type == "edge reference": pack_offsets = torch.arange(count, device=self.device) * self.num_edge v = v + pack_offsets.repeat_interleave(length) elif _type == "graph reference": pack_offsets = torch.arange(count, device=self.device) * self.batch_size v = v + pack_offsets.repeat_interleave(length) data_dict[k] = v return type(self)(edge_list, edge_weight=self.edge_weight.repeat(count), num_nodes=num_nodes, num_edges=num_edges, num_relation=self.num_relation, offsets=offsets, meta_dict=self.meta_dict, **data_dict)
[docs] def repeat_interleave(self, repeats): """ Repeat this packed graph. This function behaves similarly to `torch.repeat_interleave`_. .. _torch.repeat_interleave: https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html Parameters: repeats (Tensor or int): number of repetitions for each graph Returns: PackedGraph """ repeats = torch.as_tensor(repeats, dtype=torch.long, device=self.device) if repeats.numel() == 1: repeats = repeats * torch.ones(self.batch_size, dtype=torch.long, device=self.device) num_nodes = self.num_nodes.repeat_interleave(repeats) num_edges = self.num_edges.repeat_interleave(repeats) num_cum_nodes = num_nodes.cumsum(0) num_cum_edges = num_edges.cumsum(0) num_node = num_nodes.sum() num_edge = num_edges.sum() batch_size = repeats.sum() num_graphs = torch.ones(batch_size, device=self.device) # special case 1: graphs[i] may have no node or no edge # special case 2: repeats[i] may be 0 cum_repeats_shifted = repeats.cumsum(0) - repeats graph_mask = cum_repeats_shifted < batch_size cum_repeats_shifted = cum_repeats_shifted[graph_mask] index = num_cum_nodes - num_nodes index = torch.cat([index, index[cum_repeats_shifted]]) value = torch.cat([-num_nodes, self.num_nodes[graph_mask]]) mask = index < num_node node_index = scatter_add(value[mask], index[mask], dim_size=num_node) node_index = (node_index + 1).cumsum(0) - 1 index = num_cum_edges - num_edges index = torch.cat([index, index[cum_repeats_shifted]]) value = torch.cat([-num_edges, self.num_edges[graph_mask]]) mask = index < num_edge edge_index = scatter_add(value[mask], index[mask], dim_size=num_edge) edge_index = (edge_index + 1).cumsum(0) - 1 graph_index = torch.repeat_interleave(repeats) offsets = self._get_offsets(num_nodes, num_edges) edge_list = self.edge_list[edge_index] edge_list[:, :2] += (offsets - self._offsets[edge_index]).unsqueeze(-1) node_offsets = None edge_offsets = None graph_offsets = None data_dict = {} for k, v in self.data_dict.items(): num_xs = None pack_offsets = None for _type in self.meta_dict[k]: if _type == "node": v = v[node_index] num_xs = num_nodes elif _type == "edge": v = v[edge_index] num_xs = num_edges elif _type == "graph": v = v[graph_index] num_xs = num_graphs elif _type == "node reference": if node_offsets is None: node_offsets = self._get_repeat_pack_offsets(self.num_nodes, repeats) pack_offsets = node_offsets elif _type == "edge reference": if edge_offsets is None: edge_offsets = self._get_repeat_pack_offsets(self.num_edges, repeats) pack_offsets = edge_offsets elif _type == "graph reference": if graph_offsets is None: graph_offsets = self._get_repeat_pack_offsets(num_graphs, repeats) pack_offsets = graph_offsets # add offsets to make references point to indexes in their own graph if num_xs is not None and pack_offsets is not None: v = v + pack_offsets.repeat_interleave(num_xs) data_dict[k] = v return type(self)(edge_list, edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes, num_edges=num_edges, num_relation=self.num_relation, offsets=offsets, meta_dict=self.meta_dict, **data_dict)
[docs] def get_item(self, index): """ Get the i-th graph from this packed graph. Parameters: index (int): graph index Returns: Graph """ node_index = torch.arange(self.num_cum_nodes[index] - self.num_nodes[index], self.num_cum_nodes[index], device=self.device) edge_index = torch.arange(self.num_cum_edges[index] - self.num_edges[index], self.num_cum_edges[index], device=self.device) graph_index = index edge_list = self.edge_list[edge_index].clone() edge_list[:, :2] -= self._offsets[edge_index].unsqueeze(-1) data_dict, meta_dict = self.data_mask(node_index, edge_index, graph_index=graph_index) return self.unpacked_type(edge_list, edge_weight=self.edge_weight[edge_index], num_node=self.num_nodes[index], num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)
def _get_cumulative(self, edge_list, num_nodes, num_edges, offsets): if edge_list is None: raise ValueError("`edge_list` should be provided") if num_edges is None: raise ValueError("`num_edges` should be provided") edge_list = torch.as_tensor(edge_list) num_edges = torch.as_tensor(num_edges, device=edge_list.device) num_edge = num_edges.sum() if num_edge != len(edge_list): raise ValueError("Sum of `num_edges` is %d, but found %d edges in `edge_list`" % (num_edge, len(edge_list))) num_cum_edges = num_edges.cumsum(0) if offsets is None: _edge_list = edge_list else: offsets = torch.as_tensor(offsets, device=edge_list.device) _edge_list = edge_list.clone() _edge_list[:, :2] -= offsets.unsqueeze(-1) if num_nodes is None: num_nodes = [] for num_edge, num_cum_edge in zip(num_edges, num_cum_edges): num_nodes.append(self._maybe_num_node(_edge_list[num_cum_edge - num_edge: num_cum_edge])) num_nodes = torch.as_tensor(num_nodes, device=edge_list.device) num_cum_nodes = num_nodes.cumsum(0) return edge_list, num_nodes, num_edges, num_cum_nodes, num_cum_edges, offsets def _get_num_xs(self, index, num_cum_xs): x = torch.zeros(num_cum_xs[-1], dtype=torch.long, device=self.device) x[index] = 1 num_cum_indexes = x.cumsum(0) num_cum_indexes = torch.cat([torch.zeros(1, dtype=torch.long, device=self.device), num_cum_indexes]) new_num_cum_xs = num_cum_indexes[num_cum_xs] prepend = torch.zeros(1, dtype=torch.long, device=self.device) new_num_xs = torch.diff(new_num_cum_xs, prepend=prepend) return new_num_xs def data_mask(self, node_index=None, edge_index=None, graph_index=None, include=None, exclude=None): data_dict, meta_dict = self.data_by_meta(include, exclude) node_mapping = None edge_mapping = None graph_mapping = None for k, v in data_dict.items(): for type in meta_dict[k]: if type == "node" and node_index is not None: v = v[node_index] elif type == "edge" and edge_index is not None: v = v[edge_index] elif type == "graph" and graph_index is not None: v = v[graph_index] elif type == "node reference" and node_index is not None: if node_mapping is None: node_mapping = self._get_mapping(node_index, self.num_node) v = node_mapping[v] elif type == "edge reference" and edge_index is not None: if edge_mapping is None: edge_mapping = self._get_mapping(edge_index, self.num_edge) v = edge_mapping[v] elif type == "graph reference" and graph_index is not None: if graph_mapping is None: graph_mapping = self._get_mapping(graph_index, self.batch_size) v = graph_mapping[v] data_dict[k] = v return data_dict, meta_dict def __getitem__(self, index): # why do we check tuple? # case 1: x[0, 1] is parsed as (0, 1) # case 2: x[[0, 1]] is parsed as [0, 1] if not isinstance(index, tuple): index = (index,) if isinstance(index[0], int): item = self.get_item(index[0]) if len(index) > 1: item = item[index[1:]] return item if len(index) > 1: raise ValueError("Complex indexing is not supported for PackedGraph") index = self._standarize_index(index[0], self.batch_size) count = index.bincount(minlength=self.batch_size) if self.batch_size > 0 and count.max() > 1: graph = self.repeat_interleave(count) index_order = index.argsort() order = torch.zeros_like(index) order[index_order] = torch.arange(len(index), dtype=torch.long, device=self.device) return graph.subbatch(order) return self.subbatch(index) def __len__(self): return len(self.num_nodes)
[docs] def full(self): """ Return a pack of fully connected graphs. This is useful for computing node-pair-wise features. The computation can be implemented as message passing over a fully connected graph. Returns: PackedGraph """ # TODO: more efficient implementation? graphs = self.unpack() graphs = [graph.full() for graph in graphs] return graphs[0].pack(graphs)
@utils.cached_property def node2graph(self): """Node id to graph id mapping.""" node2graph = torch.repeat_interleave(self.num_nodes) return node2graph @utils.cached_property def edge2graph(self): """Edge id to graph id mapping.""" edge2graph = torch.repeat_interleave(self.num_edges) return edge2graph @property def batch_size(self): """Batch size.""" return len(self.num_nodes)
[docs] def node_mask(self, index, compact=False): """ Return a masked packed graph based on the specified nodes. Note the compact option is only applied to node ids but not graph ids. To generate compact graph ids, use :meth:`subbatch`. Parameters: index (array_like): node index compact (bool, optional): compact node ids or not Returns: PackedGraph """ index = self._standarize_index(index, self.num_node) mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) if compact: mapping[index] = torch.arange(len(index), device=self.device) num_nodes = self._get_num_xs(index, self.num_cum_nodes) offsets = self._get_offsets(num_nodes, self.num_edges) else: mapping[index] = index num_nodes = self.num_nodes offsets = self._offsets edge_list = self.edge_list.clone() edge_list[:, :2] = mapping[edge_list[:, :2]] edge_index = (edge_list[:, :2] >= 0).all(dim=-1) num_edges = self._get_num_xs(edge_index, self.num_cum_edges) if compact: data_dict, meta_dict = self.data_mask(index, edge_index) else: data_dict, meta_dict = self.data_mask(edge_index=edge_index) return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes, num_edges=num_edges, num_relation=self.num_relation, offsets=offsets[edge_index], meta_dict=meta_dict, **data_dict)
[docs] def edge_mask(self, index): """ Return a masked packed graph based on the specified edges. Parameters: index (array_like): edge index Returns: PackedGraph """ index = self._standarize_index(index, self.num_edge) data_dict, meta_dict = self.data_mask(edge_index=index) num_edges = self._get_num_xs(index, self.num_cum_edges) return type(self)(self.edge_list[index], edge_weight=self.edge_weight[index], num_nodes=self.num_nodes, num_edges=num_edges, num_relation=self.num_relation, offsets=self._offsets[index], meta_dict=meta_dict, **data_dict)
[docs] def graph_mask(self, index, compact=False): """ Return a masked packed graph based on the specified graphs. This function can also be used to re-order the graphs. Parameters: index (array_like): graph index compact (bool, optional): compact graph ids or not Returns: PackedGraph """ index = self._standarize_index(index, self.batch_size) graph_mapping = -torch.ones(self.batch_size, dtype=torch.long, device=self.device) graph_mapping[index] = torch.arange(len(index), device=self.device) node_index = graph_mapping[self.node2graph] >= 0 node_index = self._standarize_index(node_index, self.num_node) mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device) if compact: key = graph_mapping[self.node2graph[node_index]] * self.num_node + node_index order = key.argsort() node_index = node_index[order] mapping[node_index] = torch.arange(len(node_index), device=self.device) num_nodes = self.num_nodes[index] else: mapping[node_index] = node_index num_nodes = torch.zeros_like(self.num_nodes) num_nodes[index] = self.num_nodes[index] edge_list = self.edge_list.clone() edge_list[:, :2] = mapping[edge_list[:, :2]] edge_index = (edge_list[:, :2] >= 0).all(dim=-1) edge_index = self._standarize_index(edge_index, self.num_edge) if compact: key = graph_mapping[self.edge2graph[edge_index]] * self.num_edge + edge_index order = key.argsort() edge_index = edge_index[order] num_edges = self.num_edges[index] else: num_edges = torch.zeros_like(self.num_edges) num_edges[index] = self.num_edges[index] offsets = self._get_offsets(num_nodes, num_edges) if compact: data_dict, meta_dict = self.data_mask(node_index, edge_index, graph_index=index) else: data_dict, meta_dict = self.data_mask(edge_index=edge_index) return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes, num_edges=num_edges, num_relation=self.num_relation, offsets=offsets, meta_dict=meta_dict, **data_dict)
[docs] def subbatch(self, index): """ Return a subbatch based on the specified graphs. Equivalent to :meth:`graph_mask(index, compact=True) <graph_mask>`. Parameters: index (array_like): graph index Returns: PackedGraph See also: :meth:`PackedGraph.graph_mask` """ return self.graph_mask(index, compact=True)
[docs] def line_graph(self): """ Construct a packed line graph of this packed graph. The node features of the line graphs are inherited from the edge features of the original graphs. In the line graph, each node corresponds to an edge in the original graph. For a pair of edges (a, b) and (b, c) that share the same intermediate node in the original graph, there is a directed edge (a, b) -> (b, c) in the line graph. Returns: PackedGraph """ node_in, node_out = self.edge_list.t()[:2] edge_index = torch.arange(self.num_edge, device=self.device) edge_in = edge_index[node_out.argsort()] edge_out = edge_index[node_in.argsort()] degree_in = node_in.bincount(minlength=self.num_node) degree_out = node_out.bincount(minlength=self.num_node) size = degree_out * degree_in starts = (size.cumsum(0) - size).repeat_interleave(size) range = torch.arange(size.sum(), device=self.device) # each node u has degree_out[u] * degree_in[u] local edges local_index = range - starts local_inner_size = degree_in.repeat_interleave(size) edge_in_offset = (degree_out.cumsum(0) - degree_out).repeat_interleave(size) edge_out_offset = (degree_in.cumsum(0) - degree_in).repeat_interleave(size) edge_in_index = torch.div(local_index, local_inner_size, rounding_mode="floor") + edge_in_offset edge_out_index = local_index % local_inner_size + edge_out_offset edge_in = edge_in[edge_in_index] edge_out = edge_out[edge_out_index] edge_list = torch.stack([edge_in, edge_out], dim=-1) node_feature = getattr(self, "edge_feature", None) num_nodes = self.num_edges num_edges = scatter_add(size, self.node2graph, dim=0, dim_size=self.batch_size) offsets = self._get_offsets(num_nodes, num_edges) return PackedGraph(edge_list, num_nodes=num_nodes, num_edges=num_edges, offsets=offsets, node_feature=node_feature)
[docs] def undirected(self, add_inverse=False): """ Flip all the edges to create undirected graphs. For knowledge graphs, the flipped edges can either have the original relation or an inverse relation. The inverse relation for relation :math:`r` is defined as :math:`|R| + r`. Parameters: add_inverse (bool, optional): whether to use inverse relations for flipped edges """ edge_list = self.edge_list.clone() edge_list[:, :2] = edge_list[:, :2].flip(1) num_relation = self.num_relation if num_relation and add_inverse: edge_list[:, 2] += num_relation num_relation = num_relation * 2 edge_list = torch.stack([self.edge_list, edge_list], dim=1).flatten(0, 1) offsets = self._offsets.unsqueeze(-1).expand(-1, 2).flatten() index = torch.arange(self.num_edge, device=self.device).unsqueeze(-1).expand(-1, 2).flatten() data_dict, meta_dict = self.data_mask(edge_index=index, exclude="edge reference") return type(self)(edge_list, edge_weight=self.edge_weight[index], num_nodes=self.num_nodes, num_edges=self.num_edges * 2, num_relation=num_relation, offsets=offsets, meta_dict=meta_dict, **data_dict)
[docs] def detach(self): """ Detach this packed graph. """ return type(self)(self.edge_list.detach(), edge_weight=self.edge_weight.detach(), num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, offsets=self._offsets, meta_dict=self.meta_dict, **utils.detach(self.data_dict))
[docs] def clone(self): """ Clone this packed graph. """ return type(self)(self.edge_list.clone(), edge_weight=self.edge_weight.clone(), num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, offsets=self._offsets, meta_dict=self.meta_dict, **utils.clone(self.data_dict))
[docs] def cuda(self, *args, **kwargs): """ Return a copy of this packed graph in CUDA memory. This is a non-op if the graph is already on the correct device. """ edge_list = self.edge_list.cuda(*args, **kwargs) if edge_list is self.edge_list: return self else: return type(self)(edge_list, edge_weight=self.edge_weight, num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, offsets=self._offsets, meta_dict=self.meta_dict, **utils.cuda(self.data_dict, *args, **kwargs))
[docs] def cpu(self): """ Return a copy of this packed graph in CPU memory. This is a non-op if the graph is already in CPU memory. """ edge_list = self.edge_list.cpu() if edge_list is self.edge_list: return self else: return type(self)(edge_list, edge_weight=self.edge_weight, num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation, offsets=self._offsets, meta_dict=self.meta_dict, **utils.cpu(self.data_dict))
def __repr__(self): fields = ["batch_size=%d" % self.batch_size, "num_nodes=%s" % pretty.long_array(self.num_nodes.tolist()), "num_edges=%s" % pretty.long_array(self.num_edges.tolist())] if self.num_relation is not None: fields.append("num_relation=%d" % self.num_relation) if self.device.type != "cpu": fields.append("device='%s'" % self.device) return "%s(%s)" % (self.__class__.__name__, ", ".join(fields))
[docs] def visualize(self, titles=None, save_file=None, figure_size=(3, 3), layout="spring", num_row=None, num_col=None): """ Visualize the packed graphs with matplotlib. Parameters: titles (list of str, optional): title for each graph. Default is the ID of each graph. save_file (str, optional): ``png`` or ``pdf`` file to save visualization. If not provided, show the figure in window. figure_size (tuple of int, optional): width and height of the figure layout (str, optional): graph layout num_row (int, optional): number of rows in the figure num_col (int, optional): number of columns in the figure See also: `NetworkX graph layout`_ .. _NetworkX graph layout: https://networkx.github.io/documentation/stable/reference/drawing.html#module-networkx.drawing.layout """ if titles is None: graph = self.get_item(0) titles = ["%s %d" % (type(graph).__name__, i) for i in range(self.batch_size)] if num_col is None: if num_row is None: num_col = math.ceil(self.batch_size ** 0.5) else: num_col = math.ceil(self.batch_size / num_row) if num_row is None: num_row = math.ceil(self.batch_size / num_col) figure_size = (num_col * figure_size[0], num_row * figure_size[1]) fig = plt.figure(figsize=figure_size) for i in range(self.batch_size): graph = self.get_item(i) ax = fig.add_subplot(num_row, num_col, i + 1) graph.visualize(title=titles[i], ax=ax, layout=layout) # remove the space of axis labels fig.tight_layout() if save_file: fig.savefig(save_file) else: fig.show()
Graph.packed_type = PackedGraph def cat(graphs): for i, graph in enumerate(graphs): if not isinstance(graph, PackedGraph): graphs[i] = graph.pack([graph]) edge_list = torch.cat([graph.edge_list for graph in graphs]) pack_num_nodes = torch.stack([graph.num_node for graph in graphs]) pack_num_edges = torch.stack([graph.num_edge for graph in graphs]) pack_num_cum_edges = pack_num_edges.cumsum(0) graph_index = pack_num_cum_edges < len(edge_list) pack_offsets = scatter_add(pack_num_nodes[graph_index], pack_num_cum_edges[graph_index], dim_size=len(edge_list)) pack_offsets = pack_offsets.cumsum(0) edge_list[:, :2] += pack_offsets.unsqueeze(-1) offsets = torch.cat([graph._offsets for graph in graphs]) + pack_offsets edge_weight = torch.cat([graph.edge_weight for graph in graphs]) num_nodes = torch.cat([graph.num_nodes for graph in graphs]) num_edges = torch.cat([graph.num_edges for graph in graphs]) num_relation = graphs[0].num_relation assert all(graph.num_relation == num_relation for graph in graphs) # only keep attributes that exist in all graphs # TODO: this interface is not safe. re-design the interface keys = set(graphs[0].meta_dict.keys()) for graph in graphs: keys = keys.intersection(graph.meta_dict.keys()) meta_dict = {k: graphs[0].meta_dict[k] for k in keys} data_dict = {} for k in keys: data_dict[k] = torch.cat([graph.data_dict[k] for graph in graphs]) return type(graphs[0])(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation, offsets=offsets, meta_dict=meta_dict, **data_dict)