Skip to content

Machine learning models

The qim3d library aims to ease the creation of ML models for volumetric images.

qim3d.ml.models

qim3d.ml.models.UNet

Bases: Module

3D UNet model designed for imaging segmentation tasks.

Parameters:

Name Type Description Default
size str

Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Default is 'medium'.

'medium'
dropout float

Dropout rate between 0 and 1. Default is 0.

0
kernel_size int

Convolution kernel size. Default is 3.

3
up_kernel_size int

Up-convolution kernel size. Default is 3.

3
activation str

Activation function. Default is 'PReLU'.

'PReLU'
bias bool

Whether to include bias in convolutions. Default is True.

True
adn_order str

ADN (Activation, Dropout, Normalization) ordering. Default is 'NDA'.

'NDA'

Returns:

Name Type Description
model Module

3D UNet model.

Raises:

Type Description
ValueError

If size is not one of 'small', 'medium', or 'large'.

Example
import qim3d

model = qim3d.ml.models.UNet(size = 'small')
Source code in qim3d/ml/models/_unet.py
class UNet(nn.Module):
    """
    3D UNet model designed for imaging segmentation tasks.

    Args:
        size (str, optional): Size of the UNet model. Must be one of 'small', 'medium', or 'large'. Default is 'medium'.
        dropout (float, optional): Dropout rate between 0 and 1. Default is 0.
        kernel_size (int, optional): Convolution kernel size. Default is 3.
        up_kernel_size (int, optional): Up-convolution kernel size. Default is 3.
        activation (str, optional): Activation function. Default is 'PReLU'.
        bias (bool, optional): Whether to include bias in convolutions. Default is True.
        adn_order (str, optional): ADN (Activation, Dropout, Normalization) ordering. Default is 'NDA'.

    Returns:
        model (torch.nn.Module): 3D UNet model.

    Raises:
        ValueError: If `size` is not one of 'small', 'medium', or 'large'.

    Example:
        ```python
        import qim3d

        model = qim3d.ml.models.UNet(size = 'small')
        ```

    """

    def __init__(
        self,
        size: str = 'medium',
        dropout: float = 0,
        kernel_size: int = 3,
        up_kernel_size: int = 3,
        activation: str = 'PReLU',
        bias: bool = True,
        adn_order: str = 'NDA',
    ):
        super().__init__()

        self.size = size
        self.dropout = dropout
        self.kernel_size = kernel_size
        self.up_kernel_size = up_kernel_size
        self.activation = activation
        self.bias = bias
        self.adn_order = adn_order

        self.model = self._model_choice()

    def _model_choice(self) -> nn.Module:
        from monai.networks.nets import UNet as monai_UNet

        size_options = {
            'xxsmall': (4, 8),  # 2 layers
            'xsmall': (16, 32),  # 2 layers
            'small': (32, 64, 128),  # 3 layers
            'medium': (64, 128, 256),  # 3 layers
            'large': (64, 128, 256, 512, 1024),  # 5 layers
            'xlarge': (64, 128, 256, 512, 1024, 2048),  # 6 layers
            'xxlarge': (64, 128, 256, 512, 1024, 2048, 4096),  # 7 layers
        }

        if self.size in size_options:
            self.channels = size_options[self.size]
        else:
            message = (
                f"Unknown size '{self.size}'. Choose from {list(size_options.keys())}"
            )
            raise ValueError(message)

        model = monai_UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=1,
            channels=self.channels,
            strides=(2,) * (len(self.channels) - 1),
            num_res_units=2,
            kernel_size=self.kernel_size,
            up_kernel_size=self.up_kernel_size,
            act=self.activation,
            dropout=self.dropout,
            bias=self.bias,
            adn_ordering=self.adn_order,
        )
        return model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model(x)
        return x

qim3d.ml

qim3d.ml.Augmentation

Class for defining image augmentation transformations using the MONAI library.

Parameters:

Name Type Description Default
resize str

Specifies how the images should be reshaped to the appropriate size, either 'crop', 'resize', or 'padding'. Defaults to 'crop'.

'crop'
trainsform_train str

Level of transformation for the training set, either 'light', 'moderate', 'heavy' or None. Defaults to 'moderate'.

required
transform_validation str

Level of transformation for the validation set, either 'light', 'moderate', 'heavy' or None. Defaults to None.

None
transform_test str

Level of transformation for the test set, either 'light', 'moderate', 'heavy' or None. Defaults to None.

None

Raises:

Type Description
ValueError

If resize is neither 'crop', 'resize' nor 'padding'.

Example
import qim3d

augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
Source code in qim3d/ml/_augmentations.py
class Augmentation:
    """
    Class for defining image augmentation transformations using the MONAI library.

    Args:
        resize (str, optional): Specifies how the images should be reshaped to the appropriate size, either 'crop', 'resize', or 'padding'. Defaults to 'crop'.
        trainsform_train (str, optional): Level of transformation for the training set, either 'light', 'moderate', 'heavy' or None. Defaults to 'moderate'.
        transform_validation (str, optional): Level of transformation for the validation set, either 'light', 'moderate', 'heavy' or None. Defaults to None.
        transform_test (str, optional): Level of transformation for the test set, either 'light', 'moderate', 'heavy' or None. Defaults to None.

    Raises:
        ValueError: If `resize` is neither 'crop', 'resize' nor 'padding'.

    Example:
        ```python
        import qim3d

        augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
        ```

    """

    def __init__(
        self,
        resize: str = 'crop',
        transform_train: str | None = 'moderate',
        transform_validation: str | None = None,
        transform_test: str | None = None,
    ):
        if resize not in ['crop', 'reshape', 'padding']:
            msg = f"Invalid resize type: {resize}. Use either 'crop', 'resize' or 'padding'."
            raise ValueError(msg)

        self.resize = resize
        self.transform_train = transform_train
        self.transform_validation = transform_validation
        self.transform_test = transform_test

    def augment(
        self, img_shape: tuple, level: str | None = None
    ) -> monai.transforms.Compose:
        """
        Creates an augmentation pipeline based on the specified level.

        Args:
            img_shape (tuple): Dimensions of the volume as (D, W, H).
            level (str, optional): Level of augmentation, either 'light', 'moderate', 'heavy' or None. Defaults to None.

        Returns:
            Compose (monai.transforms.Compose): Compose object with the specified augmentations.

        Raises:
            ValueError: If `img_shape` is not 3D.
            ValueError: If `level` is neither None, 'light', 'moderate' nor 'heavy'.

        """
        from monai.transforms import (
            CenterSpatialCropd,
            Compose,
            RandAffined,
            RandFlipd,
            RandGaussianSmoothd,
            RandRotate90d,
            Resized,
            SpatialPadd,
            ToTensor,
        )

        # Check if image is 3D
        if len(img_shape) == 3:
            im_d, im_h, im_w = img_shape

        else:
            msg = f'Invalid image shape: {img_shape}. Must be 3D.'
            raise ValueError(msg)

        # Check if one of standard augmentation levels
        if level not in [None, 'light', 'moderate', 'heavy']:
            msg = f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."
            raise ValueError(msg)

        # Baseline augmentations
        # TODO: Figure out how to properly do normalization in 3D (normalization should be done channel-wise)
        baseline_aug = [ToTensor()]  # , NormalizeIntensityd(keys=["image"])]

        # Resize augmentations
        if self.resize == 'crop':
            resize_aug = [
                CenterSpatialCropd(keys=['image', 'label'], roi_size=(im_d, im_h, im_w))
            ]

        elif self.resize == 'reshape':
            resize_aug = [
                Resized(keys=['image', 'label'], spatial_size=(im_d, im_h, im_w))
            ]

        elif self.resize == 'padding':
            resize_aug = [
                SpatialPadd(keys=['image', 'label'], spatial_size=(im_d, im_h, im_w))
            ]

        # Level of augmentation
        if level is None:
            # No augmentation for the validation and test sets
            level_aug = []
            resize_aug = []

        elif level == 'light':
            # TODO: Do rotations along other axes?
            level_aug = [
                RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1))
            ]

        elif level == 'moderate':
            level_aug = [
                RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1)),
                RandFlipd(keys=['image', 'label'], prob=0.3, spatial_axis=0),
                RandFlipd(keys=['image', 'label'], prob=0.3, spatial_axis=1),
                RandGaussianSmoothd(keys=['image'], sigma_x=(0.7, 0.7), prob=0.1),
                RandAffined(
                    keys=['image', 'label'],
                    prob=0.5,
                    translate_range=(0.1, 0.1),
                    scale_range=(0.9, 1.1),
                ),
            ]

        elif level == 'heavy':
            level_aug = [
                RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1)),
                RandFlipd(keys=['image', 'label'], prob=0.7, spatial_axis=0),
                RandFlipd(keys=['image', 'label'], prob=0.7, spatial_axis=1),
                RandGaussianSmoothd(keys=['image'], sigma_x=(1.2, 1.2), prob=0.3),
                RandAffined(
                    keys=['image', 'label'],
                    prob=0.5,
                    translate_range=(0.2, 0.2),
                    scale_range=(0.8, 1.4),
                    shear_range=(-15, 15),
                ),
            ]

        return Compose(baseline_aug + resize_aug + level_aug)

qim3d.ml.Augmentation.augment

augment(img_shape, level=None)

Creates an augmentation pipeline based on the specified level.

Parameters:

Name Type Description Default
img_shape tuple

Dimensions of the volume as (D, W, H).

required
level str

Level of augmentation, either 'light', 'moderate', 'heavy' or None. Defaults to None.

None

Returns:

Name Type Description
Compose Compose

Compose object with the specified augmentations.

Raises:

Type Description
ValueError

If img_shape is not 3D.

ValueError

If level is neither None, 'light', 'moderate' nor 'heavy'.

