extorch.nn.functional

extorch.nn.functional.average_logits(logits_list: List[torch.Tensor]) torch.Tensor[source]

Aggregates logits from different networks in the average manner, and returns the aggregated logits. Used for neural ensemble networks.

Parameters

logits_list (List[Tensor]) – A list of logits from different networks.

Returns

The aggregated logits (Tensor).

extorch.nn.functional.dec_soft_assignment(input: torch.Tensor, centers: torch.Tensor, alpha: float) torch.Tensor[source]

Soft assignment used by Deep Embedded Clustering (DEC, `Link`_). Measure the similarity between embedded point and centroid with the Student’s t-distribution.

Parameters
  • input (Tensor) – A batch of embedded points. FloatTensor of [batch size, embedding dimension].

  • centers (Tensor) – The cluster centroids. FloatTensor of [cluster_number, embedding dimension].

  • alpha (float) – The degrees of freedom of the Student’s tdistribution. Default: 1.0.

Returns

The similarity between embedded point and centroid. FloatTensor [batch size, cluster_number].

Return type

similarity (Tensor)

Examples::
>>> embeddings = torch.ones((2, 10))
>>> centers = torch.zeros((3, 10))
>>> similarity = dec_soft_assignment(embeddings, centers, alpha = 1.)
extorch.nn.functional.mix_data(input: torch.Tensor, target: torch.Tensor, alpha: float = 1.0)[source]

Mixup data for training neural networks on convex combinations of pairs of examples and their labels (`Link`_).

Parameters
  • input (Tensor) – Input examples.

  • target (Tensor) – Labels of input examples.

  • alpha (float) – Parameter of the beta distribution. Default: 1.0.

Returns

Input examples after mixup. mixed_target (Tensor): Labels of mixed inputs. _lambda (float): Parameter sampled from the beta distribution.

Return type

mixed_input (Tensor)

extorch.nn.functional.soft_voting(logits_list: List[torch.Tensor]) torch.Tensor[source]

Aggregates logits from different sub-networks in the soft-voting manner, and returns the aggregated logits. Used for neural ensemble networks.

Parameters

logits_list (List[Tensor]) – A list of logits from different networks.

Returns

The aggregated logits (Tensor).