extorch.nn.functional
- extorch.nn.functional.average_logits(logits_list: List[torch.Tensor]) torch.Tensor [source]
Aggregates logits from different networks in the average manner, and returns the aggregated logits. Used for neural ensemble networks.
- Parameters
logits_list (List[Tensor]) – A list of logits from different networks.
- Returns
The aggregated logits (Tensor).
- extorch.nn.functional.dec_soft_assignment(input: torch.Tensor, centers: torch.Tensor, alpha: float) torch.Tensor [source]
Soft assignment used by Deep Embedded Clustering (DEC, `Link`_). Measure the similarity between embedded point and centroid with the Student’s t-distribution.
- Parameters
input (Tensor) – A batch of embedded points. FloatTensor of [batch size, embedding dimension].
centers (Tensor) – The cluster centroids. FloatTensor of [cluster_number, embedding dimension].
alpha (float) – The degrees of freedom of the Student’s tdistribution. Default: 1.0.
- Returns
The similarity between embedded point and centroid. FloatTensor [batch size, cluster_number].
- Return type
similarity (Tensor)
- Examples::
>>> embeddings = torch.ones((2, 10)) >>> centers = torch.zeros((3, 10)) >>> similarity = dec_soft_assignment(embeddings, centers, alpha = 1.)
- extorch.nn.functional.mix_data(input: torch.Tensor, target: torch.Tensor, alpha: float = 1.0)[source]
Mixup data for training neural networks on convex combinations of pairs of examples and their labels (`Link`_).
- Parameters
input (Tensor) – Input examples.
target (Tensor) – Labels of input examples.
alpha (float) – Parameter of the beta distribution. Default: 1.0.
- Returns
Input examples after mixup. mixed_target (Tensor): Labels of mixed inputs. _lambda (float): Parameter sampled from the beta distribution.
- Return type
mixed_input (Tensor)
- extorch.nn.functional.soft_voting(logits_list: List[torch.Tensor]) torch.Tensor [source]
Aggregates logits from different sub-networks in the soft-voting manner, and returns the aggregated logits. Used for neural ensemble networks.
- Parameters
logits_list (List[Tensor]) – A list of logits from different networks.
- Returns
The aggregated logits (Tensor).