import os
import torch
from torch import autograd
from torchdrug import utils
backend = "fast"
path = os.path.join(os.path.dirname(__file__), "extension")
embedding = utils.load_extension("embedding",
[os.path.join(path, "embedding.cpp"), os.path.join(path, "embedding.cu")])
class TransEFunction(autograd.Function):
@staticmethod
def forward(ctx, entity, relation, h_index, t_index, r_index):
if entity.device.type == "cuda":
forward = embedding.transe_forward_cuda
else:
forward = embedding.transe_forward_cpu
score = forward(entity, relation, h_index, t_index, r_index)
ctx.save_for_backward(entity, relation, h_index, t_index, r_index)
return score
@staticmethod
def backward(ctx, score_grad):
if score_grad.device.type == "cuda":
backward = embedding.transe_backward_cuda
else:
backward = embedding.transe_backward_cpu
entity_grad, relation_grad = backward(*ctx.saved_tensors, score_grad)
return entity_grad, relation_grad, None, None, None
class DistMultFunction(autograd.Function):
@staticmethod
def forward(ctx, entity, relation, h_index, t_index, r_index):
if entity.device.type == "cuda":
forward = embedding.distmult_forward_cuda
else:
forward = embedding.distmult_forward_cpu
score = forward(entity, relation, h_index, t_index, r_index)
ctx.save_for_backward(entity, relation, h_index, t_index, r_index)
return score
@staticmethod
def backward(ctx, score_grad):
if score_grad.device.type == "cuda":
backward = embedding.distmult_backward_cuda
else:
backward = embedding.distmult_backward_cpu
entity_grad, relation_grad = backward(*ctx.saved_tensors, score_grad)
return entity_grad, relation_grad, None, None, None
class ComplExFunction(autograd.Function):
@staticmethod
def forward(ctx, entity, relation, h_index, t_index, r_index):
if entity.device.type == "cuda":
forward = embedding.complex_forward_cuda
else:
forward = embedding.complex_forward_cpu
score = forward(entity, relation, h_index, t_index, r_index)
ctx.save_for_backward(entity, relation, h_index, t_index, r_index)
return score
@staticmethod
def backward(ctx, score_grad):
if score_grad.device.type == "cuda":
backward = embedding.complex_backward_cuda
else:
backward = embedding.complex_backward_cpu
entity_grad, relation_grad = backward(*ctx.saved_tensors, score_grad)
return entity_grad, relation_grad, None, None, None
class SimplEFunction(autograd.Function):
@staticmethod
def forward(ctx, entity, relation, h_index, t_index, r_index):
if entity.device.type == "cuda":
forward = embedding.simple_forward_cuda
else:
forward = embedding.simple_forward_cpu
score = forward(entity, relation, h_index, t_index, r_index)
ctx.save_for_backward(entity, relation, h_index, t_index, r_index)
return score
@staticmethod
def backward(ctx, score_grad):
if score_grad.device.type == "cuda":
backward = embedding.simple_backward_cuda
else:
backward = embedding.simple_backward_cpu
entity_grad, relation_grad = backward(*ctx.saved_tensors, score_grad)
return entity_grad, relation_grad, None, None, None
class RotatEFunction(autograd.Function):
@staticmethod
def forward(ctx, entity, relation, h_index, t_index, r_index):
if entity.device.type == "cuda":
forward = embedding.rotate_forward_cuda
else:
forward = embedding.rotate_forward_cpu
score = forward(entity, relation, h_index, t_index, r_index)
ctx.save_for_backward(entity, relation, h_index, t_index, r_index)
return score
@staticmethod
def backward(ctx, score_grad):
if score_grad.device.type == "cuda":
backward = embedding.rotate_backward_cuda
else:
backward = embedding.rotate_backward_cpu
entity_grad, relation_grad = backward(*ctx.saved_tensors, score_grad)
return entity_grad, relation_grad, None, None, None
[docs]def transe_score(entity, relation, h_index, t_index, r_index):
"""
TransE score function from `Translating Embeddings for Modeling Multi-relational Data`_.
.. _Translating Embeddings for Modeling Multi-relational Data:
https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf
Parameters:
entity (Tensor): entity embeddings of shape :math:`(|V|, d)`
relation (Tensor): relation embeddings of shape :math:`(|R|, d)`
h_index (LongTensor): index of head entities
t_index (LongTensor): index of tail entities
r_index (LongTensor): index of relations
"""
if backend == "native":
h = entity[h_index]
r = relation[r_index]
t = entity[t_index]
score = (h + r - t).norm(p=1, dim=-1)
elif backend == "fast":
score = TransEFunction.apply(entity, relation, h_index, t_index, r_index)
else:
raise ValueError("Unknown embedding backend `%s`" % backend)
return score
[docs]def distmult_score(entity, relation, h_index, t_index, r_index):
"""
DistMult score function from `Embedding Entities and Relations for Learning and Inference in Knowledge Bases`_.
.. _Embedding Entities and Relations for Learning and Inference in Knowledge Bases:
https://arxiv.org/pdf/1412.6575.pdf
Parameters:
entity (Tensor): entity embeddings of shape :math:`(|V|, d)`
relation (Tensor): relation embeddings of shape :math:`(|R|, d)`
h_index (LongTensor): index of head entities
t_index (LongTensor): index of tail entities
r_index (LongTensor): index of relations
"""
if backend == "native":
h = entity[h_index]
r = relation[r_index]
t = entity[t_index]
score = (h * r * t).sum(dim=-1)
elif backend == "fast":
score = DistMultFunction.apply(entity, relation, h_index, t_index, r_index)
else:
raise ValueError("Unknown embedding backend `%s`" % backend)
return score
[docs]def complex_score(entity, relation, h_index, t_index, r_index):
"""
ComplEx score function from `Complex Embeddings for Simple Link Prediction`_.
.. _Complex Embeddings for Simple Link Prediction:
http://proceedings.mlr.press/v48/trouillon16.pdf
Parameters:
entity (Tensor): entity embeddings of shape :math:`(|V|, 2d)`
relation (Tensor): relation embeddings of shape :math:`(|R|, 2d)`
h_index (LongTensor): index of head entities
t_index (LongTensor): index of tail entities
r_index (LongTensor): index of relations
"""
if backend == "native":
h = entity[h_index]
r = relation[r_index]
t = entity[t_index]
h_re, h_im = h.chunk(2, dim=-1)
r_re, r_im = r.chunk(2, dim=-1)
t_re, t_im = t.chunk(2, dim=-1)
x_re = h_re * r_re - h_im * r_im
x_im = h_re * r_im + h_im * r_re
x = x_re * t_re + x_im * t_im
score = x.sum(dim=-1)
elif backend == "fast":
score = ComplExFunction.apply(entity, relation, h_index, t_index, r_index)
else:
raise ValueError("Unknown embedding backend `%s`" % backend)
return score
[docs]def simple_score(entity, relation, h_index, t_index, r_index):
"""
SimplE score function from `SimplE Embedding for Link Prediction in Knowledge Graphs`_.
.. _SimplE Embedding for Link Prediction in Knowledge Graphs:
https://papers.nips.cc/paper/2018/file/b2ab001909a8a6f04b51920306046ce5-Paper.pdf
Parameters:
entity (Tensor): entity embeddings of shape :math:`(|V|, 2d)`
relation (Tensor): relation embeddings of shape :math:`(|R|, d)`
h_index (LongTensor): index of head entities
t_index (LongTensor): index of tail entities
r_index (LongTensor): index of relations
"""
if backend == "native":
h = entity[h_index]
r = relation[r_index]
t = entity[t_index]
t_flipped = torch.cat(t.chunk(2, dim=-1)[::-1], dim=-1)
score = (h * r * t_flipped).sum(dim=-1)
elif backend == "fast":
score = SimplEFunction.apply(entity, relation, h_index, t_index, r_index)
else:
raise ValueError("Unknown embedding backend `%s`" % backend)
return score
[docs]def rotate_score(entity, relation, h_index, t_index, r_index):
"""
RotatE score function from `RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space`_.
.. _RotatE\: Knowledge Graph Embedding by Relational Rotation in Complex Space:
https://arxiv.org/pdf/1902.10197.pdf
Parameters:
entity (Tensor): entity embeddings of shape :math:`(|V|, 2d)`
relation (Tensor): relation embeddings of shape :math:`(|R|, d)`
h_index (LongTensor): index of head entities
t_index (LongTensor): index of tail entities
r_index (LongTensor): index of relations
"""
if backend == "native":
h = entity[h_index]
r = relation[r_index]
t = entity[t_index]
h_re, h_im = h.chunk(2, dim=-1)
r_re, r_im = torch.cos(r), torch.sin(r)
t_re, t_im = t.chunk(2, dim=-1)
x_re = h_re * r_re - h_im * r_im - t_re
x_im = h_re * r_im + h_im * r_re - t_im
x = torch.stack([x_re, x_im], dim=-1)
score = x.norm(p=2, dim=-1).sum(dim=-1)
elif backend == "fast":
score = RotatEFunction.apply(entity, relation, h_index, t_index, r_index)
else:
raise ValueError("Unknown embedding backend `%s`" % backend)
return score