Source code in qim3d/ml/_augmentations.py
def augment(
    self, img_shape: tuple, level: str | None = None
) -> monai.transforms.Compose:
    """
    Creates an augmentation pipeline based on the specified level.

    Args:
        img_shape (tuple): Dimensions of the volume as (D, W, H).
        level (str, optional): Level of augmentation, either 'light', 'moderate', 'heavy' or None. Defaults to None.

    Returns:
        Compose (monai.transforms.Compose): Compose object with the specified augmentations.

    Raises:
        ValueError: If `img_shape` is not 3D.
        ValueError: If `level` is neither None, 'light', 'moderate' nor 'heavy'.

    """
    from monai.transforms import (
        CenterSpatialCropd,
        Compose,
        RandAffined,
        RandFlipd,
        RandGaussianSmoothd,
        RandRotate90d,
        Resized,
        SpatialPadd,
        ToTensor,
    )

    # Check if image is 3D
    if len(img_shape) == 3:
        im_d, im_h, im_w = img_shape

    else:
        msg = f'Invalid image shape: {img_shape}. Must be 3D.'
        raise ValueError(msg)

    # Check if one of standard augmentation levels
    if level not in [None, 'light', 'moderate', 'heavy']:
        msg = f"Invalid transformation level: {level}. Please choose one of the following levels: None, 'light', 'moderate', 'heavy'."
        raise ValueError(msg)

    # Baseline augmentations
    # TODO: Figure out how to properly do normalization in 3D (normalization should be done channel-wise)
    baseline_aug = [ToTensor()]  # , NormalizeIntensityd(keys=["image"])]

    # Resize augmentations
    if self.resize == 'crop':
        resize_aug = [
            CenterSpatialCropd(keys=['image', 'label'], roi_size=(im_d, im_h, im_w))
        ]

    elif self.resize == 'reshape':
        resize_aug = [
            Resized(keys=['image', 'label'], spatial_size=(im_d, im_h, im_w))
        ]

    elif self.resize == 'padding':
        resize_aug = [
            SpatialPadd(keys=['image', 'label'], spatial_size=(im_d, im_h, im_w))
        ]

    # Level of augmentation
    if level is None:
        # No augmentation for the validation and test sets
        level_aug = []
        resize_aug = []

    elif level == 'light':
        # TODO: Do rotations along other axes?
        level_aug = [
            RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1))
        ]

    elif level == 'moderate':
        level_aug = [
            RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1)),
            RandFlipd(keys=['image', 'label'], prob=0.3, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.3, spatial_axis=1),
            RandGaussianSmoothd(keys=['image'], sigma_x=(0.7, 0.7), prob=0.1),
            RandAffined(
                keys=['image', 'label'],
                prob=0.5,
                translate_range=(0.1, 0.1),
                scale_range=(0.9, 1.1),
            ),
        ]

    elif level == 'heavy':
        level_aug = [
            RandRotate90d(keys=['image', 'label'], prob=1, spatial_axes=(0, 1)),
            RandFlipd(keys=['image', 'label'], prob=0.7, spatial_axis=0),
            RandFlipd(keys=['image', 'label'], prob=0.7, spatial_axis=1),
            RandGaussianSmoothd(keys=['image'], sigma_x=(1.2, 1.2), prob=0.3),
            RandAffined(
                keys=['image', 'label'],
                prob=0.5,
                translate_range=(0.2, 0.2),
                scale_range=(0.8, 1.4),
                shear_range=(-15, 15),
            ),
        ]

    return Compose(baseline_aug + resize_aug + level_aug)

qim3d.ml.Hyperparameters

Hyperparameters for training the 3D UNet model.

Parameters:

Name Type Description Default
model Module

PyTorch model.

required
n_epochs int

Number of training epochs. Default is 10.

10
learning_rate float

Learning rate for the optimizer. Default is 1e-3.

0.001
optimizer str

Optimizer algorithm. Must be one of 'Adam', 'SGD', 'RMSprop'. Default is 'Adam'.

'Adam'
momentum float

Momentum value for SGD and RMSprop optimizers. Default is 0.

0
weight_decay float

Weight decay (L2 penalty) for the optimizer. Default is 0.

0
loss_function str

Loss function criterion. Must be one of 'BCE', 'Dice', 'Focal', 'DiceCE'. Default is 'BCE'.

'Focal'

Returns:

Name Type Description
hyperparameters dict

Dictionary of hyperparameters.

Raises:

Type Description
ValueError

If loss_function is not one of 'BCE', 'Dice', 'Focal', 'DiceCE'.

ValueError

If optimizer is not one of 'Adam', 'SGD', 'RMSprop'.

Example
import qim3d

# Set up the model and hyperparameters
model = qim3d.ml.UNet(size = 'small')

hyperparameters = qim3d.ml.Hyperparameters(
    model = model,
    n_epochs = 10,
    learning_rate = 5e-3,
    loss_function = 'DiceCE',
    weight_decay  = 1e-3
    )

# Retrieve the hyperparameters
parameters_dict = hyperparameters()

