extorch.utils.stats
- class extorch.utils.stats.SegConfusionMatrix(num_classes: int)[source]
Bases:
object
Confusion matrix as well as metric calculation for image segmentation.
- Parameters
num_classes (int) – The number of classes, including the background.
- extorch.utils.stats.accuracy(outputs: torch.Tensor, targets: torch.Tensor, topk: Tuple[int] = (1,)) List[torch.Tensor] [source]
- extorch.utils.stats.cal_flops(model: torch.nn.modules.module.Module, inputs: torch.Tensor) float [source]
Calculate FLOPs of the given model.
- Parameters
model (nn.Module) – The model whose FLOPs is to be calculated.
inputs (Tensor) – Example inputs to the model.
- Returns
FLOPs of the model.
- Return type
flops (float)
- Examples::
>>> import torch >>> from extorch.nn import AuxiliaryHead >>> module = AuxiliaryHead(3, 10) >>> input = torch.randn((10, 3, 32, 32)) >>> flops = cal_flops(module, input) / 1.e6 # 32.109868