1

I am trying to export a custom PyTorch model to ONNX to perform inference but without success... The tricky thing here is that I'm trying to use the script-based exporter as shown in the example here in order to call a function from my model.

I can export the model without any complain but then when trying to start an InferenceSession I get the following error:

Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from ner.onnx failed:Type Error: Type parameter (T) bound to different types (tensor(int64) and tensor(float) in node (Concat_1260).

I tried to identify the root cause of that problem and it seems to be generated by the use of torch.matmul() in the following function (quite nasty cause I'm trying to only use pytorch operators):

@torch.jit.script
def valid_sequence_output(sequence_output, valid_mask):
    X = torch.where(valid_mask.unsqueeze(-1) == 1, sequence_output, torch.zeros_like(sequence_output))
    bs, max_len, _ = X.shape

    tu = torch.unique(torch.nonzero(X)[:, :2], dim=0)
    batch_axis = tu[:, 0]
    rows_axis = tu[:, 1]

    a = torch.arange(bs).repeat(batch_axis.shape).reshape(batch_axis.shape[0], -1)
    a = torch.transpose(a, 0, 1)

    T = torch.cumsum(torch.where(batch_axis == a, torch.ones_like(a), torch.zeros_like(a)), dim=1) - 1
    cols_axis = T[batch_axis, torch.arange(batch_axis.shape[0])]

    A = torch.zeros((bs, max_len, max_len))
    A[(batch_axis, cols_axis, rows_axis)] = 1.0

    valid_output = torch.matmul(A, X)
    valid_attention_mask = torch.where(valid_output[:, :, 0] != 0, torch.ones_like(valid_mask),
                                       torch.zeros_like(valid_mask))
    return valid_output, valid_attention_mask

It seems like torch.matmul isn't supported (according to the docs) so I tried a bunch of workaround (e.g. A.matmul(X), torch.baddbmm) but I still get the same issue...

Any suggestions on how to fix this behavior would be awesome :D Thanks for your help!

Jules
  • 395
  • 4
  • 17

1 Answers1

0

This points to a model conversion issue. Please open an issue againt the Torch exporter feature. A type (T) has to be bound to the same type for the model to be valid and ORT is basically complaining about this.