11

I would like to create an LSTM class by myself, however, I don't want to rewrite the classic LSTM functions from scratch again.

Digging in the code of PyTorch, I only find a dirty implementation involving at least 3-4 classes with inheritance:

  1. https://github.com/pytorch/pytorch/blob/98c24fae6b6400a7d1e13610b20aa05f86f77070/torch/nn/modules/rnn.py#L323
  2. https://github.com/pytorch/pytorch/blob/98c24fae6b6400a7d1e13610b20aa05f86f77070/torch/nn/modules/rnn.py#L12
  3. https://github.com/pytorch/pytorch/blob/98c24fae6b6400a7d1e13610b20aa05f86f77070/torch/nn/_functions/rnn.py#L297

Does a clean PyTorch implementation of an LSTM exist somewhere? Any links would help.

For example, I know that clean implementations of a LSTM exists in TensorFlow, but I would need to derive a PyTorch one.

For a clear example, what I'm searching for is an implementation as clean as this, but in PyTorch:

imbr
  • 6,226
  • 4
  • 53
  • 65
Guillaume Chevalier
  • 9,613
  • 8
  • 51
  • 79

2 Answers2

16

The best implementation I found is here
https://github.com/pytorch/benchmark/blob/master/rnns/benchmarks/lstm_variants/lstm.py

It even implements four different variants of recurrent dropout, which is very useful!
If you take the dropout parts away you get

import math
import torch as th
import torch.nn as nn

class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.i2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, x, hidden):
        h, c = hidden
        h = h.view(h.size(1), -1)
        c = c.view(c.size(1), -1)
        x = x.view(x.size(1), -1)

        # Linear mappings
        preact = self.i2h(x) + self.h2h(h)

        # activations
        gates = preact[:, :3 * self.hidden_size].sigmoid()
        g_t = preact[:, 3 * self.hidden_size:].tanh()
        i_t = gates[:, :self.hidden_size]
        f_t = gates[:, self.hidden_size:2 * self.hidden_size]
        o_t = gates[:, -self.hidden_size:]

        c_t = th.mul(c, f_t) + th.mul(i_t, g_t)

        h_t = th.mul(o_t, c_t.tanh())

        h_t = h_t.view(1, h_t.size(0), -1)
        c_t = c_t.view(1, c_t.size(0), -1)
        return h_t, (h_t, c_t)

PS: The repository contains many more variants of LSTM and other RNNs:
https://github.com/pytorch/benchmark/tree/master/rnns/benchmarks.
Check it out, maybe the extension you had in mind is already there!

EDIT:
As mentioned in the comments, you can wrap the LSTM cell above to process sequential output:

import math
import torch as th
import torch.nn as nn


class LSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        # As before

    def reset_parameters(self):
        # As before

    def forward(self, x, hidden):

        if hidden is None:
            hidden = self._init_hidden(x)

        # Rest as before

    @staticmethod
    def _init_hidden(input_):
        h = th.zeros_like(input_.view(1, input_.size(1), -1))
        c = th.zeros_like(input_.view(1, input_.size(1), -1))
        return h, c


class LSTM(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__()
        self.lstm_cell = LSTMCell(input_size, hidden_size, bias)

    def forward(self, input_, hidden=None):
        # input_ is of dimensionalty (1, time, input_size, ...)

        outputs = []
        for x in torch.unbind(input_, dim=1):
            hidden = self.lstm_cell(x, hidden)
            outputs.append(hidden[0].clone())

        return torch.stack(outputs, dim=1)

I havn't tested the code since I'm working with a convLSTM implementation. Please let me know if something is wrong.

UPDATE: Fixed links.

Richard
  • 1,020
  • 8
  • 16
  • After a bit of testing, I realized that it says `For now, they only support a sequence size of 1`. So this code might need a lot of refactoring to be useable. – Guillaume Chevalier May 04 '18 at 20:37
  • 3
    The code I've given above is often referred to as an LSTM cell. In order to process a sequential input, just wrap it in a module which sets initial hidden states and then iterates over the temporal dimension of the input, calling the LSTM cell at each time point (similar to how it is done here https://discuss.pytorch.org/t/implementation-of-multiplicative-lstm/2328/9) – Richard May 06 '18 at 23:12
  • 1
    Also have a look how it is done here vor a convolutional LSTM https://github.com/automan000/Convolution_LSTM_PyTorch/blob/master/convolution_lstm.py. It’s very similar to the regular LSTM. – Richard May 06 '18 at 23:20
  • Could you explain why you use add in an extra dimension for h and c using `view` in the static method? – An Ignorant Wanderer May 31 '20 at 02:48
  • same question for input_ dimensionality. Why is there a leading 1? – An Ignorant Wanderer May 31 '20 at 03:03
  • 1
    @AnIgnorantWanderer this is because this implementation assumes a batch size of 1 so that it works for input sequences of any length. If you'd want to use larger batch sizes you would either need a fixed sequence length or [sequence padding](https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pad_sequence) and you'd need to modify the code accordingly. – Richard Jun 02 '20 at 11:39
1

I made a simple and general frame to customize LSTMs: https://github.com/daehwannam/pytorch-rnn-util

You can implement custom LSTMs by designing LSTM cells and providing them to LSTMFrame. An example of custom LSTM is LayerNormLSTM in the package:

# snippet from rnn_util/seq.py
class LayerNormLSTM(LSTMFrame):
    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0, r_dropout=0, bidirectional=False, layer_norm_enabled=True):
        r_dropout_layer = nn.Dropout(r_dropout)
        rnn_cells = tuple(
            tuple(
                LayerNormLSTMCell(
                    input_size if layer_idx == 0 else hidden_size * (2 if bidirectional else 1),
                    hidden_size,
                    dropout=r_dropout_layer,
                    layer_norm_enabled=layer_norm_enabled)
                for _ in range(2 if bidirectional else 1))
            for layer_idx in range(num_layers))

        super().__init__(rnn_cells, dropout, bidirectional)

LayerNormLSTM has the key options of PyTorch's standard LSTM and additional options, r_dropout and layer_norm_enabled:

# example.py
import torch
import rnn_util


bidirectional = True
num_directions = 2 if bidirectional else 1

rnn = rnn_util.LayerNormLSTM(10, 20, 2, dropout=0.3, r_dropout=0.25,
                             bidirectional=bidirectional, layer_norm_enabled=True)
# rnn = torch.nn.LSTM(10, 20, 2, bidirectional=bidirectional)

input = torch.randn(5, 3, 10)
h0 = torch.randn(2 * num_directions, 3, 20)
c0 = torch.randn(2 * num_directions, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))

print(output.size())
dhnam
  • 131
  • 1
  • 7