Tensorflow beginner here. This is my first project and I am working with pre-defined estimators.
I have an extremely unbalanced dataset where positive outcomes represent roughly 0.1% of the total data and I suspect this imbalance to considerably affect the performance of my model. As a first attempt to solve the issue, since I have tons of data, I would like to throw away most of my negatives in order to create a balanced dataset. I can see two ways of doing it: preprocess the data to keep only a thousandth of the negatives then save it in a new file before passing it to tensorflow, for example with pyspark; and asking tensorflow to use only one negative out of a thousand it finds.
I tried to code this last idea but didn't manage. I modified my input function to read like
def train_input_fn(data_file="../data/train_input.csv", shuffle_size=100_000, batch_size=128):
"""Generate an input function for the Estimator."""
dataset = tf.data.TextLineDataset(data_file) # Extract lines from input files using the Dataset API.
dataset = dataset.map(parse_csv, num_parallel_calls=3)
dataset = dataset.shuffle(shuffle_size).repeat().batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
# TRY TO IMPLEMENT THE SELECTION OF NEGATIVES
thrown = 0
flag = np.random.randint(1000)
while labels == 0 and flag != 0:
features, labels = iterator.get_next()
thrown += 1
flag = np.random.randint(1000)
print("I've thrown away {} negative examples before going for label {}!".format(thrown, labels))
return features, labels
This, of course, doesn't work because iterators don't know what's inside them, so the labels==0 condition is never satisfied. Also, there is only one print in the stdout, meaning that this function is only called once (and meaning that I still don't understand how tensorflow really works). Anyways, is there a way to implement what I want?
PS: I suspect that the previous code, even if it worked as intended, would return less than a thousandth of the initial negatives due to the count restarting every time it finds a positive. This is a minor issue, and so far I could even find a magic number inside the flag that gives me the expected result without worrying too much about the mathematical beauty of it.