0

I'm filtering the dataset according to certain labels. Once I call the filtering method, everything is fine. But once I call next(iter(dataset))for certain values it gets processing for more the 12 hours - for other value it just give the result.

My filtering line code is:

 def balanced_dataset(dataset, labels_list, sample_size=1000):
    datasets_list = []
       for label in labels_list:
          print(f'Preparando o dataset {label}')
          locals()[label] = dataset.filter(lambda x, y: tf.greater(tf.reduce_sum(tf.cast(tf.equal(tf.constant(label, dtype=tf.int64), y), tf.float32)), tf.constant(0.)))
          datasets_list.append(locals()[label].take(sample_size))
      ds = tf.data.Dataset.from_tensor_slices(datasets_list)
      # 2. extract all elements from datasets and concat them into one dataset
      concat_ds = ds.interleave(lambda x: x, cycle_length=len(labels_list), num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)   


    return concat_ds 
Marlon Teixeira
  • 334
  • 1
  • 14

0 Answers0