0

I have a weight model named model.pt for brain segmentation from head CT scan. How can i convert this into torchscript file so that i can use the model for deployment,

Network defn:

3dUNet, 
in channel: 1(image), 
out channel: 2(brain label and background)

Input defn:

 "image": {
                "type": "image",
                "format": "hounsfield",
                "modality": "CT",
                "num_channels": 1,
                "spatial_shape": [
                    96,
                    96,
                    96
                ],
                "dtype": "float32",
                "value_range": [
                    0,
                    1
                ],
                "is_patch_data": true,
                "channel_def": {
                    "0": "image"
                }
            }
        },

Train/val split: 13 images for training and 3 for validation

Output defn:

          "pred": {
              "type": "image",
              "format": "segmentation",
              "num_channels": 2,
              "spatial_shape": [
                  96,
                  96,
                  96
              ],
              "dtype": "float32",
              "value_range": [
                  0,
                  1
              ],
              "is_patch_data": true,
              "channel_def": {
                  "0": "background",
                  "1": "brain"
              }
          }
      

Now, how can i use tracing/scripting to convert into torchsctipt. Are these pieces of information enough?

I tried

import torch

model = torch.load('model/model.pt')

example = torch.rand(13, 96, 96, 96)

traced_script_module = torch.jit.script(model, (example))
torch.save(traced_script_module, "model/traced_resnet_model.ts")

I only used the model input size, I also tried torch.jit.trace. But both failed.

Any help would be very appreciable.

0 Answers0