AI segmentation of bird brain

Hello! I am a undergraduate student researcher who is new to this area. I’ve recently started working on a project involving automatically segmenting CT scans of Bird Brains. I was wondering if you had tips on how to get started with this process? You mentioned that it is a matter of putting GPUs to use. I have a hand full of unsegmented nrrd files, as well as some manually segmented nrrd files.

There are multiple frameworks to train your own custom segmentation network. Here are pointers to two.

  1. Slicer extension for MONAILabel. MONAILabel/plugins/slicer at main · Project-MONAI/MONAILabel · GitHub

  2. An example python script to use nn-Unet to train a segmentaiton network that we used in a recent SlicerMorph workshop. GitHub - pieper/nnmouse: Sample scripts for training nnU-Net on mouse fetus data

In both cases the important thing is to have some label (segmented) data at hand so that you can do supervised training.

1 Like

Thank you for your help - I am currently trying to follow this guide. But am running into a lot of issues, specifically around data loading. As it is, I have unsegmented nrrd files, manually segmented .seg.nrrd files, as well as folders containing .nii files which are the slices of the 3D nrrd files.

I also started out looking at the memos module, but that is proving difficult to rework into code that works for my purposes. I was wondering if there was access to the code that was used to train/create the best_metric_model_largePatch_noise.pth model?

Edit: Here is the current code that I have tried (scratch), but am coming across an error that I don’t quite understand:

RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [1, 1, 0, 108, 108, 455]

import os
import torch
from torch.utils.data import DataLoader
from monai.transforms import (
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityd,
    SpatialCropd,
    ToTensord,
    Compose,
)
from monai.data import Dataset, NrrdReader
from monai.networks.nets import UNet
from monai.losses import DiceLoss

# Set your device (cuda if available, otherwise cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CustomDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.unsegmented_files = [file for file in os.listdir(data_dir) if file.endswith(".nrrd")]
        self.transform = Compose([
            LoadImaged(keys=["image"], reader=NrrdReader()),
            EnsureChannelFirstd(keys=["image"]),
            ScaleIntensityd(keys=["image"]),
            SpatialCropd(keys=["image"], roi_start=(10, 10, 10), roi_end=(118, 118, 118)),
            ToTensord(keys=["image"]),
        ])

    def __len__(self):
        return len(self.unsegmented_files)

    def __getitem__(self, index):
        unsegmented_path = os.path.join(self.data_dir, self.unsegmented_files[index])
        data = {"image": unsegmented_path}
        transformed_data = self.transform(data)
        
        # Check if the transformed image is empty after cropping
        if transformed_data["image"].shape[0] == 0:
            print(f"Skipping empty image after cropping: {unsegmented_path}")
            return self.__getitem__((index + 1) % len(self))  # move to the next item

        return transformed_data

# Create an instance of your dataset

data_dir = "/panfs/jay/groups/25/barkerfk/nguy4214/SimoneData/test"
dataset = CustomDataset(data_dir)

# print(dataset)

# Create a DataLoader for your dataset
batch_size = 1
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

# Create a simple 3D U-Net model
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
).to(device)

# Set up optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = DiceLoss(sigmoid=True)
# print(optimizer)

# Train the model
max_epochs = 10
for epoch in range(max_epochs):
    model.train()
    total_loss = 0.0
    for batch in data_loader:
        inputs = batch["image"].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        
        targets = batch["label"].to(device) 
        loss = loss_function(outputs, targets)
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(data_loader)
    print(f"Epoch [{epoch+1}/{max_epochs}], Average Loss: {average_loss}")

# Save the trained model
torch.save(model.state_dict(), "segmentation_model.pth")

I don’t have access to that guide, so I can’t comment how appropriate it is.

You need to convert the segmentation files to labelmap for training. That might be your data load issue.

The code in memos is inference only, I don’t think that will help you much.

Try with segmentations converted to labelmap (right click and export as labelmap and save them).

1 Like

Thank you for all of your help! I have been able to successfully create a model that makes predictions (accuracy TBD), and was looking to integrate it into the MEMOS module. It seems it is not as easy as plugging in my pth file and input images and running the module, as that has not worked for me so far.

My model is a 3D UNet designed for volumetric data and takes a single-channel (grayscale) input and outputs segmentation maps with a specified number of classes (5 in my case). The model was trained on patches extracted from 3D volumes. I was wondering if this approach is affecting whether or not I am able to utilize the model in MEMOS, and if you had any advice on how to go about this implementation?

Thank you again.

MEMOS is not a general-purpose tool for unet inference. It shouldn’t work, because I imagine you have different spatial dimensions, labels etc, than what MEMOS is expecting.

having said that you can easily modify those. So edit the source code of MEMOS.py, and change the relevant sections based on your MONAI model settings. Most of those should be in the logic section:

1 Like

Hello! I was making some changes to try and implement what you had mentioned, but was wondering how I can actually test the changes I make in 3D Slicer? Is there a simpler way to test changes that I have made to MEMOS.py?

You can edit and save the memos.py and restart the slicer and give it a try.

You can save a few clicks if you enable the developer mode (see slicer documentation)

1 Like

I am able to download/edit MEMOS.py locally using VSCode, but was confused on how to load and use the edited version on 3D Slicer. I’m unsure what you mean by editing, is there a way to edit the module in 3D Slicer?

As long as you save the changed/edited MEMOS.py, the same you can simply restart slicer and the “changed” memos will be loaded into the slicer (not ours).

If you enable the Developer Mode things will be easier. Please look at slicer documentation at slicer.readthedocs.io.

1 Like

Thank you. I will look into developer mode!

https://slicer.readthedocs.io/en/latest/user_guide/settings.html#developer-mode

In short you do not need to rebuild slicer from the source to make changes to a python module. Just enable the developer mode, click the Edit button that will appear on top of the Memos module to make your changes to memos.py, save and then hit the reload button to reinitialize the memos module with your changes.

These will appear on top of python modules after you enable the Developer Mode and restart Slicer

1 Like

Hello again, I have been working on implementing my custom model into the MEMOS.py module, debugging, etc using developer mode. However, I am now encountering the issue seen in the attached jpeg when trying to run my model with my own input. I was wondering if you have encountered this issue before? Thank you

If you altered the MEMOS.py to fit your needs, then we can’t really help you, you really need to debug yourself. Is there anything in the Slicer log file. Screenshot only displays a crash notification.

Also you are using a fairly old Slicer, I would suggest working with the latest preview (or at least the current stable) version, if you are using MEMOs as a template to build your own module.