extorch.nn.utils
- class extorch.nn.utils.WrapperModel(module: torch.nn.modules.module.Module, mean: torch.Tensor, std: torch.Tensor)[source]
Bases:
torch.nn.modules.module.Module
A wrapper model for computer vision tasks. Normalize the input before feed-forward.
- Parameters
module (nn.Module) – A network with input range [0., 1.].
mean (Tensor) – The mean value used for input transforms.
std (Tensor) – The standard value used for input transforms.
- forward(input: torch.Tensor) torch.Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- extorch.nn.utils.net_device(module: torch.nn.modules.module.Module) torch.device [source]
Get current device of the network, assuming all weights of the network are on the same device.
- Parameters
module (nn.Module) – The network.
- Returns
The device.
- Return type
torch.device
- Examples::
>>> module = nn.Conv2d(3, 3, 3) >>> device = net_device(module) # "cpu"
- extorch.nn.utils.use_params(module: torch.nn.modules.module.Module, params: collections.OrderedDict) None [source]
Replace the parameters in the module with the given ones. And then recover the old parameters.
- Parameters
module (nn.Module) – The targeted module.
params (OrderedDict) – The given parameters.
Examples
>>> m = nn.Conv2d(1, 10, 3) >>> params = m.state_dict() >>> for p in params.values(): >>> p.data = torch.zeros_like(p.data) >>> input = torch.ones((2, 1, 10)) >>> with use_params(m, params): >>> output = m(input)