import random
from typing import Union, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms as T
from torchvision.transforms import functional as F
from .functional import _totuple, _get_image_size
__all__ = [
"SegCompose",
"SegRandomHorizontalFlip",
"SegNormalize",
"SegCenterCrop",
"SegRandomCrop",
"SegResize",
"SegRandomResize",
"SegPILToTensor",
"SegConvertImageDtype"
]
[docs]class SegCompose(T.Compose):
r"""
Transform compose for segmentation.
"""
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
[docs]class SegRandomHorizontalFlip(nn.Module):
r"""
Horizontally flip the given image and label randomly with a given probability (`Link`_).
If the image and label are torch Tensor, they are expected to have [..., H, W] shape,
where ... means an arbitrary number of leading dimensions.
Args:
p (float): probability of the image and label being flipped. Default: 0.5.
.. _Link:
https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#RandomHorizontalFlip
"""
def __init__(self, p: float = 0.5) -> None:
super(SegRandomHorizontalFlip, self).__init__()
self.p = p
[docs] def forward(self, image, target):
if random.random() < self.p:
image = F.hflip(image)
target = F.hflip(target)
return image, target
[docs]class SegNormalize(nn.Module):
r"""
Normalization for segmentation, where the normalization is only applied on the input image.
Args:
mean (List[float]): List of means for each channel.
std (List[float]): List of standard deviations for each channel.
"""
def __init__(self, mean: List[float], std: List[float]) -> None:
super(SegNormalize, self).__init__()
self.mean = mean
self.std = std
[docs] def forward(self, image, target):
image = F.normalize(image, mean = self.mean, std = self.std)
return image, target
[docs]class SegCenterCrop(nn.Module):
r"""
Crops the given image at the center.
Args:
size (Union[int, List[int]]): Height and width of the crop box.
"""
def __init__(self, size: Union[int, List[int]]):
super(SegCenterCrop, self).__init__()
self.size = size
[docs] def forward(self, image, target):
image = F.center_crop(image, self.size)
target = F.center_crop(target, self.size)
return image, target
[docs]class SegRandomCrop(nn.Module):
r"""
Random cropping for segmentation.
The Cropping is applied on the image and target at the same time.
Args:
size (Union[int, Tuple[int, int]]): Desired output size of the crop.
"""
def __init__(self, size: Union[int, Tuple[int, int]]) -> None:
super(SegRandomCrop, self).__init__()
self.size = _totuple(size)
[docs] def forward(self, image, target):
width, height = _get_image_size(image)
if width < self.size[1]:
padding = [self.size[1] - width, 0]
image = F.pad(image, padding, 0, "constant")
target = F.pad(target, padding, 255, "constant")
if height < self.size[0]:
padding = [0, self.size[0] - height]
image = F.pad(image, padding, 0, "constant")
target = F.pad(target, padding, 255, "constant")
crop_params = T.RandomCrop.get_params(image, self.size)
image = F.crop(image, *crop_params)
target = F.crop(target, *crop_params)
return image, target
[docs]class SegResize(nn.Module):
r"""
Resize for segmentation.
Args:
size (Union[int, Tuple[int, int]]): Desired output size.
"""
def __init__(self, size: Union[int, Tuple[int, int]]) -> None:
super(SegResize, self).__init__()
self.size = _totuple(size)
[docs] def forward(self, image, target):
image = F.resize(image, self.size)
target = F.resize(target, self.size, interpolation = T.InterpolationMode.NEAREST)
return image, target
[docs]class SegRandomResize(nn.Module):
r"""
Random resize for segmentation.
Args:
min_size (Union[int, Tuple[int, int]]): Desired minimum output size.
max_size (Optional[Union[int, Tuple[int, int]]]): Desired maximum output size. Default: `None`.
"""
def __init__(self, min_size: Union[int, Tuple[int, int]],
max_size: Optional[Union[int, Tuple[int, int]]] = None) -> None:
super(SegRandomResize, self).__init__()
self.min_size = _totuple(min_size)
self.max_size = _totuple(max_size) if max_size else self.min_size
[docs] def forward(self, image, target):
size = (
random.randint(self.min_size[0], self.max_size[0]),
random.randint(self.min_size[1], self.max_size[1])
)
image = F.resize(image, size)
target = F.resize(target, size, interpolation = T.InterpolationMode.NEAREST)
return image, target
[docs]class SegPILToTensor(nn.Module):
r"""
PIL to Tensor for segmentation.
"""
def __init__(self):
super(SegPILToTensor, self).__init__()
[docs] def forward(self, image, target):
image = F.pil_to_tensor(image)
target = torch.as_tensor(np.array(target), dtype = torch.int64)
return image, target
[docs]class SegConvertImageDtype(nn.Module):
r"""
Convert image dtype for segmentation.
"""
def __init__(self, dtype) -> None:
super(SegConvertImageDtype, self).__init__()
self.dtype = dtype
[docs] def forward(self, image, target):
image = F.convert_image_dtype(image, self.dtype)
return image, target