extorch.utils.visual

extorch.utils.visual.denormalize(image: torch.Tensor, mean: List[float], std: List[float], transpose: bool = False, detach_numpy: bool = False) Union[torch.Tensor, numpy.ndarray][source]

De-normalize the tensor-like image.

Parameters
  • image (Tensor) – The image to be de-normalized with shape [B, C, H, W] or [C, H, W].

  • mean (List[float]) – Sequence of means for each channel while normalizing the origin image.

  • std (List[float]) – Sequence of standard deviations for each channel while normalizing the origin image.

  • transpose (bool) – Whether transpose the image to [B, H, W, C] or [H, W, C]. Default: False.

  • detach_numpy (bool) – If true, return Tensor.detach().cpu().numpy().

Returns

The de-normalized image.

Return type

image (Union[Tensor, np.ndarray])

Examples

>>> image = torch.randn((5, 3, 32, 32)).cuda()  # Shape: [5, 3, 32, 32] (cuda)
>>> mean = [0.5, 0.5, 0.5]
>>> std = [1., 1., 1.]
>>> de_image = denormalize(image, mean, std, True, True)  # Shape: [5, 32, 32, 3] (cpu)
extorch.utils.visual.tsne_fit(feature: numpy.ndarray, n_components: int = 2, init: str = 'pca', **kwargs)[source]

Fit input features into an embedded space and return that transformed output.

Parameters
  • feature (np.ndarray) – The features to be embedded.

  • n_components (int) – Dimension of the embedded space. Default: 2.

  • init (str) – Initialization of embedding. Possible options are “random”, “pca”, and a numpy array of shape (n_samples, n_components). PCA initialization cannot be used with precomputed distances and is usually more globally stable than random initialization. Default: “pca”.

  • kwargs – Other configurations for TSNE model construction.

Returns

The representation in the embedding space.

Return type

node_pos (np.ndarray)

Examples::
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> features = np.random.randn(50, 10)
>>> labels = np.random.randint(0, 2, (50, 1))
>>> node_pos = tsne_fit(features, 2, "pca")
>>> plt.figure()
>>> plt.scatter(node_pos[:, 0], node_pos[:, 1], c = labels)
>>> plt.show()