import copy
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch_scatter import scatter_min
from torchdrug import core, tasks, layers
from torchdrug.data import constant
from torchdrug.layers import functional
from torchdrug.core import Registry as R
[docs]@R.register("tasks.EdgePrediction")
class EdgePrediction(tasks.Task, core.Configurable):
    """
    Edge prediction task proposed in `Inductive Representation Learning on Large Graphs`_.
    .. _Inductive Representation Learning on Large Graphs:
        https://arxiv.org/abs/1706.02216
    Parameters:
        model (nn.Module): node representation model
    """
    def __init__(self, model):
        super(EdgePrediction, self).__init__()
        self.model = model
    def _get_directed(self, graph):
        mask = graph.edge_list[:, 0] < graph.edge_list[:, 1]
        graph = graph.edge_mask(mask)
        return graph
    def predict(self, batch, all_loss=None, metric=None):
        graph = batch["graph"]
        output = self.model(graph, graph.node_feature.float(), all_loss, metric)
        node_feature = output["node_feature"]
        graph = self._get_directed(graph)
        node_in, node_out = graph.edge_list.t()[:2]
        neg_index = (torch.rand(2, graph.num_edge, device=self.device) * graph.num_nodes[graph.edge2graph]).long()
        neg_index = neg_index + (graph.num_cum_nodes - graph.num_nodes)[graph.edge2graph]
        node_in = torch.cat([node_in, neg_index[0]])
        node_out = torch.cat([node_out, neg_index[1]])
        pred = torch.einsum("bd, bd -> b", node_feature[node_in], node_feature[node_out])
        return pred
    def target(self, batch):
        graph = batch["graph"]
        target = torch.ones(graph.num_edge, device=self.device)
        target[len(target) // 2:] = 0
        return target
    def evaluate(self, pred, target):
        metric = {}
        accuracy = ((pred > 0) == (target > 0.5)).float().mean()
        name = tasks._get_metric_name("acc")
        metric[name] = accuracy
        return metric
    def forward(self, batch):
        """"""
        all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
        metric = {}
        pred = self.predict(batch, all_loss, metric)
        target = self.target(batch)
        loss = F.binary_cross_entropy_with_logits(pred, target)
        name = tasks._get_criterion_name("bce")
        metric[name] = loss
        metric.update(self.evaluate(pred, target))
        all_loss += loss
        return all_loss, metric 
[docs]@R.register("tasks.AttributeMasking")
class AttributeMasking(tasks.Task, core.Configurable):
    """
    Attribute masking proposed in `Strategies for Pre-training Graph Neural Networks`_.
    .. _Strategies for Pre-training Graph Neural Networks:
        https://arxiv.org/abs/1905.12265
    Parameters:
        model (nn.Module): node representation model
        mask_rate (float, optional): rate of masked nodes
        num_mlp_layer (int, optional): number of MLP layers
    """
    def __init__(self, model, mask_rate=0.15, num_mlp_layer=2, graph_construction_model=None):
        super(AttributeMasking, self).__init__()
        self.model = model
        self.mask_rate = mask_rate
        self.num_mlp_layer = num_mlp_layer
        self.graph_construction_model = graph_construction_model
    def preprocess(self, train_set, valid_set, test_set):
        data = train_set[0]
        self.view = getattr(data["graph"], "view", "atom")
        if hasattr(self.model, "node_output_dim"):
            model_output_dim = self.model.node_output_dim
        else:
            model_output_dim = self.model.output_dim
        if self.view == "atom":
            num_label = constant.NUM_ATOM
        else:
            num_label = constant.NUM_AMINO_ACID
        self.mlp = layers.MLP(model_output_dim, [model_output_dim] * (self.num_mlp_layer - 1) + [num_label])
    def predict_and_target(self, batch, all_loss=None, metric=None):
        graph = batch["graph"]
        if self.graph_construction_model:
            graph = self.graph_construction_model.apply_node_layer(graph)
        num_nodes = graph.num_nodes if self.view in ["atom", "node"] else graph.num_residues
        num_cum_nodes = num_nodes.cumsum(0)
        num_samples = (num_nodes * self.mask_rate).long().clamp(1)
        num_sample = num_samples.sum()
        sample2graph = torch.repeat_interleave(num_samples)
        node_index = (torch.rand(num_sample, device=self.device) * num_nodes[sample2graph]).long()
        node_index = node_index + (num_cum_nodes - num_nodes)[sample2graph]
        if self.view == "atom":
            target = graph.atom_type[node_index]
            input = graph.node_feature.float()
            input[node_index] = 0
        else:
            target = graph.residue_type[node_index]
            with graph.residue():
                graph.residue_feature[node_index] = 0
                graph.residue_type[node_index] = 0
            # Generate masked edge features. Any better implementation?
            if self.graph_construction_model:
                graph = self.graph_construction_model.apply_edge_layer(graph)
            input = graph.residue_feature.float()
        output = self.model(graph, input, all_loss, metric)
        if self.view in ["node", "atom"]:
            node_feature = output["node_feature"]
        else:
            node_feature = output.get("residue_feature", output.get("node_feature"))
        node_feature = node_feature[node_index]
        pred = self.mlp(node_feature)
        return pred, target
    def evaluate(self, pred, target):
        metric = {}
        accuracy = (pred.argmax(dim=-1) == target).float().mean()
        name = tasks._get_metric_name("acc")
        metric[name] = accuracy
        return metric
    def forward(self, batch):
        """"""
        all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
        metric = {}
        pred, target = self.predict_and_target(batch, all_loss, metric)
        metric.update(self.evaluate(pred, target))
        loss = F.cross_entropy(pred, target)
        name = tasks._get_criterion_name("ce")
        metric[name] = loss
        all_loss += loss
        return all_loss, metric 
[docs]@R.register("tasks.ContextPrediction")
class ContextPrediction(tasks.Task, core.Configurable):
    """
    Context prediction task proposed in `Strategies for Pre-training Graph Neural Networks`_.
    .. _Strategies for Pre-training Graph Neural Networks:
        https://arxiv.org/abs/1905.12265
    For a given center node, the subgraph is defined as a k-hop neighborhood (inclusive) around the selected node.
    The context graph is defined as the surrounding graph structure between r1- (exclusive) and r2-hop (inclusive)
    from the center node. Nodes between k- and r1-hop are picked as anchor nodes for the context representation.
    Parameters:
        model (nn.Module): node representation model for subgraphs.
        context_model (nn.Module, optional): node representation model for context graphs.
            By default, use the same architecture as ``model`` without parameter sharing.
        k (int, optional): radius for subgraphs
        r1 (int, optional): inner radius for context graphs
        r2 (int, optional): outer radius for context graphs
        readout (nn.Module, optional): readout function over context anchor nodes
        num_negative (int, optional): number of negative samples per positive sample
    """
    def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1):
        super(ContextPrediction, self).__init__()
        self.model = model
        self.k = k
        self.r1 = r1
        self.r2 = r2
        self.num_negative = num_negative
        assert r1 < k < r2
        if context_model is None:
            self.context_model = copy.deepcopy(model)
        else:
            self.context_model = context_model
        if readout == "sum":
            self.readout = layers.SumReadout()
        elif readout == "mean":
            self.readout = layers.MeanReadout()
        else:
            raise ValueError("Unknown readout `%s`" % readout)
    def substruct_and_context(self, graph):
        center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long()
        center_index = center_index + graph.num_cum_nodes - graph.num_nodes
        dist = torch.full((graph.num_node,), self.r2 + 1, dtype=torch.long, device=self.device)
        dist[center_index] = 0
        # single source shortest path
        node_in, node_out = graph.edge_list.t()[:2]
        for i in range(self.r2):
            new_dist = scatter_min(dist[node_in], node_out, dim_size=graph.num_node)[0] + 1
            dist = torch.min(dist, new_dist)
        substruct_mask = dist <= self.k
        context_mask = (dist > self.r1) & (dist <= self.r2)
        is_center_node = functional.as_mask(center_index, graph.num_node)
        is_anchor_node = (dist > self.r1) & (dist <= self.k)
        substruct = graph.clone()
        context = graph.clone()
        with substruct.node():
            substruct.is_center_node = is_center_node
        with context.node():
            context.is_anchor_node = is_anchor_node
        substruct = substruct.subgraph(substruct_mask)
        context = context.subgraph(context_mask)
        valid = context.num_nodes > 0
        substruct = substruct[valid]
        context = context[valid]
        return substruct, context
    def predict_and_target(self, batch, all_loss=None, metric=None):
        graph = batch["graph"]
        substruct, context = self.substruct_and_context(graph)
        anchor = context.subgraph(context.is_anchor_node)
        substruct_output = self.model(substruct, substruct.node_feature.float(), all_loss, metric)
        substruct_feature = substruct_output["node_feature"][substruct.is_center_node]
        context_output = self.context_model(context, context.node_feature.float(), all_loss, metric)
        anchor_feature = context_output["node_feature"][context.is_anchor_node]
        context_feature = self.readout(anchor, anchor_feature)
        shift = torch.arange(self.num_negative, device=self.device) + 1
        neg_index = (torch.arange(len(context), device=self.device).unsqueeze(-1) + shift) % len(context) # (batch_size, num_negative)
        context_feature = torch.cat([context_feature.unsqueeze(1), context_feature[neg_index]], dim=1)
        substruct_feature = substruct_feature.unsqueeze(1).expand_as(context_feature)
        pred = torch.einsum("bnd, bnd -> bn", substruct_feature, context_feature)
        target = torch.zeros_like(pred)
        target[:, 0] = 1
        return pred, target
    def evaluate(self, pred, target):
        metric = {}
        accuracy = ((pred > 0) == (target > 0.5)).float().mean()
        name = tasks._get_metric_name("acc")
        metric[name] = accuracy
        return metric
    def forward(self, batch):
        """"""
        all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
        metric = {}
        pred, target = self.predict_and_target(batch, all_loss, metric)
        metric.update(self.evaluate(pred, target))
        loss = F.binary_cross_entropy_with_logits(pred, target)
        name = tasks._get_criterion_name("bce")
        metric[name] = loss
        all_loss += loss
        return all_loss, metric 
