Thank you for your help - I am currently trying to follow this guide. But am running into a lot of issues, specifically around data loading. As it is, I have unsegmented nrrd files, manually segmented .seg.nrrd files, as well as folders containing .nii files which are the slices of the 3D nrrd files.
I also started out looking at the memos module, but that is proving difficult to rework into code that works for my purposes. I was wondering if there was access to the code that was used to train/create the best_metric_model_largePatch_noise.pth model?
Edit: Here is the current code that I have tried (scratch), but am coming across an error that I don’t quite understand:
RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [1, 1, 0, 108, 108, 455]
import os
import torch
from torch.utils.data import DataLoader
from monai.transforms import (
LoadImaged,
EnsureChannelFirstd,
ScaleIntensityd,
SpatialCropd,
ToTensord,
Compose,
)
from monai.data import Dataset, NrrdReader
from monai.networks.nets import UNet
from monai.losses import DiceLoss
# Set your device (cuda if available, otherwise cpu)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CustomDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.unsegmented_files = [file for file in os.listdir(data_dir) if file.endswith(".nrrd")]
self.transform = Compose([
LoadImaged(keys=["image"], reader=NrrdReader()),
EnsureChannelFirstd(keys=["image"]),
ScaleIntensityd(keys=["image"]),
SpatialCropd(keys=["image"], roi_start=(10, 10, 10), roi_end=(118, 118, 118)),
ToTensord(keys=["image"]),
])
def __len__(self):
return len(self.unsegmented_files)
def __getitem__(self, index):
unsegmented_path = os.path.join(self.data_dir, self.unsegmented_files[index])
data = {"image": unsegmented_path}
transformed_data = self.transform(data)
# Check if the transformed image is empty after cropping
if transformed_data["image"].shape[0] == 0:
print(f"Skipping empty image after cropping: {unsegmented_path}")
return self.__getitem__((index + 1) % len(self)) # move to the next item
return transformed_data
# Create an instance of your dataset
data_dir = "/panfs/jay/groups/25/barkerfk/nguy4214/SimoneData/test"
dataset = CustomDataset(data_dir)
# print(dataset)
# Create a DataLoader for your dataset
batch_size = 1
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
# Create a simple 3D U-Net model
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
).to(device)
# Set up optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = DiceLoss(sigmoid=True)
# print(optimizer)
# Train the model
max_epochs = 10
for epoch in range(max_epochs):
model.train()
total_loss = 0.0
for batch in data_loader:
inputs = batch["image"].to(device)
optimizer.zero_grad()
outputs = model(inputs)
targets = batch["label"].to(device)
loss = loss_function(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
average_loss = total_loss / len(data_loader)
print(f"Epoch [{epoch+1}/{max_epochs}], Average Loss: {average_loss}")
# Save the trained model
torch.save(model.state_dict(), "segmentation_model.pth")