import math
import warnings
from functools import reduce
from collections import defaultdict
import networkx as nx
from matplotlib import pyplot as plt
import torch
from torch_scatter import scatter_add, scatter_min
from torchdrug import core, utils
from torchdrug.data import Dictionary
from torchdrug.utils import pretty
plt.switch_backend("agg")
[docs]class Graph(core._MetaContainer):
r"""
Basic container for sparse graphs.
To batch graphs with variadic sizes, use :meth:`data.Graph.pack <torchdrug.data.Graph.pack>`.
This will return a PackedGraph object with the following block diagonal adjacency matrix.
.. math::
\begin{bmatrix}
A_1 & \cdots & 0 \\
\vdots & \ddots & \vdots \\
0 & \cdots & A_n
\end{bmatrix}
where :math:`A_i` is the adjacency of :math:`i`-th graph.
You may register dynamic attributes for each graph.
The registered attributes will be automatically processed during packing.
Example::
>>> graph = data.Graph(torch.randint(10, (30, 2)))
>>> with graph.node():
>>> graph.my_node_attr = torch.rand(10, 5, 5)
Parameters:
edge_list (array_like, optional): list of edges of shape :math:`(|E|, 2)` or :math:`(|E|, 3)`.
Each tuple is (node_in, node_out) or (node_in, node_out, relation).
edge_weight (array_like, optional): edge weights of shape :math:`(|E|,)`
num_node (int, optional): number of nodes.
By default, it will be inferred from the largest id in `edge_list`
num_relation (int, optional): number of relations
node_feature (array_like, optional): node features of shape :math:`(|V|, ...)`
edge_feature (array_like, optional): edge features of shape :math:`(|E|, ...)`
graph_feature (array_like, optional): graph feature of any shape
"""
_meta_types = {"node", "edge", "graph", "node reference", "edge reference", "graph reference"}
def __init__(self, edge_list=None, edge_weight=None, num_node=None, num_relation=None,
node_feature=None, edge_feature=None, graph_feature=None, **kwargs):
super(Graph, self).__init__(**kwargs)
# edge_list: N * [h, t] or N * [h, t, r]
edge_list, num_edge = self._standarize_edge_list(edge_list, num_relation)
edge_weight = self._standarize_edge_weight(edge_weight, edge_list)
num_node = self._standarize_num_node(num_node, edge_list)
num_relation = self._standarize_num_relation(num_relation, edge_list)
self._edge_list = edge_list
self._edge_weight = edge_weight
self.num_node = num_node
self.num_edge = num_edge
self.num_relation = num_relation
if node_feature is not None:
with self.node():
self.node_feature = torch.as_tensor(node_feature, device=self.device)
if edge_feature is not None:
with self.edge():
self.edge_feature = torch.as_tensor(edge_feature, device=self.device)
if graph_feature is not None:
with self.graph():
self.graph_feature = torch.as_tensor(graph_feature, device=self.device)
[docs] def node(self):
"""
Context manager for node attributes.
"""
return self.context("node")
[docs] def edge(self):
"""
Context manager for edge attributes.
"""
return self.context("edge")
[docs] def graph(self):
"""
Context manager for graph attributes.
"""
return self.context("graph")
[docs] def node_reference(self):
"""
Context manager for node references.
"""
return self.context("node reference")
[docs] def edge_reference(self):
"""
Context manager for edge references.
"""
return self.context("edge reference")
[docs] def graph_reference(self):
"""
Context manager for graph references.
"""
return self.context("graph reference")
def _check_attribute(self, key, value):
for type in self._meta_contexts:
if "reference" in type:
if value.dtype != torch.long:
raise TypeError("Tensors used as reference must be long tensors")
if type == "node":
if len(value) != self.num_node:
raise ValueError("Expect node attribute `%s` to have shape (%d, *), but found %s" %
(key, self.num_node, value.shape))
elif type == "edge":
if len(value) != self.num_edge:
raise ValueError("Expect edge attribute `%s` to have shape (%d, *), but found %s" %
(key, self.num_edge, value.shape))
elif type == "node reference":
is_valid = (value >= -1) & (value < self.num_node)
if not is_valid.all():
error_value = value[~is_valid]
raise ValueError("Expect node reference in [-1, %d), but found %d" %
(self.num_node, error_value[0]))
elif type == "edge reference":
is_valid = (value >= -1) & (value < self.num_edge)
if not is_valid.all():
error_value = value[~is_valid]
raise ValueError("Expect edge reference in [-1, %d), but found %d" %
(self.num_edge, error_value[0]))
elif type == "graph reference":
is_valid = (value >= -1) & (value < self.batch_size)
if not is_valid.all():
error_value = value[~is_valid]
raise ValueError("Expect graph reference in [-1, %d), but found %d" %
(self.batch_size, error_value[0]))
def __setattr__(self, key, value):
if hasattr(self, "meta_dict"):
self._check_attribute(key, value)
super(Graph, self).__setattr__(key, value)
def _standarize_edge_list(self, edge_list, num_relation):
if edge_list is not None and len(edge_list):
if isinstance(edge_list, torch.Tensor) and edge_list.dtype != torch.long:
try:
edge_list = torch.LongTensor(edge_list)
except TypeError:
raise TypeError("Can't convert `edge_list` to torch.long")
else:
edge_list = torch.as_tensor(edge_list, dtype=torch.long)
else:
num_element = 2 if num_relation is None else 3
if isinstance(edge_list, torch.Tensor):
device = edge_list.device
else:
device = "cpu"
edge_list = torch.zeros(0, num_element, dtype=torch.long, device=device)
if (edge_list < 0).any():
raise ValueError("`edge_list` should only contain non-negative indexes")
num_edge = torch.tensor(len(edge_list), device=edge_list.device)
return edge_list, num_edge
def _standarize_edge_weight(self, edge_weight, edge_list):
if edge_weight is not None:
edge_weight = torch.as_tensor(edge_weight, dtype=torch.float, device=edge_list.device)
if len(edge_list) != len(edge_weight):
raise ValueError("`edge_list` and `edge_weight` should be the same size, but found %d and %d"
% (len(edge_list), len(edge_weight)))
else:
edge_weight = torch.ones(len(edge_list), device=edge_list.device)
return edge_weight
def _standarize_num_node(self, num_node, edge_list):
if num_node is None:
num_node = self._maybe_num_node(edge_list)
num_node = torch.as_tensor(num_node, device=edge_list.device)
if (edge_list[:, :2] >= num_node).any():
raise ValueError("`num_node` is %d, but found node %d in `edge_list`" % (num_node, edge_list[:, :2].max()))
return num_node
def _standarize_num_relation(self, num_relation, edge_list):
if num_relation is None and edge_list.shape[1] > 2:
num_relation = self._maybe_num_relation(edge_list)
if num_relation is not None:
num_relation = torch.as_tensor(num_relation, device=edge_list.device)
return num_relation
def _maybe_num_node(self, edge_list):
warnings.warn("_maybe_num_node() is used to determine the number of nodes. "
"This may underestimate the count if there are isolated nodes.")
if len(edge_list):
return edge_list[:, :2].max().item() + 1
else:
return 0
def _maybe_num_relation(self, edge_list):
warnings.warn("_maybe_num_relation() is used to determine the number of relations. "
"This may underestimate the count if there are unseen relations.")
return edge_list[:, 2].max().item() + 1
def _standarize_index(self, index, count):
if isinstance(index, slice):
start = index.start or 0
if start < 0:
start += count
stop = index.stop or count
if stop < 0:
stop += count
step = index.step or 1
index = torch.arange(start, stop, step, device=self.device)
else:
index = torch.as_tensor(index, device=self.device)
if index.ndim == 0:
index = index.unsqueeze(0)
if index.dtype == torch.bool:
index = index.nonzero().squeeze(-1)
else:
index = index.long()
max_index = -1 if len(index) == 0 else index.max().item()
if max_index >= count:
raise ValueError("Invalid index. Expect index smaller than %d, but found %d" % (count, max_index))
return index
def _get_mapping(self, index, count):
index = self._standarize_index(index, count)
if (index.bincount() > 1).any():
raise ValueError("Can't create mapping for duplicate index")
mapping = -torch.ones(count + 1, dtype=torch.long, device=self.device)
mapping[index] = torch.arange(len(index), device=self.device)
return mapping
[docs] @classmethod
def from_dense(cls, adjacency, node_feature=None, edge_feature=None):
"""
Create a sparse graph from a dense adjacency matrix.
For zero entries in the adjacency matrix, their edge features will be ignored.
Parameters:
adjacency (array_like): adjacency matrix of shape :math:`(|V|, |V|)` or :math:`(|V|, |V|, |R|)`
node_feature (array_like): node features of shape :math:`(|V|, ...)`
edge_feature (array_like): edge features of shape :math:`(|V|, |V|, ...)` or :math:`(|V|, |V|, |R|, ...)`
"""
adjacency = torch.as_tensor(adjacency)
if adjacency.shape[0] != adjacency.shape[1]:
raise ValueError("`adjacency` should be a square matrix, but found %d and %d" % adjacency.shape[:2])
edge_list = adjacency.nonzero()
edge_weight = adjacency[tuple(edge_list.t())]
num_node = adjacency.shape[0]
num_relation = adjacency.shape[2] if adjacency.ndim > 2 else None
if edge_feature is not None:
edge_feature = torch.as_tensor(edge_feature)
edge_feature = edge_feature[tuple(edge_list.t())]
return cls(edge_list, edge_weight, num_node, num_relation, node_feature, edge_feature)
[docs] def connected_components(self):
"""
Split this graph into connected components.
Returns:
(PackedGraph, LongTensor): connected components, number of connected components per graph
"""
node_in, node_out = self.edge_list.t()[:2]
range = torch.arange(self.num_node, device=self.device)
node_in, node_out = torch.cat([node_in, node_out, range]), torch.cat([node_out, node_in, range])
# find connected component
# O(d|E|), d is the diameter of the graph
min_neighbor = torch.arange(self.num_node, device=self.device)
last = torch.zeros_like(min_neighbor)
while not torch.equal(min_neighbor, last):
last = min_neighbor
min_neighbor = scatter_min(min_neighbor[node_out], node_in, dim_size=self.num_node)[0]
anchor = torch.unique(min_neighbor)
num_cc = self.node2graph[anchor].bincount(minlength=self.batch_size)
return self.split(min_neighbor), num_cc
[docs] def split(self, node2graph):
"""
Split a graph into multiple disconnected graphs.
Parameters:
node2graph (array_like): ID of the graph each node belongs to
Returns:
PackedGraph
"""
node2graph = torch.as_tensor(node2graph, dtype=torch.long, device=self.device)
# coalesce arbitrary graph IDs to [0, n)
_, node2graph = torch.unique(node2graph, return_inverse=True)
num_graph = node2graph.max() + 1
index = node2graph.argsort()
mapping = torch.zeros_like(index)
mapping[index] = torch.arange(len(index), device=self.device)
node_in, node_out = self.edge_list.t()[:2]
edge_mask = node2graph[node_in] == node2graph[node_out]
edge2graph = node2graph[node_in]
edge_index = edge2graph.argsort()
edge_index = edge_index[edge_mask[edge_index]]
is_first_node = torch.ones(self.num_node, dtype=torch.bool, device=self.device)
is_first_node[1:] = node2graph[index[1:]] != node2graph[index[:-1]]
graph_index = self.node2graph[is_first_node]
edge_list = self.edge_list.clone()
edge_list[:, :2] = mapping[edge_list[:, :2]]
num_nodes = node2graph.bincount(minlength=num_graph)
num_edges = edge2graph[edge_index].bincount(minlength=num_graph)
num_cum_nodes = num_nodes.cumsum(0)
offsets = (num_cum_nodes - num_nodes)[edge2graph[edge_index]]
data_dict, meta_dict = self.data_mask(index, edge_index, graph_index=graph_index, exclude="graph reference")
return self.packed_type(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes,
num_edges=num_edges, num_relation=self.num_relation, offsets=offsets,
meta_dict=meta_dict, **data_dict)
[docs] @classmethod
def pack(cls, graphs):
"""
Pack a list of graphs into a PackedGraph object.
Parameters:
graphs (list of Graph): list of graphs
Returns:
PackedGraph
"""
edge_list = []
edge_weight = []
num_nodes = []
num_edges = []
num_relation = -1
num_cum_node = 0
num_cum_edge = 0
num_graph = 0
data_dict = defaultdict(list)
meta_dict = graphs[0].meta_dict
for graph in graphs:
edge_list.append(graph.edge_list)
edge_weight.append(graph.edge_weight)
num_nodes.append(graph.num_node)
num_edges.append(graph.num_edge)
for k, v in graph.data_dict.items():
for type in meta_dict[k]:
if type == "graph":
v = v.unsqueeze(0)
elif type == "node reference":
v = v + num_cum_node
elif type == "edge reference":
v = v + num_cum_edge
elif type == "graph reference":
v = v + num_graph
data_dict[k].append(v)
if num_relation == -1:
num_relation = graph.num_relation
elif num_relation != graph.num_relation:
raise ValueError("Inconsistent `num_relation` in graphs. Expect %d but got %d."
% (num_relation, graph.num_relation))
num_cum_node += graph.num_node
num_cum_edge += graph.num_edge
num_graph += 1
edge_list = torch.cat(edge_list)
edge_weight = torch.cat(edge_weight)
data_dict = {k: torch.cat(v) for k, v in data_dict.items()}
return cls.packed_type(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges,
num_relation=num_relation, meta_dict=meta_dict, **data_dict)
[docs] def repeat(self, count):
"""
Repeat this graph.
Parameters:
count (int): number of repetitions
Returns:
PackedGraph
"""
edge_list = self.edge_list.repeat(count, 1)
edge_weight = self.edge_weight.repeat(count)
num_nodes = [self.num_node] * count
num_edges = [self.num_edge] * count
num_relation = self.num_relation
data_dict = {}
for k, v in self.data_dict.items():
if "graph" in self.meta_dict[k]:
v = v.unsqueeze(0)
shape = [1] * v.ndim
shape[0] = count
data_dict[k] = v.repeat(shape)
return self.packed_type(edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges,
num_relation=num_relation, meta_dict=self.meta_dict, **data_dict)
[docs] def get_edge(self, edge):
"""
Get the weight of of an edge.
Parameters:
edge (array_like): index of shape :math:`(2,)` or :math:`(3,)`
Returns:
Tensor: weight of the edge
"""
if len(edge) != self.edge_list.shape[1]:
raise ValueError("Incorrect edge index. Expect %d axes but got %d axes"
% (self.edge_list.shape[1], len(edge)))
edge_index, num_match = self.match(edge)
return self.edge_weight[edge_index].sum()
[docs] def match(self, pattern):
"""
Return all matched indexes for each pattern. Support patterns with ``-1`` as the wildcard.
Parameters:
pattern (array_like): index of shape :math:`(N, 2)` or :math:`(N, 3)`
Returns:
(LongTensor, LongTensor): matched indexes, number of matches per edge
Examples::
>>> graph = data.Graph([[0, 1], [1, 0], [1, 2], [2, 1], [2, 0], [0, 2]])
>>> index, num_match = graph.match([[0, -1], [1, 2]])
>>> assert (index == torch.tensor([0, 5, 2])).all()
>>> assert (num_match == torch.tensor([2, 1])).all()
"""
if len(pattern) == 0:
index = num_match = torch.zeros(0, dtype=torch.long, device=self.device)
return index, num_match
if not hasattr(self, "inverted_index"):
self.inverted_index = {}
pattern = torch.as_tensor(pattern, dtype=torch.long, device=self.device)
if pattern.ndim == 1:
pattern = pattern.unsqueeze(0)
mask = pattern != -1
scale = 2 ** torch.arange(pattern.shape[-1], device=self.device)
query_type = (mask * scale).sum(dim=-1)
query_index = query_type.argsort()
num_query = query_type.unique(return_counts=True)[1]
query_ends = num_query.cumsum(0)
query_starts = query_ends - num_query
mask_set = mask[query_index[query_starts]].tolist()
type_ranges = []
type_orders = []
# get matched range for each query type
for i, mask in enumerate(mask_set):
query_type = tuple(mask)
type_index = query_index[query_starts[i]: query_ends[i]]
type_edge = pattern[type_index][:, mask]
if query_type not in self.inverted_index:
self.inverted_index[query_type] = self._build_inverted_index(mask)
inverted_range, order = self.inverted_index[query_type]
ranges = inverted_range.get(type_edge, default=0)
type_ranges.append(ranges)
type_orders.append(order)
ranges = torch.cat(type_ranges)
orders = torch.stack(type_orders)
types = torch.arange(len(mask_set), device=self.device)
types = types.repeat_interleave(num_query)
# reorder matched ranges according to the query order
ranges = scatter_add(ranges, query_index, dim=0, dim_size=len(pattern))
types = scatter_add(types, query_index, dim_size=len(pattern))
# convert range to indexes
starts, ends = ranges.t()
num_match = ends - starts
offsets = num_match.cumsum(0) - num_match
types = types.repeat_interleave(num_match)
ranges = torch.arange(num_match.sum(), device=self.device)
ranges = ranges + (starts - offsets).repeat_interleave(num_match)
index = orders[types, ranges]
return index, num_match
def _build_inverted_index(self, mask):
keys = self.edge_list[:, mask]
base = torch.tensor(self.shape, device=self.device)
base = base[mask]
max = reduce(int.__mul__, base.tolist())
if max > torch.iinfo(torch.int64).max:
raise ValueError("Fail to build an inverted index table based on sorting. "
"The graph is too large.")
scale = base.cumprod(0)
scale = scale[-1] // scale
key = (keys * scale).sum(dim=-1)
order = key.argsort()
num_keys = key.unique(return_counts=True)[1]
ends = num_keys.cumsum(0)
starts = ends - num_keys
ranges = torch.stack([starts, ends], dim=-1)
keys_set = keys[order[starts]]
inverted_range = Dictionary(keys_set, ranges)
return inverted_range, order
def __getitem__(self, index):
# why do we check tuple?
# case 1: x[0, 1] is parsed as (0, 1)
# case 2: x[[0, 1]] is parsed as [0, 1]
if not isinstance(index, tuple):
index = (index,)
index = list(index)
while len(index) < 2:
index.append(slice(None))
if len(index) > 2:
raise ValueError("Graph has only 2 axis, but %d axis is indexed" % len(index))
if all([isinstance(axis_index, int) for axis_index in index]):
return self.get_edge(index)
edge_list = self.edge_list.clone()
for i, axis_index in enumerate(index):
axis_index = self._standarize_index(axis_index, self.num_node)
mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device)
mapping[axis_index] = axis_index
edge_list[:, i] = mapping[edge_list[:, i]]
edge_index = (edge_list >= 0).all(dim=-1)
return self.edge_mask(edge_index)
def __len__(self):
return 1
@property
def batch_size(self):
"""Batch size."""
return 1
[docs] def subgraph(self, index):
"""
Return a subgraph based on the specified nodes.
Equivalent to :meth:`node_mask(index, compact=True) <node_mask>`.
Parameters:
index (array_like): node index
Returns:
Graph
See also:
:meth:`Graph.node_mask`
"""
return self.node_mask(index, compact=True)
def data_mask(self, node_index=None, edge_index=None, graph_index=None, include=None, exclude=None):
data_dict, meta_dict = self.data_by_meta(include, exclude)
node_mapping = None
edge_mapping = None
graph_mapping = None
for k, v in data_dict.items():
for type in meta_dict[k]:
if type == "node" and node_index is not None:
v = v[node_index]
elif type == "edge" and edge_index is not None:
v = v[edge_index]
elif type == "graph" and graph_index is not None:
v = v.unsqueeze(0)[graph_index]
elif type == "node reference" and node_index is not None:
if node_mapping is None:
node_mapping = self._get_mapping(node_index, self.num_node)
v = node_mapping[v]
elif type == "edge reference" and edge_index is not None:
if edge_mapping is None:
edge_mapping = self._get_mapping(edge_index, self.num_edge)
v = edge_mapping[v]
elif type == "graph reference" and graph_index is not None:
if graph_mapping is None:
graph_mapping = self._get_mapping(graph_index, self.batch_size)
v = graph_mapping[v]
data_dict[k] = v
return data_dict, meta_dict
[docs] def node_mask(self, index, compact=False):
"""
Return a masked graph based on the specified nodes.
This function can also be used to re-order the nodes.
Parameters:
index (array_like): node index
compact (bool, optional): compact node ids or not
Returns:
Graph
Examples::
>>> graph = data.Graph.from_dense(torch.eye(3))
>>> assert graph.node_mask([1, 2]).adjacency.shape == (3, 3)
>>> assert graph.node_mask([1, 2], compact=True).adjacency.shape == (2, 2)
"""
index = self._standarize_index(index, self.num_node)
mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device)
if compact:
mapping[index] = torch.arange(len(index), device=self.device)
num_node = len(index)
else:
mapping[index] = index
num_node = self.num_node
edge_list = self.edge_list.clone()
edge_list[:, :2] = mapping[edge_list[:, :2]]
edge_index = (edge_list[:, :2] >= 0).all(dim=-1)
if compact:
data_dict, meta_dict = self.data_mask(index, edge_index)
else:
data_dict, meta_dict = self.data_mask(edge_index=edge_index)
return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_node=num_node,
num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)
[docs] def compact(self):
"""
Remove isolated nodes and compact node ids.
Returns:
Graph
"""
index = self.degree_out + self.degree_in > 0
return self.subgraph(index)
[docs] def edge_mask(self, index):
"""
Return a masked graph based on the specified edges.
This function can also be used to re-order the edges.
Parameters:
index (array_like): edge index
Returns:
Graph
"""
index = self._standarize_index(index, self.num_edge)
data_dict, meta_dict = self.data_mask(edge_index=index)
return type(self)(self.edge_list[index], edge_weight=self.edge_weight[index], num_node=self.num_node,
num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)
[docs] def full(self):
"""
Return a fully connected graph over the nodes.
Returns:
Graph
"""
index = torch.arange(self.num_node, device=self.device)
if self.num_relation:
edge_list = torch.meshgrid(index, index, torch.arange(self.num_relation, device=self.device))
else:
edge_list = torch.meshgrid(index, index)
edge_list = torch.stack(edge_list).flatten(1)
edge_weight = torch.ones(len(edge_list))
data_dict, meta_dict = self.data_by_meta(exclude="edge")
return type(self)(edge_list, edge_weight=edge_weight, num_node=self.num_node, num_relation=self.num_relation,
meta_dict=meta_dict, **data_dict)
[docs] def directed(self, order=None):
"""
Mask the edges to create a directed graph.
Edges that go from a node index to a larger or equal node index will be kept.
Parameters:
order (Tensor, optional): topological order of the nodes
"""
node_in, node_out = self.edge_list.t()[:2]
if order is not None:
edge_index = order[node_in] <= order[node_out]
else:
edge_index = node_in <= node_out
return self.edge_mask(edge_index)
[docs] def undirected(self, add_inverse=False):
"""
Flip all the edges to create an undirected graph.
For knowledge graphs, the flipped edges can either have the original relation or an inverse relation.
The inverse relation for relation :math:`r` is defined as :math:`|R| + r`.
Parameters:
add_inverse (bool, optional): whether to use inverse relations for flipped edges
"""
edge_list = self.edge_list.clone()
edge_list[:, :2] = edge_list[:, :2].flip(1)
num_relation = self.num_relation
if num_relation and add_inverse:
edge_list[:, 2] += num_relation
num_relation = num_relation * 2
edge_list = torch.stack([self.edge_list, edge_list], dim=1).flatten(0, 1)
index = torch.arange(self.num_edge, device=self.device).unsqueeze(-1).expand(-1, 2).flatten()
data_dict, meta_dict = self.data_mask(edge_index=index)
return type(self)(edge_list, edge_weight=self.edge_weight[index], num_node=self.num_node,
num_relation=num_relation, meta_dict=meta_dict, **data_dict)
@utils.cached_property
def adjacency(self):
"""
Adjacency matrix of this graph.
If :attr:`num_relation` is specified, a sparse tensor of shape :math:`(|V|, |V|, num\_relation)` will be
returned.
Otherwise, a sparse tensor of shape :math:`(|V|, |V|)` will be returned.
"""
return utils.sparse_coo_tensor(self.edge_list.t(), self.edge_weight, self.shape)
_tensor_names = ["edge_list", "edge_weight", "num_node", "num_relation", "edge_feature"]
def to_tensors(self):
edge_feature = getattr(self, "edge_feature", torch.tensor(0, device=self.device))
return self.edge_list, self.edge_weight, self.num_node, self.num_relation, edge_feature
@classmethod
def from_tensors(cls, tensors):
edge_list, edge_weight, num_node, num_relation, edge_feature = tensors
if edge_feature.ndim == 0:
edge_feature = None
return cls(edge_list, edge_weight, num_node, num_relation, edge_feature=edge_feature)
@property
def node2graph(self):
"""Node id to graph id mapping."""
return torch.zeros(self.num_node, dtype=torch.long, device=self.device)
@property
def edge2graph(self):
"""Edge id to graph id mapping."""
return torch.zeros(self.num_edge, dtype=torch.long, device=self.device)
@utils.cached_property
def degree_out(self):
"""
Weighted number of edges containing each node as output.
Note this is the **in-degree** in graph theory.
"""
return scatter_add(self.edge_weight, self.edge_list[:, 1], dim_size=self.num_node)
@utils.cached_property
def degree_in(self):
"""
Weighted number of edges containing each node as input.
Note this is the **out-degree** in graph theory.
"""
return scatter_add(self.edge_weight, self.edge_list[:, 0], dim_size=self.num_node)
@property
def edge_list(self):
"""List of edges."""
return self._edge_list
@property
def edge_weight(self):
"""Edge weights."""
return self._edge_weight
@property
def device(self):
"""Device."""
return self.edge_list.device
@property
def requires_grad(self):
return self.edge_weight.requires_grad
@property
def grad(self):
return self.edge_weight.grad
@property
def data(self):
return self
def requires_grad_(self):
self.edge_weight.requires_grad_()
return self
def size(self, dim=None):
if self.num_relation:
size = torch.Size((self.num_node, self.num_node, self.num_relation))
else:
size = torch.Size((self.num_node, self.num_node))
if dim is None:
return size
return size[dim]
@property
def shape(self):
return self.size()
[docs] def copy_(self, src):
"""
Copy data from ``src`` into ``self`` and return ``self``.
The ``src`` graph must have the same set of attributes as ``self``.
"""
self.edge_list.copy_(src.edge_list)
self.edge_weight.copy_(src.edge_weight)
self.num_node.copy_(src.num_node)
self.num_edge.copy_(src.num_edge)
if self.num_relation is not None:
self.num_relation.copy_(src.num_relation)
keys = set(self.data_dict.keys())
src_keys = set(src.data_dict.keys())
if keys != src_keys:
raise RuntimeError("Attributes mismatch. Trying to assign attributes %s, "
"but current graph has attributes %s" % (src_keys, keys))
for k, v in self.data_dict.items():
v.copy_(src.data_dict[k])
return self
[docs] def detach(self):
"""
Detach this graph.
"""
return type(self)(self.edge_list.detach(), edge_weight=self.edge_weight.detach(),
num_node=self.num_node, num_relation=self.num_relation,
meta_dict=self.meta_dict, **utils.detach(self.data_dict))
[docs] def clone(self):
"""
Clone this graph.
"""
return type(self)(self.edge_list.clone(), edge_weight=self.edge_weight.clone(),
num_node=self.num_node, num_relation=self.num_relation,
meta_dict=self.meta_dict, **utils.clone(self.data_dict))
[docs] def cuda(self, *args, **kwargs):
"""
Return a copy of this graph in CUDA memory.
This is a non-op if the graph is already on the correct device.
"""
edge_list = self.edge_list.cuda(*args, **kwargs)
if edge_list is self.edge_list:
return self
else:
return type(self)(edge_list, edge_weight=self.edge_weight,
num_node=self.num_node, num_relation=self.num_relation,
meta_dict=self.meta_dict, **utils.cuda(self.data_dict, *args, **kwargs))
[docs] def cpu(self):
"""
Return a copy of this graph in CPU memory.
This is a non-op if the graph is already in CPU memory.
"""
edge_list = self.edge_list.cpu()
if edge_list is self.edge_list:
return self
else:
return type(self)(edge_list, edge_weight=self.edge_weight, num_node=self.num_node,
num_relation=self.num_relation, meta_dict=self.meta_dict, **utils.cpu(self.data_dict))
[docs] def to(self, device, *args, **kwargs):
"""
Return a copy of this graph on the given device.
"""
device = torch.device(device)
if device.type == "cpu":
return self.cpu(*args, **kwargs)
else:
return self.cuda(device, *args, **kwargs)
def __repr__(self):
fields = ["num_node=%d" % self.num_node, "num_edge=%d" % self.num_edge]
if self.num_relation is not None:
fields.append("num_relation=%d" % self.num_relation)
if self.device.type != "cpu":
fields.append("device='%s'" % self.device)
return "%s(%s)" % (self.__class__.__name__, ", ".join(fields))
[docs] def visualize(self, title=None, save_file=None, figure_size=(3, 3), ax=None, layout="spring"):
"""
Visualize this graph with matplotlib.
Parameters:
title (str, optional): title for this graph
save_file (str, optional): ``png`` or ``pdf`` file to save visualization.
If not provided, show the figure in window.
figure_size (tuple of int, optional): width and height of the figure
ax (matplotlib.axes.Axes, optional): axis to plot the figure
layout (str, optional): graph layout
See also:
`NetworkX graph layout`_
.. _NetworkX graph layout:
https://networkx.github.io/documentation/stable/reference/drawing.html#module-networkx.drawing.layout
"""
is_root = ax is None
if ax is None:
fig = plt.figure(figsize=figure_size)
if title is not None:
ax = plt.gca()
else:
ax = fig.add_axes([0, 0, 1, 1])
if title is not None:
ax.set_title(title)
edge_list = self.edge_list[:, :2].tolist()
G = nx.DiGraph(edge_list)
G.add_nodes_from(range(self.num_node))
if hasattr(nx, "%s_layout" % layout):
func = getattr(nx, "%s_layout" % layout)
else:
raise ValueError("Unknown networkx layout `%s`" % layout)
if layout == "spring" or layout == "random":
pos = func(G, seed=0)
else:
pos = func(G)
nx.draw_networkx(G, pos, ax=ax)
if self.num_relation:
edge_labels = self.edge_list[:, 2].tolist()
edge_labels = {tuple(e): l for e, l in zip(edge_list, edge_labels)}
nx.draw_networkx_edge_labels(G, pos, edge_labels, ax=ax)
ax.set_frame_on(False)
if is_root:
if save_file:
fig.savefig(save_file)
else:
fig.show()
def __getstate__(self):
state = {}
cls = self.__class__
for k, v in self.__dict__.items():
# do not pickle property / cached property
if hasattr(cls, k) and isinstance(getattr(cls, k), property):
continue
state[k] = v
return state
[docs]class PackedGraph(Graph):
"""
Container for sparse graphs with variadic sizes.
To create a PackedGraph from Graph objects
>>> batch = data.Graph.pack(graphs)
To retrieve Graph objects from a PackedGraph
>>> graphs = batch.unpack()
Parameters:
edge_list (array_like, optional): list of edges of shape :math:`(|E|, 2)` or :math:`(|E|, 3)`.
Each tuple is (node_in, node_out) or (node_in, node_out, relation).
edge_weight (array_like, optional): edge weights of shape :math:`(|E|,)`
num_nodes (array_like, optional): number of nodes in each graph
By default, it will be inferred from the largest id in `edge_list`
num_edges (array_like, optional): number of edges in each graph
num_relation (int, optional): number of relations
node_feature (array_like, optional): node features of shape :math:`(|V|, ...)`
edge_feature (array_like, optional): edge features of shape :math:`(|E|, ...)`
offsets (array_like, optional): node id offsets of shape :math:`(|E|,)`.
If not provided, nodes in `edge_list` should be relative index, i.e., the index in each graph.
If provided, nodes in `edge_list` should be absolute index, i.e., the index in the packed graph.
"""
unpacked_type = Graph
def __init__(self, edge_list=None, edge_weight=None, num_nodes=None, num_edges=None, num_relation=None,
offsets=None, **kwargs):
edge_list, num_nodes, num_edges, num_cum_nodes, num_cum_edges, offsets = \
self._get_cumulative(edge_list, num_nodes, num_edges, offsets)
if offsets is None:
offsets = self._get_offsets(num_nodes, num_edges, num_cum_nodes)
edge_list = edge_list.clone()
edge_list[:, :2] += offsets.unsqueeze(-1)
num_node = num_nodes.sum()
if (edge_list[:, :2] >= num_node).any():
raise ValueError("Sum of `num_nodes` is %d, but found %d in `edge_list`" %
(num_node, edge_list[:, :2].max()))
self._offsets = offsets
self.num_nodes = num_nodes
self.num_edges = num_edges
self.num_cum_nodes = num_cum_nodes
self.num_cum_edges = num_cum_edges
super(PackedGraph, self).__init__(edge_list, edge_weight=edge_weight, num_node=num_node,
num_relation=num_relation, **kwargs)
def _get_offsets(self, num_nodes=None, num_edges=None, num_cum_nodes=None, num_cum_edges=None):
if num_nodes is None:
num_cum_nodes_shifted = torch.cat(
[torch.zeros(1, dtype=torch.long, device=self.device), num_cum_nodes[:-1]])
num_nodes = num_cum_nodes - num_cum_nodes_shifted
if num_edges is None:
num_cum_edges_shifted = torch.cat(
[torch.zeros(1, dtype=torch.long, device=self.device), num_cum_edges[:-1]])
num_edges = num_cum_edges - num_cum_edges_shifted
if num_cum_nodes is None:
num_cum_nodes = num_nodes.cumsum(0)
return (num_cum_nodes - num_nodes).repeat_interleave(num_edges)
[docs] def merge(self, graph2graph):
"""
Merge multiple graphs into a single graph.
Parameters:
graph2graph (array_like): ID of the new graph each graph belongs to
"""
graph2graph = torch.as_tensor(graph2graph, dtype=torch.long, device=self.device)
# coalesce arbitrary graph IDs to [0, n)
_, graph2graph = torch.unique(graph2graph, return_inverse=True)
graph_key = graph2graph * self.batch_size + torch.arange(self.batch_size, device=self.device)
graph_index = graph_key.argsort()
graph = self.subbatch(graph_index)
graph2graph = graph2graph[graph_index]
num_graph = graph2graph[-1] + 1
num_nodes = scatter_add(graph.num_nodes, graph2graph, dim_size=num_graph)
num_edges = scatter_add(graph.num_edges, graph2graph, dim_size=num_graph)
offsets = self._get_offsets(num_nodes, num_edges)
data_dict, meta_dict = graph.data_mask(exclude="graph")
return type(self)(graph.edge_list, edge_weight=graph.edge_weight, num_nodes=num_nodes,
num_edges=num_edges, num_relation=graph.num_relation, offsets=offsets,
meta_dict=meta_dict, **data_dict)
[docs] def unpack(self):
"""
Unpack this packed graph into a list of graphs.
Returns:
list of Graph
"""
graphs = []
for i in range(self.batch_size):
graphs.append(self.get_item(i))
return graphs
def __iter__(self):
self._iter_index = 0
return self
def __next__(self):
if self._iter_index < self.batch_size:
item = self[self._iter_index]
self._iter_index += 1
return item
raise StopIteration
def _check_attribute(self, key, value):
for type in self._meta_contexts:
if "reference" in type:
if value.dtype != torch.long:
raise TypeError("Tensors used as reference must be long tensors")
if type == "node":
if len(value) != self.num_node:
raise ValueError("Expect node attribute `%s` to have shape (%d, *), but found %s" %
(key, self.num_node, value.shape))
elif type == "edge":
if len(value) != self.num_edge:
raise ValueError("Expect edge attribute `%s` to have shape (%d, *), but found %s" %
(key, self.num_edge, value.shape))
elif type == "graph":
if len(value) != self.batch_size:
raise ValueError("Expect graph attribute `%s` to have shape (%d, *), but found %s" %
(key, self.batch_size, value.shape))
elif type == "node reference":
is_valid = (value >= -1) & (value < self.num_node)
if not is_valid.all():
error_value = value[~is_valid]
raise ValueError("Expect node reference in [-1, %d), but found %d" %
(self.num_node, error_value[0]))
elif type == "edge reference":
is_valid = (value >= -1) & (value < self.num_edge)
if not is_valid.all():
error_value = value[~is_valid]
raise ValueError("Expect edge reference in [-1, %d), but found %d" %
(self.num_edge, error_value[0]))
elif type == "graph reference":
is_valid = (value >= -1) & (value < self.batch_size)
if not is_valid.all():
error_value = value[~is_valid]
raise ValueError("Expect graph reference in [-1, %d), but found %d" %
(self.batch_size, error_value[0]))
[docs] def unpack_data(self, data, type="auto"):
"""
Unpack node or edge data according to the packed graph.
Parameters:
data (Tensor): data to unpack
type (str, optional): data type. Can be ``auto``, ``node``, or ``edge``.
Returns:
list of Tensor
"""
if type == "auto":
if self.num_node == self.num_edge:
raise ValueError("Ambiguous type. Please specify either `node` or `edge`")
if len(data) == self.num_node:
type = "node"
elif len(data) == self.num_edge:
type = "edge"
else:
raise ValueError("Graph has %d nodes and %d edges, but data has %d entries" %
(self.num_node, self.num_edge, len(data)))
data_list = []
if type == "node":
for i in range(self.batch_size):
data_list.append(data[self.num_cum_nodes[i] - self.num_nodes[i]: self.num_cum_nodes[i]])
elif type == "edge":
for i in range(self.batch_size):
data_list.append(data[self.num_cum_edges[i] - self.num_edges[i]: self.num_cum_edges[i]])
return data_list
[docs] def repeat(self, count):
"""
Repeat this packed graph. This function behaves similarly to `torch.Tensor.repeat`_.
.. _torch.Tensor.repeat:
https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html
Parameters:
count (int): number of repetitions
Returns:
PackedGraph
"""
pack_offsets = torch.arange(count, device=self.device) * self.num_node
pack_offsets = pack_offsets.unsqueeze(-1).expand(-1, self.num_edge).flatten()
edge_list = self.edge_list.repeat(count, 1)
edge_list[:, :2] += pack_offsets.unsqueeze(-1)
offsets = self._offsets.repeat(count) + pack_offsets
# TODO: handle node/edge/graph reference
data_dict = {}
for k, v in self.data_dict.items():
shape = [1] * v.ndim
shape[0] = count
data_dict[k] = v.repeat(shape)
return type(self)(edge_list, edge_weight=self.edge_weight.repeat(count),
num_nodes=self.num_nodes.repeat(count), num_edges=self.num_edges.repeat(count),
num_relation=self.num_relation, offsets=offsets,
meta_dict=self.meta_dict, **data_dict)
[docs] def repeat_interleave(self, repeats):
"""
Repeat this packed graph. This function behaves similarly to `torch.repeat_interleave`_.
.. _torch.repeat_interleave:
https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html
Parameters:
repeats (Tensor or int): number of repetitions for each graph
Returns:
PackedGraph
"""
repeats = torch.as_tensor(repeats, dtype=torch.long, device=self.device)
if repeats.numel() == 1:
repeats = repeats * torch.ones(self.batch_size, dtype=torch.long, device=self.device)
num_nodes = self.num_nodes.repeat_interleave(repeats)
num_edges = self.num_edges.repeat_interleave(repeats)
num_cum_nodes = num_nodes.cumsum(0)
num_cum_edges = num_edges.cumsum(0)
num_node = num_nodes.sum()
num_edge = num_edges.sum()
batch_size = repeats.sum()
# special case 1: graphs[i] may have no node or no edge
# special case 2: repeats[i] may be 0
cum_repeats_shifted = repeats.cumsum(0) - repeats
graph_mask = cum_repeats_shifted < batch_size
cum_repeats_shifted = cum_repeats_shifted[graph_mask]
index = num_cum_nodes - num_nodes
index = torch.cat([index, index[cum_repeats_shifted]])
value = torch.cat([-num_nodes, self.num_nodes[graph_mask]])
mask = index < num_node
node_index = scatter_add(value[mask], index[mask], dim_size=num_node)
node_index = (node_index + 1).cumsum(0) - 1
index = num_cum_edges - num_edges
index = torch.cat([index, index[cum_repeats_shifted]])
value = torch.cat([-num_edges, self.num_edges[graph_mask]])
mask = index < num_edge
edge_index = scatter_add(value[mask], index[mask], dim_size=num_edge)
edge_index = (edge_index + 1).cumsum(0) - 1
graph_index = torch.repeat_interleave(repeats)
index = num_cum_edges - num_edges
index = torch.cat([index, index[cum_repeats_shifted]])
value = torch.cat([num_nodes, -self.num_nodes[graph_mask]])
mask = index < num_edge
pack_offsets = scatter_add(value[mask], index[mask], dim_size=num_edge)
pack_offsets = pack_offsets.cumsum(0)
edge_list = self.edge_list[edge_index]
edge_list[:, :2] += pack_offsets.unsqueeze(-1)
offsets = self._offsets[edge_index] + pack_offsets
# TODO: handle reference in repeat_interleave
data_dict, meta_dict = self.data_mask(node_index, edge_index, graph_index=graph_index,
exclude=("node reference", "edge reference", "graph reference"))
return type(self)(edge_list, edge_weight=self.edge_weight[edge_index],
num_nodes=num_nodes, num_edges=num_edges, num_relation=self.num_relation, offsets=offsets,
meta_dict=meta_dict, **data_dict)
[docs] def get_item(self, index):
"""
Get the i-th graph from this packed graph.
Parameters:
index (int): graph index
Returns:
Graph
"""
node_index = torch.arange(self.num_cum_nodes[index] - self.num_nodes[index], self.num_cum_nodes[index],
device=self.device)
edge_index = torch.arange(self.num_cum_edges[index] - self.num_edges[index], self.num_cum_edges[index],
device=self.device)
graph_index = index
edge_list = self.edge_list[edge_index].clone()
edge_list[:, :2] -= self._offsets[edge_index].unsqueeze(-1)
data_dict, meta_dict = self.data_mask(node_index, edge_index, graph_index=graph_index)
return self.unpacked_type(edge_list, edge_weight=self.edge_weight[edge_index], num_node=self.num_nodes[index],
num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)
def _get_cumulative(self, edge_list, num_nodes, num_edges, offsets):
if edge_list is None:
raise ValueError("`edge_list` should be provided")
if num_edges is None:
raise ValueError("`num_edges` should be provided")
edge_list = torch.as_tensor(edge_list)
num_edges = torch.as_tensor(num_edges, device=edge_list.device)
num_edge = num_edges.sum()
if num_edge != len(edge_list):
raise ValueError("Sum of `num_edges` is %d, but found %d edges in `edge_list`" % (num_edge, len(edge_list)))
num_cum_edges = num_edges.cumsum(0)
if offsets is None:
_edge_list = edge_list
else:
offsets = torch.as_tensor(offsets, device=edge_list.device)
_edge_list = edge_list.clone()
_edge_list[:, :2] -= offsets.unsqueeze(-1)
if num_nodes is None:
num_nodes = []
for num_edge, num_cum_edge in zip(num_edges, num_cum_edges):
num_nodes.append(self._maybe_num_node(_edge_list[num_cum_edge - num_edge: num_cum_edge]))
num_nodes = torch.as_tensor(num_nodes, device=edge_list.device)
num_cum_nodes = num_nodes.cumsum(0)
return edge_list, num_nodes, num_edges, num_cum_nodes, num_cum_edges, offsets
def _get_num_xs(self, index, num_cum_xs):
x = torch.zeros(num_cum_xs[-1], dtype=torch.long, device=self.device)
x[index] = 1
num_cum_indexes = x.cumsum(0)
num_cum_indexes = torch.cat([torch.zeros(1, dtype=torch.long, device=self.device), num_cum_indexes])
new_num_cum_xs = num_cum_indexes[num_cum_xs]
new_num_cum_xs_shifted = torch.cat([torch.zeros(1, dtype=torch.long, device=self.device), new_num_cum_xs[:-1]])
new_num_xs = new_num_cum_xs - new_num_cum_xs_shifted
return new_num_xs
def data_mask(self, node_index=None, edge_index=None, graph_index=None, include=None, exclude=None):
data_dict, meta_dict = self.data_by_meta(include, exclude)
node_mapping = None
edge_mapping = None
graph_mapping = None
for k, v in data_dict.items():
for type in meta_dict[k]:
if type == "node" and node_index is not None:
v = v[node_index]
elif type == "edge" and edge_index is not None:
v = v[edge_index]
elif type == "graph" and graph_index is not None:
v = v[graph_index]
elif type == "node reference" and node_index is not None:
if node_mapping is None:
node_mapping = self._get_mapping(node_index, self.num_node)
v = node_mapping[v]
elif type == "edge reference" and edge_index is not None:
if edge_mapping is None:
edge_mapping = self._get_mapping(edge_index, self.num_edge)
v = edge_mapping[v]
elif type == "graph reference" and graph_index is not None:
if graph_mapping is None:
graph_mapping = self._get_mapping(graph_index, self.batch_size)
v = graph_mapping[v]
data_dict[k] = v
return data_dict, meta_dict
def __getitem__(self, index):
# why do we check tuple?
# case 1: x[0, 1] is parsed as (0, 1)
# case 2: x[[0, 1]] is parsed as [0, 1]
if not isinstance(index, tuple):
index = (index,)
if isinstance(index[0], int):
item = self.get_item(index[0])
if len(index) > 1:
item = item[index[1:]]
return item
if len(index) > 1:
raise ValueError("Complex indexing is not supported for PackedGraph")
index = self._standarize_index(index[0], self.batch_size)
count = index.bincount(minlength=self.batch_size)
if count.max() > 1:
graph = self.repeat_interleave(count)
index_order = index.argsort()
order = torch.zeros_like(index)
order[index_order] = torch.arange(len(index), dtype=torch.long, device=self.device)
return graph.subbatch(order)
return self.subbatch(index)
def __len__(self):
return len(self.num_nodes)
[docs] def full(self):
"""
Return a pack of fully connected graphs.
This is useful for computing node-pair-wise features.
The computation can be implemented as message passing over a fully connected graph.
Returns:
PackedGraph
"""
# TODO: more efficient implementation?
graphs = self.unpack()
graphs = [graph.full() for graph in graphs]
return graphs[0].pack(graphs)
@utils.cached_property
def node2graph(self):
"""Node id to graph id mapping."""
range = torch.arange(self.batch_size, device=self.device)
node2graph = range.repeat_interleave(self.num_nodes)
return node2graph
@utils.cached_property
def edge2graph(self):
"""Edge id to graph id mapping."""
range = torch.arange(self.batch_size, device=self.device)
edge2graph = range.repeat_interleave(self.num_edges)
return edge2graph
@property
def batch_size(self):
"""Batch size."""
return len(self.num_nodes)
[docs] def node_mask(self, index, compact=False):
"""
Return a masked packed graph based on the specified nodes.
Note the compact option is only applied to node ids but not graph ids.
To generate compact graph ids, use :meth:`subbatch`.
Parameters:
index (array_like): node index
compact (bool, optional): compact node ids or not
Returns:
PackedGraph
"""
index = self._standarize_index(index, self.num_node)
mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device)
if compact:
mapping[index] = torch.arange(len(index), device=self.device)
num_nodes = self._get_num_xs(index, self.num_cum_nodes)
offsets = self._get_offsets(num_nodes, self.num_edges)
else:
mapping[index] = index
num_nodes = self.num_nodes
offsets = self._offsets
edge_list = self.edge_list.clone()
edge_list[:, :2] = mapping[edge_list[:, :2]]
edge_index = (edge_list[:, :2] >= 0).all(dim=-1)
num_edges = self._get_num_xs(edge_index, self.num_cum_edges)
if compact:
data_dict, meta_dict = self.data_mask(index, edge_index)
else:
data_dict, meta_dict = self.data_mask(edge_index=edge_index)
return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes,
num_edges=num_edges, num_relation=self.num_relation, offsets=offsets[edge_index],
meta_dict=meta_dict, **data_dict)
[docs] def edge_mask(self, index):
"""
Return a masked packed graph based on the specified edges.
Parameters:
index (array_like): edge index
Returns:
PackedGraph
"""
index = self._standarize_index(index, self.num_edge)
data_dict, meta_dict = self.data_mask(edge_index=index)
num_edges = self._get_num_xs(index, self.num_cum_edges)
return type(self)(self.edge_list[index], edge_weight=self.edge_weight[index], num_nodes=self.num_nodes,
num_edges=num_edges, num_relation=self.num_relation, offsets=self._offsets[index],
meta_dict=meta_dict, **data_dict)
[docs] def graph_mask(self, index, compact=False):
"""
Return a masked packed graph based on the specified graphs.
This function can also be used to re-order the graphs.
Parameters:
index (array_like): graph index
compact (bool, optional): compact graph ids or not
Returns:
PackedGraph
"""
index = self._standarize_index(index, self.batch_size)
graph_mapping = -torch.ones(self.batch_size, dtype=torch.long, device=self.device)
graph_mapping[index] = torch.arange(len(index), device=self.device)
node_index = graph_mapping[self.node2graph] >= 0
node_index = self._standarize_index(node_index, self.num_node)
mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device)
if compact:
key = graph_mapping[self.node2graph[node_index]] * self.num_node + node_index
order = key.argsort()
node_index = node_index[order]
mapping[node_index] = torch.arange(len(node_index), device=self.device)
num_nodes = self.num_nodes[index]
else:
mapping[node_index] = node_index
num_nodes = torch.zeros_like(self.num_nodes)
num_nodes[index] = self.num_nodes[index]
edge_list = self.edge_list.clone()
edge_list[:, :2] = mapping[edge_list[:, :2]]
edge_index = (edge_list[:, :2] >= 0).all(dim=-1)
edge_index = self._standarize_index(edge_index, self.num_edge)
if compact:
key = graph_mapping[self.edge2graph[edge_index]] * self.num_edge + edge_index
order = key.argsort()
edge_index = edge_index[order]
num_edges = self.num_edges[index]
else:
num_edges = torch.zeros_like(self.num_edges)
num_edges[index] = self.num_edges[index]
offsets = self._get_offsets(num_nodes, num_edges)
if compact:
data_dict, meta_dict = self.data_mask(node_index, edge_index, graph_index=index)
else:
data_dict, meta_dict = self.data_mask(edge_index=edge_index)
return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_nodes=num_nodes,
num_edges=num_edges, num_relation=self.num_relation, offsets=offsets,
meta_dict=meta_dict, **data_dict)
[docs] def subbatch(self, index):
"""
Return a subbatch based on the specified graphs.
Equivalent to :meth:`graph_mask(index, compact=True) <graph_mask>`.
Parameters:
index (array_like): graph index
Returns:
PackedGraph
See also:
:meth:`PackedGraph.graph_mask`
"""
return self.graph_mask(index, compact=True)
[docs] def undirected(self, add_inverse=False):
"""
Flip all the edges to create undirected graphs.
For knowledge graphs, the flipped edges can either have the original relation or an inverse relation.
The inverse relation for relation :math:`r` is defined as :math:`|R| + r`.
Parameters:
add_inverse (bool, optional): whether to use inverse relations for flipped edges
"""
edge_list = self.edge_list.clone()
edge_list[:, :2] = edge_list[:, :2].flip(1)
num_relation = self.num_relation
if num_relation and add_inverse:
edge_list[:, 2] += num_relation
num_relation = num_relation * 2
edge_list = torch.stack([self.edge_list, edge_list], dim=1).flatten(0, 1)
offsets = self._offsets.unsqueeze(-1).expand(-1, 2).flatten()
index = torch.arange(self.num_edge, device=self.device).unsqueeze(-1).expand(-1, 2).flatten()
data_dict, meta_dict = self.data_mask(edge_index=index, exclude="edge reference")
return type(self)(edge_list, edge_weight=self.edge_weight[index], num_nodes=self.num_nodes,
num_edges=self.num_edges * 2, num_relation=num_relation, offsets=offsets,
meta_dict=meta_dict, **data_dict)
[docs] def detach(self):
"""
Detach this packed graph.
"""
return type(self)(self.edge_list.detach(), edge_weight=self.edge_weight.detach(),
num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation,
offsets=self._offsets, meta_dict=self.meta_dict, **utils.detach(self.data_dict))
[docs] def clone(self):
"""
Clone this packed graph.
"""
return type(self)(self.edge_list.clone(), edge_weight=self.edge_weight.clone(),
num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation,
offsets=self._offsets, meta_dict=self.meta_dict, **utils.clone(self.data_dict))
[docs] def cuda(self, *args, **kwargs):
"""
Return a copy of this packed graph in CUDA memory.
This is a non-op if the graph is already on the correct device.
"""
edge_list = self.edge_list.cuda(*args, **kwargs)
if edge_list is self.edge_list:
return self
else:
return type(self)(edge_list, edge_weight=self.edge_weight,
num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation,
offsets=self._offsets, meta_dict=self.meta_dict,
**utils.cuda(self.data_dict, *args, **kwargs))
[docs] def cpu(self):
"""
Return a copy of this packed graph in CPU memory.
This is a non-op if the graph is already in CPU memory.
"""
edge_list = self.edge_list.cpu()
if edge_list is self.edge_list:
return self
else:
return type(self)(edge_list, edge_weight=self.edge_weight,
num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation,
offsets=self._offsets, meta_dict=self.meta_dict, **utils.cpu(self.data_dict))
def __repr__(self):
fields = ["batch_size=%d" % self.batch_size,
"num_nodes=%s" % pretty.long_array(self.num_nodes.tolist()),
"num_edges=%s" % pretty.long_array(self.num_edges.tolist())]
if self.num_relation is not None:
fields.append("num_relation=%d" % self.num_relation)
if self.device.type != "cpu":
fields.append("device='%s'" % self.device)
return "%s(%s)" % (self.__class__.__name__, ", ".join(fields))
[docs] def visualize(self, titles=None, save_file=None, figure_size=(3, 3), layout="spring", num_row=None, num_col=None):
"""
Visualize the packed graphs with matplotlib.
Parameters:
titles (list of str, optional): title for each graph. Default is the ID of each graph.
save_file (str, optional): ``png`` or ``pdf`` file to save visualization.
If not provided, show the figure in window.
figure_size (tuple of int, optional): width and height of the figure
layout (str, optional): graph layout
num_row (int, optional): number of rows in the figure
num_col (int, optional): number of columns in the figure
See also:
`NetworkX graph layout`_
.. _NetworkX graph layout:
https://networkx.github.io/documentation/stable/reference/drawing.html#module-networkx.drawing.layout
"""
if titles is None:
graph = self.get_item(0)
titles = ["%s %d" % (type(graph).__name__, i) for i in range(self.batch_size)]
if num_col is None:
if num_row is None:
num_col = math.ceil(self.batch_size ** 0.5)
else:
num_col = math.ceil(self.batch_size / num_row)
if num_row is None:
num_row = math.ceil(self.batch_size / num_col)
figure_size = (num_col * figure_size[0], num_row * figure_size[1])
fig = plt.figure(figsize=figure_size)
for i in range(self.batch_size):
graph = self.get_item(i)
ax = fig.add_subplot(num_row, num_col, i + 1)
graph.visualize(title=titles[i], ax=ax, layout=layout)
# remove the space of axis labels
fig.tight_layout()
if save_file:
fig.savefig(save_file)
else:
fig.show()
Graph.packed_type = PackedGraph
def cat(graphs):
for i, graph in enumerate(graphs):
if not isinstance(graph, PackedGraph):
graphs[i] = graph.pack([graph])
edge_list = torch.cat([graph.edge_list for graph in graphs])
pack_num_nodes = torch.stack([graph.num_node for graph in graphs])
pack_num_edges = torch.stack([graph.num_edge for graph in graphs])
pack_num_cum_edges = pack_num_edges.cumsum(0)
graph_index = pack_num_cum_edges < len(edge_list)
pack_offsets = scatter_add(pack_num_nodes[graph_index], pack_num_cum_edges[graph_index],
dim_size=len(edge_list))
pack_offsets = pack_offsets.cumsum(0)
edge_list[:, :2] += pack_offsets.unsqueeze(-1)
offsets = torch.cat([graph._offsets for graph in graphs]) + pack_offsets
edge_weight = torch.cat([graph.edge_weight for graph in graphs])
num_nodes = torch.cat([graph.num_nodes for graph in graphs])
num_edges = torch.cat([graph.num_edges for graph in graphs])
num_relation = graphs[0].num_relation
assert all(graph.num_relation == num_relation for graph in graphs)
# only keep attributes that exist in all graphs
# TODO: this interface is not safe. re-design the interface
keys = set(graphs[0].meta_dict.keys())
for graph in graphs:
keys = keys.intersection(graph.meta_dict.keys())
meta_dict = {k: graphs[0].meta_dict[k] for k in keys}
data_dict = {}
for k in keys:
data_dict[k] = torch.cat([graph.data_dict[k] for graph in graphs])
return type(graphs[0])(edge_list, edge_weight=edge_weight,
num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation, offsets=offsets,
meta_dict=meta_dict, **data_dict)