extorch.nn.init
- extorch.nn.init.normal_(module: torch.nn.modules.module.Module, conv_mean: float = 0.0, conv_std: float = 1.0, bn_mean: float = 0.0, bn_std: float = 1.0, linear_mean: float = 0.0, linear_std: float = 1.0) None [source]
Initialize the module with values drawn from the normal distribution.
- Parameters
module (nn.Module) – A pytorch module.
conv_mean (float) – The mean of the normal distribution for convolution.
conv_std (float) – The standard deviation of the normal distribution for convolution.
bn_mean (float) – The mean of the normal distribution for batch-normalization.
bn_std (float) – The standard deviation of the normal distribution for batch-normalization.
linear_mean (float) – The mean of the normal distribution for linear layers.
bn_std – The standard deviation of the normal distribution linear layers.
- Examples::
>>> import torch.nn as nn >>> module = nn.Sequential( nn.Conv2d(5, 10, 3), nn.BatchNorm2d(10), nn.ReLU() ) >>> module.apply(normal_)