[docs]@R.register("tasks.DistancePrediction")
class DistancePrediction(tasks.Task, core.Configurable):
    """
    Pairwise spatial distance prediction task proposed in
    `Protein Representation Learning by Geometric Structure Pretraining`_.
    .. _Protein Representation Learning by Geometric Structure Pretraining:
        https://arxiv.org/pdf/2203.06125.pdf
    Randomly select some edges and predict the lengths of the edges using the representations of two nodes.
    The selected edges are removed from the input graph to prevent trivial solutions.
    Parameters:
        model (nn.Module): node representation model
        num_sample (int, optional): number of edges selected from each graph
        num_mlp_layer (int, optional): number of MLP layers in distance predictor
        graph_construction_model (nn.Module, optional): graph construction model
    """
    def __init__(self, model, num_sample=256, num_mlp_layer=2, graph_construction_model=None):
        super(DistancePrediction, self).__init__()
        self.model = model
        self.num_sample = num_sample
        self.num_mlp_layer = num_mlp_layer
        self.graph_construction_model = graph_construction_model
        self.mlp = layers.MLP(2 * model.output_dim, [model.output_dim] * (num_mlp_layer - 1) + [1])
    def predict_and_target(self, batch, all_loss=None, metric=None):
        graph = batch["graph"]
        if self.graph_construction_model:
            graph = self.graph_construction_model(graph)
        node_in, node_out = graph.edge_list[:, :2].t()
        indices = torch.arange(graph.num_edge, device=self.device)
        indices = functional.variadic_sample(indices, graph.num_edges, self.num_sample).flatten(-2, -1)
        node_i = node_in[indices]
        node_j = node_out[indices]
        graph = graph.edge_mask(~functional.as_mask(indices, graph.num_edge))
        # Calculate distance
        target = (graph.node_position[node_i] - graph.node_position[node_j]).norm(p=2, dim=-1)
        output = self.model(graph, graph.node_feature.float() , all_loss, metric)["node_feature"]
        node_feature = torch.cat([output[node_i], output[node_j]], dim=-1)
        pred = self.mlp(node_feature).squeeze(-1)
        return pred, target
    def evaluate(self, pred, target):
        metric = {}
        mse = F.mse_loss(pred, target)
        name = tasks._get_metric_name("mse")
        metric[name] = mse
        return metric
