For a project, I am using tf.data.Dataset to write the input pipeline. The input is an image RGB The label is a list of 2D coordinates of of objects in the image that used to generate a heatmap
Here is a MWE to reproduce the problem.
def encode_images(image, label):
"""
Parameters
----------
image
label
Returns
-------
"""
# load image
# here the normal code
# img_contents = tf.io.read_file(image)
# # decode the image
# img = tf.image.decode_jpeg(img_contents, channels=3)
# img = tf.image.resize(img, (256, 256))
# img = tf.cast(img, tf.float32)
# this is just for testing
image = tf.random.uniform(
(256, 256, 3), minval=0, maxval=255, dtype=tf.dtypes.float32, seed=None, name=None
)
return image, label
def generate_heatmap(image, label):
"""
Parameters
----------
image
label
Returns
-------
"""
start = 0.5
sigma=3
img_shape = (image.shape[0] , image.shape[1] )
density_map = np.zeros(img_shape, dtype=np.float32)
for center_x, center_y in label[0]:
for v_y in range(img_shape[0]):
for v_x in range(img_shape[1]):
x = start + v_x
y = start + v_y
d2 = (x - center_x) * (x - center_x) + (y - center_y) * (y - center_y)
exp = d2 / (2.0 * sigma**2)
if exp > math.log(100):
continue
density_map[v_y, v_x] = math.exp(-exp)
return density_map
X = ["img1.png", "img2.png", "img3.png", "img4.png", "img5.png"]
y = [[[2, 3], [100, 120], [100, 120]],
[[2, 3], [100, 120], [100, 120], [2, 1]],
[[2, 3], [100, 120], [100, 120], [10, 10], [11, 12]],
[[2, 3], [100, 120], [100, 120], [10, 10], [11, 12], [10, 2]],
[[2, 3], [100, 120], [100, 120]]
]
dataset = tf.data.Dataset.from_tensor_slices((X, tf.ragged.constant(y)))
dataset = dataset.map(encode_images, num_parallel_calls=8)
dataset = dataset.map(generate_heatmap, num_parallel_calls=8)
dataset = dataset.batch(1, drop_remainder=False)
The problem is that in generate_heatmap()
function, I have used numpy array to modify the elements by indices which is ~ not possible in tensorflow. I try to iterate over the label tensor which is not possible in tensorflow till now. The other things is that the eager mode seems not enabled in tf.data.Dataset
!! Any suggestion to deal with that! I think in pytorch such code can be done quickly without suffering :) !