I am trying to train a classifier with MNIST dataset using pytorch-lightening.
import pytorch_lightning as pl
from torchvision import transforms
from torchvision.datasets import MNIST, SVHN
from torch.utils.data import DataLoader, random_split
class MNISTData(pl.LightningDataModule):
def __init__(self, data_dir='./', batch_size=256):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.ToTensor()
def download(self):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
if stage == 'fit' or stage is None:
mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return mnist_train
def val_dataloader(self):
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
return mnist_val
def test_dataloader(self):
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
After using MNISTData().setup(), I gained MNISTData().mnist_train, MNISTData().mnist_val, MNISTData().mnist_test whose length are 55000, 5000, 10000 with type of torch.utils.data.dataset.Subset.
But when i call dataloader w.r.t MNISTData().train_dataloader, MNISTData().val_dataloader, MNISTData().test_dataloader I only get DataLoader with 215, 20, None datas in them.
Can someone know the reason or could fix the problem?
215, 20, None? BTW, there is noreturnin thetest_dataloader(...).returnoftest_dataloader()I still have an issue.a = MNISTData()a.setup()b,c,d = a.train_dataloader(), a.val_dataloader(),a.test_dataloader()could you try above code and check the variables ?