-1

I want to use the pre-trained model MobileNetV2 in order to classify the Binary Alpha Data. However, this data comes in shape (20, 16, 1) (greyscale one channel) and not as needed (20, 16, 3) (3 RGB channel). Actually I also have to resize, since 20x16 is neither a valid input, but I know how to do this. So my question is how can I get a 1 channel grescale (DatasetV1Adapter) into a 3 channel?

My code so far:

import tensorflow as tf
import os
import PIL
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow_hub as hub

from tensorflow.keras.models import Sequential

from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import keras_preprocessing
from keras_preprocessing import image

from tensorflow.python.keras.utils.version_utils import training
from tensorflow.keras.optimizers import RMSprop
(raw_train, raw_test, raw_validation), metadata = tfds.load(
    'binary_alpha_digits',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)
IMG_SIZE = 96

def format_example(image, label):
  image = tf.cast(image, tf.float32)
  image = image*1/255.0
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000

train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)

When I check train_batches I get the output:

<DatasetV1Adapter shapes: ((None, 96, 96, 1), (None,)), types: (tf.float32, tf.int64)>

This gives an error when trying to fit later:

ValueError: The input must have 3 channels; got `input_shape=(96, 96, 1)`

So according to this post I tried:

def load_image_into_numpy_array(image):
    # The function supports only grayscale images
    # assert len(image.shape) == 2, "Not a grayscale input image" 
    last_axis = -1
    dim_to_repeat = 2
    repeats = 3
    grscale_img_3dims = np.expand_dims(image, last_axis)
    training_image = np.repeat(grscale_img_3dims, repeats, dim_to_repeat).astype('uint8')
    assert len(training_image.shape) == 3
    assert training_image.shape[-1] == 3
    return training_image

train_mod=load_image_into_numpy_array(raw_train)

But I get an error:

AxisError: axis 2 is out of bounds for array of dimension 1

How can I get this into input_shape=(96, 96, 3)?

today
  • 32,602
  • 8
  • 95
  • 115
Stat Tistician
  • 813
  • 5
  • 17
  • 45

3 Answers3

2

Just from looking at it, shouldn't

def format_example(image, label):
  image = tf.cast(image, tf.float32)
  image = image*1/255.0
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  image = tf.image.grayscale_to_rgb(image)
  return image, label

be the easiest solution.

BertHobe
  • 217
  • 1
  • 14
0

I will like to help, you can try to use the API grayscale_to_rgb. You will have to use the map API to transform your dataset. I have tried your code on the google colab and the following code works.

In your code, you can do the following changes.

def load_image_into_numpy_array(image):
    # The function supports only grayscale images
    # assert len(image.shape) == 2, "Not a grayscale input image" 
    last_axis = -1
    dim_to_repeat = 2
    repeats = 3
    # this way you can change the image to rgb.
    training_image = tf.image.grayscale_to_rgb(image)
    assert len(training_image.shape) == 3
    assert training_image.shape[-1] == 3
    return training_image

Then you will have to call the map function on your dataset like below.

raw_train = raw_train.map(lambda x,y: load_image_into_numpy_array(x))

Then you can check the results that you have received three channels(RGB) or not.

for i in raw_train.take(1):
  print(i.shape)

I hope my answer serves you well.

pratsbhatt
  • 1,498
  • 10
  • 20
0

First of all, note that MobileNetV2 has been trained on 8-bit color images (i.e. commonly known as RGB), whereas the dataset you are trying to use for finetuning consists of 1-bit monochrome images (i.e. binary images). Therefore, this may not necessarily produce good results which could be attributed to this difference in data format.

With that said, to achieve the conversion through repeating the channel dimension, you can simply use tf.repeat function inside the map mapping function (i.e. format_example):

image = tf.repeat(image, repeats=[3], axis=-1)
today
  • 32,602
  • 8
  • 95
  • 115
  • I put the code line into format_example and run it. The resulting shape is (96, 96, None)? However, I need (96,96,3). – Stat Tistician Jul 22 '20 at 12:57
  • Or am I wrong? So according to the common signatures for images specification from the tensorflow hub: The last dimension is fixed to the number 3 of RGB channels. And shape must be [..., 3]. – Stat Tistician Jul 22 '20 at 13:03
  • @StatTistician It's `None` because the shape has not been inferred automatically, otherwise it works and produces images of shape `(96, 96, 3)` (you can verify this by iterating over the `train` dataset and printing the shape of its first element which is the image). If you really want the shape to be determined explicitly, you can use `image.set_shape((IMG_SIZE, IMG_SIZE, 3))` just before the `return` statement in `format_example`. – today Jul 22 '20 at 14:41
  • @StatTistician Alternatively, you can use `tf.image.grayscale_to_rgb` as suggested in the other answer. However, unlike in that answer, you can use it in `format_example` same as `tf.repeat`. This, i.e. using it in `format_example`, would be more efficient since you are working directly on Tensor objects (instead of Numpy arrays). – today Jul 22 '20 at 14:45
  • "otherwise it works and produces images of shape" the problem is that it does not work, when I continue in my code and try to fit my model then I get an error message. – Stat Tistician Jul 22 '20 at 14:58
  • @StatTistician Even if you set the shape as I suggested? Can you edit your question and add the code where you load the model and start finetuning it? – today Jul 22 '20 at 14:59
  • @StatTistician I just tried my answer (without any further modification) on MobileNetV2 and it starts training without any problems. – today Jul 22 '20 at 15:15