optimizer = params_dict['optimizer']
criterion = params_dict['criterion']
n_epochs  = params_dict['n_epochs']
Source code in qim3d/ml/models/_unet.py
class Hyperparameters:
    """
    Hyperparameters for training the 3D UNet model.

    Args:
        model (torch.nn.Module): PyTorch model.
        n_epochs (int, optional): Number of training epochs. Default is 10.
        learning_rate (float, optional): Learning rate for the optimizer. Default is 1e-3.
        optimizer (str, optional): Optimizer algorithm. Must be one of 'Adam', 'SGD', 'RMSprop'. Default is 'Adam'.
        momentum (float, optional): Momentum value for SGD and RMSprop optimizers. Default is 0.
        weight_decay (float, optional): Weight decay (L2 penalty) for the optimizer. Default is 0.
        loss_function (str, optional): Loss function criterion. Must be one of 'BCE', 'Dice', 'Focal', 'DiceCE'. Default is 'BCE'.

    Returns:
        hyperparameters (dict): Dictionary of hyperparameters.

    Raises:
        ValueError: If `loss_function` is not one of 'BCE', 'Dice', 'Focal', 'DiceCE'.
        ValueError: If `optimizer` is not one of 'Adam', 'SGD', 'RMSprop'.

    Example:
        ```python
        import qim3d

        # Set up the model and hyperparameters
        model = qim3d.ml.UNet(size = 'small')

        hyperparameters = qim3d.ml.Hyperparameters(
            model = model,
            n_epochs = 10,
            learning_rate = 5e-3,
            loss_function = 'DiceCE',
            weight_decay  = 1e-3
            )

        # Retrieve the hyperparameters
        parameters_dict = hyperparameters()

        optimizer = params_dict['optimizer']
        criterion = params_dict['criterion']
        n_epochs  = params_dict['n_epochs']
        ```

    """

    def __init__(
        self,
        model: nn.Module,
        n_epochs: int = 10,
        learning_rate: float = 1e-3,
        optimizer: str = 'Adam',
        momentum: float = 0,
        weight_decay: float = 0,
        loss_function: str = 'Focal',
    ):
        # TODO: Implement custom loss_functions? Then add a check to see if loss works for segmentation.
        if loss_function not in ['BCE', 'Dice', 'Focal', 'DiceCE']:
            msg = f'Invalid loss function: {loss_function}. Loss criterion must be one of the following: "BCE", "Dice", "Focal", "DiceCE".'
            raise ValueError(msg)

        # TODO: Implement custom optimizer? And add check to see if valid.
        if optimizer not in ['Adam', 'SGD', 'RMSprop']:
            msg = f'Invalid optimizer: {optimizer}. Optimizer must be one of the following: "Adam", "SGD", "RMSprop".'
            raise ValueError(msg)

        if (momentum != 0) and optimizer == 'Adam':
            log.info(
                "Momentum isn't an input in the 'Adam' optimizer. "
                "Change optimizer to 'SGD' or 'RMSprop' to use momentum."
            )

        self.model = model
        self.n_epochs = n_epochs
        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.loss_function = loss_function

    def __call__(self):
        return self.model_params(
            self.model,
            self.n_epochs,
            self.optimizer,
            self.learning_rate,
            self.weight_decay,
            self.momentum,
            self.loss_function,
        )

    def model_params(
        self,
        model: nn.Module,
        n_epochs: int,
        optimizer: str,
        learning_rate: float,
        weight_decay: float,
        momentum: float,
        loss_function: str,
    ) -> dict:
        optim = self._optimizer(model, optimizer, learning_rate, weight_decay, momentum)
        criterion = self._loss_functions(loss_function)

        hyper_dict = {
            'optimizer': optim,
            'criterion': criterion,
            'n_epochs': n_epochs,
        }
        return hyper_dict

    # Selecting the optimizer
    def _optimizer(
        self,
        model: nn.Module,
        optimizer: str,
        learning_rate: float,
        weight_decay: float,
        momentum: float,
    ) -> torch.optim.Optimizer:
        from torch.optim import SGD, Adam, RMSprop

        if optimizer == 'Adam':
            optim = Adam(
                model.parameters(), lr=learning_rate, weight_decay=weight_decay
            )
        elif optimizer == 'SGD':
            optim = SGD(
                model.parameters(),
                lr=learning_rate,
                momentum=momentum,
                weight_decay=weight_decay,
            )
        elif optimizer == 'RMSprop':
            optim = RMSprop(
                model.parameters(),
                lr=learning_rate,
                weight_decay=weight_decay,
                momentum=momentum,
            )
        return optim

    # Selecting the loss function
    def _loss_functions(self, loss_function: str) -> torch.nn:
        from monai.losses import DiceCELoss, DiceLoss, FocalLoss
        from torch.nn import BCEWithLogitsLoss

        if loss_function == 'BCE':
            criterion = BCEWithLogitsLoss(reduction='mean')
        elif loss_function == 'Dice':
            criterion = DiceLoss(sigmoid=True, reduction='mean')
        elif loss_function == 'Focal':
            criterion = FocalLoss(reduction='mean')
        elif loss_function == 'DiceCE':
            criterion = DiceCELoss(sigmoid=True, reduction='mean')
        return criterion

qim3d.ml.prepare_datasets

prepare_datasets(path, val_fraction, model, augmentation)

Splits and augments the train/validation/test datasets.

Parameters:

Name Type Description Default
path str

Path to the dataset.

required
val_fraction float

Fraction of the data for the validation set.

required
model Module

PyTorch Model.

required
augmentation Compose

Augmentation class for the dataset with predefined augmentation levels.

required

Returns:

Name Type Description
train_set Subset

Training dataset.

val_set Subset

Validation dataset.

test_set Subset

Testing dataset.

Raises:

Type Description
ValueError

If the validation fraction is not a float, and is not between 0 and 1.

Example
import qim3d

base_path = "C:/dataset/"
model = qim3d.ml.models.UNet(size = 'small')
augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')

