I have been implementing BeamSearch in my RNN Decoder in order to avoid repetitive predictions in my output sequences. Now I ran into the issue that my model quickly learns to predict the EOS
token immediately. I would have though that this might happen when the total probability of a sequence path is not normalized by its length, as shorter sequences have a higher probability, but I (thought) took care of that.
Would anybody be able to help me with this issue? My BeamSearch implementation is as follows:
class BeamSearchDecoder:
def __init__(self, decoder, vocab, b, d=-1, alpha=1, block_repeat=True):
self.decoder = decoder # The Decoder RNN
self.vocab = vocab
self.eos_idx = vocab.EOS['idx']
self.b = b # Beam Width
"""Search Depth - Take d to power of b because at each iteration we add
b items to stack which are all still on same depth level"""
self.d = (d**self.b)
self.alpha = alpha
# Block paths with immediate repetitions
self.block_repeat = block_repeat
self.roots = [] # Will be the root nodes of the current tree
self.leafs = [] # Will contain all leaf nodes after building the graph
def _build_tree(self, decoder_output, decoder_hidden):
"""
Builds the search tree.
Parameters
----------
decoder_output : torch.Tensor
Previous Output of the DecoderRNN.
decoder_hidden : torch.Tensor
Previous Hidden state of the DecoderRNN.
Returns
-------
None.
"""
self.clear_tree()
# Set the root nodes of this tree
topv, topi = decoder_output.topk(self.b)
for tv, ti in zip(topv.view(-1, 1), topi.view(-1, 1)):
self.roots.append(BeamNode(ti.item(), tv.item()))
# Disable gradient computation during beam search
with torch.no_grad():
stack = [*self.roots] # Nodes to iterate. Initialize with roots
depth = 0 # keep track of depth of search tree
while stack:
curr_node = stack.pop(0)
decoder_input = torch.tensor([[curr_node.idx]], device=self.device)
# Stop condition: EOS token
if decoder_input == self.eos_idx:
self.leafs.append(curr_node)
continue # Don't continue expanding current node
# Stop condition: Reached maximum search depth. self.d = self.d**self.b (see __init__)
if (self.d > 0) and (depth >= self.d):
self.leafs.append(curr_node)
continue # Don't continue expanding current node
# Stop condition: repeated prediction
if self.block_repeat:
if curr_node.parent:
if curr_node.idx == curr_node.parent.idx:
continue
# Do forward pass to obtain logits of next words (returns logsoftmax)
decoder_output, decoder_hidden = self.decoder(
decoder_input, decoder_hidden)
# Get top b childs
topv, topi = decoder_output.topk(self.b)
# For each child, calculate path probability and add it to stack
for tv, ti in zip(topv.view(-1, 1), topi.view(-1, 1)):
path_logit = curr_node.path_logit + tv.item() # total path logit
path_length = curr_node.path_length + 1 # depth of current child
# Set current node as child
child_node = BeamNode(
ti.item(), tv.item(), path_logit, path_length, parent=curr_node)
# Add `child_node` to childs of `curr_node`
curr_node.childs.append(child_node)
# Append to stack so child of `child_node` get expanded too
stack.append(child_node)
depth += 1
def _backtrace(self):
max_path_logit = -100000
max_leaf = None
for leaf in self.leafs:
# Normalizing logit of current path by its length (alpha is a hyperparameter)
path_logit = (1/pow(leaf.path_length, self.alpha))*leaf.path_logit
if path_logit > max_path_logit:
# Keep track of path with highest logit
max_path_logit = path_logit
max_leaf = leaf
# Backtrace leaf to root path in order to obtain the root node, that will be the output of the BeamSearch
root_node_idx = max_leaf.get_root().idx # word index in vocab
return root_node_idx
def decode(self, decoder_output, decoder_hidden):
self._build_tree(decoder_output, decoder_hidden)
return self._backtrace()
A BeamNode is defined as this:
class BeamNode:
def __init__(self, idx, logit, path_logit=0, path_length=1, parent=None):
self.idx = idx # Vocabulary index of this node
self.logit = logit # The logit value of this node in the tree
self.path_logit = path_logit # Total logit in the path up to here
# Total number of parents - used for normalizing path
self.path_length = path_length
self.parent = parent # Parent node. Can only be one
self.childs = [] # Child nodes. Has size of beam width
def get_root(self):
# Gets the root node that leads to this (self) node
if not self.parent:
return self
node = self.parent
while node.parent:
node = node.parent
return node
I call BeamSearchDecoder from my training loop:
decoder_input = self.beam_decoder.decode(decoder_output, decoder_hidden)
decoder_input = torch.tensor([[decoder_input]], device=self.device)
The design of the BeamSearchDecoder is this:
_build_tree()
Building the search tree. For each decoder output, theb
most probable predictions (childs) are kept. Each of them are fed into the Decoder iteratively, to obtain the next set ofb
most probable predictions each (and so on). This repeats until a stop condition is met (A. predicting EOS, B. reaching user defined depth or C. repeating same token again). While building the stack of child nodes, the current path length and path score is calculated on the fly._backtrace()
Obtaining total path logits for each path. Find the leaf with the maximum pathh logit and backtrace it's path to the root node. This root node will be the final output of BeamSearch.
decoder_output
contains outputs of torch.nn.LogSoftmax
.
Implementing BeamSearch is a followup of my previous topic: Repetitive word predictions in RNN
Please let me know if you need additional information! Thank you!