5

I am trying to learn image auto encoding but I can't use input and output images to train the model

ex: input images folder: ".../Pictures/Input"
output images folder: ".../Pictures/Output"

#get input images from data_dir_input
ds_input = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir_input,
    seed=123,
    image_size=(img_height, img_width),
    label_mode=None,
    batch_size=batch_size)

#get output images from data_dir_output
ds_output = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir_output,
    seed=123,
    image_size=(img_height, img_width),
    label_mode=None,
    batch_size=batch_size)

# --------- model init etc --------------
# ...


model.fit(x=ds_input, y=ds_output, batch_size=32, epochs=50)

But I get error saying this:

`y` argument is not supported when using dataset as input

How can i use my own input images and output images when training model?

1 Answers1

7

You could use tf.data.Dataset for some more flexibility. From what I read, image_dataset_from_directory doesn't support any custom label other than an integer.

Try this:

import os
import tensorflow as tf
os.chdir(r'c:/users/user/Pictures')
from glob2 import glob

x_files = glob('inputs/*.jpg')
y_files = glob('targets/*.jpg')

files_ds = tf.data.Dataset.from_tensor_slices((x_files, y_files))

def process_img(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, size=(28, 28))
    return img

files_ds = files_ds.map(lambda x, y: (process_img(x), process_img(y))).batch(1)

original, target = next(iter(files_ds))
<tf.Tensor: shape=(1, 28, 28, 3), dtype=float32, numpy=
array([[[[0.357423  , 0.3325731 , 0.20412168],
         [0.36274514, 0.21940777, 0.17623049],
         [0.34821934, 0.13921566, 0.06858743],
         ...,
         [0.25486213, 0.27446997, 0.2520612 ],
         [0.04925931, 0.26666668, 0.07619007],
         [0.48167226, 0.5287311 , 0.520888  ]]]

Then you won't need to pass an y to the fit() call. You will be able to use it as such:

model.fit(files_ds, epochs=5)
Nicolas Gervais
  • 33,817
  • 13
  • 115
  • 143
  • Thanks for the response, but my input and output directories are different so i can't pass same images for both – Emre Özincegedik Aug 18 '20 at 00:22
  • 1
    OK I edited my answer with this constraint in mind. Then you just need to get a list of matching files for x and y. Just replace `os.listdir()` with a list of files that makes sense for your task. Make sure it's ordered correctly with the corresponding y. – Nicolas Gervais Aug 18 '20 at 06:57
  • When I use the new code block i get this error `OP_REQUIRES failed at whole_file_read_ops.cc:116 : Not found: NewRandomAccessFile failed to Create/Open: 000128_17.png : The system cannot find the file specified` (btw i changed decode_jpeg to decode_png) – Emre Özincegedik Aug 18 '20 at 12:08
  • That's probably because `'000128_17.png'` doesn't have the folder name. I made you an example with `glob2.glob` instead which takes care of that – Nicolas Gervais Aug 18 '20 at 12:25
  • This worked thanks, can i augment this dataset like rescaling, rotating etc though? `image_dataset_from_directory` had some uses for augmentation but I'm curious for this method as well – Emre Özincegedik Aug 20 '20 at 05:16
  • `tf.image` has a bunch of random transformations, for instance https://www.tensorflow.org/api_docs/python/tf/image/random_contrast. BTW, `image_dataset_from_directory` returns a `tf.data.Dataset` so everything in the former can be implemented in the latter – Nicolas Gervais Aug 20 '20 at 05:29
  • I see thanks for detailed explanation. Due to my stupidity I couldn't handle this as well lol. I asked another question for that so this question stays in it's own scope. [link](https://stackoverflow.com/questions/63581973/input-and-output-image-dataset-augmentation) – Emre Özincegedik Aug 25 '20 at 15:21
  • How can I use non-TF functions in `process_img`? – hans Mar 12 '21 at 11:25
  • @NicolasGervais model.fit(ds, epochs=5) ds is files_ds right? – justRandomLearner May 31 '23 at 16:38