1

I am facing an issue with my post-processing method. I have a pipeline that involves preprocessing, inference, and post-processing steps. During the preprocessing step, I tokenize the input data and handle token overflow for sequences greater than 512 tokens. The overflowed tokens are chunked and processed accordingly in the subsequent steps.

The inference method processes the entire input correctly and returns predictions for all tokens. However, my post-processing method only seems to handle the first 512 tokens and does not return any processed data beyond that.

Preprocessing Method:

  • I use a stride length of 128 for token overflow.
  • The encoded inputs are divided into chunks using the overflow_to_sample_mapping.
  • The debug log for the overflow_to_sample_mapping shows tensor([0, 0, 0, 0]), which means there are 4 chunks, all corresponding to the first sample.
  • All input shapes are of size (4, 512).

Inference Method:

  • The method processes the input and returns predictions.
  • The debug log indicates an inference prediction indices length of 2048, confirming that all chunks are processed.

Postprocessing Method:

  • The method is designed to process each page of data and extract relevant information based on the predictions.
  • Currently, the method seems to only handle the first 512 tokens. The results beyond the first 512 tokens are not being processed or returned.

Methods:

def preprocess(self, batch):
    """
    Transform raw input into model input data.
    :param batch: list of raw requests, should match batch size
    :return: list of preprocessed model input data
    """
    logger.debug(f"Processing batch of size: {len(batch)}")
    inference_dict = batch
    self._raw_input_data = inference_dict
    processor = load_processor(self.model_dir)
    images = [Image.open(path).convert("RGB") for path in inference_dict['image_path']]
    logger.debug(f"Loaded {len(images)} images with sizes: {self._images_size}")
    self._images_size = [img.size for img in images]
    words = inference_dict['words']

    boxes = [[normalize_box(box, images[i].size[0], images[i].size[1])
            for box in doc] for i, doc in enumerate(inference_dict['bboxes'])]

    stride_length = 128  
    new_words, new_boxes = [], []

    for w, b in zip(words, boxes):
        if len(w) <= 512: 
            new_words.append(w)
            new_boxes.append(b)
        else:
            for i in range(0, len(w) - 512 + stride_length, stride_length):
                new_words.append(w[i:i+512])
                new_boxes.append(b[i:i+512])

    encoded_inputs = processor(images, new_words, boxes=new_boxes, return_tensors="pt", padding="max_length",
                            truncation=True, stride=128, return_overflowing_tokens=True,
                            return_offsets_mapping=True)
    logger.debug(f"Encoded inputs with keys: {encoded_inputs.keys()}")
    self._offset_mapping = encoded_inputs['offset_mapping']
    self._overflow_to_sample_mapping = encoded_inputs.get('overflow_to_sample_mapping', None)
    logger.debug(f"Overflow to sample mapping: {self._overflow_to_sample_mapping}")
    logger.debug(f"ofset mapping: {self._offset_mapping}")
    if 'overflow_to_sample_mapping' in encoded_inputs:
        encoded_inputs.pop('overflow_to_sample_mapping')

    for key in ['pixel_values', 'input_ids', 'attention_mask', 'bbox']:
        if key in encoded_inputs:
            if isinstance(encoded_inputs[key], list):
                if key == 'pixel_values':
                    encoded_inputs[key] = [i.view(1, *i.shape) for i in encoded_inputs[key]]
                    encoded_inputs[key] = torch.cat(encoded_inputs[key], dim=0)
                else:
                    encoded_inputs[key] = [i.view(-1, 512) for i in encoded_inputs[key]]
                    encoded_inputs[key] = torch.cat(encoded_inputs[key], dim=0)
    logger.debug(f"input ids are: {encoded_inputs['input_ids'].shape}")
    logger.debug(f"attention mask is: {encoded_inputs['attention_mask'].shape}")
    logger.debug(f"bbox is: {encoded_inputs['bbox'].shape}")
    logger.debug(f"pixel values are: {encoded_inputs['pixel_values'].shape}")
    self._processed_data = encoded_inputs
    return encoded_inputs


def inference(self, model_input):
    if "offset_mapping" in model_input:
        model_input.pop("offset_mapping")

    logger.debug(f"Inference input shape: {model_input.get('input_ids', {}).shape if 'input_ids' in model_input else 'Unknown'}")

    with torch.no_grad():
        inference_outputs = self.model(**model_input)
        predictions = inference_outputs.logits
        predictions = predictions.view(-1, predictions.shape[-1])
        predicted_indices = predictions.argmax(-1)
        predictions = [self.model.config.id2label[index] for index in predicted_indices.tolist()]
    logger.debug(f"Inference output predictions: {predicted_indices.tolist()} (first 10 shown for brevity)")
    logger.debug(f"inference predictions indices length: {len(predicted_indices.tolist())}")
    logger.debug(f"inference predictions: {predictions}")
    return predicted_indices.tolist()


