Source code for extorch.utils.stats

from typing import List, Tuple, Union
from collections import OrderedDict

import torch
import torch.nn as nn
from torch import Tensor
from thop import profile

from extorch.nn.utils import net_device


[docs]class OrderedStats(object): def __init__(self) -> None: self.stat_meters = None def __nonzero__(self) -> bool: return self.stat_meters is not None __bool__ = __nonzero__
[docs] def update(self, stats, n: int = 1) -> None: if self.stat_meters is None: self.stat_meters = OrderedDict([(name, AverageMeter()) for name in stats]) [self.stat_meters[name].update(v, n) for name, v in stats.items()]
[docs] def avgs(self) -> Union[OrderedDict, None]: if self.stat_meters is None: return None return OrderedDict((name, meter.avg) for name, meter in self.stat_meters.items())
[docs] def items(self): return self.stat_meters.items() if self.stat_meters is not None else None
[docs]class AverageMeter(object): def __init__(self) -> None: self.reset()
[docs] def is_empty(self) -> bool: return self.cnt == 0
[docs] def reset(self) -> None: self.avg = 0. self.sum = 0. self.cnt = 0
[docs] def update(self, val: Union[float, int], n: int = 1) -> None: self.sum += val * n self.cnt += n self.avg = self.sum / self.cnt
[docs]def accuracy(outputs: Tensor, targets: Tensor, topk: Tuple[int] = (1, )) -> List[Tensor]: with torch.no_grad(): maxk = max(topk) batch_size = targets.size(0) _, pred = outputs.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(targets.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res
[docs]class SegConfusionMatrix(object): r""" Confusion matrix as well as metric calculation for image segmentation. Args: num_classes (int): The number of classes, including the background. """ def __init__(self, num_classes: int) -> None: self.num_classes = num_classes self.mat = None
[docs] def update(self, output: Tensor, target: Tensor) -> None: n = self.num_classes output = output.argmax(1).flatten() target = target.squeeze(0).to(torch.int64).flatten() k = (target >= 0) & (target < n) inds = n * target[k].to(torch.int64) + output[k] if self.mat is None: self.mat = torch.zeros((n, n), dtype = torch.int64, device = output.device) self.mat += torch.bincount(inds, minlength = n ** 2).reshape(n, n)
[docs] def reset(self) -> None: self.mat.zero_()
[docs] def compute(self): h = self.mat.float() acc_global = torch.diag(h).sum() / h.sum() acc = torch.diag(h) / h.sum(1) iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) return acc_global, acc, iu
[docs]def get_params(model: nn.Module, only_trainable: bool = False) -> int: r""" Get the parameter number of the model. Args: only_trainable (bool): If only_trainable is true, only trainable parameters will be counted. """ return sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)
[docs]def cal_flops(model: nn.Module, inputs: Tensor) -> float: r""" Calculate FLOPs of the given model. Args: model (nn.Module): The model whose FLOPs is to be calculated. inputs (Tensor): Example inputs to the model. Returns: flops (float): FLOPs of the model. Examples:: >>> import torch >>> from extorch.nn import AuxiliaryHead >>> module = AuxiliaryHead(3, 10) >>> input = torch.randn((10, 3, 32, 32)) >>> flops = cal_flops(module, input) / 1.e6 # 32.109868 """ device = net_device(model) flops, _ = profile(model, inputs = [inputs.to(device)], verbose = False) flops /= len(inputs) return flops