[docs]    def forward(self, batch):
        all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
        metric = {}
        pred, target = self.predict_and_target(batch, all_loss, metric)
        metric.update(self.evaluate(pred, target))
        loss = F.mse_loss(pred, target)
        name = tasks._get_criterion_name("mse")
        metric[name] = loss
        all_loss += loss
        return all_loss, metric  
[docs]@R.register("tasks.AnglePrediction")
class AnglePrediction(tasks.Task, core.Configurable):
    """
    Angle prediction task proposed in `Protein Representation Learning by Geometric Structure Pretraining`_.
    .. _Protein Representation Learning by Geometric Structure Pretraining:
        https://arxiv.org/pdf/2203.06125.pdf
    Randomly select pairs of adjacent edges and predict the angles between them using the representations of three
    nodes. The selected edges are removed from the input graph to prevent trivial solutions.
    Parameters:
        model (nn.Module): node representation model
        num_sample (int, optional): number of edge pairs selected from each graph
        num_class (int, optional): number of classes to discretize the angles
        num_mlp_layer (int, optional): number of MLP layers in angle predictor
        graph_construction_model (nn.Module, optional): graph construction model
    """
    def __init__(self, model, num_sample=256, num_class=8, num_mlp_layer=2, graph_construction_model=None):
        super(AnglePrediction, self).__init__()
        self.model = model
        self.num_sample = num_sample
        self.num_mlp_layer = num_mlp_layer
        self.graph_construction_model = graph_construction_model
        boundary = torch.arange(0, math.pi, math.pi / num_class)
        self.register_buffer("boundary", boundary)
        self.mlp = layers.MLP(3 * model.output_dim, [model.output_dim] * (num_mlp_layer - 1) + [num_class])
    def predict_and_target(self, batch, all_loss=None, metric=None):
        graph = batch["graph"]
        if self.graph_construction_model:
            graph = self.graph_construction_model(graph)
        node_in, node_out = graph.edge_list[:, :2].t()
        line_graph = graph.line_graph()
        edge_in, edge_out = line_graph.edge_list[:, :2].t()
        is_self_loop1 = (edge_in == edge_out)
        is_self_loop2 = (node_in[edge_in] == node_out[edge_out])
        is_remove = is_self_loop1 | is_self_loop2
        line_graph = line_graph.edge_mask(~is_remove)
        edge_in, edge_out = line_graph.edge_list[:, :2].t()
        # (k->j) - (j->i)
        node_i = node_out[edge_out]
        node_j = node_in[edge_out]
        node_k = node_in[edge_in]
        indices = torch.arange(line_graph.num_edge, device=self.device)
        indices = functional.variadic_sample(indices, line_graph.num_edges, self.num_sample).flatten(-2, -1)
        node_i = node_i[indices]
        node_j = node_j[indices]
        node_k = node_k[indices]
        mask = torch.ones((graph.num_edge,), device=graph.device, dtype=torch.bool)
        mask[edge_out[indices]] = 0
        mask[edge_in[indices]] = 0
        graph = graph.edge_mask(mask)
        # Calculate angles
        vector1 = graph.node_position[node_i] - graph.node_position[node_j]
        vector2 = graph.node_position[node_k] - graph.node_position[node_j]
        x = (vector1 * vector2).sum(dim=-1)
        y = torch.cross(vector1, vector2).norm(dim=-1)
        angle = torch.atan2(y, x)
        target = torch.bucketize(angle, self.boundary, right=True) - 1
        output = self.model(graph, graph.node_feature.float() , all_loss, metric)["node_feature"]
        node_feature = torch.cat([output[node_i], output[node_j], output[node_k]], dim=-1)
        pred = self.mlp(node_feature)
        return pred, target
    def evaluate(self, pred, target):
        metric = {}
        accuracy = (pred.argmax(dim=-1) == target).float().mean()
        name = tasks._get_metric_name("acc")
        metric[name] = accuracy
        return metric
