Source code for extorch.adversarial.base

import abc

from torch import nn, Tensor


[docs]class BaseAdversary(nn.Module): r""" Base adversarial adversary. Args: use_eval_mode (bool): Whether use eval mode while generating adversarial examples. Default: ``False``. """ def __init__(self, use_eval_mode: bool = False) -> None: super(BaseAdversary, self).__init__() self.use_eval_mode = use_eval_mode
[docs] def forward(self, net: nn.Module, input: Tensor, target: Tensor, output: Tensor = None) -> Tensor: r""" Generate adversarial examples. Args: net (nn.Module): The victim network. input (Tensor): Origin input. target (Tensor): Label of the input. output (Tensor): Origin output. Default: None. Returns: adv_examples (Tensor): The generated adversarial examples. """ if self.use_eval_mode: is_training_stored = net.training net.eval() else: is_training_stored = False adv_examples = self.generate_adv(net, input, target, output) net.zero_grad() if is_training_stored: # restore the mode net.train() return adv_examples
[docs] @abc.abstractmethod def generate_adv(self, net: nn.Module, input: Tensor, target: Tensor, output: Tensor) -> Tensor: r""" Adversarial example generation. Args: net (nn.Module): The victim network. input (Tensor): Origin input. target (Tensor): Label of the input. output (Tensor): Origin output. Returns: adv_examples (Tensor): The generated adversarial examples. """