0

I am trying to train a classifier with MNIST dataset using pytorch-lightening.

import pytorch_lightning as pl
from torchvision import transforms
from torchvision.datasets import MNIST, SVHN
from torch.utils.data import DataLoader, random_split


class MNISTData(pl.LightningDataModule):

    def __init__(self, data_dir='./', batch_size=256):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.ToTensor()

    def download(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
        return mnist_train

    def val_dataloader(self):
        mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
        return mnist_val

    def test_dataloader(self):
        mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)

After using MNISTData().setup(), I gained MNISTData().mnist_train, MNISTData().mnist_val, MNISTData().mnist_test whose length are 55000, 5000, 10000 with type of torch.utils.data.dataset.Subset.

But when i call dataloader w.r.t MNISTData().train_dataloader, MNISTData().val_dataloader, MNISTData().test_dataloader I only get DataLoader with 215, 20, None datas in them.

Can someone know the reason or could fix the problem?

3
  • Where is the code that returns 215, 20, None? BTW, there is no return in the test_dataloader(...). Commented Sep 14, 2021 at 1:44
  • After correcting return of test_dataloader() I still have an issue. Commented Sep 14, 2021 at 2:26
  • a = MNISTData() a.setup() b,c,d = a.train_dataloader(), a.val_dataloader(),a.test_dataloader() could you try above code and check the variables ? Commented Sep 14, 2021 at 2:28

3 Answers 3

1

As I told in the comments, and Ivan posted in his answer, there was missing return statement:

def test_dataloader(self):
    mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
    return mnist_test  # <<< missing return

As per your comment, if we try:

a = MNISTData()
# skip download, assuming you already have it
a.setup()

b, c, d = a.train_dataloader(), a.val_dataloader(), a.test_dataloader()
# len(b)=215, len(c)=20, len(d)=40

I think your question is why the length of b, c, d are different from the length of the datasets. The answer is that the len() of a DataLoader is equal to the number of batches, not the number of samples, therefore:

import math

batch_size = 256
len(b) = math.ceil(55000 / batch_size) = 215
len(c) = math.ceil(5000 / batch_size) = 20
len(d) = math.ceil(10000 / batch_size) = 40

BTW, we're using math.ceil because DataLoader has drop_last=False by default, otherwise it would be math.floor.

Sign up to request clarification or add additional context in comments.

Comments

0

Your test_dataloader function is missing a return statement!

def test_dataloader(self):
    mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
    return mnist_test

>>> ds = MNISTData()
>>> ds.download()
>>> ds.setup()

Then:

>>> [len(subset) for subset in \
          (ds.mnist_train, ds.mnist_val, ds.mnist_test)]
[55000, 5000, 10000]


>>> [len(loader) for loader in \
         (ds.train_dataloader(), ds.val_dataloader(), ds.test_dataloader())]
[215, 20, 40]

Comments

0

Others pointing out the fact that you are missing a return is the test_dataloader() is certainly correct.

Judging by how the question is framed, it seems you are confused about the length of a Dataset and a DataLoader.

len(Dataset(..)) returns the number of data samples in your dataset.

whereas, len(DataLoader(ds, ...)) returns the number of batches; and that depends of how much batch_size=... you requested, whether you want to drop_last batch etc. The exact calculations are provided correctly by @Berriel

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.