Source code for torchdrug.utils.decorator

import re
import inspect
import warnings
import functools

from decorator import decorator

import torch
from torch import nn

from torchdrug import data


[docs]def copy_args(obj, args=None, ignore=None): """ Copy argument documentation from another function to fill the document of \*\*kwargs in this function. This class should be applied as a decorator. Parameters: obj (object): object to copy document from args (tuple of str, optional): arguments to copy. By default, it copies all argument documentation from ``obj``, except those already exist in the current function. ignore (tuple of str, optional): arguments to ignore """ def wrapper(obj): sig = get_signature(obj) parameters = list(sig.parameters.values()) if parameters[0].name == "cls" or parameters[0].name == "self": parameters.pop(0) docs = get_param_docs(obj) if len(docs) != len(parameters): raise ValueError("Fail to parse the docstring of `%s`. " "Inconsistent number of parameters in signature and docstring." % obj.__name__) new_params = [] new_docs = [] param_names = {p.name for p in parameters} for param, doc in zip(parameters, docs): if param.kind == inspect.Parameter.VAR_POSITIONAL: for arg in from_args: if arg.name in param_names: continue new_params.append(arg) new_docs.append(from_docs[arg.name]) elif param.kind == inspect.Parameter.VAR_KEYWORD: for kwarg in from_kwargs: if kwarg.name in param_names: continue new_params.append(kwarg) new_docs.append(from_docs[kwarg.name]) else: new_params.append(param) new_docs.append(doc) new_sig = sig.replace(parameters=new_params) set_signature(obj, new_sig) set_param_docs(obj, new_docs) return obj from_obj = obj if args is not None: args = set(args) if ignore is not None: ignore = set(ignore) sig = get_signature(from_obj) parameters = list(sig.parameters.values()) if parameters[0].name == "cls" or parameters[0].name == "self": parameters.pop(0) from_args = [] from_kwargs = [] for param in parameters: if (args is None or param.name in args) and (ignore is None or param.name not in ignore): if param.default == inspect._empty: from_args.append(param) else: from_kwargs.append(param) from_docs = get_param_docs(from_obj, as_dict=True) if len(from_docs) != len(parameters): raise ValueError("Fail to parse the docstring of `%s`. " "Inconsistent number of parameters in signature and docstring." % from_obj.__name__) return wrapper
[docs]class cached_property(property): """ Cache the property once computed. """ def __init__(self, func): self.func = func self.__doc__ = func.__doc__ def __get__(self, obj, cls): if obj is None: return self result = self.func(obj) obj.__dict__[self.func.__name__] = result return result
def cached(forward, debug=False): """ Cache the result of last function call. """ @decorator def wrapper(forward, self, *args, **kwargs): def equal(x, y): if isinstance(x, nn.Parameter): x = x.data if isinstance(y, nn.Parameter): y = y.data if type(x) != type(y): return False if isinstance(x, torch.Tensor): return x.shape == y.shape and (x == y).all() elif isinstance(x, data.Graph): if x.num_node != y.num_node or x.num_edge != y.num_edge or x.num_relation != y.num_relation: return False edge_feature = getattr(x, "edge_feature", torch.tensor(0, device=x.device)) y_edge_feature = getattr(y, "edge_feature", torch.tensor(0, device=y.device)) if edge_feature.shape != y_edge_feature.shape: return False return (x.edge_list == y.edge_list).all() and (x.edge_weight == y.edge_weight).all() \ and (edge_feature == y_edge_feature).all() else: return x == y if self.training: return forward(self, *args, **kwargs) sig = inspect.signature(forward) func = sig.bind(self, *args, **kwargs) func.apply_defaults() arguments = func.arguments.copy() arguments.pop(next(iter(arguments.keys()))) if hasattr(self, "_forward_cache"): hit = True message = [] for k, v in arguments.items(): if not equal(self._forward_cache[k], v): hit = False message.append("%s: miss" % k) break message.append("%s: hit" % k) if debug: print("[cache] %s" % ", ".join(message)) else: hit = False if debug: print("[cache] cold start") if hit: return self._forward_cache["result"] else: self._forward_cache = {} for k, v in arguments.items(): if isinstance(v, torch.Tensor) or isinstance(v, data.Graph): v = v.detach() self._forward_cache[k] = v result = forward(self, *args, **kwargs) self._forward_cache["result"] = result return result return wrapper(forward) def deprecated_alias(**alias): """ Handle argument alias for a function and output deprecated warnings. """ def decorate(obj): @functools.wraps(obj) def wrapper(*args, **kwargs): for key, value in alias.items(): if key in kwargs: if value in kwargs: raise TypeError("%s() got values for both `%s` and `%s`" % (obj.__name__, value, key)) warnings.warn("%s(): argument `%s` is deprecated in favor of `%s`" % (obj.__name__, key, value)) kwargs[value] = kwargs.pop(key) return obj(*args, **kwargs) sig = get_signature(obj) parameters = list(sig.parameters.values()) param_docs = get_param_docs(obj, as_dict=True) docs = list(param_docs.values()) alias_params = [] alias_docs = [] for key, value in alias.items(): param = inspect.Parameter(key, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None, annotation=sig.parameters[value].annotation) alias_params.append(param) param_doc = param_docs[value] match = re.search(r" \(.*?\)", param_doc) if match: type_str = match.group() else: type_str = "" alias_docs.append("%s%s: deprecated alias of ``%s``" % (key, type_str, value)) if parameters[-1].kind == inspect.Parameter.VAR_KEYWORD: new_params = parameters[:-1] + alias_params + parameters[-1:] new_docs = docs[:-1] + alias_docs + docs[-1:] else: new_params = parameters + alias_params new_docs = docs + alias_docs new_sig = sig.replace(parameters=new_params) set_signature(wrapper, new_sig) set_param_docs(wrapper, new_docs) return wrapper return decorate def get_param_docs(obj, as_dict=False): doc = obj.__doc__ or "" match = re.search(r"Parameters:\n", doc) if not match: return [] begin = match.end() indent = re.search(r"\s+", doc[begin:]).group() match = re.search(r"^(?!%s)" % indent, doc[begin:]) if match: end = begin + match.start() else: end = None param_docs = [] pattern = r"^%s\S.*(?:\n%s\s+\S.*)*" % (indent, indent) for match in re.finditer(pattern, doc[begin:end], re.MULTILINE): doc = match.group() doc = re.sub("^%s" % indent, "", doc, re.MULTILINE) # remove indent param_docs.append(doc) if as_dict: param_docs = {re.search("\S+", doc).group(): doc for doc in param_docs} return param_docs def set_param_docs(obj, param_docs): doc = obj.__doc__ or "" if isinstance(param_docs, dict): param_docs = param_docs.values() match = re.search(r"Parameters:\n", doc) if not match: indent = None for match in re.finditer(r"^(\s*)", doc): if indent is None or len(match.group(1)) < len(indent): indent = match.group(1) param_docs = [re.sub("^", indent, doc, re.MULTILINE) for doc in param_docs] # add indent param_docs = "\n".join(param_docs) doc = "\n".join([doc, "%sParameters" % indent, param_docs]) else: begin = match.end() indent = re.search(r"\s*", doc[begin:]).group() pattern = r"^%s\S.*(?:\n%s\s+\S.*)*(?:\n%s\S.*(?:\n%s\s+\S.*)*)*" % ((indent,) * 4) end = begin + re.search(pattern, doc[begin:], re.MULTILINE).end() param_docs = [re.sub("^", indent, doc, re.MULTILINE) for doc in param_docs] # add indent param_docs = "\n".join(param_docs) doc = "".join([doc[:begin], param_docs, doc[end:]]) obj.__doc__ = doc def get_signature(obj): if hasattr(obj, "__signature__"): # already overrided sig = obj.__signature__ elif inspect.isclass(obj): sig = inspect.signature(obj.__init__) else: sig = inspect.signature(obj) return sig def set_signature(obj, sig): doc = obj.__doc__ or "" match = re.search(r"^\s*\W+\(.*?\)( *-> *\W+)?", doc, re.MULTILINE) if not match: doc = "%s%s\n%s" % (obj.__name__, sig, doc) else: begin, end = match.span() doc = "".join([doc[:begin], obj.__name__, str(sig), doc[end:]]) obj.__doc__ = doc obj.__signature__ = sig