1

I am trying to apply a filter to a tf.data.Dataset which removes any strings where one group > 50% of the string. Here is my Dataset:

import tensorflow as tf


strings = [
    ["ABCDEFGABCDEFG\tUseless\tLabel1"],
    ["AAAAAAAADEFGAB\tUseless\tLabel2"],
    ["HIJKLMNHIJKLMN\tUseless\tLabel3"],
    ["HIJKLMMMMMMMNH\tUseless\tLabel4"],
]
ds = tf.data.Dataset.from_tensor_slices(strings)

def _clean(x):
    x = tf.strings.split(x, "\t")
    return x[0], x[2]

def _filter(x):
    s = tf.strings.bytes_split(x)
    _, _, count = tf.unique_with_counts(s)
    percent = tf.reduce_max(count) / tf.shape(s)[0]
    return tf.less_equal(percent, 0.5)

ds = ds.map(_clean)
ds = ds.filter(lambda x, y: _filter(x))

for x, y in ds:
    tf.print(x, y)

This creates the following error:

TypeError: Failed to convert elements of tf.RaggedTensor(values=Tensor("StringsByteSplit/StringSplit:1", shape=(None,), dtype=string), row_splits=Tensor("StringsByteSplit/RaggedFromValueRowIds/RowPartitionFromValueRowIds/concat:0", shape=(None,), dtype=int64)) to Tensor. Consider casting elements to a supported type.

Any way to solve this problem in a tf.data.Dataset graph?

Oliver
  • 281
  • 3
  • 14

1 Answers1

1

You can solve this using tf.strings:

import tensorflow as tf

def filter_data(x):
  s = tf.strings.strip(tf.strings.regex_replace(x, '', ' '))
  s = tf.strings.split(s, sep=" ")
  _, _, count = tf.unique_with_counts(s)
  return tf.less_equal(tf.reduce_max(count) / tf.shape(s)[0], 0.25)

ds = tf.data.Dataset.from_tensor_slices([["AAAABBBCC", "Label1"], ["AAAAAABC", "Label2"], ["ABBAABCCCCAB", "Label3"], ["ABDC", "Label4"]])
ds = ds.map(lambda x: (x[0], x[1]))

ds = ds.filter(lambda x, y: filter_data(x))
for x, y in ds:
  tf.print(x, y)
"ABDC" "Label4"

However, I would reconsider the threshold of 25% as all the samples in your example dataset are above this threshold and therefore not added to the dataset. I have added a fourth example to your dataset to show that the method works with tf.less_equal.

Take for example AAAABBBCC, A occurs most often (4 times) and is divided by the total length of the string (9), giving 4/9=0.44, which means it is excluded from the dataset. Maybe this behavior is desired. Anyway, I just wanted to inform you about it.

AloneTogether
  • 25,814
  • 5
  • 20
  • 39
  • 1
    Thank you so much for your help!! One issue: the original strings are actually tab seperated strings, so earlier in the pipeline, i use ```tf.strings.split(x, "\t")``` to split them into separate parts. This raises an issue for filtering. I've edited the question so it serves as a better example. – Oliver Feb 01 '22 at 17:27
  • 1
    I think you can figure that out using the `tf.strings` tools ;) – AloneTogether Feb 01 '22 at 17:30
  • 1
    Also, ```s=tf.strings.bytes_split()``` is a great way to do the first two lines of the filter function :) – Oliver Feb 01 '22 at 17:31
  • 1
    Change your clean function to this and all will work: `def _clean(x): x = tf.squeeze(tf.strings.split(x, "\t"), axis=0) return x[0], x[2]` – AloneTogether Feb 01 '22 at 17:36
  • You are a wizard ... why does removing an axis solve the issue? – Oliver Feb 01 '22 at 17:43
  • 1
    A ragged tensor is always multidimensional – AloneTogether Feb 01 '22 at 17:44