# Set up datasets
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
    path = base_path,
    val_fraction = 0.5,
    model = model,
    augmentation = augmentation
    )
Source code in qim3d/ml/_data.py
def prepare_datasets(
    path: str,
    val_fraction: float,
    model: torch.nn.Module,
    augmentation: Augmentation,
) -> tuple[torch.utils.data.Subset, torch.utils.data.Subset, torch.utils.data.Subset]:
    """
    Splits and augments the train/validation/test datasets.

    Args:
        path (str): Path to the dataset.
        val_fraction (float): Fraction of the data for the validation set.
        model (torch.nn.Module): PyTorch Model.
        augmentation (monai.transforms.Compose): Augmentation class for the dataset with predefined augmentation levels.

    Returns:
        train_set (torch.utils.data.Subset): Training dataset.
        val_set (torch.utils.data.Subset): Validation dataset.
        test_set (torch.utils.data.Subset): Testing dataset.

    Raises:
        ValueError: If the validation fraction is not a float, and is not between 0 and 1.

    Example:
        ```python
        import qim3d

        base_path = "C:/dataset/"
        model = qim3d.ml.models.UNet(size = 'small')
        augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')

        # Set up datasets
        train_set, val_set, test_set = qim3d.ml.prepare_datasets(
            path = base_path,
            val_fraction = 0.5,
            model = model,
            augmentation = augmentation
            )
        ```

    """

    if not isinstance(val_fraction, float) or not (0 <= val_fraction < 1):
        msg = 'The validation fraction must be a float between 0 and 1.'
        raise ValueError(msg)

    resize = augmentation.resize
    n_channels = len(model.channels)

    # Get the first image to check the shape
    im_path = Path(path) / 'train'
    first_img = sorted((im_path / 'images').iterdir())[0]

    # Load 3D volume
    image = qim3d.io.load(first_img)
    orig_shape = image.shape

    final_shape = check_resize(orig_shape, resize, n_channels)

    train_set = Dataset(
        root_path=path,
        transform=augmentation.augment(final_shape, level=augmentation.transform_train),
    )
    val_set = Dataset(
        root_path=path,
        transform=augmentation.augment(
            final_shape, level=augmentation.transform_validation
        ),
    )
    test_set = Dataset(
        root_path=path,
        split='test',
        transform=augmentation.augment(final_shape, level=augmentation.transform_test),
    )

    split_idx = int(np.floor(val_fraction * len(train_set)))
    indices = torch.randperm(len(train_set))

    train_set = torch.utils.data.Subset(train_set, indices[split_idx:])
    val_set = torch.utils.data.Subset(val_set, indices[:split_idx])

    return train_set, val_set, test_set

qim3d.ml.prepare_dataloaders

prepare_dataloaders(train_set, val_set, test_set, batch_size, shuffle_train=True, num_workers=8, pin_memory=False)

Prepares the dataloaders for model training.

Parameters:

Name Type Description Default
train_set data

Training dataset.

required
val_set data

Validation dataset.

required
test_set data

Testing dataset.

required
batch_size int

Size of the batches that should be trained upon.

required
shuffle_train bool

Optional input to shuffle the training data (training robustness).

True
num_workers int

Defines how many processes should be run in parallel. Default is 8.

8
pin_memory bool

Loads the datasets as CUDA tensors. Default is False.

False

Returns:

Name Type Description
train_loader DataLoader

Training dataloader.

val_loader DataLoader

Validation dataloader.

test_loader DataLoader

Testing dataloader.

Example
import qim3d

base_path = "C:/dataset/"
model = qim3d.ml.models.UNet(size = 'small')
augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')

# Set up datasets
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
    path = base_path,
    val_fraction = 0.5,
    model = model,
    augmentation = augmentation
    )

# Set up dataloaders
train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
    train_set = train_set,
    val_set = val_set,
    test_set = test_set,
    batch_size = 1,
    )
Source code in qim3d/ml/_data.py
def prepare_dataloaders(
    train_set: torch.utils.data,
    val_set: torch.utils.data,
    test_set: torch.utils.data,
    batch_size: int,
    shuffle_train: bool = True,
    num_workers: int = 8,
    pin_memory: bool = False,
) -> tuple[
    torch.utils.data.DataLoader,
    torch.utils.data.DataLoader,
    torch.utils.data.DataLoader,
]:
    """
    Prepares the dataloaders for model training.

    Args:
        train_set (torch.utils.data): Training dataset.
        val_set (torch.utils.data): Validation dataset.
        test_set (torch.utils.data): Testing dataset.
        batch_size (int): Size of the batches that should be trained upon.
        shuffle_train (bool, optional): Optional input to shuffle the training data (training robustness).
        num_workers (int, optional): Defines how many processes should be run in parallel. Default is 8.
        pin_memory (bool, optional): Loads the datasets as CUDA tensors. Default is False.

    Returns:
        train_loader (torch.utils.data.DataLoader): Training dataloader.
        val_loader (torch.utils.data.DataLoader): Validation dataloader.
        test_loader (torch.utils.data.DataLoader): Testing dataloader.

    Example:
        ```python
        import qim3d

        base_path = "C:/dataset/"
        model = qim3d.ml.models.UNet(size = 'small')
        augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')

        # Set up datasets
        train_set, val_set, test_set = qim3d.ml.prepare_datasets(
            path = base_path,
            val_fraction = 0.5,
            model = model,
            augmentation = augmentation
            )

        # Set up dataloaders
        train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
            train_set = train_set,
            val_set = val_set,
            test_set = test_set,
            batch_size = 1,
            )
        ```

    """
    from torch.utils.data import DataLoader

    train_loader = DataLoader(
        dataset=train_set,
        batch_size=batch_size,
        shuffle=shuffle_train,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    val_loader = DataLoader(
        dataset=val_set,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    test_loader = DataLoader(
        dataset=test_set,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    return train_loader, val_loader, test_loader

qim3d.ml.model_summary

model_summary(model, dataloader)

Prints the summary of a PyTorch model.

Parameters:

Name Type Description Default
model Module

The PyTorch model to summarize.

required
dataloader DataLoader

The data loader used to determine the input shape.

required

Returns:

Name Type Description
summary str

Summary of the model architecture.

Example
import qim3d

base_path = "C:/dataset/"
model = qim3d.ml.models.UNet(size = 'small')
augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')

# Set up datasets and dataloaders
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
    path = base_path,
    val_fraction = 0.5,
    model = model,
    augmentation = augmentation
    )

train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
    train_set = train_set,
    val_set = val_set,
    test_set = test_set,
    batch_size = 1,
    )

# Get model summary
summary = qim3d.ml.model_summary(model, train_loader)
print(summary)
Source code in qim3d/ml/_ml_utils.py
def model_summary(
    model: torch.nn.Module, dataloader: torch.utils.data.DataLoader
) -> ModelStatistics:
    """
    Prints the summary of a PyTorch model.

    Args:
        model (torch.nn.Module): The PyTorch model to summarize.
        dataloader (torch.utils.data.DataLoader): The data loader used to determine the input shape.

    Returns:
        summary (str): Summary of the model architecture.

    Example:
        ```python
        import qim3d

        base_path = "C:/dataset/"
        model = qim3d.ml.models.UNet(size = 'small')
        augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')

        # Set up datasets and dataloaders
        train_set, val_set, test_set = qim3d.ml.prepare_datasets(
            path = base_path,
            val_fraction = 0.5,
            model = model,
            augmentation = augmentation
            )

        train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
            train_set = train_set,
            val_set = val_set,
            test_set = test_set,
            batch_size = 1,
            )

        # Get model summary
        summary = qim3d.ml.model_summary(model, train_loader)
        print(summary)
        ```

    """
    images, _ = next(iter(dataloader))
    batch_size = tuple(images.shape)
    model_s = summary(model, batch_size, depth=torch.inf)

    return model_s

qim3d.ml.train_model

train_model(model, hyperparameters, train_loader, val_loader, checkpoint_directory=None, eval_every=1, print_every=1, plot=True, return_loss=False)

Trains the specified model.

The function trains the model using the data from the training and validation data loaders, according to the specified hyperparameters. Optionally, the final checkpoint of the trained model is saved as a .pth file, the loss curves are plotted, and the loss values are returned.

Parameters:

Name Type Description Default
model Module

PyTorch model.

required
hyperparameters class

Dictionary with n_epochs, optimizer and criterion.

required
train_loader DataLoader

DataLoader for the training data.

required
val_loader DataLoader

DataLoader for the validation data.

required
checkpoint_directory str

Directory to save model checkpoint. Default is None.

None
eval_every int

Frequency of model evaluation. Default is every epoch.

1
print_every int

Frequency of log for model performance. Default is every 5 epochs.

1
plot bool

If True, plots the training and validation loss after the model is done training. Default is True.

True
return_loss bool

If True, returns a dictionary with the history of the train and validation losses. Default is False.

False

Returns:

Name Type Description
train_loss dict

Dictionary with average losses and batch losses for training loop. Only returned when return_loss = True.

val_loss dict

Dictionary with average losses and batch losses for validation loop. Only returned when return_loss = True.

Example
import qim3d

base_path = "C:/dataset/"
model = qim3d.ml.models.UNet(size = 'small')
augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs = 10)

# Set up datasets and dataloaders
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
    path = base_path,
    val_fraction = 0.5,
    model = model,
    augmentation = augmentation
    )

train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
    train_set = train_set,
    val_set = val_set,
    test_set = test_set,
    batch_size = 1,
    )

# Train model
qim3d.ml.train_model(
    model = model,
    hyperparameters = hyperparameters,
    train_loader = train_loader,
    val_loader = val_loader,
    checkpoint_directory = base_path,
    plot = True)
