Source code for torchdrug.layers.functional.spmm

import os
import sys

import torch
from torch import autograd

from torchdrug import utils

module = sys.modules[__name__]

path = os.path.join(os.path.dirname(__file__), "extension")
spmm = utils.load_extension("spmm", [os.path.join(path, "spmm.cpp"), os.path.join(path, "rspmm.cpp"),
                                     os.path.join(path, "spmm.cu"), os.path.join(path, "rspmm.cu")])


class SPMMAddMulFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.spmm_add_mul_forward_cuda
        else:
            forward = spmm.spmm_add_mul_forward_cpu
        output = forward(sparse, input)
        ctx.save_for_backward(sparse, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.spmm_add_mul_backward_cuda
        else:
            backward = spmm.spmm_add_mul_backward_cpu
        sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, input_grad


class SPMMMinMulFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.spmm_min_mul_forward_cuda
        else:
            forward = spmm.spmm_min_mul_forward_cpu
        output = forward(sparse, input)
        ctx.save_for_backward(sparse, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.spmm_min_mul_backward_cuda
        else:
            backward = spmm.spmm_min_mul_backward_cpu
        sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, input_grad


class SPMMMaxMulFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.spmm_max_mul_forward_cuda
        else:
            forward = spmm.spmm_max_mul_forward_cpu
        output = forward(sparse, input)
        ctx.save_for_backward(sparse, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.spmm_max_mul_backward_cuda
        else:
            backward = spmm.spmm_max_mul_backward_cpu
        sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, input_grad


class SPMMAddAddFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.spmm_add_add_forward_cuda
        else:
            forward = spmm.spmm_add_add_forward_cpu
        output = forward(sparse, input)
        ctx.save_for_backward(sparse, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.spmm_add_add_backward_cuda
        else:
            backward = spmm.spmm_add_add_backward_cpu
        sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, input_grad


class SPMMMinAddFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.spmm_min_add_forward_cuda
        else:
            forward = spmm.spmm_min_add_forward_cpu
        output = forward(sparse, input)
        ctx.save_for_backward(sparse, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.spmm_min_add_backward_cuda
        else:
            backward = spmm.spmm_min_add_backward_cpu
        sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, input_grad


class SPMMMaxAddFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.spmm_max_add_forward_cuda
        else:
            forward = spmm.spmm_max_add_forward_cpu
        output = forward(sparse, input)
        ctx.save_for_backward(sparse, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.spmm_max_add_backward_cuda
        else:
            backward = spmm.spmm_max_add_backward_cpu
        sparse_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, input_grad


class RSPMMAddMulFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, relation, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.rspmm_add_mul_forward_cuda
        else:
            forward = spmm.rspmm_add_mul_forward_cpu
        output = forward(sparse, relation, input)
        ctx.save_for_backward(sparse, relation, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.rspmm_add_mul_backward_cuda
        else:
            backward = spmm.rspmm_add_mul_backward_cpu
        sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, relation_grad, input_grad


class RSPMMMinMulFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, relation, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.rspmm_min_mul_forward_cuda
        else:
            forward = spmm.rspmm_min_mul_forward_cpu
        output = forward(sparse, relation, input)
        ctx.save_for_backward(sparse, relation, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.rspmm_min_mul_backward_cuda
        else:
            backward = spmm.rspmm_min_mul_backward_cpu
        sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, relation_grad, input_grad


class RSPMMMaxMulFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, relation, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.rspmm_max_mul_forward_cuda
        else:
            forward = spmm.rspmm_max_mul_forward_cpu
        output = forward(sparse, relation, input)
        ctx.save_for_backward(sparse, relation, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.rspmm_max_mul_backward_cuda
        else:
            backward = spmm.rspmm_max_mul_backward_cpu
        sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, relation_grad, input_grad


class RSPMMAddAddFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, relation, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.rspmm_add_add_forward_cuda
        else:
            forward = spmm.rspmm_add_add_forward_cpu
        output = forward(sparse, relation, input)
        ctx.save_for_backward(sparse, relation, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.rspmm_add_add_backward_cuda
        else:
            backward = spmm.rspmm_add_add_backward_cpu
        sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, relation_grad, input_grad


class RSPMMMinAddFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, relation, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.rspmm_min_add_forward_cuda
        else:
            forward = spmm.rspmm_min_add_forward_cpu
        output = forward(sparse, relation, input)
        ctx.save_for_backward(sparse, relation, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.rspmm_min_add_backward_cuda
        else:
            backward = spmm.rspmm_min_add_backward_cpu
        sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, relation_grad, input_grad


class RSPMMMaxAddFunction(autograd.Function):

    @staticmethod
    def forward(ctx, sparse, relation, input):
        assert sparse.is_coalesced()
        if input.device.type == "cuda":
            forward = spmm.rspmm_max_add_forward_cuda
        else:
            forward = spmm.rspmm_max_add_forward_cpu
        output = forward(sparse, relation, input)
        ctx.save_for_backward(sparse, relation, input, output)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        if output_grad.device.type == "cuda":
            backward = spmm.rspmm_max_add_backward_cuda
        else:
            backward = spmm.rspmm_max_add_backward_cpu
        sparse_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
        if not ctx.saved_tensors[0].requires_grad:
            sparse_grad = None
        return sparse_grad, relation_grad, input_grad


[docs]def generalized_spmm(sparse, input, sum="add", mul="mul"): r""" Generalized sparse-dense matrix multiplication. This function computes the matrix multiplication of a sparse matrix and a dense input matrix. The output dense matrix satisfies .. math:: output_{i,k} = \bigoplus_{j: sparse_{i,j} \neq 0} sparse_{i,j} \otimes input_{j,k} where :math:`\oplus` and :math:`\otimes` are the summation and the multiplication operators respectively. .. warning:: Gradient w.r.t. the sparse matrix is only computed for non-zero entries of the sparse matrix. This behaves differently from dense-dense matrix multiplication with zero entries. Parameters: sparse (SparseTensor): 2D sparse tensor input (Tensor): 2D dense tensor sum (str, optional): generalized summation operator. Available operators are ``add``, ``min`` and ``max``. mul (str, optional): generalized multiplication operator. Available operators are ``add`` and ``mul``. """ name = "SPMM%s%sFunction" % (sum.capitalize(), mul.capitalize()) if not hasattr(module, name): raise ValueError("No generalized spmm implementation found for summation `%s` and multiplication `%s`" % (sum, mul)) Function = getattr(module, name) return Function.apply(sparse.coalesce(), input)
[docs]def generalized_rspmm(sparse, relation, input, sum="add", mul="mul"): r""" Generalized relational sparse-dense matrix multiplication. This function computes the matrix multiplication of a sparse matrix, a dense relation matrix and a dense input matrix. The output dense matrix satisfies .. math:: output_{i,l} = \bigoplus_{j,k: sparse_{i,j,k} \neq 0} sparse_{i, j, k} \times (relation_{k,l} \otimes input_{j,l}) where :math:`\oplus` and :math:`\otimes` are the summation and the multiplication operators respectively. .. warning:: Gradient w.r.t. the sparse matrix is only computed for non-zero entries of the sparse matrix. This behaves differently from dense-dense matrix multiplication with zero entries. Parameters: sparse (SparseTensor): 3D sparse tensor relation (Tensor): 2D dense tensor input (Tensor): 2D dense tensor sum (str, optional): generalized summation operator. Available operators are ``add``, ``min`` and ``max``. mul (str, optional): generalized multiplication operator. Available operators are ``add`` and ``mul``. """ name = "RSPMM%s%sFunction" % (sum.capitalize(), mul.capitalize()) if not hasattr(module, name): raise ValueError("No generalized rspmm implementation found for summation `%s` and multiplication `%s`" % (sum, mul)) Function = getattr(module, name) return Function.apply(sparse.coalesce(), relation, input)