I am trying to generate comments from subreddit posts using images, titles, and source subreddit of the post.
if you don't know what a subreddit is, just think about it as a category of the post e.g cats, dogs, cars
I am using CNN for images, a simple neural network for subreddits, LSTM for titles, and finally LSTM for comment generation.
I made it work but it was hard to code inference part, so I refactored it so each component of the final model is a small model. But when I try to put everything together I am getting an error shown below.
How can I fix it and what is the cause of the problem?
Error
AssertionError Traceback (most recent call last)
/tmp/ipykernel_42/2610877267.py in <module>
5 inputs = encoder.inputs
6
----> 7 result = decoder(decoder_target, encoder(inputs))
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
950 if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
951 return self._functional_construction_call(inputs, args, kwargs,
--> 952 input_list)
953
954 # Maintains info about the `Layer.call` stack.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
1089 # Check input assumptions set after layer building, e.g. input shape.
1090 outputs = self._keras_tensor_symbolic_call(
-> 1091 inputs, input_masks, args, kwargs)
1092
1093 if outputs is None:
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs)
820 return nest.map_structure(keras_tensor.KerasTensor, output_signature)
821 else:
--> 822 return self._infer_output_signature(inputs, args, kwargs, input_masks)
823
824 def _infer_output_signature(self, inputs, args, kwargs, input_masks):
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in _infer_output_signature(self, inputs, args, kwargs, input_masks)
861 # TODO(kaftan): do we maybe_build here, or have we already done it?
862 self._maybe_build(inputs)
--> 863 outputs = call_fn(inputs, *args, **kwargs)
864
865 self._handle_activity_regularization(inputs, outputs)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in call(self, inputs, training, mask)
423 """
424 return self._run_internal_graph(
--> 425 inputs, training=training, mask=mask)
426
427 def compute_output_shape(self, input_shape):
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in _run_internal_graph(self, inputs, training, mask)
567 for x in self.outputs:
568 x_id = str(id(x))
--> 569 assert x_id in tensor_dict, 'Could not compute output ' + str(x)
570 output_tensors.append(tensor_dict[x_id].pop())
571
AssertionError: Could not compute output KerasTensor(type_spec=TensorSpec(shape=(None, 80, 5000), dtype=tf.float32, name=None), name='dense_25/truediv:0', description="created by layer 'dense_25'")
Title encoder
Takes tokenized title of the post embeded it and returns state of the LSTM.
def build_title_embeddings():
encoder_inputs = Input(shape=(max_title_len), name='Post title')
encoder_emb = Embedding(vocab_size, title_embedding_size)(encoder_inputs)
encoder = LSTM(title_hidden_size, return_state=True)
_, state_h, state_c = encoder(encoder_emb)
return Model(inputs=encoder_inputs, outputs=[state_h, state_c], name="Title encoder")
Cnn
Cnn pretrained on imagenet from keras.applications.EfficientNetB3
def build_cnn():
cnn = EfficientNetB3(include_top=False, input_shape=(300,300,3), pooling="avg")
cnn.trainable = False
return Model(inputs=cnn.inputs, outputs=cnn.outputs, name="cnn")
Subreddit embedding
Takes as input one hot encoded vector. It pupose is reduce dimentionality of the input
def build_subreddit_embeddings():
input_subreddit = Input(shape=(num_subreddits), name='One hot encoded subreddit')
sub_emb = Dense(subreddit_embedding_size, activation='relu')(input_subreddit)
return Model(inputs=input_subreddit, outputs=sub_emb, name="Subreddit embedding")
Final encoder
Takes everything above and produce initial state for the decoder
def build_final_encoder():
cnn = build_cnn()
subreddit_embeddings = build_subreddit_embeddings()
title_embeddings = build_title_embeddings()
merged = Concatenate()([cnn.output, subreddit_embeddings.output, *title_embeddings.output])
intermediate_layer = Dense(intermediate_layer_size, activation='relu')(merged)
state1 = Dense(decoder_hidden_size, activation='relu')(intermediate_layer)
state2 = Dense(decoder_hidden_size, activation='relu')(intermediate_layer)
return Model(inputs=[cnn.inputs, subreddit_embeddings.inputs, title_embeddings.inputs], outputs=[state1,state2], name="post_encoder")
Decoder
Takes as an input output of the final encoder and produces comment tokens.
decoders inputs are previous tokens predicted by itself
def build_decoder():
decoder_inputs = Input(shape=(max_decoder_len,), name="Decoder inputs")
state_inputs1 = Input(shape=(decoder_hidden_size,), name="Decoder state input 1")
state_inputs2 = Input(shape=(decoder_hidden_size,), name="Decoder state input 2")
decoder_emb = Embedding(vocab_size, decoder_embedding_size)(decoder_inputs)
decoder_lstm = LSTM(decoder_hidden_size, return_sequences=True, return_state=True, name="decoder_lstm")
decoder_outputs, _, _ = decoder_lstm(decoder_emb, initial_state=[state_inputs1, state_inputs2])
decoder_outputs = Dense(vocab_size, activation="softmax")(decoder_outputs)
return Model(inputs=[decoder_inputs, [state_inputs1,state_inputs2]], outputs=decoder_outputs, name="final_decoder")
Attempt of bringing everything together
Decoder target is correct tokens ids shifted by one during training and previous prediction during inference
e.g. during training
Decoder target - ['<start_of_sentence_token>','I','like','pizza']
Correct answer - ['I','like','pizza','<end_of_sentence_token>']
decoder_target = Input(shape=(max_decoder_len,), name="Decoder target")
encoder = build_final_encoder()
decoder = build_decoder()
inputs = encoder.inputs
result = decoder(decoder_target, encoder(inputs)) # Errors happen here
```