import re
import inspect
import warnings
import functools
from decorator import decorator
import torch
from torch import nn
from torchdrug import data
[docs]def copy_args(obj, args=None, ignore=None):
    """
    Copy argument documentation from another function to fill the document of \*\*kwargs in this function.
    This class should be applied as a decorator.
    Parameters:
        obj (object): object to copy document from
        args (tuple of str, optional): arguments to copy.
            By default, it copies all argument documentation from ``obj``,
            except those already exist in the current function.
        ignore (tuple of str, optional): arguments to ignore
    """
    def wrapper(obj):
        sig = get_signature(obj)
        parameters = list(sig.parameters.values())
        if parameters[0].name == "cls" or parameters[0].name == "self":
            parameters.pop(0)
        docs = get_param_docs(obj)
        if len(docs) != len(parameters):
            raise ValueError("Fail to parse the docstring of `%s`. "
                             "Inconsistent number of parameters in signature and docstring." % obj.__name__)
        new_params = []
        new_docs = []
        param_names = {p.name for p in parameters}
        for param, doc in zip(parameters, docs):
            if param.kind == inspect.Parameter.VAR_POSITIONAL:
                for arg in from_args:
                    if arg.name in param_names:
                        continue
                    new_params.append(arg)
                    new_docs.append(from_docs[arg.name])
            elif param.kind == inspect.Parameter.VAR_KEYWORD:
                for kwarg in from_kwargs:
                    if kwarg.name in param_names:
                        continue
                    new_params.append(kwarg)
                    new_docs.append(from_docs[kwarg.name])
            else:
                new_params.append(param)
                new_docs.append(doc)
        new_sig = sig.replace(parameters=new_params)
        set_signature(obj, new_sig)
        set_param_docs(obj, new_docs)
        return obj
    from_obj = obj
    if args is not None:
        args = set(args)
    if ignore is not None:
        ignore = set(ignore)
    sig = get_signature(from_obj)
    parameters = list(sig.parameters.values())
    if parameters[0].name == "cls" or parameters[0].name == "self":
        parameters.pop(0)
    from_args = []
    from_kwargs = []
    for param in parameters:
        if (args is None or param.name in args) and (ignore is None or param.name not in ignore):
            if param.default == inspect._empty:
                from_args.append(param)
            else:
                from_kwargs.append(param)
    from_docs = get_param_docs(from_obj, as_dict=True)
    if len(from_docs) != len(parameters):
        raise ValueError("Fail to parse the docstring of `%s`. "
                         "Inconsistent number of parameters in signature and docstring." % from_obj.__name__)
    return wrapper 
[docs]class cached_property(property):
    """
    Cache the property once computed.
    """
    def __init__(self, func):
        self.func = func
        self.__doc__ = func.__doc__
    def __get__(self, obj, cls):
        if obj is None:
            return self
        result = self.func(obj)
        obj.__dict__[self.func.__name__] = result
        return result 
def cached(forward, debug=False):
    """
    Cache the result of last function call.
    """
    @decorator
    def wrapper(forward, self, *args, **kwargs):
        def equal(x, y):
            if isinstance(x, nn.Parameter):
                x = x.data
            if isinstance(y, nn.Parameter):
                y = y.data
            if type(x) != type(y):
                return False
            if isinstance(x, torch.Tensor):
                return x.shape == y.shape and (x == y).all()
            elif isinstance(x, data.Graph):
                if x.num_node != y.num_node or x.num_edge != y.num_edge or x.num_relation != y.num_relation:
                    return False
                edge_feature = getattr(x, "edge_feature", torch.tensor(0, device=x.device))
                y_edge_feature = getattr(y, "edge_feature", torch.tensor(0, device=y.device))
                if edge_feature.shape != y_edge_feature.shape:
                    return False
                return (x.edge_list == y.edge_list).all() and (x.edge_weight == y.edge_weight).all() \
                       and (edge_feature == y_edge_feature).all()
            else:
                return x == y
        if self.training:
            return forward(self, *args, **kwargs)
        sig = inspect.signature(forward)
        func = sig.bind(self, *args, **kwargs)
        func.apply_defaults()
        arguments = func.arguments.copy()
        arguments.pop(next(iter(arguments.keys())))
        if hasattr(self, "_forward_cache"):
            hit = True
            message = []
            for k, v in arguments.items():
                if not equal(self._forward_cache[k], v):
                    hit = False
                    message.append("%s: miss" % k)
                    break
                message.append("%s: hit" % k)
            if debug:
                print("[cache] %s" % ", ".join(message))
        else:
            hit = False
            if debug:
                print("[cache] cold start")
        if hit:
            return self._forward_cache["result"]
        else:
            self._forward_cache = {}
        for k, v in arguments.items():
            if isinstance(v, torch.Tensor) or isinstance(v, data.Graph):
                v = v.detach()
            self._forward_cache[k] = v
        result = forward(self, *args, **kwargs)
        self._forward_cache["result"] = result
        return result
    return wrapper(forward)
def deprecated_alias(**alias):
    """
    Handle argument alias for a function and output deprecated warnings.
    """
    def decorate(obj):
        @functools.wraps(obj)
        def wrapper(*args, **kwargs):
            for key, value in alias.items():
                if key in kwargs:
                    if value in kwargs:
                        raise TypeError("%s() got values for both `%s` and `%s`" % (obj.__name__, value, key))
                    warnings.warn("%s(): argument `%s` is deprecated in favor of `%s`" % (obj.__name__, key, value))
                    kwargs[value] = kwargs.pop(key)
            return obj(*args, **kwargs)
        sig = get_signature(obj)
        parameters = list(sig.parameters.values())
        param_docs = get_param_docs(obj, as_dict=True)
        docs = list(param_docs.values())
        alias_params = []
        alias_docs = []
        for key, value in alias.items():
            param = inspect.Parameter(key, inspect.Parameter.POSITIONAL_OR_KEYWORD,
                                      default=None, annotation=sig.parameters[value].annotation)
            alias_params.append(param)
            param_doc = param_docs[value]
            match = re.search(r" \(.*?\)", param_doc)
            if match:
                type_str = match.group()
            else:
                type_str = ""
            alias_docs.append("%s%s: deprecated alias of ``%s``" % (key, type_str, value))
        if parameters[-1].kind == inspect.Parameter.VAR_KEYWORD:
            new_params = parameters[:-1] + alias_params + parameters[-1:]
            new_docs = docs[:-1] + alias_docs + docs[-1:]
        else:
            new_params = parameters + alias_params
            new_docs = docs + alias_docs
        new_sig = sig.replace(parameters=new_params)
        set_signature(wrapper, new_sig)
        set_param_docs(wrapper, new_docs)
        return wrapper
    return decorate
def get_param_docs(obj, as_dict=False):
    doc = obj.__doc__ or ""
    match = re.search(r"Parameters:\n", doc)
    if not match:
        return []
    begin = match.end()
    indent = re.search(r"\s+", doc[begin:]).group()
    match = re.search(r"^(?!%s)" % indent, doc[begin:])
    if match:
        end = begin + match.start()
    else:
        end = None
    param_docs = []
    pattern = r"^%s\S.*(?:\n%s\s+\S.*)*" % (indent, indent)
    for match in re.finditer(pattern, doc[begin:end], re.MULTILINE):
        doc = match.group()
        doc = re.sub("^%s" % indent, "", doc, re.MULTILINE)  # remove indent
        param_docs.append(doc)
    if as_dict:
        param_docs = {re.search("\S+", doc).group(): doc for doc in param_docs}
    return param_docs
def set_param_docs(obj, param_docs):
    doc = obj.__doc__ or ""
    if isinstance(param_docs, dict):
        param_docs = param_docs.values()
    match = re.search(r"Parameters:\n", doc)
    if not match:
        indent = None
        for match in re.finditer(r"^(\s*)", doc):
            if indent is None or len(match.group(1)) < len(indent):
                indent = match.group(1)
        param_docs = [re.sub("^", indent, doc, re.MULTILINE) for doc in param_docs]  # add indent
        param_docs = "\n".join(param_docs)
        doc = "\n".join([doc, "%sParameters" % indent, param_docs])
    else:
        begin = match.end()
        indent = re.search(r"\s*", doc[begin:]).group()
        pattern = r"^%s\S.*(?:\n%s\s+\S.*)*(?:\n%s\S.*(?:\n%s\s+\S.*)*)*" % ((indent,) * 4)
        end = begin + re.search(pattern, doc[begin:], re.MULTILINE).end()
        param_docs = [re.sub("^", indent, doc, re.MULTILINE) for doc in param_docs]  # add indent
        param_docs = "\n".join(param_docs)
        doc = "".join([doc[:begin], param_docs, doc[end:]])
    obj.__doc__ = doc
def get_signature(obj):
    if hasattr(obj, "__signature__"):  # already overrided
        sig = obj.__signature__
    elif inspect.isclass(obj):
        sig = inspect.signature(obj.__init__)
    else:
        sig = inspect.signature(obj)
    return sig
def set_signature(obj, sig):
    doc = obj.__doc__ or ""
    match = re.search(r"^\s*\W+\(.*?\)( *-> *\W+)?", doc, re.MULTILINE)
    if not match:
        doc = "%s%s\n%s" % (obj.__name__, sig, doc)
    else:
        begin, end = match.span()
        doc = "".join([doc[:begin], obj.__name__, str(sig), doc[end:]])
    obj.__doc__ = doc
    obj.__signature__ = sig