Source code in qim3d/ml/_ml_utils.py
def train_model(
    model: torch.nn.Module,
    hyperparameters: Hyperparameters,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    checkpoint_directory: str = None,
    eval_every: int = 1,
    print_every: int = 1,
    plot: bool = True,
    return_loss: bool = False,
) -> tuple[tuple[float], tuple[float]]:
    """
    Trains the specified model.

    The function trains the model using the data from the training and validation data loaders, according to the specified hyperparameters.
    Optionally, the final checkpoint of the trained model is saved as a .pth file, the loss curves are plotted, and the loss values are returned.

    Args:
        model (torch.nn.Module): PyTorch model.
        hyperparameters (class): Dictionary with n_epochs, optimizer and criterion.
        train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
        val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
        checkpoint_directory (str, optional): Directory to save model checkpoint. Default is None.
        eval_every (int, optional): Frequency of model evaluation. Default is every epoch.
        print_every (int, optional): Frequency of log for model performance. Default is every 5 epochs.
        plot (bool, optional): If True, plots the training and validation loss after the model is done training. Default is True.
        return_loss (bool, optional): If True, returns a dictionary with the history of the train and validation losses. Default is False.

    Returns:
        train_loss (dict): Dictionary with average losses and batch losses for training loop. Only returned when `return_loss = True`.
        val_loss (dict): Dictionary with average losses and batch losses for validation loop. Only returned when `return_loss = True`.

    Example:
        ```python
        import qim3d

        base_path = "C:/dataset/"
        model = qim3d.ml.models.UNet(size = 'small')
        augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
        hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs = 10)

        # Set up datasets and dataloaders
        train_set, val_set, test_set = qim3d.ml.prepare_datasets(
            path = base_path,
            val_fraction = 0.5,
            model = model,
            augmentation = augmentation
            )

        train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
            train_set = train_set,
            val_set = val_set,
            test_set = test_set,
            batch_size = 1,
            )

        # Train model
        qim3d.ml.train_model(
            model = model,
            hyperparameters = hyperparameters,
            train_loader = train_loader,
            val_loader = val_loader,
            checkpoint_directory = base_path,
            plot = True)
        ```

    """
    # Get hyperparameters
    params_dict = hyperparameters()

    n_epochs = params_dict['n_epochs']
    optimizer = params_dict['optimizer']
    criterion = params_dict['criterion']

    # Choosing best device available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.to(device)

    # Avoid logging twice
    log.propagate = False

    # Set up dictionaries to store training and validation losses
    train_loss = {'loss': [], 'batch_loss': []}
    val_loss = {'loss': [], 'batch_loss': []}

    with logging_redirect_tqdm():
        for epoch in tqdm(range(n_epochs), desc='Training epochs', unit='epoch'):
            epoch_loss = 0
            step = 0

            model.train()

            for data in train_loader:
                inputs, targets = data
                inputs = inputs.to(device)
                targets = targets.to(device)

                optimizer.zero_grad()
                outputs = model(inputs)

                loss = criterion(outputs, targets)

                # Backpropagation
                loss.backward()
                optimizer.step()

                epoch_loss += loss.detach().item()
                step += 1

                # Log and store batch training loss
                train_loss['batch_loss'].append(loss.detach().item())

            # Log and store average training loss per epoch
            epoch_loss = epoch_loss / step
            train_loss['loss'].append(epoch_loss)

            if epoch % eval_every == 0:
                eval_loss = 0
                step = 0

                model.eval()

                for data in val_loader:
                    inputs, targets = data
                    inputs = inputs.to(device)
                    targets = targets.to(device)

                    with torch.no_grad():
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)

                    eval_loss += loss.item()
                    step += 1

                    # Log and store batch validation loss
                    val_loss['batch_loss'].append(loss.item())

                # Log and store average validation loss
                eval_loss = eval_loss / step
                val_loss['loss'].append(eval_loss)

                if epoch % print_every == 0:
                    log.info(
                        f"Epoch {epoch: 3}, train loss: {train_loss['loss'][epoch]:.4f}, "
                        f"val loss: {val_loss['loss'][epoch]:.4f}"
                    )

    if checkpoint_directory:
        checkpoint_filename = f'model_{n_epochs}epochs.pth'
        checkpoint_path = os.path.join(checkpoint_directory, checkpoint_filename)

        # Save model checkpoint to .pth file
        torch.save(model.state_dict(), checkpoint_path)
        log.info(f'Model checkpoint saved at: {checkpoint_path}')

    if plot:
        plot_metrics(train_loss, val_loss, labels=['Train', 'Valid.'], show=True)

    if return_loss:
        return train_loss, val_loss

qim3d.ml.load_checkpoint

load_checkpoint(model, checkpoint_path)

Loads a trained model checkpoint from a .pth file.

Parameters:

Name Type Description Default
model Module

The PyTorch model to load the checkpoint into.

required
checkpoint_path str

The path to the model checkpoint .pth file.

required

Returns:

Name Type Description
model Module

The model with the loaded checkpoint.

Example
import qim3d

# Instantiate model architecture
model = qim3d.ml.models.UNet(size = 'small')
checkpoint_path = "C:/dataset/model_10epochs.pth"

# Load checkpoint into model
model = qim3d.ml.load_checkpoint(model, checkpoint_path)
Source code in qim3d/ml/_ml_utils.py
def load_checkpoint(model: torch.nn.Module, checkpoint_path: str) -> torch.nn.Module:
    """
    Loads a trained model checkpoint from a .pth file.

    Args:
        model (torch.nn.Module): The PyTorch model to load the checkpoint into.
        checkpoint_path (str): The path to the model checkpoint .pth file.

    Returns:
        model (torch.nn.Module): The model with the loaded checkpoint.

    Example:
        ```python
        import qim3d

        # Instantiate model architecture
        model = qim3d.ml.models.UNet(size = 'small')
        checkpoint_path = "C:/dataset/model_10epochs.pth"

        # Load checkpoint into model
        model = qim3d.ml.load_checkpoint(model, checkpoint_path)
        ```

    """
    model.load_state_dict(torch.load(checkpoint_path))
    log.info(f'Model checkpoint loaded from: {checkpoint_path}')

    return model

