24

The pytorch tutorial for data loading and processing is quite specific to one example, could someone help me with what the function should look like for a more generic simple loading of images?

Tutorial: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

My Data:

I have the MINST dataset as jpg's in the following folder structure. (I know I can just use the dataset class, but this is purely to see how to load simple images into pytorch without csv's or complex features).

The folder name is the label and the images are 28x28 png's in greyscale, no transformations required.

data
    train
        0
            3.png
            5.png
            13.png
            23.png
            ...
        1
            3.png
            10.png
            11.png
            ...
        2
            4.png
            13.png
            ...
        3
            8.png
            ...
        4
            ...
        5
            ...
        6
            ...
        7
            ...
        8
            ...
        9
            ...
0

2 Answers 2

44

Here's what I did for pytorch 0.4.1 (should still work in 1.3)

def load_dataset():
    data_path = 'data/train/'
    train_dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=torchvision.transforms.ToTensor()
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=64,
        num_workers=0,
        shuffle=True
    )
    return train_loader

for batch_idx, (data, target) in enumerate(load_dataset()):
    #train network
Sign up to request clarification or add additional context in comments.

3 Comments

How is the class label specified in your load_dataset() function?
It's generated by ImageFolder depending on the class folder: pytorch.org/docs/stable/torchvision/…
For MNIST It's may be necessary to use "transforms.Grayscale()" : test_dataset = torchvision.datasets.ImageFolder( root=data_path, transform=transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]) )
13

If you're using mnist, there's already a preset in pytorch via torchvision.
You could do

import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd

transform = transforms.Compose(
[transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

mnistTrainSet = torchvision.datasets.MNIST(root='./data', train=True,
                                    download=True, transform=transform)
mnistTrainLoader = torch.utils.data.DataLoader(mnistTrainSet, batch_size=16,
                                      shuffle=True, num_workers=2)

If you want to generalize to a directory of images (same imports as above), you could do

class mnistmTrainingDataset(torch.utils.data.Dataset):

    def __init__(self,text_file,root_dir,transform=transformMnistm):
        """
        Args:
            text_file(string): path to text file
            root_dir(string): directory with all train images
        """
        self.name_frame = pd.read_csv(text_file,sep=" ",usecols=range(1))
        self.label_frame = pd.read_csv(text_file,sep=" ",usecols=range(1,2))
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.name_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.name_frame.iloc[idx, 0])
        image = Image.open(img_name)
        image = self.transform(image)
        labels = self.label_frame.iloc[idx, 0]
        #labels = labels.reshape(-1, 2)
        sample = {'image': image, 'labels': labels}

        return sample


mnistmTrainSet = mnistmTrainingDataset(text_file ='Downloads/mnist_m/mnist_m_train_labels.txt',
                                   root_dir = 'Downloads/mnist_m/mnist_m_train')

mnistmTrainLoader = torch.utils.data.DataLoader(mnistmTrainSet,batch_size=16,shuffle=True, num_workers=2)

You can then iterate over it like:

for i_batch,sample_batched in enumerate(mnistmTrainLoader,0):
    print("training sample for mnist-m")
    print(i_batch,sample_batched['image'],sample_batched['labels'])

There are a bunch of ways to generalize pytorch for image dataset loading, the method that I know of is subclassing torch.utils.data.dataset

2 Comments

Loading the same file and accessing its two columns two times independently is highly inefficient!
Yes. Rather use one single data frame. Load it with index_col=False in read_csv to obtain a numeric index. Then use self.df.at[idx, "filename"] and self.df.at[idx, "label"] in __getitem__.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.