import os
import sys
import logging
from itertools import islice
import torch
from torch import distributed as dist
from torch import nn
from torch.utils import data as torch_data
from torchdrug import data, core, utils
from torchdrug.core import Registry as R
from torchdrug.utils import comm, pretty
module = sys.modules[__name__]
logger = logging.getLogger(__name__)
[docs]@R.register("core.Engine")
class Engine(core.Configurable):
"""
General class that handles everything about training and test of a task.
This class can perform synchronous distributed parallel training over multiple CPUs or GPUs.
To invoke parallel training, launch with one of the following commands.
1. Single-node multi-process case.
.. code-block:: bash
python -m torch.distributed.launch --nproc_per_node={number_of_gpus} {your_script.py} {your_arguments...}
2. Multi-node multi-process case.
.. code-block:: bash
python -m torch.distributed.launch --nnodes={number_of_nodes} --node_rank={rank_of_this_node}
--nproc_per_node={number_of_gpus} {your_script.py} {your_arguments...}
If :meth:`preprocess` is defined by the task, it will be applied to ``train_set``, ``valid_set`` and ``test_set``.
Parameters:
task (nn.Module): task
train_set (data.Dataset): training set
valid_set (data.Dataset): validation set
test_set (data.Dataset): test set
optimizer (optim.Optimizer): optimizer
scheduler (lr_scheduler._LRScheduler, optional): scheduler
gpus (list of int, optional): GPU ids. By default, CPUs will be used.
For multi-node multi-process case, repeat the GPU ids for each node.
batch_size (int, optional): batch size of a single CPU / GPU
gradient_interval (int, optional): perform a gradient update every n batches.
This creates an equivalent batch size of ``batch_size * gradient_interval`` for optimization.
num_worker (int, optional): number of CPU workers per GPU
logger (str or core.LoggerBase, optional): logger type or logger instance.
Available types are ``logging`` and ``wandb``.
log_interval (int, optional): log every n gradient updates
"""
def __init__(self, task, train_set, valid_set, test_set, optimizer, scheduler=None, gpus=None, batch_size=1,
gradient_interval=1, num_worker=0, logger="logging", log_interval=100):
self.rank = comm.get_rank()
self.world_size = comm.get_world_size()
self.gpus = gpus
self.batch_size = batch_size
self.gradient_interval = gradient_interval
self.num_worker = num_worker
if gpus is None:
self.device = torch.device("cpu")
else:
if len(gpus) != self.world_size:
error_msg = "World size is %d but found %d GPUs in the argument"
if self.world_size == 1:
error_msg += ". Did you launch with `python -m torch.distributed.launch`?"
raise ValueError(error_msg % (self.world_size, len(gpus)))
self.device = torch.device(gpus[self.rank % len(gpus)])
if self.world_size > 1 and not dist.is_initialized():
if self.rank == 0:
module.logger.info("Initializing distributed process group")
backend = "gloo" if gpus is None else "nccl"
comm.init_process_group(backend, init_method="env://")
if hasattr(task, "preprocess"):
if self.rank == 0:
module.logger.warning("Preprocess training set")
# TODO: more elegant implementation
# handle dynamic parameters in optimizer
old_params = list(task.parameters())
result = task.preprocess(train_set, valid_set, test_set)
if result is not None:
train_set, valid_set, test_set = result
new_params = list(task.parameters())
if len(new_params) != len(old_params):
optimizer.add_param_group({"params": new_params[len(old_params):]})
if self.world_size > 1:
task = nn.SyncBatchNorm.convert_sync_batchnorm(task)
buffers_to_ignore = []
for name, buffer in task.named_buffers():
if not isinstance(buffer, torch.Tensor):
buffers_to_ignore.append(name)
task._ddp_params_and_buffers_to_ignore = set(buffers_to_ignore)
if self.device.type == "cuda":
task = task.cuda(self.device)
self.model = task
self.train_set = train_set
self.valid_set = valid_set
self.test_set = test_set
self.optimizer = optimizer
self.scheduler = scheduler
if isinstance(logger, str):
if logger == "logging":
logger = core.LoggingLogger()
elif logger == "wandb":
logger = core.WandbLogger(project=task.__class__.__name__)
else:
raise ValueError("Unknown logger `%s`" % logger)
self.meter = core.Meter(log_interval=log_interval, silent=self.rank > 0, logger=logger)
self.meter.log_config(self.config_dict())
[docs] def train(self, num_epoch=1, batch_per_epoch=None):
"""
Train the model.
If ``batch_per_epoch`` is specified, randomly draw a subset of the training set for each epoch.
Otherwise, the whole training set is used for each epoch.
Parameters:
num_epoch (int, optional): number of epochs
batch_per_epoch (int, optional): number of batches per epoch
"""
sampler = torch_data.DistributedSampler(self.train_set, self.world_size, self.rank)
dataloader = data.DataLoader(self.train_set, self.batch_size, sampler=sampler, num_workers=self.num_worker)
batch_per_epoch = batch_per_epoch or len(dataloader)
model = self.model
model.split = "train"
if self.world_size > 1:
if self.device.type == "cuda":
model = nn.parallel.DistributedDataParallel(model, device_ids=[self.device],
find_unused_parameters=True)
else:
model = nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
model.train()
for epoch in self.meter(num_epoch):
sampler.set_epoch(epoch)
metrics = []
start_id = 0
# the last gradient update may contain less than gradient_interval batches
gradient_interval = min(batch_per_epoch - start_id, self.gradient_interval)
for batch_id, batch in enumerate(islice(dataloader, batch_per_epoch)):
if self.device.type == "cuda":
batch = utils.cuda(batch, device=self.device)
loss, metric = model(batch)
if not loss.requires_grad:
raise RuntimeError("Loss doesn't require grad. Did you define any loss in the task?")
loss = loss / gradient_interval
loss.backward()
metrics.append(metric)
if batch_id - start_id + 1 == gradient_interval:
self.optimizer.step()
self.optimizer.zero_grad()
metric = utils.stack(metrics, dim=0)
metric = utils.mean(metric, dim=0)
if self.world_size > 1:
metric = comm.reduce(metric, op="mean")
self.meter.update(metric)
metrics = []
start_id = batch_id + 1
gradient_interval = min(batch_per_epoch - start_id, self.gradient_interval)
if self.scheduler:
self.scheduler.step()
[docs] @torch.no_grad()
def evaluate(self, split, log=True):
"""
Evaluate the model.
Parameters:
split (str): split to evaluate. Can be ``train``, ``valid`` or ``test``.
log (bool, optional): log metrics or not
Returns:
dict: metrics
"""
if comm.get_rank() == 0:
logger.warning(pretty.separator)
logger.warning("Evaluate on %s" % split)
test_set = getattr(self, "%s_set" % split)
sampler = torch_data.DistributedSampler(test_set, self.world_size, self.rank)
dataloader = data.DataLoader(test_set, self.batch_size, sampler=sampler, num_workers=self.num_worker)
model = self.model
model.split = split
model.eval()
preds = []
targets = []
for batch in dataloader:
if self.device.type == "cuda":
batch = utils.cuda(batch, device=self.device)
pred, target = model.predict_and_target(batch)
preds.append(pred)
targets.append(target)
pred = utils.cat(preds)
target = utils.cat(targets)
if self.world_size > 1:
pred = comm.cat(pred)
target = comm.cat(target)
metric = model.evaluate(pred, target)
if log:
self.meter.log(metric, category="%s/epoch" % split)
return metric
[docs] def load(self, checkpoint, load_optimizer=True, strict=True):
"""
Load a checkpoint from file.
Parameters:
checkpoint (file-like): checkpoint file
load_optimizer (bool, optional): load optimizer state or not
strict (bool, optional): whether to strictly check the checkpoint matches the model parameters
"""
if comm.get_rank() == 0:
logger.warning("Load checkpoint from %s" % checkpoint)
checkpoint = os.path.expanduser(checkpoint)
state = torch.load(checkpoint, map_location=self.device)
self.model.load_state_dict(state["model"], strict=strict)
if load_optimizer:
self.optimizer.load_state_dict(state["optimizer"])
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.device)
comm.synchronize()
[docs] def save(self, checkpoint):
"""
Save checkpoint to file.
Parameters:
checkpoint (file-like): checkpoint file
"""
if comm.get_rank() == 0:
logger.warning("Save checkpoint to %s" % checkpoint)
checkpoint = os.path.expanduser(checkpoint)
if self.rank == 0:
state = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict()
}
torch.save(state, checkpoint)
comm.synchronize()
[docs] @classmethod
def load_config_dict(cls, config):
"""
Construct an instance from the configuration dict.
"""
if getattr(cls, "_registry_key", cls.__name__) != config["class"]:
raise ValueError("Expect config class to be `%s`, but found `%s`" % (cls.__name__, config["class"]))
optimizer_config = config.pop("optimizer")
new_config = {}
for k, v in config.items():
if isinstance(v, dict) and "class" in v:
v = core.Configurable.load_config_dict(v)
if k != "class":
new_config[k] = v
optimizer_config["params"] = new_config["task"].parameters()
new_config["optimizer"] = core.Configurable.load_config_dict(optimizer_config)
return cls(**new_config)
@property
def epoch(self):
"""Current epoch."""
return self.meter.epoch_id