qim3d.ml.test_model

test_model(model, test_set, threshold=0.5)

Performs inference on input data using the specified model.

The input data should be in the form of a list, where each item is a tuple containing the input image tensor and the corresponding target label tensor. The function checks the format and validity of the input data, ensures the model is in evaluation mode, and generates predictions using the model. The input images, target labels, and predicted labels are returned as a tuple.

Parameters:

Name Type Description Default
model Module

The trained model used for predicting segmentations.

required
test_set Dataset

A test dataset containing input images and ground truth label data.

required
threshold float

The threshold value used to binarize the model predictions.

0.5

Returns:

Name Type Description
results list

List of tuples (volume, target, pred) containing the input images, target labels, and predicted labels.

Raises:

Type Description
ValueError

If the data items do not consist of tensors.

Notes
  • The function assumes that the model is not already in evaluation mode (model.eval()).
Example
import qim3d

base_path = "C:/dataset/"
model = qim3d.ml.models.UNet(size = 'small')
augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs = 10)

# Set up datasets and dataloaders
train_set, val_set, test_set = qim3d.ml.prepare_datasets(
    path = base_path,
    val_fraction = 0.5,
    model = model,
    augmentation = augmentation
    )

train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
    train_set = train_set,
    val_set = val_set,
    test_set = test_set,
    batch_size = 1,
    )

# Train model
qim3d.ml.train_model(
    model = model,
    hyperparameters = hyperparameters,
    train_loader = train_loader,
    val_loader = val_loader,
    plot = True)

# Test model
results = qim3d.ml.test_model(
    model = model,
    test_set = test_set
    )

# Get the result of the first test image
volume, target, pred = results[0]
qim3d.viz.slices_grid(pred, num_slices = 5)
Source code in qim3d/ml/_ml_utils.py
def test_model(
    model: torch.nn.Module,
    test_set: torch.utils.data.Dataset,
    threshold: float = 0.5,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Performs inference on input data using the specified model.

    The input data should be in the form of a list, where each item is a tuple containing the input image tensor and the corresponding target label tensor.
    The function checks the format and validity of the input data, ensures the model is in evaluation mode,
    and generates predictions using the model. The input images, target labels, and predicted labels are returned as a tuple.

    Args:
        model (torch.nn.Module): The trained model used for predicting segmentations.
        test_set (torch.utils.data.Dataset): A test dataset containing input images and ground truth label data.
        threshold (float): The threshold value used to binarize the model predictions.

    Returns:
        results (list): List of tuples (volume, target, pred) containing the input images, target labels, and predicted labels.

    Raises:
        ValueError: If the data items do not consist of tensors.

    Notes:
        - The function assumes that the model is not already in evaluation mode (`model.eval()`).

    Example:
        ```python
        import qim3d

        base_path = "C:/dataset/"
        model = qim3d.ml.models.UNet(size = 'small')
        augmentation =  qim3d.ml.Augmentation(resize = 'crop', transform_train = 'light')
        hyperparameters = qim3d.ml.Hyperparameters(model, n_epochs = 10)

        # Set up datasets and dataloaders
        train_set, val_set, test_set = qim3d.ml.prepare_datasets(
            path = base_path,
            val_fraction = 0.5,
            model = model,
            augmentation = augmentation
            )

        train_loader, val_loader, test_loader = qim3d.ml.prepare_dataloaders(
            train_set = train_set,
            val_set = val_set,
            test_set = test_set,
            batch_size = 1,
            )

        # Train model
        qim3d.ml.train_model(
            model = model,
            hyperparameters = hyperparameters,
            train_loader = train_loader,
            val_loader = val_loader,
            plot = True)

        # Test model
        results = qim3d.ml.test_model(
            model = model,
            test_set = test_set
            )

        # Get the result of the first test image
        volume, target, pred = results[0]
        qim3d.viz.slices_grid(pred, num_slices = 5)
        ```

    """
    # Set model to evaluation mode
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    model.eval()

    # List to store results
    results = []

    for volume, target in test_set:
        if not isinstance(volume, torch.Tensor) or not isinstance(target, torch.Tensor):
            msg = 'Data items must consist of tensors'
            raise ValueError(msg)

        # Add batch and channel dimensions
        volume = volume.unsqueeze(0).to(device)  # Shape: [1, 1, D, H, W]
        target = target.unsqueeze(0).to(device)  # Shape: [1, 1, D, H, W]

        with torch.no_grad():
            # Get model predictions (logits)
            output = model(volume)

            # Convert logits to probabilities [0, 1]
            pred = torch.sigmoid(output)

            # Convert to binary mask by thresholding the probabilities
            pred = (pred > threshold).float()

            # Remove batch and channel dimensions
            volume = volume.squeeze().cpu().numpy()
            target = target.squeeze().cpu().numpy()
            pred = pred.squeeze().cpu().numpy()

        # TODO: Compute DICE score between target and prediction?

        # Append results to list
        results.append((volume, target, pred))

    return results