Unexpected behaviour with torch.jit.load

Hi everyone!

I’m trying to use a custom pre-trained torch model in a scripted module, but I encountered a problem that I couldn’t fix. Following the guide “Deploying your MONAI machine learning model within 3D Slicer”, I exported my model with torch.jit.script, and I am now trying to reload it using torch.load.
When I try to load the model in a scripted module, the MemoryError: std::bad_alloc error is generated. Here the detailed error:

[...]
    self.model = torch.jit.load(modelPath)
  File "/home/yyy/Slicer-5.8.0-linux-amd64/lib/Python/lib/python3.9/site-packages/torch/jit/_serialization.py", line 162, in load
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files, _restore_shapes)  # type: ignore[call-arg]
MemoryError: std::bad_alloc

The same error is generated if I repeat the operation inside the Python terminal within Slicer.

However, if I run PythonSlicer in a terminal using: [slicer_path]/Slicer --launch PythonSlicer and I run: torch.jit.load(modelPath), the model loads without any problems.

I have tried exporting the model both as a zip file and as a pt file.
I have enough memory to allocate the model.
I’m working on Ubuntu 24.10. The Slicer version is 5.8.1, the model was exported using torch 2.1.0. The torch version in Slicer is 2.1.0.

What could be the load within Slicer or the scripted module?

Thanks in advance!

Hi again!

I’m following up on this issue as I haven’t been able to resolve it yet. I’m hoping that someone in the community might have more insight into it. The problem appears to be specific to loading TorchScript models within Slicer’s embedded GUI rather than its CLI backend.

I tried exporting a dummy model using this code to confirm that the problem was not with my model or memory.

import torch
import torch.nn as nn

class DummyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 1)
        )

    def forward(self, x):
        return self.net(x)

model = DummyModel()
model.eval()
model.cpu()

example_input = torch.randn(1, 10)
traced_model = torch.jit.script(model)
traced_model.save("dummy_model.pt")

Once again, I got the same error message:

File "/home/yyy/Slicer-5.8.0-linux-amd64/lib/Python/lib/python3.9/site-packages/torch/jit/_serialization.py", line 161, in load
    cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
MemoryError: std::bad_alloc

Could someone replicate and confirm the error? Does anyone have another solution to recommend?

The problem has not been solved and the incompatibility with Torch JIT remains. In the meantime, however, I have found a workaround using onnxruntime, which I used after exporting the Torch model with torch.onnx.export.