Source code for extorch.vision.dataset.cifar

from torchvision import datasets, transforms

from extorch.vision.dataset import CVClassificationDataset


# Standard transformation for CIFAR10 datasets
CIFAR10_MEAN = [0.49139968, 0.48215827, 0.44653124]
CIFAR10_STD = [0.24703233, 0.24348505, 0.26158768]

CIFAR10_TRAIN_TRANSFORM = transforms.Compose([
    transforms.RandomCrop(32, padding = 4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)
])

CIFAR10_TEST_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)
])


[docs]class CIFAR10(CVClassificationDataset): def __init__(self, data_dir: str, train_transform: transforms.Compose = CIFAR10_TRAIN_TRANSFORM, test_transform: transforms.Compose = CIFAR10_TEST_TRANSFORM) -> None: super(CIFAR10, self).__init__(data_dir, train_transform, test_transform) self.datasets["train"] = datasets.CIFAR10(root = self.data_dir, train = True, download = True, transform = self.transforms["train"]) self.datasets["test"] = datasets.CIFAR10(root = self.data_dir, train = False, download = True, transform = self.transforms["test"]) self._num_classes = 10
# Standard transformation for CIFAR100 datasets CIFAR100_MEAN = [0.5070751592371322, 0.4865488733149497, 0.44091784336703466] CIFAR100_STD = [0.26733428587924063, 0.25643846291708833, 0.27615047132568393] CIFAR100_TRAIN_TRANSFORM = transforms.Compose([ transforms.RandomCrop(32, padding = 4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(CIFAR100_MEAN, CIFAR100_STD) ]) CIFAR100_TEST_TRANSFORM = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR100_MEAN, CIFAR100_STD) ])
[docs]class CIFAR100(CVClassificationDataset): def __init__(self, data_dir: str, train_transform: transforms.Compose = CIFAR100_TRAIN_TRANSFORM, test_transform: transforms.Compose = CIFAR100_TEST_TRANSFORM) -> None: super(CIFAR100, self).__init__(data_dir, train_transform, test_transform) self.datasets["train"] = datasets.CIFAR100(root = self.data_dir, train = True, download = True, transform = self.transforms["train"]) self.datasets["test"] = datasets.CIFAR100(root = self.data_dir, train = False, download = True, transform = self.transforms["test"]) self._num_classes = 100