Where the values are
rnn_size: 512
batch_size: 128
rnn_inputs: Tensor("embedding_lookup/Identity_1:0", shape=(?, ?, 128), dtype=float32)
sequence_length: Tensor("inputs_length:0", shape=(?,), dtype=int32)
cell_fw: <tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.DropoutWrapper object at 0x7f4f534eb6d0>
cell_bw: <tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.DropoutWrapper object at 0x7f4f534eb910>
Getting the enc_state value by
enc_output, enc_state = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw,
cell_bw,
rnn_inputs,
sequence_length,
dtype=tf.float32)
Where the enc_state value is
enc_state: LSTMStateTuple(c=<tf.Tensor 'RNN_Encoder_Cell_2D/encoder_1/bidirectional_rnn/fw/fw/while/Exit_3:0' shape=(?, 512) dtype=float32>, h=<tf.Tensor 'RNN_Encoder_Cell_2D/encoder_1/bidirectional_rnn/fw/fw/while/Exit_4:0' shape=(?, 512) dtype=float32>)
TF1 code:
initial_state = tf.contrib.seq2seq.DynamicAttentionWrapperState(enc_state,
_zero_state_tensors(rnn_size,
batch_size,
tf.float32))
converting into TF2 by
initial_state = tfa.seq2seq.AttentionWrapper(enc_state,_zero_state_tensors(rnn_size, batch_size, tf.float32))
Getting error:
TypeError Traceback (most recent call last)
<ipython-input-54-d87646b9df5d> in <module>()
8 threshold)
9 model = build_graph(keep_probability, rnn_size, num_layers, batch_size,
---> 10 learning_rate, embedding_size, direction)
11 train(model, epochs, log_string)
6 frames
/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py in check_type(argname, value, expected_type, memo)
596 raise TypeError(
597 'type of {} must be {}; got {} instead'.
--> 598 format(argname, qualified_name(expected_type), qualified_name(value)))
599 elif isinstance(expected_type, TypeVar):
600 # Only happens on < 3.6
TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead
Also can you explain the last line of the error i.e
TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead