Source code for torchdrug.core.meter

import time
import logging
from collections import defaultdict

import numpy as np
import torch

from torchdrug import core
from torchdrug.utils import pretty

logger = logging.getLogger(__name__)


[docs]class Meter(object): """ Meter for recording metrics and training progress. Parameters: log_interval (int, optional): log every n updates silent (int, optional): surpress all outputs or not logger (core.LoggerBase, optional): log handler """ def __init__(self, log_interval=100, silent=False, logger=None): self.records = defaultdict(list) self.log_interval = log_interval self.epoch2batch = [0] self.time = [time.time()] self.epoch_id = 0 self.batch_id = 0 self.silent = silent self.logger = logger
[docs] def log(self, record, category="train/batch"): """ Log a record. Parameters: record (dict): dict of any metric category (str, optional): log category. Available types are ``train/batch``, ``train/epoch``, ``valid/epoch`` and ``test/epoch``. """ if category.endswith("batch"): step_id = self.batch_id elif category.endswith("epoch"): step_id = self.epoch_id self.logger.log(record, step_id=step_id, category=category)
[docs] def log_config(self, config): """ Log a hyperparameter config. Parameters: config (dict): hyperparameter config """ self.logger.log_config(config)
[docs] def update(self, record): """ Update the meter with a record. Parameters: record (dict): dict of any metric """ if self.batch_id % self.log_interval == 0: self.log(record, category="train/batch") self.batch_id += 1 for k, v in record.items(): if isinstance(v, torch.Tensor): v = v.item() self.records[k].append(v)
[docs] def step(self): """ Step an epoch for this meter. Instead of manually invoking :meth:`step()`, it is suggested to use the following line >>> for epoch in meter(num_epoch): >>> # do something """ self.epoch_id += 1 self.epoch2batch.append(self.batch_id) self.time.append(time.time()) index = slice(self.epoch2batch[-2], self.epoch2batch[-1]) duration = self.time[-1] - self.time[-2] speed = (self.epoch2batch[-1] - self.epoch2batch[-2]) / duration if self.silent: return logger.warning("duration: %s" % pretty.time(duration)) logger.warning("speed: %.2f batch / sec" % speed) eta = (self.time[-1] - self.time[self.start_epoch]) \ / (self.epoch_id - self.start_epoch) * (self.end_epoch - self.epoch_id) logger.warning("ETA: %s" % pretty.time(eta)) if torch.cuda.is_available(): logger.warning("max GPU memory: %.1f MiB" % (torch.cuda.max_memory_allocated() / 1e6)) torch.cuda.reset_peak_memory_stats() record = {} for k, v in self.records.items(): record[k] = np.mean(v[index]) self.log(record, category="train/epoch")
def __call__(self, num_epoch): self.start_epoch = self.epoch_id self.end_epoch = self.start_epoch + num_epoch for epoch in range(self.start_epoch, self.end_epoch): if not self.silent: logger.warning(pretty.separator) logger.warning("Epoch %d begin" % epoch) yield epoch if not self.silent: logger.warning(pretty.separator) logger.warning("Epoch %d end" % epoch) self.step()