I applied queued reader to tensorflow seq2seq to avoid reading the whole dataset into memory and process them all in advance. I didn't bucket the dataset into different bucket files first to ensure one bucket-size per batch for that will also take a lot of time. As a consequence, each batch of data from queue reader may contain sequences of different bucket-size, which lead to a failure to run the original seq2seq model (It assume that data in one batch is of the same bucket-size, and only chose one sub-graph depending on the bucket-size to execute)
What i have tried:
In the original implementation, sub-graphs, as many as buckets, are constructed to share the same parameters. The only difference between them is the times of computation that should be taken during it's RNN process.
I changed the sub-graph to a conditional one, which, when the switch
is True, will compute the bucket_loss
of this bucket and add it to loss_list
and when the switch
is False, will do nothing and add tf.constant(0.0)
to loss_list
. Finally, I use total_loss = tf.reduce_sum(loss_list)
to collect all the losses and constructed gradient graph on it. Also, I feed a switches_list
into model at every step. The size of switches_list
is the same as that of buckets, and if there is any data of the ith bucket-size in this batch, the corresponding ith switch in switches_list
will be True, otherwise False.
The Problems encountered:
- when the backpropagation process went through the
tf.cond(...)
node, I was warned bygradient.py
that some sparse tensors are transformed to dense one - when I tried to fetch the
total_loss
orbucket_loss
, I was told:
ValueError: Operation u'cond/model_with_one_buckets/sequence_loss/truediv' has been marked as not fetchable.
Would you please help me:
- How can I solve the two problems above?
- How should I modify the graph to meet my requirement?
- Any better ideas for training data of different bucket-size in one batch?
- Any better ideas for applying asynchronous queue reader to seq2seq framework without bucketing the whole dataset first?