23

I would like to use a tree-LSTM in , similar to what is described in this article: https://arxiv.org/abs/1503.00075. It is essentially similar to a Long Short-Term Memory network, but with a tree-like input sequence instead of a chain-like one.

I think it is a relatively standard architecture and would find uses in a lot of contexts, but I couldn't find any public implementation of it. Is this something that already exists somewhere?

The closest I could find is this implementation: https://github.com/stanfordnlp/treelstm, but that won't integrate well with the rest of my project.

Question is, how can I implement Tree-RNN or Tree-LSTM in ? FYI, it wasn't possible (AFAIK) to implement such architecture with sequential or functional API but it can be implemented in subclassed API introduced in , source.

Innat
  • 16,113
  • 6
  • 53
  • 101
  • Due to the fact that "tree-LSTM" is an introduced concept of the given article, clearly, there won't be a keras implementation to it. Nonetheless, you can use the code [here](https://adventuresinmachinelearning.com/keras-lstm-tutorial/) (also see the relevant git that he points to) to see how LSTM is handled in keras and make the necessary adaptations to it that will transform it to a "tree-LSTM" as the article states. – LemonPy Jan 22 '19 at 15:37
  • 1
    Thanks for the reply. I am not sure what you mean by introduced concept, LSTMs are also a concept that was introduced in an article (Hochreiter & Schmidhuber, 1997), and the Tree-LSTM paper is 4 years old with almost 1000 citations so it is not a fringe object. They are substantially different from a simple LSTM as it is a recursive neural network and not a recurrent one like a LSTM. – deSitterUniverse Jan 23 '19 at 16:10
  • Exactly my point. LSTM was introduced 22 years ago and has over 15,000 citations - more than an order of magnitude over tree-LSTM. If you go to the actual code of it (https://github.com/keras-team/keras/blob/master/keras/layers/recurrent.py#L2051) you will see that it was only written in 2015. – LemonPy Jan 24 '19 at 13:27
  • 2
    Well, the first release of keras was in 2015 so it would have been difficult to write the code much before :) But point taken, I will check back in 2033! – deSitterUniverse Jan 25 '19 at 18:32
  • 1
    Why this issue is closed? This is a valid question. – Exploring May 12 '21 at 23:34
  • Voting to re-open this question, it's valid to ask. – Innat May 29 '21 at 01:37

2 Answers2

1

You can implement a tree-LSTM in Keras using the Subclassing API. This will allow you to define your own custom layers and models by subclassing the tf.keras.layers.Layer and tf.keras.Model classes, respectively.

To implement a tree-LSTM in the Subclassing API, you will need to define a custom layer that takes a tree-structured input and applies the LSTM operation to each node in the tree. Here is some pseudocode that outlines the steps you can follow:

class TreeLSTMLayer(tf.keras.layers.Layer):
  def __init__(self, units, **kwargs):
    super(TreeLSTMLayer, self).__init__(**kwargs)
    self.units = units

  def build(self, input_shape):
    # Define the weight matrices and biases for the LSTM operation
    # (e.g., self.W_i, self.W_f, self.W_o, self.W_c, self.b_i, etc.)
    # based on the number of units in the layer
    # (e.g., input_dim = units, output_dim = units)
    # and the input shape of the tree (i.e., input_shape[0])

  def call(self, inputs):
    # Unpack the inputs into the tree structure and the initial states
    # (e.g., tree, h_0, c_0 = inputs)

    # Initialize a list to store the output states for each node in the tree
    output_states = []

    # Recursively traverse the tree and apply the LSTM operation
    # at each node, updating the hidden and cell states as you go
    # (e.g., h_t, c_t = lstm(x_t, h_t-1, c_t-1))
    def traverse_tree(node, h_t, c_t):
      # Apply the LSTM operation to the current node
      # (e.g., i_t, f_t, o_t, g_t = lstm(x_t, h_t, c_t))
      # Update the hidden and cell states
      # (e.g., c_t = f_t * c_t + i_t * g_t, h_t = o_t * tf.tanh(c_t))
      output_states.append((h_t, c_t))
      # Recursively traverse the children of the current node
      for child in node.children:
        traverse_tree(child, h_t, c_t)

    # Start the recursive traversal at the root of the tree
    traverse_tree(tree.root, h_0, c_0)

    # Return the output states for each node in the tree
    return output_states

Once you have defined your custom TreeLSTMLayer, you can use it to build a tree-LSTM model by subclassing the tf.keras.Model class and using the TreeLSTMLayer as one of the layers in your model.

s16h
  • 4,647
  • 1
  • 21
  • 33
-2

You can flattened the input tree, and map the question into the normal use case of a flat input.

The effective solution will be the same as long as you make sure that your degrees of freedom are the same.

DataYoda
  • 771
  • 5
  • 18