MEMOS: New extension for deep-learning based segmentation of 3D fetal mice scans.

We would like to announce a new Slicer extension, MEMOS, for deep-learning based segmentation of diceCT scans of fetal mice. See the repository for quick installation instructions.

Accompanying paper can be found in Biology Open: Deep learning enabled multi-organ segmentation of mouse embryos | Biology Open | The Company of Biologists

Using GPU, MEMOs provide segmentation of 50 anatomical structures in the E15.5 atlas provided by the International Mouse Phenotyping Consortium in about 40-60 seconds.

In future, we plan to incorporate more developmental time points as a separate trained networks.

12 Likes

Hi, that sounds great, I’d like to try it, but I’m unable to install this extension
image

Which os do you use?

That might be related with the patch release last night. I just downloaded and install fresh on windows (v5.2.2), can you try again?

I’m using WIN10 and slicer 5.2.1 r31317 / 77da381, still not working, I’ll update to 5.2.2 and try again. Thank you

İ see. As it is a new extension, it is only available for stable or 5.3.0 (or higher) previews.
Yes, try with latest stable.

I have only used Total Segmentator so far but MEMOS also looks fantastic.

Is there anything comparable available for adult mice? I have >1000 CT scans of mice that need to be segmented (mainly liver, kidneys, and heart).

Best, Jan

No, we don’t have anything for adult mice. We don’t really do physiology, mostly development.

However, it is getting easier to train deep learning networks for custom segmentation tasks. So if you have already manually segmented ones that can be used as training data, it might be just a matter of putting some GPUs to use.

Hello! I am a undergraduate student researcher who is new to this area. I’ve recently started working on a project involving automatically segmenting CT scans of Bird Brains. I was wondering if you had tips on how to get started with this process? You mentioned that it is a matter of putting GPUs to use. I have a hand full of unsegmented nrrd files, as well as some manually segmented nrrd files.

There are multiple frameworks to train your own custom segmentation network. Here are pointers to two.

  1. Slicer extension for MONAILabel. MONAILabel/plugins/slicer at main · Project-MONAI/MONAILabel · GitHub

  2. An example python script to use nn-Unet to train a segmentaiton network that we used in a recent SlicerMorph workshop. GitHub - pieper/nnmouse: Sample scripts for training nnU-Net on mouse fetus data

In both cases the important thing is to have some label (segmented) data at hand so that you can do supervised training.

1 Like

Thank you for your help - I am currently trying to follow this guide. But am running into a lot of issues, specifically around data loading. As it is, I have unsegmented nrrd files, manually segmented .seg.nrrd files, as well as folders containing .nii files which are the slices of the 3D nrrd files.

I also started out looking at the memos module, but that is proving difficult to rework into code that works for my purposes. I was wondering if there was access to the code that was used to train/create the best_metric_model_largePatch_noise.pth model?

Edit: Here is the current code that I have tried (scratch), but am coming across an error that I don’t quite understand:

RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [1, 1, 0, 108, 108, 455]

import os
import torch
from torch.utils.data import DataLoader
from monai.transforms import (
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityd,
    SpatialCropd,
    ToTensord,
    Compose,
)
from monai.data import Dataset, NrrdReader
from monai.networks.nets import UNet
from monai.losses import DiceLoss

# Set your device (cuda if available, otherwise cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CustomDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.unsegmented_files = [file for file in os.listdir(data_dir) if file.endswith(".nrrd")]
        self.transform = Compose([
            LoadImaged(keys=["image"], reader=NrrdReader()),
            EnsureChannelFirstd(keys=["image"]),
            ScaleIntensityd(keys=["image"]),
            SpatialCropd(keys=["image"], roi_start=(10, 10, 10), roi_end=(118, 118, 118)),
            ToTensord(keys=["image"]),
        ])

    def __len__(self):
        return len(self.unsegmented_files)

    def __getitem__(self, index):
        unsegmented_path = os.path.join(self.data_dir, self.unsegmented_files[index])
        data = {"image": unsegmented_path}
        transformed_data = self.transform(data)
        
        # Check if the transformed image is empty after cropping
        if transformed_data["image"].shape[0] == 0:
            print(f"Skipping empty image after cropping: {unsegmented_path}")
            return self.__getitem__((index + 1) % len(self))  # move to the next item

        return transformed_data

# Create an instance of your dataset

data_dir = "/panfs/jay/groups/25/barkerfk/nguy4214/SimoneData/test"
dataset = CustomDataset(data_dir)

# print(dataset)

# Create a DataLoader for your dataset
batch_size = 1
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

# Create a simple 3D U-Net model
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
).to(device)

# Set up optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = DiceLoss(sigmoid=True)
# print(optimizer)

# Train the model
max_epochs = 10
for epoch in range(max_epochs):
    model.train()
    total_loss = 0.0
    for batch in data_loader:
        inputs = batch["image"].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        
        targets = batch["label"].to(device) 
        loss = loss_function(outputs, targets)
        
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(data_loader)
    print(f"Epoch [{epoch+1}/{max_epochs}], Average Loss: {average_loss}")

# Save the trained model
torch.save(model.state_dict(), "segmentation_model.pth")

I don’t have access to that guide, so I can’t comment how appropriate it is.

You need to convert the segmentation files to labelmap for training. That might be your data load issue.

The code in memos is inference only, I don’t think that will help you much.

Try with segmentations converted to labelmap (right click and export as labelmap and save them).

1 Like