Skip to content

ML Architecture

RadiObject's ML module bridges TileDB-backed medical imaging storage with PyTorch's training ecosystem. This document explains the data flow, design decisions, and trade-offs.

Data Flow

TileDB Dense Array (on disk / S3)
VolumeCollection.iloc[i] → Volume
       ├── .to_numpy()     → full 3D array (FULL_VOLUME mode)
       ├── .slice(x,y,z)   → sub-array read (PATCH mode)
       └── .axial(z)       → 2D slice (SLICE_2D mode)
Dataset.__getitem__(idx)
       │  Returns dict: {"image": Tensor, "mask": Tensor, "idx": int, ...}
transform(sample_dict)     → MONAI/TorchIO dict transforms
DataLoader (batching, shuffling, multi-process prefetch)
Training loop

Key insight: PATCH & SLICE_2D mode leverages TileDB's cloud-native sub-array reads — only the requested voxels are decompressed and transferred. No full volume is loaded into memory.

Dataset Types

VolumeCollectionDataset SegmentationDataset
Input 1+ VolumeCollections (stacked as channels) Separate image + mask collections
Output keys "image" (multi-channel) "image", "mask" (single-channel each)
Use case Classification, regression, multi-modal Segmentation with per-key transforms
Labels Flexible LabelSource (column, dict, fn) Mask is the label
Foreground sampling No Yes (pre-computed coordinates)

Both datasets share the same loading modes (FULL_VOLUME, PATCH, SLICE_2D) and configuration model (DatasetConfig).

Loading Modes

FULL_VOLUME

Loads entire 3D volumes. Best for small volumes or when the model expects full spatial context (e.g., whole-brain classification).

  • Length: n_volumes
  • Requires uniform shapes across all collections

PATCH

Extracts random 3D patches. The primary mode for training on large medical volumes.

  • Length: n_volumes × patches_per_volume
  • Supports heterogeneous volume shapes (each volume can differ)
  • Foreground sampling: pre-computes nonzero mask coordinates at init, then samples patch centers from those coordinates (no extra I/O during training)

SLICE_2D

Extracts 2D slices along a configurable orientation axis.

  • Length: n_volumes × n_slices (where n_slices depends on orientation)
  • slice_orientation: AXIAL (default, Z-axis), SAGITTAL (X-axis), CORONAL (Y-axis)
  • Requires uniform shapes across all collections

Transform Integration

Both datasets accept a single transform callable that receives the full sample dict. This aligns with MONAI's dict-transform pattern:

from monai.transforms import Compose, RandFlipd, NormalizeIntensityd

transform = Compose([
    RandFlipd(keys=["image", "mask"], prob=0.5),       # spatial: both
    NormalizeIntensityd(keys="image"),                   # intensity: image only
])

dataset = SegmentationDataset(image=ct, mask=seg, transform=transform)

MONAI dict transforms use keys to select which tensors to modify, so users control transform scope naturally. TorchIO SubjectsDataset is supported via the separate VolumeCollectionSubjectsDataset compatibility layer.

Distributed Training

create_distributed_dataloader wraps VolumeCollectionDataset with PyTorch's DistributedSampler for DDP. It handles data partitioning only — process group initialization and model wrapping are the user's responsibility. SegmentationDataset does not have distributed support; users needing distributed segmentation training should manage the DistributedSampler manually.

loader = create_distributed_dataloader(collections, rank=rank, world_size=world_size)
# User must call set_epoch(loader, epoch) each epoch for proper shuffling

Worker Process Isolation

PyTorch DataLoader forks worker processes. Each forked process needs its own TileDB context because libtiledb's internal state (memory pools, S3 connections) cannot be shared across process boundaries.

worker_init_fn calls ctx_for_process() to create an isolated TileDB context per worker. This is configured automatically by all factory functions when num_workers > 0.

Trade-offs

Optimized for Not covered
Random patch access (TileDB sub-array reads) Streaming / sequential access patterns
Medical imaging volumes (3D/4D dense arrays) Sparse data (point clouds, meshes)
MONAI dict-transforms + TorchIO (via compat layer) Custom transform protocols
Single-node + DDP training Model-parallel or pipeline-parallel
Foreground-biased sampling (pre-computed) Online hard-example mining

RNG Design

Patch sampling RNG differs by dataset: SegmentationDataset uses np.random.default_rng(seed=None) (unseeded, so each call produces a different patch), while VolumeCollectionDataset uses np.random.default_rng(seed=idx) (deterministic per index for reproducibility). Combined with DataLoader shuffling and persistent_workers, this ensures diversity across epochs. For distributed training, DistributedSampler.set_epoch() re-seeds the sampler each epoch.