ML Training with MSD Lung RadiObject: Segmentation¶
This notebook trains a 3D UNet for lung tumor segmentation using the Medical Segmentation Decathlon data.
Overview¶
- Load RadiObject from URI (S3 or local)
- Explore data and segmentation masks
- Split into train/validation sets
- Train a MONAI UNet model
- Evaluate with Dice score
Task¶
Semantic segmentation: Predict lung tumor mask from CT volume patches.
Prerequisites: Run 03_ingest_msd.ipynb first to create the MSD Lung RadiObject with CT and seg collections.
In [ ]:
Copied!
import matplotlib.pyplot as plt
import numpy as np
import torch
from monai.inferers import sliding_window_inference
from monai.losses import DiceFocalLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import Compose, NormalizeIntensityd, RandFlipd, RandRotate90d
from radiobject import RadiObject, S3Config, configure
from radiobject.ml import (
create_segmentation_dataloader,
)
# ── Storage URI (must match notebook 05 output) ─────────────────
# Default: S3 (requires AWS credentials)
MSD_LUNG_URI = "s3://souzy-scratch/msd-lung/radiobject-2mm"
# For local storage, comment out the line above and uncomment:
# MSD_LUNG_URI = "./data/msd_lung_radiobject"
# ─────────────────────────────────────────────────────────────────
print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"RadiObject URI: {MSD_LUNG_URI}")
import matplotlib.pyplot as plt
import numpy as np
import torch
from monai.inferers import sliding_window_inference
from monai.losses import DiceFocalLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import Compose, NormalizeIntensityd, RandFlipd, RandRotate90d
from radiobject import RadiObject, S3Config, configure
from radiobject.ml import (
create_segmentation_dataloader,
)
# ── Storage URI (must match notebook 05 output) ─────────────────
# Default: S3 (requires AWS credentials)
MSD_LUNG_URI = "s3://souzy-scratch/msd-lung/radiobject-2mm"
# For local storage, comment out the line above and uncomment:
# MSD_LUNG_URI = "./data/msd_lung_radiobject"
# ─────────────────────────────────────────────────────────────────
print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"RadiObject URI: {MSD_LUNG_URI}")
In [2]:
Copied!
# Determine compute device
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
else:
DEVICE = torch.device("cpu")
print(f"Training device: {DEVICE}")
# Determine compute device
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
else:
DEVICE = torch.device("cpu")
print(f"Training device: {DEVICE}")
Training device: mps
In [ ]:
Copied!
# Configure S3 region
configure(s3=S3Config(region="us-east-2", max_parallel_ops=8))
# Configure S3 region
configure(s3=S3Config(region="us-east-2", max_parallel_ops=8))
In [4]:
Copied!
# Load RadiObject
radi = RadiObject(MSD_LUNG_URI)
# Quick summary using describe()
print(radi.describe())
# Load RadiObject
radi = RadiObject(MSD_LUNG_URI)
# Quick summary using describe()
print(radi.describe())
RadiObject Summary ================== URI: s3://souzy-scratch/msd-lung/radiobject-2mm Subjects: 63 Collections: 4 Collections: - seg_resampled: 63 volumes, shape=heterogeneous (mixed shapes) - seg: 63 volumes, shape=heterogeneous (mixed shapes) - CT_resampled: 63 volumes, shape=heterogeneous (mixed shapes) - CT: 63 volumes, shape=heterogeneous (mixed shapes)
In [5]:
Copied!
# Verify resampled CT and seg collections exist
print(f"Collections: {radi.collection_names}")
print(f"CT_resampled shape: {radi.CT_resampled.shape}")
print(f"seg_resampled shape: {radi.seg_resampled.shape}")
if "seg_resampled" not in radi.collection_names:
raise RuntimeError(
"Resampled segmentation collection not found. "
"Please re-run 03_ingest_msd.ipynb with FORCE_REINGEST=True"
)
# Verify resampled CT and seg collections exist
print(f"Collections: {radi.collection_names}")
print(f"CT_resampled shape: {radi.CT_resampled.shape}")
print(f"seg_resampled shape: {radi.seg_resampled.shape}")
if "seg_resampled" not in radi.collection_names:
raise RuntimeError(
"Resampled segmentation collection not found. "
"Please re-run 03_ingest_msd.ipynb with FORCE_REINGEST=True"
)
Collections: ('seg_resampled', 'seg', 'CT_resampled', 'CT')
CT_resampled shape: None
seg_resampled shape: None
In [6]:
Copied!
# Visualize CT and segmentation overlay for a few subjects
subject_ids = list(radi.obs_subject_ids)[:3]
# Standard CT lung window (W=1500, L=-600 -> range -1350 to 150 HU)
ct_vmin, ct_vmax = -1350, 150
fig, axes = plt.subplots(len(subject_ids), 3, figsize=(12, 4 * len(subject_ids)))
for row, subject_id in enumerate(subject_ids):
ct_vol = radi.loc[subject_id].CT_resampled.iloc[0]
seg_vol = radi.loc[subject_id].seg_resampled.iloc[0]
# Find slice with tumor
seg_data = seg_vol.to_numpy()
tumor_slices = np.where(seg_data.sum(axis=(0, 1)) > 0)[0]
mid_z = (
tumor_slices[len(tumor_slices) // 2] if len(tumor_slices) > 0 else seg_data.shape[2] // 2
)
ct_slice = ct_vol.axial(z=mid_z).T
# CT
axes[row, 0].imshow(ct_slice, cmap="gray", origin="lower", vmin=ct_vmin, vmax=ct_vmax)
axes[row, 0].set_title(f"{subject_id} - CT")
axes[row, 0].axis("off")
# Segmentation
axes[row, 1].imshow(seg_vol.axial(z=mid_z).T, cmap="hot", origin="lower")
axes[row, 1].set_title(f"{subject_id} - Tumor Mask")
axes[row, 1].axis("off")
# Overlay
axes[row, 2].imshow(ct_slice, cmap="gray", origin="lower", vmin=ct_vmin, vmax=ct_vmax)
mask = seg_vol.axial(z=mid_z).T > 0
axes[row, 2].imshow(np.ma.masked_where(~mask, mask), cmap="Reds", alpha=0.5, origin="lower")
axes[row, 2].set_title(f"{subject_id} - Overlay")
axes[row, 2].axis("off")
plt.tight_layout()
plt.show()
# Visualize CT and segmentation overlay for a few subjects
subject_ids = list(radi.obs_subject_ids)[:3]
# Standard CT lung window (W=1500, L=-600 -> range -1350 to 150 HU)
ct_vmin, ct_vmax = -1350, 150
fig, axes = plt.subplots(len(subject_ids), 3, figsize=(12, 4 * len(subject_ids)))
for row, subject_id in enumerate(subject_ids):
ct_vol = radi.loc[subject_id].CT_resampled.iloc[0]
seg_vol = radi.loc[subject_id].seg_resampled.iloc[0]
# Find slice with tumor
seg_data = seg_vol.to_numpy()
tumor_slices = np.where(seg_data.sum(axis=(0, 1)) > 0)[0]
mid_z = (
tumor_slices[len(tumor_slices) // 2] if len(tumor_slices) > 0 else seg_data.shape[2] // 2
)
ct_slice = ct_vol.axial(z=mid_z).T
# CT
axes[row, 0].imshow(ct_slice, cmap="gray", origin="lower", vmin=ct_vmin, vmax=ct_vmax)
axes[row, 0].set_title(f"{subject_id} - CT")
axes[row, 0].axis("off")
# Segmentation
axes[row, 1].imshow(seg_vol.axial(z=mid_z).T, cmap="hot", origin="lower")
axes[row, 1].set_title(f"{subject_id} - Tumor Mask")
axes[row, 1].axis("off")
# Overlay
axes[row, 2].imshow(ct_slice, cmap="gray", origin="lower", vmin=ct_vmin, vmax=ct_vmax)
mask = seg_vol.axial(z=mid_z).T > 0
axes[row, 2].imshow(np.ma.masked_where(~mask, mask), cmap="Reds", alpha=0.5, origin="lower")
axes[row, 2].set_title(f"{subject_id} - Overlay")
axes[row, 2].axis("off")
plt.tight_layout()
plt.show()
In [7]:
Copied!
# 80/20 split
all_ids = list(radi.obs_subject_ids)
np.random.seed(42)
np.random.shuffle(all_ids)
split_idx = int(0.8 * len(all_ids))
train_ids = all_ids[:split_idx]
val_ids = all_ids[split_idx:]
print(f"Training subjects: {len(train_ids)}")
print(f"Validation subjects: {len(val_ids)}")
# 80/20 split
all_ids = list(radi.obs_subject_ids)
np.random.seed(42)
np.random.shuffle(all_ids)
split_idx = int(0.8 * len(all_ids))
train_ids = all_ids[:split_idx]
val_ids = all_ids[split_idx:]
print(f"Training subjects: {len(train_ids)}")
print(f"Validation subjects: {len(val_ids)}")
Training subjects: 50 Validation subjects: 13
In [8]:
Copied!
# Create train/val views (no data duplication!)
# Views are fully supported by the ML API - VolumeReader respects view filtering
radi_train = radi.loc[train_ids]
radi_val = radi.loc[val_ids]
print(f"Train RadiObject: {radi_train} (is_view: {radi_train.is_view})")
print(f"Val RadiObject: {radi_val} (is_view: {radi_val.is_view})")
# Create train/val views (no data duplication!)
# Views are fully supported by the ML API - VolumeReader respects view filtering
radi_train = radi.loc[train_ids]
radi_val = radi.loc[val_ids]
print(f"Train RadiObject: {radi_train} (is_view: {radi_train.is_view})")
print(f"Val RadiObject: {radi_val} (is_view: {radi_val.is_view})")
Train RadiObject: RadiObject(50 subjects, 4 collections: [seg_resampled, seg, CT_resampled, CT]) (view) (is_view: True) Val RadiObject: RadiObject(13 subjects, 4 collections: [seg_resampled, seg, CT_resampled, CT]) (view) (is_view: True)
In [ ]:
Copied!
# Training hyperparameters
BATCH_SIZE = 2
PATCH_SIZE = (96, 96, 96)
# ---------------------------------------------------------------------------
# Using create_segmentation_dataloader for cleaner image/mask separation
#
# This returns {"image": (B,1,D,H,W), "mask": (B,1,D,H,W)} instead of
# stacking CT and mask as channels. Much cleaner for segmentation!
#
# Key parameters:
# - transform: single Compose of MONAI dict transforms. Use key selection
# to control which tensors are affected (e.g., RandFlipd(keys=["image", "mask"])
# for spatial transforms, NormalizeIntensityd(keys="image") for image-only).
# - foreground_sampling: bias patches toward tumor regions (helps with class imbalance)
# When enabled, foreground coordinates are pre-computed once at init to avoid
# repeated I/O during training. Patches are centered on random foreground voxels.
# - patches_per_volume: extract multiple patches per volume per epoch
# ---------------------------------------------------------------------------
train_transform = Compose(
[
RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=[0, 1, 2]),
RandRotate90d(keys=["image", "mask"], prob=0.3, spatial_axes=(0, 1)),
NormalizeIntensityd(keys="image"),
]
)
train_loader = create_segmentation_dataloader(
image=radi_train.CT_resampled,
mask=radi_train.seg_resampled,
batch_size=BATCH_SIZE,
patch_size=PATCH_SIZE,
num_workers=0,
pin_memory=False,
persistent_workers=False,
transform=train_transform,
foreground_sampling=True,
patches_per_volume=2,
)
# Validation: no augmentation, just normalization
val_transform = Compose([NormalizeIntensityd(keys="image")])
val_loader = create_segmentation_dataloader(
image=radi_val.CT_resampled,
mask=radi_val.seg_resampled,
batch_size=BATCH_SIZE,
patch_size=PATCH_SIZE,
num_workers=0,
pin_memory=False,
persistent_workers=False,
transform=val_transform,
foreground_sampling=True,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
# Training hyperparameters
BATCH_SIZE = 2
PATCH_SIZE = (96, 96, 96)
# ---------------------------------------------------------------------------
# Using create_segmentation_dataloader for cleaner image/mask separation
#
# This returns {"image": (B,1,D,H,W), "mask": (B,1,D,H,W)} instead of
# stacking CT and mask as channels. Much cleaner for segmentation!
#
# Key parameters:
# - transform: single Compose of MONAI dict transforms. Use key selection
# to control which tensors are affected (e.g., RandFlipd(keys=["image", "mask"])
# for spatial transforms, NormalizeIntensityd(keys="image") for image-only).
# - foreground_sampling: bias patches toward tumor regions (helps with class imbalance)
# When enabled, foreground coordinates are pre-computed once at init to avoid
# repeated I/O during training. Patches are centered on random foreground voxels.
# - patches_per_volume: extract multiple patches per volume per epoch
# ---------------------------------------------------------------------------
train_transform = Compose(
[
RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=[0, 1, 2]),
RandRotate90d(keys=["image", "mask"], prob=0.3, spatial_axes=(0, 1)),
NormalizeIntensityd(keys="image"),
]
)
train_loader = create_segmentation_dataloader(
image=radi_train.CT_resampled,
mask=radi_train.seg_resampled,
batch_size=BATCH_SIZE,
patch_size=PATCH_SIZE,
num_workers=0,
pin_memory=False,
persistent_workers=False,
transform=train_transform,
foreground_sampling=True,
patches_per_volume=2,
)
# Validation: no augmentation, just normalization
val_transform = Compose([NormalizeIntensityd(keys="image")])
val_loader = create_segmentation_dataloader(
image=radi_val.CT_resampled,
mask=radi_val.seg_resampled,
batch_size=BATCH_SIZE,
patch_size=PATCH_SIZE,
num_workers=0,
pin_memory=False,
persistent_workers=False,
transform=val_transform,
foreground_sampling=True,
)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
In [10]:
Copied!
# Inspect a batch - now with separate image and mask keys
batch = next(iter(train_loader))
print(f"Batch keys: {list(batch.keys())}")
print(f"Image shape: {batch['image'].shape}") # (B, 1, D, H, W) - CT only
print(f"Mask shape: {batch['mask'].shape}") # (B, 1, D, H, W) - segmentation
print(f"Image dtype: {batch['image'].dtype}")
print(f"Memory per batch: {(batch['image'].nbytes + batch['mask'].nbytes) / 1024 / 1024:.1f} MB")
# Verify data ranges after normalization
print(f"\nImage (normalized) range: [{batch['image'].min():.2f}, {batch['image'].max():.2f}]")
print(f"Mask unique values: {torch.unique(batch['mask']).tolist()}")
fg_frac = (batch["mask"] > 0).float().mean().item()
print(f"Foreground fraction: {fg_frac:.4f}")
# Inspect a batch - now with separate image and mask keys
batch = next(iter(train_loader))
print(f"Batch keys: {list(batch.keys())}")
print(f"Image shape: {batch['image'].shape}") # (B, 1, D, H, W) - CT only
print(f"Mask shape: {batch['mask'].shape}") # (B, 1, D, H, W) - segmentation
print(f"Image dtype: {batch['image'].dtype}")
print(f"Memory per batch: {(batch['image'].nbytes + batch['mask'].nbytes) / 1024 / 1024:.1f} MB")
# Verify data ranges after normalization
print(f"\nImage (normalized) range: [{batch['image'].min():.2f}, {batch['image'].max():.2f}]")
print(f"Mask unique values: {torch.unique(batch['mask']).tolist()}")
fg_frac = (batch["mask"] > 0).float().mean().item()
print(f"Foreground fraction: {fg_frac:.4f}")
Batch keys: ['image', 'mask', 'idx', 'patch_idx', 'patch_start', 'obs_id', 'obs_subject_id'] Image shape: torch.Size([2, 1, 96, 96, 96]) Mask shape: torch.Size([2, 1, 96, 96, 96]) Image dtype: torch.float32 Memory per batch: 8.4 MB Image (normalized) range: [-1.49, 6.67] Mask unique values: [0, 1] Foreground fraction: 0.0109
In [11]:
Copied!
# MONAI UNet for 3D segmentation
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2, # background + tumor
channels=(32, 64, 128, 256),
strides=(2, 2, 2),
num_res_units=2,
).to(DEVICE)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# MONAI UNet for 3D segmentation
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2, # background + tumor
channels=(32, 64, 128, 256),
strides=(2, 2, 2),
num_res_units=2,
).to(DEVICE)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Model parameters: 4,739,869
In [12]:
Copied!
# Training configuration
NUM_EPOCHS = 30
LEARNING_RATE = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
criterion = DiceFocalLoss(to_onehot_y=True, softmax=True)
dice_metric = DiceMetric(include_background=False, reduction="mean")
# Training history
history = {
"train_loss": [],
"train_dice": [],
"val_loss": [],
"val_dice": [],
}
# Training configuration
NUM_EPOCHS = 30
LEARNING_RATE = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
criterion = DiceFocalLoss(to_onehot_y=True, softmax=True)
dice_metric = DiceMetric(include_background=False, reduction="mean")
# Training history
history = {
"train_loss": [],
"train_dice": [],
"val_loss": [],
"val_dice": [],
}
In [13]:
Copied!
print(f"Training on {DEVICE} for {NUM_EPOCHS} epochs...\n")
best_val_dice = 0.0
for epoch in range(NUM_EPOCHS):
# Training phase
model.train()
train_loss = 0.0
dice_metric.reset()
for batch in train_loader:
images = batch["image"].to(DEVICE)
labels = batch["mask"].long().to(DEVICE)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
preds = torch.argmax(outputs, dim=1, keepdim=True)
dice_metric(preds, labels)
train_dice = dice_metric.aggregate().item()
# Validation phase
model.eval()
val_loss = 0.0
dice_metric.reset()
with torch.no_grad():
for batch in val_loader:
images = batch["image"].to(DEVICE)
labels = batch["mask"].long().to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
preds = torch.argmax(outputs, dim=1, keepdim=True)
dice_metric(preds, labels)
val_dice = dice_metric.aggregate().item()
scheduler.step()
# Record metrics
history["train_loss"].append(train_loss / len(train_loader))
history["train_dice"].append(train_dice)
history["val_loss"].append(val_loss / len(val_loader))
history["val_dice"].append(val_dice)
improved = val_dice > best_val_dice
if improved:
best_val_dice = val_dice
if (epoch + 1) % 5 == 0 or epoch == 0 or improved:
print(
f"Epoch {epoch + 1:3d}/{NUM_EPOCHS}: "
f"Train Loss={history['train_loss'][-1]:.4f}, "
f"Train Dice={history['train_dice'][-1]:.4f}, "
f"Val Loss={history['val_loss'][-1]:.4f}, "
f"Val Dice={history['val_dice'][-1]:.4f} "
f"{'*BEST*' if improved else ''}"
)
print(f"Training on {DEVICE} for {NUM_EPOCHS} epochs...\n")
best_val_dice = 0.0
for epoch in range(NUM_EPOCHS):
# Training phase
model.train()
train_loss = 0.0
dice_metric.reset()
for batch in train_loader:
images = batch["image"].to(DEVICE)
labels = batch["mask"].long().to(DEVICE)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
preds = torch.argmax(outputs, dim=1, keepdim=True)
dice_metric(preds, labels)
train_dice = dice_metric.aggregate().item()
# Validation phase
model.eval()
val_loss = 0.0
dice_metric.reset()
with torch.no_grad():
for batch in val_loader:
images = batch["image"].to(DEVICE)
labels = batch["mask"].long().to(DEVICE)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
preds = torch.argmax(outputs, dim=1, keepdim=True)
dice_metric(preds, labels)
val_dice = dice_metric.aggregate().item()
scheduler.step()
# Record metrics
history["train_loss"].append(train_loss / len(train_loader))
history["train_dice"].append(train_dice)
history["val_loss"].append(val_loss / len(val_loader))
history["val_dice"].append(val_dice)
improved = val_dice > best_val_dice
if improved:
best_val_dice = val_dice
if (epoch + 1) % 5 == 0 or epoch == 0 or improved:
print(
f"Epoch {epoch + 1:3d}/{NUM_EPOCHS}: "
f"Train Loss={history['train_loss'][-1]:.4f}, "
f"Train Dice={history['train_dice'][-1]:.4f}, "
f"Val Loss={history['val_loss'][-1]:.4f}, "
f"Val Dice={history['val_dice'][-1]:.4f} "
f"{'*BEST*' if improved else ''}"
)
Training on mps for 30 epochs...
Epoch 1/30: Train Loss=0.6532, Train Dice=0.0019, Val Loss=0.5503, Val Dice=0.0000
Epoch 2/30: Train Loss=0.5281, Train Dice=0.0135, Val Loss=0.5156, Val Dice=0.0444 *BEST*
Epoch 3/30: Train Loss=0.5067, Train Dice=0.0987, Val Loss=0.5026, Val Dice=0.1010 *BEST*
Epoch 4/30: Train Loss=0.4937, Train Dice=0.1415, Val Loss=0.4940, Val Dice=0.1500 *BEST*
Epoch 5/30: Train Loss=0.4694, Train Dice=0.2145, Val Loss=0.4468, Val Dice=0.2956 *BEST*
Epoch 7/30: Train Loss=0.3939, Train Dice=0.3716, Val Loss=0.3684, Val Dice=0.4242 *BEST*
Epoch 8/30: Train Loss=0.3440, Train Dice=0.4574, Val Loss=0.3156, Val Dice=0.5220 *BEST*
Epoch 9/30: Train Loss=0.3097, Train Dice=0.5166, Val Loss=0.2947, Val Dice=0.5370 *BEST*
Epoch 10/30: Train Loss=0.2873, Train Dice=0.5576, Val Loss=0.2730, Val Dice=0.5891 *BEST*
Epoch 15/30: Train Loss=0.2825, Train Dice=0.4814, Val Loss=0.2536, Val Dice=0.5345
Epoch 19/30: Train Loss=0.2018, Train Dice=0.6295, Val Loss=0.1858, Val Dice=0.6604 *BEST*
Epoch 20/30: Train Loss=0.1948, Train Dice=0.6388, Val Loss=0.2081, Val Dice=0.6118
Epoch 23/30: Train Loss=0.1808, Train Dice=0.6635, Val Loss=0.1684, Val Dice=0.6845 *BEST*
Epoch 24/30: Train Loss=0.1823, Train Dice=0.6606, Val Loss=0.1705, Val Dice=0.6877 *BEST*
Epoch 25/30: Train Loss=0.1809, Train Dice=0.6633, Val Loss=0.1606, Val Dice=0.7024 *BEST*
Epoch 27/30: Train Loss=0.1693, Train Dice=0.6848, Val Loss=0.1557, Val Dice=0.7099 *BEST*
Epoch 29/30: Train Loss=0.1639, Train Dice=0.6950, Val Loss=0.1550, Val Dice=0.7113 *BEST*
Epoch 30/30: Train Loss=0.1583, Train Dice=0.7036, Val Loss=0.1555, Val Dice=0.7129 *BEST*
In [14]:
Copied!
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Loss
axes[0].plot(history["train_loss"], label="Train")
axes[0].plot(history["val_loss"], label="Validation")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training & Validation Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Dice Score
axes[1].plot(history["train_dice"], label="Train")
axes[1].plot(history["val_dice"], label="Validation")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Dice Score")
axes[1].set_title("Training & Validation Dice")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Loss
axes[0].plot(history["train_loss"], label="Train")
axes[0].plot(history["val_loss"], label="Validation")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training & Validation Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
# Dice Score
axes[1].plot(history["train_dice"], label="Train")
axes[1].plot(history["val_dice"], label="Validation")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Dice Score")
axes[1].set_title("Training & Validation Dice")
axes[1].legend()
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
In [15]:
Copied!
# Final metrics
print("=" * 50)
print("Final Results")
print("=" * 50)
print(f"Best Train Dice: {max(history['train_dice']):.4f}")
print(f"Best Val Dice: {max(history['val_dice']):.4f}")
print(f"Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"Final Val Loss: {history['val_loss'][-1]:.4f}")
# Final metrics
print("=" * 50)
print("Final Results")
print("=" * 50)
print(f"Best Train Dice: {max(history['train_dice']):.4f}")
print(f"Best Val Dice: {max(history['val_dice']):.4f}")
print(f"Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"Final Val Loss: {history['val_loss'][-1]:.4f}")
================================================== Final Results ================================================== Best Train Dice: 0.7036 Best Val Dice: 0.7129 Final Train Loss: 0.1583 Final Val Loss: 0.1555
In [16]:
Copied!
# Full-volume inference on a validation subject
# We run sliding-window inference on the full volume and display the slice with
# the most tumor, making the visualization informative.
model.eval()
# Pick first validation subject and load full volumes via RadiObject API
subject_id = val_ids[0]
ct_vol = radi.loc[subject_id].CT_resampled.iloc[0]
seg_vol = radi.loc[subject_id].seg_resampled.iloc[0]
ct_data = ct_vol.to_numpy().astype(np.float32)
seg_data = seg_vol.to_numpy()
# Find the axial slice with the most tumor area
best_z = int(np.argmax(seg_data.sum(axis=(0, 1))))
# Normalize CT the same way as training and run sliding-window inference
ct_tensor = torch.from_numpy(ct_data).unsqueeze(0).unsqueeze(0) # (1,1,X,Y,Z)
ct_tensor = (ct_tensor - ct_tensor.mean()) / (ct_tensor.std() + 1e-8)
with torch.no_grad():
pred_vol = sliding_window_inference(
ct_tensor.to(DEVICE),
roi_size=PATCH_SIZE,
sw_batch_size=4,
predictor=model,
overlap=0.25,
)
pred_mask = torch.argmax(pred_vol, dim=1).squeeze().cpu().numpy() # (X,Y,Z)
# Compute val Dice on this subject for context
gt_flat = (seg_data > 0).astype(np.float32).ravel()
pred_flat = (pred_mask > 0).astype(np.float32).ravel()
intersection = (gt_flat * pred_flat).sum()
subject_dice = 2 * intersection / (gt_flat.sum() + pred_flat.sum() + 1e-8)
# 3-panel figure at the best slice with standard CT lung windowing
ct_slice = ct_data[:, :, best_z].T
ct_vmin, ct_vmax = -1350, 150 # Standard CT lung window (W=1500, L=-600)
gt_slice = seg_data[:, :, best_z].T > 0
pred_slice = pred_mask[:, :, best_z].T > 0
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# CT
axes[0].imshow(ct_slice, cmap="gray", origin="lower", vmin=ct_vmin, vmax=ct_vmax)
axes[0].set_title(f"{subject_id} - CT (z={best_z})")
axes[0].axis("off")
# Ground Truth overlay (red)
axes[1].imshow(ct_slice, cmap="gray", origin="lower", vmin=ct_vmin, vmax=ct_vmax)
axes[1].imshow(np.ma.masked_where(~gt_slice, gt_slice), cmap="Reds", alpha=0.5, origin="lower")
axes[1].set_title("Ground Truth")
axes[1].axis("off")
# Prediction overlay (blue)
axes[2].imshow(ct_slice, cmap="gray", origin="lower", vmin=ct_vmin, vmax=ct_vmax)
axes[2].imshow(np.ma.masked_where(~pred_slice, pred_slice), cmap="Blues", alpha=0.5, origin="lower")
axes[2].set_title(f"Prediction (Dice={subject_dice:.4f})")
axes[2].axis("off")
plt.suptitle("Full-Volume Inference - Validation Subject")
plt.tight_layout()
plt.show()
# Full-volume inference on a validation subject
# We run sliding-window inference on the full volume and display the slice with
# the most tumor, making the visualization informative.
model.eval()
# Pick first validation subject and load full volumes via RadiObject API
subject_id = val_ids[0]
ct_vol = radi.loc[subject_id].CT_resampled.iloc[0]
seg_vol = radi.loc[subject_id].seg_resampled.iloc[0]
ct_data = ct_vol.to_numpy().astype(np.float32)
seg_data = seg_vol.to_numpy()
# Find the axial slice with the most tumor area
best_z = int(np.argmax(seg_data.sum(axis=(0, 1))))
# Normalize CT the same way as training and run sliding-window inference
ct_tensor = torch.from_numpy(ct_data).unsqueeze(0).unsqueeze(0) # (1,1,X,Y,Z)
ct_tensor = (ct_tensor - ct_tensor.mean()) / (ct_tensor.std() + 1e-8)
with torch.no_grad():
pred_vol = sliding_window_inference(
ct_tensor.to(DEVICE),
roi_size=PATCH_SIZE,
sw_batch_size=4,
predictor=model,
overlap=0.25,
)
pred_mask = torch.argmax(pred_vol, dim=1).squeeze().cpu().numpy() # (X,Y,Z)
# Compute val Dice on this subject for context
gt_flat = (seg_data > 0).astype(np.float32).ravel()
pred_flat = (pred_mask > 0).astype(np.float32).ravel()
intersection = (gt_flat * pred_flat).sum()
subject_dice = 2 * intersection / (gt_flat.sum() + pred_flat.sum() + 1e-8)
# 3-panel figure at the best slice with standard CT lung windowing
ct_slice = ct_data[:, :, best_z].T
ct_vmin, ct_vmax = -1350, 150 # Standard CT lung window (W=1500, L=-600)
gt_slice = seg_data[:, :, best_z].T > 0
pred_slice = pred_mask[:, :, best_z].T > 0
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# CT
axes[0].imshow(ct_slice, cmap="gray", origin="lower", vmin=ct_vmin, vmax=ct_vmax)
axes[0].set_title(f"{subject_id} - CT (z={best_z})")
axes[0].axis("off")
# Ground Truth overlay (red)
axes[1].imshow(ct_slice, cmap="gray", origin="lower", vmin=ct_vmin, vmax=ct_vmax)
axes[1].imshow(np.ma.masked_where(~gt_slice, gt_slice), cmap="Reds", alpha=0.5, origin="lower")
axes[1].set_title("Ground Truth")
axes[1].axis("off")
# Prediction overlay (blue)
axes[2].imshow(ct_slice, cmap="gray", origin="lower", vmin=ct_vmin, vmax=ct_vmax)
axes[2].imshow(np.ma.masked_where(~pred_slice, pred_slice), cmap="Blues", alpha=0.5, origin="lower")
axes[2].set_title(f"Prediction (Dice={subject_dice:.4f})")
axes[2].axis("off")
plt.suptitle("Full-Volume Inference - Validation Subject")
plt.tight_layout()
plt.show()
/Users/samueldsouza/Desktop/Code/RadiObject/.venv/lib/python3.11/site-packages/monai/inferers/utils.py:226: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_variable_indexing.cpp:353.) win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) /Users/samueldsouza/Desktop/Code/RadiObject/.venv/lib/python3.11/site-packages/monai/inferers/utils.py:370: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/python_variable_indexing.cpp:353.) out[idx_zm] += p