8

I'm using Keras for the rest of my project, but also hoping to make use of the Bahdanau attention module that Tensorflow has implemented (see tf.contrib.seq2seq.BahdanauAttention). I've been attempting to implement this via the Keras Layer convention, but not sure whether this is an appropriate fit.

Is there some convention for wrapping Tensorflow components in this way to be compatible with the computation graph?

I've included the code that I've written thus far (not working yet) and would appreciate any pointers.

from keras import backend as K
from keras.engine.topology import Layer
from keras.models import Model
import numpy as np
import tensorflow as tf

class BahdanauAttention(Layer):

# The Bahdanau attention layer has to attend to a particular set of memory states
# These are usually the output of some encoder process, where we take the output of
# GRU states
def __init__(self, memory, num_units, **kwargs):
    self.memory = memory
    self.num_units = num_units
    super(BahdanauAttention, self).__init__(**kwargs)

def build(self, input_shape):
    # The attention component will be in control of attending to the given memory
    attention = tf.contrib.seq2seq.BahdanauAttention(self.num_units, self.memory)
    cell = tf.contrib.rnn.GRUCell(num_units)

    cell_with_attention = tf.contrib.seq2seq.DynamicAttentionWrapper(cell, attention, num_units)
    self.outputs, _ = tf.nn.dynamic_rnn(cell_with_attention, inputs, dtype=tf.float32)

    super(MyLayer, self).build(input_shape)

def call(self, x):
    return

def compute_output_shape(self, input_shape):
    return (input_shape[0], self.memory[1], self.num_units)
individualtermite
  • 3,615
  • 16
  • 49
  • 78

1 Answers1

1

The newer version of Keras uses tf.keras.layers.AdditiveAttention(). This should work off the shelf.

Alternatively a custom Bahdanau layer can be written as shown in half a dozen lines of code: Custom Attention Layer using in Keras

Allohvk
  • 915
  • 8
  • 14