extorch.nn.init

extorch.nn.init.normal_(module: torch.nn.modules.module.Module, conv_mean: float = 0.0, conv_std: float = 1.0, bn_mean: float = 0.0, bn_std: float = 1.0, linear_mean: float = 0.0, linear_std: float = 1.0) None[source]

Initialize the module with values drawn from the normal distribution.

Parameters
  • module (nn.Module) – A pytorch module.

  • conv_mean (float) – The mean of the normal distribution for convolution.

  • conv_std (float) – The standard deviation of the normal distribution for convolution.

  • bn_mean (float) – The mean of the normal distribution for batch-normalization.

  • bn_std (float) – The standard deviation of the normal distribution for batch-normalization.

  • linear_mean (float) – The mean of the normal distribution for linear layers.

  • bn_std – The standard deviation of the normal distribution linear layers.

Examples::
>>> import torch.nn as nn
>>> module = nn.Sequential(
        nn.Conv2d(5, 10, 3),
        nn.BatchNorm2d(10),
        nn.ReLU()
    )
>>> module.apply(normal_)