I'm trying to implement Attention
mechanism in order to produce abstractive text summarization using Keras
by taking a lot of help from this GitHub thread where there is a lot of informative discussion about the implementation. I'm struggling to understand certain very basic bits of the code and what will I need to modify to successfully get the output. I know that attention
is the weighted sum of the context vector generated through all hidden states of all the previous timestamps and that is what we are trying to do below.
Data:
I got the BBC news dataset consists of news text and the headlines for various categories such as Politics, Entertainment, and Sports.
Parameters:
n_embeddings = 64
vocab_size = len(voabulary)+1
max_len = 200
rnn_size = 64
Code:
_input = Input(shape=(max_len,), dtype='int32')
embedding = Embedding(input_dim=vocab_size, output_dim=n_embeddings, input_length=max_len)(_input)
activations = LSTM(rnn_size, return_sequences=True)(embedding)
# compute importance for each step
attention = Dense(1, activation='tanh')(activations)
attention = Flatten()(attention)
attention = Activation('softmax')(attention)
attention = RepeatVector(units)(attention)
attention = Permute([2, 1])(attention)
# apply the attention
sent_representation = merge([activations, attention], mode='mul')
sent_representation = Lambda(lambda xin: K.sum(xin, axis=1))(sent_representation)
probabilities = Dense(max_len, activation='softmax')(sent_representation)
model = Model(input=_input, output=probabilities)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=[])
print(model.summary())
My Questions:
- The thread linked is trying to use
Attention
for classification whereas I want to generate a text sequence (summary) so how should I utilize thesent_probabilites
and decode to generate the summary? - What is
RepeatVector
used here for? Is it for getting theactivation
and attention probability of each word at timestampT
? - I didn't find much explanation of what
Permute
layer does? - what is
Lambda(lambda xin: K.sum(xin, axis=1))(sent_representation)
for? - How does
model.fit()
look like? I have created the padded sequence of fixed length ofX
andy
.
I would really appreciate any help you could provide. Thanks a lot in advance.