2

Following my last post, I am now trying to implement a subclass of the torchvision.datasets.ImageFolder class. The following code returns an error ("name 'default_loader' is not defined"), and I can't figure out why. Will you please help me?

class ExtendingImageFolder(torchvision.datasets.ImageFolder)
   def __init__(self,root,transform=None, target_transform=None,loader=default_loader):
       super().__init__(root,transform,target_transform,loader)

When I delete the "None" and "default_loader", and write it like this;

    class ExtendingImageFolder(torchvision.datasets.ImageFolder)
   def __init__(self,root,transform, target_transform,loader):
       super().__init__(root,transform,target_transform,loader)

I get an error of missing input arguments when trying to create an instance of this class, like:

JJ=ExtendingImageFolder(root='C:/',transform=transform)

What am I doing wrong here?

Thanks in advance!

benjaminplanche
  • 14,689
  • 5
  • 57
  • 69
Dr. John
  • 273
  • 3
  • 13

1 Answers1

2

default_loader() is a function defined in torchvision/datasets/folder.py, along ImageFolder and other folder-based dataset helpers.

It is however not exported in torchvision/datasets/__init__.py (unlike ImageFolder). You can still import it directly with "from torchvision.datasets.folder import default_loader" - which should solve your import error.

benjaminplanche
  • 14,689
  • 5
  • 57
  • 69