0

I'm trying to call my SageMaker model endpoint both from Postman and the AWS CLI. The endpoint's status is "in service" but whenever I try to call it it gives me an error. When I try to use the predict function in the SageMaker notebook and provide it a numpy array (ex. np.array([1,2,3,4])), it successfully gives me an output. I'm unsure what I'm doing wrong.

$ aws2 sagemaker-runtime invoke-endpoint \
$ --endpoint-name=pytorch-model \
$ --body=1,2 \
$ --content-type=text/csv \
$ --cli-binary-format=raw-in-base64-out \
$ output.json

An error occurred (ModelError) when calling the InvokeEndpoint operation: Received server error (500) from model with message "tensors used as indices must be long, byte or bool tensors
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/sagemaker_inference/transformer.py", line 125, in transform
    result = self._transform_fn(self._model, input_data, content_type, accept)
  File "/opt/conda/lib/python3.6/site-packages/sagemaker_inference/transformer.py", line 215, in _default_transform_fn
    prediction = self._predict_fn(data, model)
  File "/opt/ml/model/code/pytorch-model-reco.py", line 268, in predict_fn
    return torch.argsort(- final_matrix[input_data, :], dim = 1)
IndexError: tensors used as indices must be long, byte or bool tensors

1 Answers1

0

The clue is in the final few lines of your stacktrace:

  File "/opt/ml/model/code/pytorch-model-reco.py", line 268, in predict_fn
    return torch.argsort(- final_matrix[input_data, :], dim = 1)
IndexError: tensors used as indices must be long, byte or bool tensors

In your predict_fn in pytorch-model-reco.py on line 268, you're trying to use input_data as indices for final_matrix, but input_data is the wrong type.

I would guess there is some type casting that your predict_fn should be doing when the input type is text/csv. This type casting is happening outside of the predict_fn when your input type is numpy data. Taking a look at the sagemaker_inference source code might reveal more.

Yoav Zimmerman
  • 588
  • 4
  • 11