def postprocess(self, inference_output):
    try:
        docs = []
        k = 0
        if isinstance(inference_output[0], list):
            inference_output = [item for sublist in inference_output for item in sublist]

        for page, doc_words in enumerate(self._raw_input_data['words']):
            doc_list = []
            width, height = self._images_size[page]
            for i, doc_word in enumerate(doc_words, start=0):
                word_tagging = None
                word_labels = []
                word = dict()
                word['id'] = k
                k += 1
                word['text'] = doc_word
                word['pageNum'] = page + 1
                word['box'] = self._raw_input_data['bboxes'][page][i]
                _normalized_box = normalize_box(self._raw_input_data['bboxes'][page][i], width, height)
                
                for j, box in enumerate(self._processed_data['bbox'].tolist()[page]):
                    if compare_boxes(box, _normalized_box):
                        if self.model.config.id2label[inference_output[j]] != 'O':
                            word_labels.append(self.model.config.id2label[inference_output[j]][2:])
                        else:
                            word_labels.append('other')

                if word_labels:
                    word_tagging = word_labels[0] if word_labels[0] != 'other' else word_labels[-1]
                else:
                    word_tagging = 'other'

                word['label'] = word_tagging
                word['pageSize'] = {'width': width, 'height': height}

                if word['label'] != 'other':
                    doc_list.append(word)

            spans = []
            def adjacents(entity): return [adj for adj in doc_list if adjacent(entity, adj)]
            output_test_tmp = doc_list[:]
            
            for entity in doc_list:
                if not adjacents(entity):
                    spans.append([entity])
                    output_test_tmp.remove(entity)

            while output_test_tmp:
                span = [output_test_tmp[0]]
                output_test_tmp = output_test_tmp[1:]
                
                while output_test_tmp and adjacent(span[-1], output_test_tmp[0]):
                    span.append(output_test_tmp[0])
                    output_test_tmp.remove(output_test_tmp[0])
                
                spans.append(span)

            output_spans = []
            
            for span in spans:
                if len(span) == 1:
                    output_span = {
                        "text": span[0]['text'],
                        "label": span[0]['label'],
                        "words": [{
                            'id': span[0]['id'],
                            'box': span[0]['box'],
                            'text': span[0]['text']
                        }]
                    }
                else:
                    output_span = {
                        "text": ' '.join([entity['text'] for entity in span]),
                        "label": span[0]['label'],
                        "words": [{
                            'id': entity['id'],
                            'box': entity['box'],
                            'text': entity['text']
                        } for entity in span]
                    }
                output_spans.append(output_span)

            docs.append({'output': output_spans})

        logger.debug(f"post-processing results: {docs}")
        filtered_docs = self.filterLabels([{'output': output_spans}])
        cleaned_docs = self.validate_fields(filtered_docs)
        ordered_docs = self.order_data_by_position(cleaned_docs)
        logger.info(f"Post-processing completed. {len(ordered_docs)} documents processed.")
        return [json.dumps(ordered_docs, ensure_ascii=False)]

    except Exception as e:
        logger.error(f"Error in postprocess: {e}")
        traceback.print_exc()
        raise e

And debugging logs which I currently have:

[2023-08-14 16:13:51] DEBUG: Encoded inputs with keys: dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping', 'bbox', 'pixel_values'])
[2023-08-14 16:13:51] DEBUG: Overflow to sample mapping: tensor([0, 0, 0, 0])
[2023-08-14 16:13:51] DEBUG: ofset mapping: tensor([[[0, 0],
         [0, 3],
         [3, 5],
         ...,
         [2, 3],
         [3, 4],
         [0, 0]],

        [[0, 0],
         [1, 2],
         [0, 1],
         ...,
         [4, 6],
         [0, 2],
         [0, 0]],

        [[0, 0],
         [2, 3],
         [3, 5],
         ...,
         [7, 9],
         [0, 3],
         [0, 0]],

        [[0, 0],
         [5, 8],
         [0, 4],
         ...,
         [0, 0],
         [0, 0],
         [0, 0]]])
[2023-08-14 16:13:51] DEBUG: input ids are: torch.Size([4, 512])
[2023-08-14 16:13:51] DEBUG: attention mask is: torch.Size([4, 512])
[2023-08-14 16:13:51] DEBUG: bbox is: torch.Size([4, 512, 4])
[2023-08-14 16:13:51] DEBUG: pixel values are: torch.Size([4, 3, 224, 224])
[2023-08-14 16:13:51] DEBUG: Inference input shape: torch.Size([4, 512])

[2023-08-14 16:13:55] DEBUG: Inference output predictions truncated for brewity: [31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 24, 24, 82, 82, 82, 82, 82, 82, 74, 28, 28]
[2023-08-14 16:13:55] DEBUG: inference predictions indices length: 2048
[2023-08-14 16:13:55] DEBUG: inference predictions truncated for brewity: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-INVOICE NUMBER', 'B-INVOICE NUMBER', 'E-INVOICE NUMBER', 'E-INVOICE NUMBER', 'E-INVOICE NUMBER', 'E-INVOICE NUMBER', 'E-INVOICE NUMBER', 'E-INVOICE NUMBER', 'B-ISSUE DATE', 'I-ISSUE DATE', 'I-ISSUE DATE', 'I-ISSUE DATE', 'I-ISSUE DATE', 'I-ISSUE DATE', 'I-ISSUE DATE', 'I-ISSUE DATE', 'I-ISSUE DATE', 'E-ISSUE DATE', 'E-ISSUE DATE']

Can you help me identify and rectify the issue in the post-processing method that causes it not to return results for tokens beyond the first 512, even though the inference method processes all tokens?

desertnaut
  • 57,590
  • 26
  • 140
  • 166
j3ws3r
  • 11
  • 1

0 Answers0