1

I'm attempting to serve a pretrained huggingface model with torchserve and i've managed to save the model as a torchscript file (.pt). However, I do not know what the handler would look like for such a model. This seems to be a requirement for the very last step which is to use torch-model-archiver to create the model archive file.

Here is the code I have so far.

# load model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base", torchscript=True)
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")

# create dummy input and output data
tokenized_dict = tokenizer(["please answer the following question: what is the boiling point of nitrogen",], ["-320.4F",], return_tensors="pt")

input_tuple = (tokenized_dict['input_ids'], tokenized_dict['attention_mask'], torch.Tensor([[2]]).long())

# Set model to eval mode
model.eval()

# Creating the trace and save model
traced_model = torch.jit.trace(model, input_tuple)
torch.jit.save(traced_model, "./model/traced_flan-t5-base.pt")

# create archive file
!torch-model-archiver --model-name t5_flan_base --version 1.0 --serialized-file traced_flan_t5.pt --handler handler.py --export-path ./deployment/model-store -f

Any help on creating a handler for this model would be greatly appreciated!

maxwellspi
  • 11
  • 2

1 Answers1

0

A handler file is a file that looks like this:

# heavily inspired by https://medium.com/analytics-vidhya/deploy-huggingface-s-bert-to-production-with-pytorch-serve-27b068026d18
import json
import logging
import os

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class TransformersClassifierHandler(BaseHandler):
    """
    Transformers text classifier handler class. This handler takes a text (string) and
    as input and returns the classification text based on the serialized transformers checkpoint.
    """
    def __init__(self):
        super(TransformersClassifierHandler, self).__init__()
        self.initialized = False

    def initialize(self, ctx):
        self.manifest = ctx.manifest

        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

        # Read model serialize/pt file
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)

        self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

        self.model.to(self.device)
        self.model.eval()

        logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir))

        self.initialized = True

    def preprocess(self, data):
        """ Very basic preprocessing code - only tokenizes. 
        """
        text = data[0].get("data")
        if text is None:
            text = data[0].get("body")
        sentences = text.decode('utf-8')
        logger.info("Received text: '%s'", sentences)

        inputs = self.tokenizer.encode_plus(
            sentences,
            add_special_tokens=True,
            truncation=True,
            return_tensors="pt"
        )
        return inputs

    def inference(self, inputs):
        """
        Predict the class of a text using a trained transformer model.
        """
        prediction = self.model.generate(
            inputs['input_ids'].to(self.device),
            max_length=256
            # token_type_ids=inputs['token_type_ids'].to(self.device)
        )

        string_pred = self.tokenizer.batch_decode(prediction, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        logger.info("Model predicted: '%s'", string_pred)

        return [string_pred]

    def postprocess(self, inference_output):
        # TODO: Add any needed post-processing of the model predictions here
        return inference_output


_service = TransformersClassifierHandler()


def handle(data, context):
    try:
        if not _service.initialized:
            _service.initialize(context)

        if data is None:
            return None

        data = _service.preprocess(data)
        data = _service.inference(data)
        data = _service.postprocess(data)

        return data
    except Exception as e:
        raise e

It defines a handle function and a descendant of a BaseHandler class which initializes the models if they are not initializes already, preprocesses the input, runs inference and possibly does some postprocessing.

fkmjec
  • 1
  • 1