To the best of my knowledge, there is no more efficient than implementing it yourself in PyTorch, i.e., there exists no simple argument option.
As you said, the standard mode is tensorflow's 'concat'
. If we wanna verify this, we can test it as follows:
import torch
from torch import nn
# Create the LSTMs
in_dim = 5
out_dim = 100
lstm = nn.LSTM(in_dim, out_dim, batch_first=True)
bilstm = nn.LSTM(in_dim, out_dim, batch_first=True, bidirectional=True)
# Copy forward weights
bilstm.weight_ih_l0 = lstm.weight_ih_l0
bilstm.weight_hh_l0 = lstm.weight_hh_l0
bilstm.bias_ih_l0 = lstm.bias_ih_l0
bilstm.bias_hh_l0 = lstm.bias_hh_l0
# Execute on random example
x = torch.randn(1, 3, in_dim)
output1, (h_n1, c_n1) = lstm(x)
output2, (h_n2, c_n2) = bilstm(x)
# Assert equality of the forward loops
assert torch.allclose(output1, output2[:, :, :out_dim]) # Output is the same
assert torch.allclose(h_n1, h_n2[0]) # Hidden state is the same
assert torch.allclose(c_n1, c_n2[0]) # Cell state is the same
In the following examples for the other three merge modes for future reference:
Initialization
in_dim = 5
out_dim = 100
bilstm = nn.LSTM(in_dim, out_dim, batch_first=True, bidirectional=True)
x = torch.randn(1, 3, in_dim)
output, (h_n, c_n) = bilstm(x)
Sum ('sum'
)
# Merge Mode: 'sum'
# Simple version
output_sum = output[:, :, :out_dim] + output[:, :, out_dim:]
assert output_sum.shape == (1, 3, out_dim)
# Faster version
output_sum2 = torch.sum(output.view(x.size(0), x.size(1), 2, -1), dim=2)
assert torch.allclose(output_sum, output_sum2)
On my machine, the "faster version" needs approx. half the time of the simple version.
Multiplication ('mul'
)
# Merge Mode: 'mul'
output_mul = output[:, :, :out_dim] * output[:, :, out_dim:]
assert output_mul.shape == (1, 3, out_dim)
# Faster version
output_mul2 = torch.prod(output.view(x.size(0), x.size(1), 2, -1), dim=2)
assert torch.allclose(output_mul, output_mul2)
Average ('ave'
)
# Merge Mode: 'ave'
# Simple version
output_ave = (output[:, :, :out_dim] + output[:, :, out_dim:]) / 2
assert output_ave.shape == (1, 3, out_dim)
# Faster version
output_ave2 = torch.mean(output.view(x.size(0), x.size(1), 2, -1), dim=2)
assert torch.allclose(output_ave, output_ave2)
Again, the faster version takes approx. 50% of the time of the simple version on my device.
I hope this helps people finding this in the future. :)