Source code for extorch.utils.stats

from typing import List, Tuple, Union
from collections import OrderedDict

import torch
import torch.nn as nn
from torch import Tensor


[docs]class OrderedStats(object): def __init__(self) -> None: self.stat_meters = None def __nonzero__(self) -> bool: return self.stat_meters is not None __bool__ = __nonzero__
[docs] def update(self, stats, n: int = 1) -> None: if self.stat_meters is None: self.stat_meters = OrderedDict([(name, AverageMeter()) for name in stats]) [self.stat_meters[name].update(v, n) for name, v in stats.items()]
[docs] def avgs(self) -> Union[OrderedDict, None]: if self.stat_meters is None: return None return OrderedDict((name, meter.avg) for name, meter in self.stat_meters.items())
[docs] def items(self): return self.stat_meters.items() if self.stat_meters is not None else None
[docs]class AverageMeter(object): def __init__(self) -> None: self.reset()
[docs] def is_empty(self) -> bool: return self.cnt == 0
[docs] def reset(self) -> None: self.avg = 0. self.sum = 0. self.cnt = 0
[docs] def update(self, val: Union[float, int], n: int = 1) -> None: self.sum += val * n self.cnt += n self.avg = self.sum / self.cnt
[docs]def accuracy(outputs: Tensor, targets: Tensor, topk: Tuple[int] = (1, )) -> List[Tensor]: with torch.no_grad(): maxk = max(topk) batch_size = targets.size(0) _, pred = outputs.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(targets.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res
[docs]def get_params(model: nn.Module, only_trainable: bool = False) -> int: r""" Get the parameter number of the model. Args: only_trainable (bool): If only_trainable is true, only trainable parameters will be counted. """ return sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)