I am trying to use tf.data pipeline to get finer control over loading image data but I receive the following error which I think is because of usage of list comprehension. My Code looks like this:
def load_files(data_dir: str, val_split=0.2):
assert len(os.listdir(os.path.join(data_dir + 'images/'))) == \
len(os.listdir(os.path.join(data_dir, 'ground_truth/'))), print("No. of image files != No. of gt files")
image_count = len(os.listdir(os.path.join(data_dir + 'images/')))
files = os.listdir(os.path.join(data_dir + 'images/'))
image_files = [os.path.join(data_dir + 'images/', file) for file in files]
image_files = np.array(image_files)
ds = tf.data.Dataset.from_tensor_slices(files)
ds = ds.map(process_data)
# train_ds = ds.skip(int(val_split * image_count))
# val_ds = ds.take(int(val_split * image_count))
return ds
def process_data(file_path):
image, = tf.io.read_file(file_path)
image = tf.io.decode_jpeg(image, channels=3)
label = tf.strings.split(file_path)
label = tf.io.decode_png(label, channels=0, dtype=tf.uint8)
return image, label
some_dir = "../../../TuSimple_lane_detection/"
img_dir = some_dir + "images/"
mask_dir = some_dir + "ground_truth/"
data_train = load_files(some_dir)
for f in data_train.take(5):
print(f.numpy())
The error looks like this:
File "E:\Datasets\KITTI_3D_Object_detection\venv\PycharmProjects\lib\site-packages\tensorflow\python\eager\function.py", line 3210, in _get_concrete_function_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "E:\Datasets\KITTI_3D_Object_detection\venv\PycharmProjects\lib\site-packages\tensorflow\python\eager\function.py", line 3557, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "E:\Datasets\KITTI_3D_Object_detection\venv\PycharmProjects\lib\site-packages\tensorflow\python\eager\function.py", line 3392, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "E:\Datasets\KITTI_3D_Object_detection\venv\PycharmProjects\lib\site-packages\tensorflow\python\framework\func_graph.py", line 1143, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "E:\Datasets\KITTI_3D_Object_detection\venv\PycharmProjects\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 4510, in wrapped_fn
ret = wrapper_helper(*args)
File "E:\Datasets\KITTI_3D_Object_detection\venv\PycharmProjects\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 4440, in wrapper_helper
ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
File "E:\Datasets\KITTI_3D_Object_detection\venv\PycharmProjects\lib\site-packages\tensorflow\python\autograph\impl\api.py", line 699, in wrapper
raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in user code:
File "E:/Datasets/KITTI_3D_Object_detection/KITTI_2D/EndToEndLaneDetection/Dataloader.py", line 21, in process_data *
image, = tf.io.read_file(file_path)
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
I did not have the problem when I was working with TF1.12 before, when I shifted o TF 2.3 I am encountering this error.