6

I am using the emnist data set via the PyTorch datasets together with a neural network that expects a 3 Channel input.

I would like to use PyTorch transforms to copy my 1D greyscale into 3D so I can use the same net for 1D and 3D data.

Which transform can I use? Or how would I extend PyTorch transforms as suggested here: https://stackoverflow.com/a/50530923/18895809

Levin Kobelke
  • 61
  • 1
  • 3

3 Answers3

3

You can achieve this by using torchvision.transforms.Grayscale with num_output_channels parameter set to 3.

Example usage:

trafos = torchvision.transforms.Compose([
    torchvision.transforms.Grayscale(num_output_channels=3),
    torchvision.transforms.ToTensor(),
])

Functional interface (i.e. torchvision.transforms.to_grayscale(num_output_channels=3)) also supports the same parameter.

Tyathalae
  • 149
  • 9
1

Training-wise I recommend using img.expand(...) as it does not allocate new memory see here instead of concatenation torch.cat. While doing so keep in mind that a 3 channel image (possibly RGB) is structurally quiet different from a gray scale one (I suspect you may have some degradation in your results.)

import torch
img = torch.zeros((1, 200, 200))
img = img.expand(3,*img.shape[1:])
D. ACAR
  • 290
  • 1
  • 9
0

You can simply use:

import torch

img = torch.zeros((1, 200, 200))            # img shape = (1, 200, 200)
img2 = torch.cat([img, img, img], dim=0)    # img shape2 = (3, 200, 200)

If you prefer, you can even code your own transformation based on the above snippet. You need to create a simple callable class, just as described by the wiki at: https://pytorch.org/tutorials/recipes/recipes/custom_dataset_transforms_loader.html

Deusy94
  • 733
  • 3
  • 13