[docs]    def forward(self, batch):
        all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
        metric = {}
        pred, target = self.predict_and_target(batch, all_loss, metric)
        metric.update(self.evaluate(pred, target))
        loss = F.cross_entropy(pred, target)
        name = tasks._get_criterion_name("ce")
        metric[name] = loss
        all_loss += loss
        return all_loss, metric  
[docs]@R.register("tasks.DihedralPrediction")
class DihedralPrediction(tasks.Task, core.Configurable):
    """
    Dihedral prediction task proposed in `Protein Representation Learning by Geometric Structure Pretraining`_.
    .. _Protein Representation Learning by Geometric Structure Pretraining:
        https://arxiv.org/pdf/2203.06125.pdf
    Randomly select three consecutive edges and predict the dihedrals among them using the representations of four
    nodes. The selected edges are removed from the input graph to prevent trivial solutions.
    Parameters:
        model (nn.Module): node representation model
        num_sample (int, optional): number of edge triplets selected from each graph
        num_class (int, optional): number of classes for discretizing the dihedrals
        num_mlp_layer (int, optional): number of MLP layers in dihedral angle predictor
        graph_construction_model (nn.Module, optional): graph construction model
    """
    def __init__(self, model, num_sample=256, num_class=8, num_mlp_layer=2, graph_construction_model=None):
        super(DihedralPrediction, self).__init__()
        self.model = model
        self.num_sample = num_sample
        self.num_mlp_layer = num_mlp_layer
        self.graph_construction_model = graph_construction_model
        boundary = torch.arange(0, math.pi, math.pi / num_class)
        self.register_buffer("boundary", boundary)
        self.mlp = layers.MLP(4 * model.output_dim, [model.output_dim] * (num_mlp_layer - 1) + [num_class])
    def predict_and_target(self, batch, all_loss=None, metric=None):
        graph = batch["graph"]
        if self.graph_construction_model:
            graph = self.graph_construction_model(graph)
        node_in, node_out = graph.edge_list[:, :2].t()
        line_graph = graph.line_graph()
        edge_in, edge_out = line_graph.edge_list[:, :2].t()
        is_self_loop1 = (edge_in == edge_out)
        is_self_loop2 = (node_in[edge_in] == node_out[edge_out])
        is_remove = is_self_loop1 | is_self_loop2
        line_graph = line_graph.edge_mask(~is_remove)
        edge_in, edge_out = line_graph.edge_list[:, :2].t()
        line2_graph = line_graph.line_graph()
        edge2_in, edge2_out = line2_graph.edge_list.t()[:2]
        is_self_loop1 = (edge2_in == edge2_out)
        is_self_loop2 = (edge_in[edge2_in] == edge_out[edge2_out])
        is_remove = is_self_loop1 | is_self_loop2
        line2_graph = line2_graph.edge_mask(~is_remove)
        edge2_in, edge2_out = line2_graph.edge_list[:, :2].t()
        # (k->t->j) - (t->j->i)
        node_i = node_out[edge_out[edge2_out]]
        node_j = node_in[edge_out[edge2_out]]
        node_t = node_in[edge_out[edge2_in]]
        node_k = node_in[edge_in[edge2_in]]
        indices = torch.arange(line2_graph.num_edge, device=self.device)
        indices = functional.variadic_sample(indices, line2_graph.num_edges, self.num_sample).flatten(-2, -1)
        node_i = node_i[indices]
        node_j = node_j[indices]
        node_t = node_t[indices]
        node_k = node_k[indices]
        mask = torch.ones((graph.num_edge,), device=graph.device, dtype=torch.bool)
        mask[edge_out[edge2_out[indices]]] = 0
        mask[edge_out[edge2_in[indices]]] = 0
        mask[edge_in[edge2_in[indices]]] = 0
        graph = graph.edge_mask(mask)
        v_ctr = graph.node_position[node_t] - graph.node_position[node_j]   # (A, 3)
        v1 = graph.node_position[node_i] - graph.node_position[node_j]
        v2 = graph.node_position[node_k] - graph.node_position[node_t]
        n1 = torch.cross(v_ctr, v1, dim=-1) # Normal vectors of the two planes
        n2 = torch.cross(v_ctr, v2, dim=-1)
        a = (n1 * n2).sum(dim=-1)
        b = torch.cross(n1, n2).norm(dim=-1)
        dihedral = torch.atan2(b, a)
        target = torch.bucketize(dihedral, self.boundary, right=True) - 1
        output = self.model(graph, graph.node_feature.float() , all_loss, metric)["node_feature"]
        node_feature = torch.cat([output[node_i], output[node_j], output[node_k], output[node_t]], dim=-1)
        pred = self.mlp(node_feature)
        return pred, target
    def evaluate(self, pred, target):
        metric = {}
        accuracy = (pred.argmax(dim=-1) == target).float().mean()
        name = tasks._get_metric_name("acc")
        metric[name] = accuracy
        return metric
[docs]    def forward(self, batch):
        all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
        metric = {}
        pred, target = self.predict_and_target(batch, all_loss, metric)
        metric.update(self.evaluate(pred, target))
        loss = F.cross_entropy(pred, target)
        name = tasks._get_criterion_name("ce")
        metric[name] = loss
        all_loss += loss
        return all_loss, metric