0

I have been implementing BeamSearch in my RNN Decoder in order to avoid repetitive predictions in my output sequences. Now I ran into the issue that my model quickly learns to predict the EOS token immediately. I would have though that this might happen when the total probability of a sequence path is not normalized by its length, as shorter sequences have a higher probability, but I (thought) took care of that.

Would anybody be able to help me with this issue? My BeamSearch implementation is as follows:

class BeamSearchDecoder:
    def __init__(self, decoder, vocab, b, d=-1, alpha=1, block_repeat=True):
        self.decoder = decoder # The Decoder RNN
        self.vocab = vocab
        self.eos_idx = vocab.EOS['idx']
        self.b = b  # Beam Width
        """Search Depth - Take d to power of b because at each iteration we add
        b items to stack which are all still on same depth level"""
        self.d = (d**self.b)
        self.alpha = alpha
        # Block paths with immediate repetitions
        self.block_repeat = block_repeat
        self.roots = []  # Will be the root nodes of the current tree
        self.leafs = []  # Will contain all leaf nodes after building the graph

    def _build_tree(self, decoder_output, decoder_hidden):
        """
        Builds the search tree.

        Parameters
        ----------
        decoder_output : torch.Tensor
            Previous Output of the DecoderRNN.
        decoder_hidden : torch.Tensor
            Previous Hidden state of the DecoderRNN.

        Returns
        -------
        None.

        """

        self.clear_tree()

        # Set the root nodes of this tree
        topv, topi = decoder_output.topk(self.b)
        for tv, ti in zip(topv.view(-1, 1), topi.view(-1, 1)):
            self.roots.append(BeamNode(ti.item(), tv.item()))

        # Disable gradient computation during beam search
        with torch.no_grad():
            stack = [*self.roots]  # Nodes to iterate. Initialize with roots

            depth = 0 # keep track of depth of search tree
            while stack:
                curr_node = stack.pop(0)
                decoder_input = torch.tensor([[curr_node.idx]], device=self.device)

                # Stop condition: EOS token
                if decoder_input == self.eos_idx:
                    self.leafs.append(curr_node)
                    continue  # Don't continue expanding current node
                # Stop condition: Reached maximum search depth. self.d = self.d**self.b (see __init__)
                if (self.d > 0) and (depth >= self.d):
                    self.leafs.append(curr_node)
                    continue  # Don't continue expanding current node
                # Stop condition: repeated prediction
                if self.block_repeat:
                    if curr_node.parent:
                        if curr_node.idx == curr_node.parent.idx:
                            continue

                # Do forward pass to obtain logits of next words (returns logsoftmax)
                decoder_output, decoder_hidden = self.decoder(
                    decoder_input, decoder_hidden)

                # Get top b childs
                topv, topi = decoder_output.topk(self.b)

                # For each child, calculate path probability and add it to stack
                for tv, ti in zip(topv.view(-1, 1), topi.view(-1, 1)):
                    path_logit = curr_node.path_logit + tv.item() # total path logit
                    path_length = curr_node.path_length + 1 # depth of current child
                    # Set current node as child
                    child_node = BeamNode(
                        ti.item(), tv.item(), path_logit, path_length, parent=curr_node)
                    # Add `child_node` to childs of `curr_node`
                    curr_node.childs.append(child_node)
                    # Append to stack so child of `child_node` get expanded too
                    stack.append(child_node)

                depth += 1

    def _backtrace(self):
        max_path_logit = -100000
        max_leaf = None
        for leaf in self.leafs:
            # Normalizing logit of current path by its length (alpha is a hyperparameter)
            path_logit = (1/pow(leaf.path_length, self.alpha))*leaf.path_logit
            if path_logit > max_path_logit:
                # Keep track of path with highest logit
                max_path_logit = path_logit
                max_leaf = leaf
        # Backtrace leaf to root path in order to obtain the root node, that will be the output of the BeamSearch
        root_node_idx = max_leaf.get_root().idx # word index in vocab
        return root_node_idx

    def decode(self, decoder_output, decoder_hidden):
        self._build_tree(decoder_output, decoder_hidden)
        return self._backtrace()

A BeamNode is defined as this:

class BeamNode:
    def __init__(self, idx, logit, path_logit=0, path_length=1, parent=None):
        self.idx = idx  # Vocabulary index of this node
        self.logit = logit  # The logit value of this node in the tree
        self.path_logit = path_logit  # Total logit in the path up to here
        # Total number of parents - used for normalizing path
        self.path_length = path_length
        self.parent = parent  # Parent node. Can only be one
        self.childs = []  # Child nodes. Has size of beam width

    def get_root(self):
        # Gets the root node that leads to this (self) node
        if not self.parent:
            return self
        node = self.parent
        while node.parent:
            node = node.parent
        return node

I call BeamSearchDecoder from my training loop:

decoder_input = self.beam_decoder.decode(decoder_output, decoder_hidden)                
decoder_input = torch.tensor([[decoder_input]], device=self.device)

The design of the BeamSearchDecoder is this:

  1. _build_tree() Building the search tree. For each decoder output, the b most probable predictions (childs) are kept. Each of them are fed into the Decoder iteratively, to obtain the next set of b most probable predictions each (and so on). This repeats until a stop condition is met (A. predicting EOS, B. reaching user defined depth or C. repeating same token again). While building the stack of child nodes, the current path length and path score is calculated on the fly.
  2. _backtrace() Obtaining total path logits for each path. Find the leaf with the maximum pathh logit and backtrace it's path to the root node. This root node will be the final output of BeamSearch.

decoder_output contains outputs of torch.nn.LogSoftmax.

Implementing BeamSearch is a followup of my previous topic: Repetitive word predictions in RNN

Please let me know if you need additional information! Thank you!

1 Answers1

0

Solution in my case: Bugs. If you use this code, make sure to properly reset the search tree (self.roots & self.leafs) for the next decoding step.

Apart from that, if you have suggestions on how to improve the computational performance of my (slow but working) implementation, I would be very happy about them!