extorch.nn.utils

class extorch.nn.utils.WrapperModel(module: torch.nn.modules.module.Module, mean: torch.Tensor, std: torch.Tensor)[source]

Bases: torch.nn.modules.module.Module

A wrapper model for computer vision tasks. Normalize the input before feed-forward.

Parameters
  • module (nn.Module) – A network with input range [0., 1.].

  • mean (Tensor) – The mean value used for input transforms.

  • std (Tensor) – The standard value used for input transforms.

forward(input: torch.Tensor) torch.Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
extorch.nn.utils.net_device(module: torch.nn.modules.module.Module) torch.device[source]

Get current device of the network, assuming all weights of the network are on the same device.

Parameters

module (nn.Module) – The network.

Returns

The device.

Return type

torch.device

Examples::
>>> module = nn.Conv2d(3, 3, 3)
>>> device = net_device(module) # "cpu"
extorch.nn.utils.use_params(module: torch.nn.modules.module.Module, params: collections.OrderedDict) None[source]

Replace the parameters in the module with the given ones. And then recover the old parameters.

Parameters
  • module (nn.Module) – The targeted module.

  • params (OrderedDict) – The given parameters.

Examples

>>> m = nn.Conv2d(1, 10, 3)
>>> params = m.state_dict()
>>> for p in params.values():
>>>     p.data = torch.zeros_like(p.data)
>>> input = torch.ones((2, 1, 10))
>>> with use_params(m, params):
>>>     output = m(input)