0

I have been recently working on tensorflow. I am coding seq2seq model. I am in the process of writing a condition to select the helper provided by the API.

When i am using following code i am facing the error.

Training Helper
helper1 = tf.contrib.seq2seq.TrainingHelper(inputs = decoder_embedded_input,sequence_length = dec_seqLen,time_major=True)


helper2 = tf.contrib.seq2seq.GreedyEmbeddingHelper(output_embedding,
                                                   tf.fill([batchSize], outT2N['<GO>']),
                                                   outT2N['<EOS>'])
helperDecider = tf.placeholder(tf.bool)
# when 0 : helper2
# when 1 : helper1
helper = tf.cond(helperDecider,helper1,helper2)

I am getting error helper must be callable, So i changed the code to

def helper1():
    return tf.contrib.seq2seq.TrainingHelper(inputs = dec_embedded_input,sequence_length = dec_seqLen,time_major=True)

def helper2():
    return tf.contrib.seq2seq.GreedyEmbeddingHelper(embeddingMatrixOut,
                                                   tf.fill([batchSize], outT2N['<GO>']),
                                                   outT2N['<EOS>'])
helperDecider = tf.placeholder(tf.bool)
# when 0 : helper2
# when 1 : helper1

helper = tf.cond(helperDecider,helper1,helper2)

decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell,helper,encoder_final_state ,output_layer = projection_layer)

Now it is throwing error,

Expected binary or unicode string, got <tensorflow.contrib.seq2seq.python.ops.helper.TrainingHelper object at 0x7fc32b96b908>

So, finally i choosed the good-old if-else and its working like it should. I just need whether it is valid or not to use following code.

#Training Helper
helper1 = tf.contrib.seq2seq.TrainingHelper(inputs = dec_embedded_input,sequence_length = dec_seqLen,time_major=True)


helper2 = tf.contrib.seq2seq.GreedyEmbeddingHelper(embeddingMatrixOut,
                                                   tf.fill([batchSize], outT2N['<GO>']),
                                                   outT2N['<EOS>'])
helperDecider = tf.placeholder(tf.bool)
# when 0 : helper2
# when 1 : helper1

if someCondition:
    helper = helper1
else:
    helper = helper2


decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell,helper,encoder_final_state ,output_layer = projection_layer)

Possible mistake can be that i can not change to other helper during the run-time since hard coded. Can someone suggest an alternative approach ?

lifeisshubh
  • 513
  • 1
  • 5
  • 27

1 Answers1

0

Normal if conditions are not valid in tensorflow as the graph is defined as soon as if is evaluated (it will always be evaluated to true and helper will be set to helper1).

For this use can use tf.cond to define conditional graph, as you did in the first snippet. The error you got is related to incorrect use of tf.cond function, it should go like this

helper = tf.cond(helperDecider,lambda: helper1, lambda: helper2)

The second and third argument of the function should be a function which it was not. Hope this works..