torchdrug.core#
Configurable#
- class Configurable[source]#
Class for load/save configuration. It will automatically record every argument passed to the
__init__
function.This class is inspired by
state_dict()
in PyTorch, but designed for hyperparameters.Inherit this class to construct a configurable class.
>>> class MyClass(nn.Module, core.Configurable):
Note
Configurable
only applies to the current class rather than any derived class. For example, the following definition only records the arguments ofMyClass
.>>> class DerivedClass(MyClass):
In order to record the arguments of
DerivedClass
, explicitly specify the inheritance.>>> class DerivedClass(MyClass, core.Configurable):
To get the configuration of an instance, use
config_dict()
, which returns a dict of argument names and values. If an argument is also an instance ofConfigurable
, it will be recursively expanded in the dict. The configuration dict can be passed toload_config_dict()
to create a copy of the instance.For classes already registered in
Registry
, they can be directly created from theConfigurable
class. This is convenient for building models from configuration files.>>> config = models.GCN(128, [128]).config_dict() >>> gcn = Configurable.load_config_dict(config)
Engine#
- class Engine(task, train_set, valid_set, test_set, optimizer, scheduler=None, gpus=None, batch_size=1, gradient_interval=1, num_worker=0, logger='logging', log_interval=100)[source]#
General class that handles everything about training and test of a task.
This class can perform synchronous distributed parallel training over multiple CPUs or GPUs. To invoke parallel training, launch with one of the following commands.
Single-node multi-process case.
python -m torch.distributed.launch --nproc_per_node={number_of_gpus} {your_script.py} {your_arguments...}
Multi-node multi-process case.
python -m torch.distributed.launch --nnodes={number_of_nodes} --node_rank={rank_of_this_node} --nproc_per_node={number_of_gpus} {your_script.py} {your_arguments...}
If
preprocess()
is defined by the task, it will be applied totrain_set
,valid_set
andtest_set
.- Parameters
task (nn.Module) – task
train_set (data.Dataset) – training set
valid_set (data.Dataset) – validation set
test_set (data.Dataset) – test set
optimizer (optim.Optimizer) – optimizer
scheduler (lr_scheduler._LRScheduler, optional) – scheduler
gpus (list of int, optional) – GPU ids. By default, CPUs will be used. For multi-node multi-process case, repeat the GPU ids for each node.
batch_size (int, optional) – batch size of a single CPU / GPU
gradient_interval (int, optional) – perform a gradient update every n batches. This creates an equivalent batch size of
batch_size * gradient_interval
for optimization.num_worker (int, optional) – number of CPU workers per GPU
logger (str or core.LoggerBase, optional) – logger type or logger instance. Available types are
logging
andwandb
.log_interval (int, optional) – log every n gradient updates
- evaluate(split, log=True)[source]#
Evaluate the model.
- Parameters
split (str) – split to evaluate. Can be
train
,valid
ortest
.log (bool, optional) – log metrics or not
- Returns
metrics
- Return type
dict
- load(checkpoint, load_optimizer=True, strict=True)[source]#
Load a checkpoint from file.
- Parameters
checkpoint (file-like) – checkpoint file
load_optimizer (bool, optional) – load optimizer state or not
strict (bool, optional) – whether to strictly check the checkpoint matches the model parameters
- save(checkpoint)[source]#
Save checkpoint to file.
- Parameters
checkpoint (file-like) – checkpoint file
- train(num_epoch=1, batch_per_epoch=None)[source]#
Train the model.
If
batch_per_epoch
is specified, randomly draw a subset of the training set for each epoch. Otherwise, the whole training set is used for each epoch.- Parameters
num_epoch (int, optional) – number of epochs
batch_per_epoch (int, optional) – number of batches per epoch
- property epoch#
Current epoch.
Meter#
- class Meter(log_interval=100, silent=False, logger=None)[source]#
Meter for recording metrics and training progress.
- Parameters
log_interval (int, optional) – log every n updates
silent (int, optional) – surpress all outputs or not
logger (core.LoggerBase, optional) – log handler
- log(record, category='train/batch')[source]#
Log a record.
- Parameters
record (dict) – dict of any metric
category (str, optional) – log category. Available types are
train/batch
,train/epoch
,valid/epoch
andtest/epoch
.
- log_config(config)[source]#
Log a hyperparameter config.
- Parameters
config (dict) – hyperparameter config
Registry#
- class Registry[source]#
Registry class for managing all call-by-name access to objects.
Typical scenarios:
Create a model according to a string.
>>> gcn = R.search("GCN")(128, [128])
Register a customize hook to the package.
>>> @R.register("features.atom.my_feature") >>> def my_featurizer(atom): >>> ... >>> >>> data.Molecule.from_smiles("C1=CC=CC=C1", atom_feature="my_feature")
- classmethod get(name)[source]#
Get an object with a canonical name. Hierarchical names are separated by
.
.
Meta Container#
- class _MetaContainer(meta_dict=None, **kwargs)[source]#
Meta container that maintains meta types about members.
The meta type of each member is tracked when a member is assigned. We use a context manager to define the meta types for a bunch of assignment.
The meta types are stored as a dict in
instance.meta_dict
, where keys are member names and values are meta types.>>> class MyClass(_MetaContainer): >>> ...
>>> instance = MyClass() >>> with instance.context("important"): >>> instance.value = 1 >>> assert instance.meta_dict["value"] == "important"
Members assigned with
context(None)
or without a context won’t be tracked.>>> instance.random = 0 >>> assert "random" not in instance.meta_dict
You can also restrict available meta types by defining a set
_meta_types
in the derived class.Note
Meta container also supports auto inference of meta types. This can be enabled by setting
enable_auto_context
toTrue
in the derived class.Once auto inference is on, any member without an explicit context will be recognized through their name prefix. For example,
instance.node_value
will be recognized asnode
ifnode
is defined inmeta_types
.This may make code hard to maintain. Use with caution.
- data_by_meta(include=None, exclude=None)[source]#
Return members based on the specific meta types.
- Parameters
include (list of string, optional) – meta types to include
exclude (list of string, optional) – meta types to exclude
- Returns
data member dict and meta type dict
- Return type
(dict, dict)
- property data_dict#
A dict that maps tracked names to members.