1

I am working an image retrieval project, for making model more fair i want to construct batches that return:

  • 5 images per class, and
  • 75 images and per batch

I have total 300 classes in my dataset, so it obvious that only 15 classes of images can be contained in each batch.data is balanced this mean there is equal number of images for per class,I am using pytorch.

I have create pytorch dataset and I want to add above functionality in my ImageFolderLoader class whose code I added below.

IMG_EXTENSIONS = [
   '.jpg', '.JPG', '.jpeg', '.JPEG',
   '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def find_classes(dir):
    classes = os.listdir(dir)
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    classes = [clss.split('.')[1] for clss in classes]
    return classes, class_to_idx

def make_dataset(dir, class_to_idx):
    images = []
    for target in os.listdir(dir):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue
        for filename in os.listdir(d):
            if is_image_file(filename):
                path = '{0}/{1}'.format(target, filename)
                item = (path, class_to_idx[target])
                images.append(item)
                
    return images

def default_loader(path):
    return Image.open(path).convert('RGB')

class ImageFolderLoader(Dataset):
    def __init__(self, root, transform=None, loader=default_loader,):
        classes, class_to_idx = find_classes(root)
        imgs = make_dataset(root, class_to_idx)
        
        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.loader = loader
        
    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(os.path.join(self.root, path))
        if self.transform is not None:
            img = self.transform(img)
            
        return img, target
    
    def __len__(self):
        return len(self.imgs)

if there is way to do this then please let me know>.

edit:- Anyone want to see solution for this, i added the solution below after solving this problem.

0

2 Answers 2

1

I solved the problem by including batch_sampler in DataLoader module. for this i used pytorch-balanced-sampler git project, which allows awesome customization for batch_sampler, you should visit this repo.

My custom dataset:

IMG_EXTENSIONS = [
   '.jpg', '.JPG', '.jpeg', '.JPEG',
   '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def find_classes(dir):
    classes = os.listdir(dir)
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    classes = [clss.split('.')[1] for clss in classes]
    return classes, class_to_idx

def make_dataset(dir, class_to_idx):
    images = []
    for target in os.listdir(dir):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue
        for filename in os.listdir(d):
            if is_image_file(filename):
                path = '{0}/{1}'.format(target, filename)
                item = (path, class_to_idx[target])
                images.append(item)
        
    data_dict = {}
    for item in images:
        cls = item[1]
        if cls not in data_dict.keys():
            data_dict[cls] = [item]
        else:
            data_dict[cls].append(item) 
        
    return images,data_dict

def default_loader(path):
    return Image.open(path).convert('RGB')

class ImageFolderLoader(Dataset):
    def __init__(self, root, transform=None, loader=default_loader):
        classes, class_to_idx = find_classes(root)
        imgs,instance_labels = make_dataset(root, class_to_idx)
        
        
        self.instance_labels = instance_labels
        
        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.loader = loader
        
    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(os.path.join(self.root, path))
        if self.transform is not None:
            img = self.transform(img)
            
        return img, target

    def __len__(self):
        return len(self.imgs)

Then i used SamplerFactory class from pytorch-balances-sampler project, you need to visit this repository for understand the parameters,

train_data = ImageFolderLoader(root=TRAIN_PATH, transform=transform)
batch_sampler = SamplerFactory().get(
    class_idxs=my_list,
    batch_size=75,
    n_batches=146,
    alpha=1,
    kind="fixed"
)
Sign up to request clarification or add additional context in comments.

Comments

0

There are a few open-ended questions in how this is implemented. For instance, do you want each class to be equally represented regardless of that class's actual frequency? Note that this may give better performance on minority classes at the expense of performance on majority classes.

Also, do you want each example to be used at most once per epoch, or at least once per epoch?

In any case, this will likely be difficult to accomplish with the standard getitem method because it returns an example with no regard for the other examples returned in the same batch. You'll likely need to define a custom dataloader object to ensure good data distribution and usage properties, which is a bit unfortunate because pytorch's dataloader and dataset objects work together quite nicely and efficiently for most simple use cases. Perhaps someone else has a solution that uses these objects.

Here's a solution that uses random sampling with replacement after each batch, so there's no guarantee that every example will be used. Also, it uses looping so you could probably do better with parallelization.

class ImageFolderLoader(Dataset):
  def __init__(self, root, transform=None, loader=default_loader,):
    classes, class_to_idx = find_classes(root)
    imgs = make_dataset(root, class_to_idx)

    #currently, imgs items are of the form (path,class)

    data_dict = {}
    for item in imgs:
       cls = item[1]
       if cls not in data_dict.keys():
           data_dict[cls] = [item]
       else:
           data_dict[cls].append(item)  
   
    # each class is the key for a list of all items belonging to that class
    self.data_dict = data_dict 

    self.root = root
    self.imgs = imgs
    self.classes = classes
    self.class_to_idx = class_to_idx
    self.transform = transform
    self.loader = loader
    
  def get_batch(self):
    img_batch = []
    label_batch = []
    
    classes = random.sample((0,300),15) 
    for cls in classes:
        class_data = self.data_dict[cls]
        selection = random.sample((0,len(class_data),5)
        for idx in selection:
           img = self.loader(os.path.join(self.root, class_data[idx][0]))
           if self.transform is not None:
               img = self.transform(img)
           img_batch.append(img)
           label_batch.append(cls)
   
    img_batch = torch.stack(img_batch)
    label_batch = torch.stack(label_batch)

    return img_batch, label_batch

  def __len__(self):
    return len(self.imgs)

2 Comments

more of a sketch for how you might do it than polished code, sorry. Obviously there's lots of unanswered questions that prevent a full implementation
These edit i have done already but code is not implementing correctly,it giving me errors. there is logical errors in your code.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.