Source code for extorch.vision.transforms.functional

from typing import Union, Tuple

import numpy as np
import torch
from torch import Tensor
from torchvision.transforms import functional as F


[docs]def cutout(image: Tensor, length: int, n_holes: int = 1) -> Tensor: r""" Cutout: Randomly mask out one or more patches from an image (`Link`_). Args: image (Tensor): Image of size (C, H, W). length (int): The length (in pixels) of each square patch. n_holes (int): Number of patches to cut out of each image. Default: 1. Returns: image (Tensor): Image with n_holes of dimension length x length cut out of it. Examples:: >>> image = torch.ones((3, 32, 32)) >>> image = cutout(image, 16, 1) .. _Link: https://arxiv.org/abs/1708.04552 """ h, w = image.shape[1:] mask = np.ones((h, w), np.float32) for n in range(n_holes): y = np.random.randint(h) x = np.random.randint(w) y_1 = np.clip(y - length // 2, 0, h) y_2 = np.clip(y + length // 2, 0, h) x_1 = np.clip(x - length // 2, 0, w) x_2 = np.clip(x + length // 2, 0, w) mask[y_1 : y_2, x_1 : x_2] = 0. mask = torch.from_numpy(mask) mask = mask.expand_as(image) image *= mask return image
def _get_image_size(img: Tensor) -> Tuple[int]: r""" Returns the size of an image as [width, height]. Args: img (PIL Image or Tensor): The image to be checked. Returns: width (int): Width of the image. height (int): Height of the image. """ try: width, height = F._get_image_size(img) except: width, height = F.get_image_size(img) return width, height def _totuple(size: Union[int, Tuple[int, int]]) -> Tuple[int, int]: r""" Transform input size from int to tuple. Args: size (Union[int, Tuple[int, int]]): Size to be checked and transformed. """ if isinstance(size, int): size = (size, size) return size