extorch.nn.utils

extorch.nn.utils.use_params(module: torch.nn.modules.module.Module, params: collections.OrderedDict) None[source]

Replace the parameters in the module with the given ones. And then recover the old parameters.

Parameters
  • module (nn.Module) – The targeted module.

  • params (OrderedDict) – The given parameters.

Examples

>>> m = nn.Conv2d(1, 10, 3)
>>> params = m.state_dict()
>>> for p in params.values():
>>>     p.data = torch.zeros_like(p.data)
>>> input = torch.ones((2, 1, 10))
>>> with use_params(m, params):
>>>     output = m(input)