Skip to content

Data visualization

The qim3d library aims to provide easy ways to explore and get insights from volumetric data.

qim3d.viz

qim3d.viz.histogram

histogram(volume, bins='auto', slice_idx=None, vertical_line=None, axis=0, kde=True, log_scale=False, despine=True, show_title=True, color='qim3d', edgecolor=None, figsize=(8, 4.5), element='step', return_fig=False, show=True, ax=None, **sns_kwargs)

Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.

Utilizes seaborn.histplot for visualization.

Parameters:

Name Type Description Default
volume ndarray

A 3D NumPy array representing the volume to be visualized.

required
bins Union[int, str]

Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".

'auto'
axis int

Axis along which to take a slice. Default is 0.

0
slice_idx Union[int, str]

Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis. If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None.

None
vertical_line int

Intensity value for a vertical line to be drawn on the histogram. Default is None.

None
kde bool

Whether to overlay a kernel density estimate. Default is True.

True
log_scale bool

Whether to use a logarithmic scale on the y-axis. Default is False.

False
despine bool

If True, removes the top and right spines from the plot for cleaner appearance. Default is True.

True
show_title bool

If True, displays a title with slice information. Default is True.

True
color str

Color for the histogram bars. If "qim3d", defaults to the qim3d color. Default is "qim3d".

'qim3d'
edgecolor str

Color for the edges of the histogram bars. Default is None.

None
figsize tuple

Size of the figure (width, height). Default is (8, 4.5).

(8, 4.5)
element str

Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step".

'step'
return_fig bool

If True, returns the figure object instead of showing it directly. Default is False.

False
show bool

If True, displays the plot. If False, suppresses display. Default is True.

True
ax Axes

Axes object where the histogram will be plotted. Default is None.

None
**sns_kwargs Union[str, float, int, bool]

Additional keyword arguments for seaborn.histplot.

{}

Returns:

Type Description
Optional[Union[Figure, Axes]]

Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]: If return_fig is True, returns the generated figure object. If return_fig is False and ax is provided, returns the Axes object. Otherwise, returns None.

Raises:

Type Description
ValueError

If axis is not a valid axis index (0, 1, or 2).

ValueError

If slice_idx is an integer and is out of range for the specified axis.

Source code in qim3d/viz/_data_exploration.py
def histogram(
    volume: np.ndarray,
    bins: Union[int, str] = 'auto',
    slice_idx: Union[int, str, None] = None,
    vertical_line: int = None,
    axis: int = 0,
    kde: bool = True,
    log_scale: bool = False,
    despine: bool = True,
    show_title: bool = True,
    color: str = 'qim3d',
    edgecolor: Optional[str] = None,
    figsize: Tuple[float, float] = (8, 4.5),
    element: str = 'step',
    return_fig: bool = False,
    show: bool = True,
    ax: Optional[plt.Axes] = None,
    **sns_kwargs: Union[str, float, int, bool],
) -> Optional[Union[plt.Figure, plt.Axes]]:
    """
    Plots a histogram of voxel intensities from a 3D volume, with options to show a specific slice or the entire volume.

    Utilizes [seaborn.histplot](https://seaborn.pydata.org/generated/seaborn.histplot.html) for visualization.

    Args:
        volume (np.ndarray): A 3D NumPy array representing the volume to be visualized.
        bins (Union[int, str], optional): Number of histogram bins or a binning strategy (e.g., "auto"). Default is "auto".
        axis (int, optional): Axis along which to take a slice. Default is 0.
        slice_idx (Union[int, str], optional): Specifies the slice to visualize. If an integer, it represents the slice index along the selected axis.
                                               If "middle", the function uses the middle slice. If None, the entire volume is visualized. Default is None.
        vertical_line (int, optional): Intensity value for a vertical line to be drawn on the histogram. Default is None.
        kde (bool, optional): Whether to overlay a kernel density estimate. Default is True.
        log_scale (bool, optional): Whether to use a logarithmic scale on the y-axis. Default is False.
        despine (bool, optional): If True, removes the top and right spines from the plot for cleaner appearance. Default is True.
        show_title (bool, optional): If True, displays a title with slice information. Default is True.
        color (str, optional): Color for the histogram bars. If "qim3d", defaults to the qim3d color. Default is "qim3d".
        edgecolor (str, optional): Color for the edges of the histogram bars. Default is None.
        figsize (tuple, optional): Size of the figure (width, height). Default is (8, 4.5).
        element (str, optional): Type of histogram to draw ('bars', 'step', or 'poly'). Default is "step".
        return_fig (bool, optional): If True, returns the figure object instead of showing it directly. Default is False.
        show (bool, optional): If True, displays the plot. If False, suppresses display. Default is True.
        ax (matplotlib.axes.Axes, optional): Axes object where the histogram will be plotted. Default is None.
        **sns_kwargs: Additional keyword arguments for `seaborn.histplot`.

    Returns:
        Optional[matplotlib.figure.Figure or matplotlib.axes.Axes]:
            If `return_fig` is True, returns the generated figure object.
            If `return_fig` is False and `ax` is provided, returns the `Axes` object.
            Otherwise, returns None.

    Raises:
        ValueError: If `axis` is not a valid axis index (0, 1, or 2).
        ValueError: If `slice_idx` is an integer and is out of range for the specified axis.

    """
    if not (0 <= axis < volume.ndim):
        raise ValueError(f'Axis must be an integer between 0 and {volume.ndim - 1}.')

    if slice_idx == 'middle':
        slice_idx = volume.shape[axis] // 2

    if slice_idx is not None:
        if 0 <= slice_idx < volume.shape[axis]:
            img_slice = np.take(volume, indices=slice_idx, axis=axis)
            data = img_slice.ravel()
            title = f'Intensity histogram of slice #{slice_idx} {img_slice.shape} along axis {axis}'
        else:
            raise ValueError(
                f'Slice index out of range. Must be between 0 and {volume.shape[axis] - 1}.'
            )
    else:
        data = volume.ravel()
        title = f'Intensity histogram for whole volume {volume.shape}'

    # Use provided Axes or create new figure
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = None

    if log_scale:
        ax.set_yscale('log')

    if color == 'qim3d':
        color = qim3d.viz.colormaps.qim(1.0)

    sns.histplot(
        data,
        bins=bins,
        kde=kde,
        color=color,
        element=element,
        edgecolor=edgecolor,
        ax=ax,  # Plot directly on the specified Axes
        **sns_kwargs,
    )

    if vertical_line is not None:
        ax.axvline(
            x=vertical_line,
            color='red',
            linestyle='--',
            linewidth=2,
        )

    if despine:
        sns.despine(
            fig=None,
            ax=ax,
            top=True,
            right=True,
            left=False,
            bottom=False,
            offset={'left': 0, 'bottom': 18},
            trim=True,
        )

    ax.set_xlabel('Voxel Intensity')
    ax.set_ylabel('Frequency')

    if show_title:
        ax.set_title(title, fontsize=10)

    # Handle show and return
    if show and fig is not None:
        plt.show()

    if return_fig:
        return fig
    elif ax is not None:
        return ax

qim3d.viz.slicer

slicer(volume, slice_axis=0, color_map='magma', value_min=None, value_max=None, image_height=3, image_width=3, display_positions=False, interpolation=None, image_size=None, color_bar=None, **matplotlib_imshow_kwargs)

Interactive widget for visualizing slices of a 3D volume.

Parameters:

Name Type Description Default
volume ndarray

The 3D volume to be sliced.

required
slice_axis int

Specifies the axis, or dimension, along which to slice. Defaults to 0.

0
color_map str or LinearSegmentedColormap

Specifies the color map for the image. Defaults to 'magma'.

'magma'
value_min float

Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
value_max float

Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
image_height int

Height of the figure. Defaults to 3.

3
image_width int

Width of the figure. Defaults to 3.

3
display_positions bool

If True, displays the position of the slices. Defaults to False.

False
interpolation str

Specifies the interpolation method for the image. Defaults to None.

None
color_bar str

Controls the options for color bar. If None, no color bar is included. If 'volume', the color map range is constant for each slice. If 'slices', the color map range changes dynamically according to the slice. Defaults to None.

None

Returns:

Name Type Description
slicer_obj interactive

The interactive widget for visualizing slices of a 3D volume.

Example

import qim3d

vol = qim3d.examples.bone_128x128x128
qim3d.viz.slicer(vol)
viz slicer

Source code in qim3d/viz/_data_exploration.py
def slicer(
    volume: np.ndarray,
    slice_axis: int = 0,
    color_map: str = 'magma',
    value_min: float = None,
    value_max: float = None,
    image_height: int = 3,
    image_width: int = 3,
    display_positions: bool = False,
    interpolation: Optional[str] = None,
    image_size: int = None,
    color_bar: str = None,
    **matplotlib_imshow_kwargs,
) -> widgets.interactive:
    """
    Interactive widget for visualizing slices of a 3D volume.

    Args:
        volume (np.ndarray): The 3D volume to be sliced.
        slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
        color_map (str or matplotlib.colors.LinearSegmentedColormap, optional): Specifies the color map for the image. Defaults to 'magma'.
        value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        image_height (int, optional): Height of the figure. Defaults to 3.
        image_width (int, optional): Width of the figure. Defaults to 3.
        display_positions (bool, optional): If True, displays the position of the slices. Defaults to False.
        interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
        color_bar (str, optional): Controls the options for color bar. If None, no color bar is included. If 'volume', the color map range is constant for each slice. If 'slices', the color map range changes dynamically according to the slice. Defaults to None.

    Returns:
        slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.

    Example:
        ```python
        import qim3d

        vol = qim3d.examples.bone_128x128x128
        qim3d.viz.slicer(vol)
        ```
        ![viz slicer](../../assets/screenshots/viz-slicer.gif)

    """

    if image_size:
        image_height = image_size
        image_width = image_size

    color_bar_options = [None, 'slices', 'volume']
    if color_bar not in color_bar_options:
        raise ValueError(
            f"Unrecognized value '{color_bar}' for parameter color_bar. "
            f'Expected one of {color_bar_options}.'
        )
    show_color_bar = color_bar is not None
    if color_bar == 'slices':
        # Precompute the minimum and maximum along each slice for faster widget sliding.
        non_slice_axes = tuple(i for i in range(volume.ndim) if i != slice_axis)
        slice_mins = np.min(volume, axis=non_slice_axes)
        slice_maxs = np.max(volume, axis=non_slice_axes)

    # Create the interactive widget
    def _slicer(slice_positions):
        if color_bar == 'slices':
            dynamic_min = slice_mins[slice_positions]
            dynamic_max = slice_maxs[slice_positions]
        else:
            dynamic_min = value_min
            dynamic_max = value_max

        fig = slices_grid(
            volume,
            slice_axis=slice_axis,
            color_map=color_map,
            value_min=dynamic_min,
            value_max=dynamic_max,
            image_height=image_height,
            image_width=image_width,
            display_positions=display_positions,
            interpolation=interpolation,
            slice_positions=slice_positions,
            num_slices=1,
            display_figure=True,
            color_bar=show_color_bar,
            **matplotlib_imshow_kwargs,
        )
        return fig

    position_slider = widgets.IntSlider(
        value=volume.shape[slice_axis] // 2,
        min=0,
        max=volume.shape[slice_axis] - 1,
        description='Slice',
        continuous_update=True,
    )
    slicer_obj = widgets.interactive(_slicer, slice_positions=position_slider)
    slicer_obj.layout = widgets.Layout(align_items='flex-start')

    return slicer_obj

qim3d.viz.slices_grid

slices_grid(volume, slice_axis=0, slice_positions=None, num_slices=15, max_columns=5, color_map='magma', value_min=None, value_max=None, image_size=None, image_height=2, image_width=2, display_figure=False, display_positions=True, interpolation=None, color_bar=False, color_bar_style='small', **matplotlib_imshow_kwargs)

Displays one or several slices from a 3d volume.

By default if slice_positions is None, slices_grid plots num_slices linearly spaced slices. If slice_positions is given as a string or integer, slices_grid will plot an overview with num_slices figures around that position. If slice_positions is given as a list, num_slices will be ignored and the slices from slice_positions will be plotted.

Parameters:

Name Type Description Default
volume ndarray

The 3D volume to be sliced.

required
slice_axis int

Specifies the axis, or dimension, along which to slice. Defaults to 0.

0
slice_positions int or list[int] or str or None

One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.

None
num_slices int

Defines how many slices the user wants to be displayed. Defaults to 15.

15
max_columns int

The maximum number of columns to be plotted. Defaults to 5.

5
color_map str or LinearSegmentedColormap

Specifies the color map for the image. Defaults to "magma".

'magma'
value_min float

Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
value_max float

Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
image_height int

Height of the figure.

2
image_width int

Width of the figure.

2
display_figure bool

If True, displays the plot (i.e. calls plt.show()). Defaults to False.

False
display_positions bool

If True, displays the position of the slices. Defaults to True.

True
interpolation str

Specifies the interpolation method for the image. Defaults to None.

None
color_bar bool

Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.

False
color_bar_style str

Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'.

'small'

Returns:

Name Type Description
fig Figure

The figure with the slices from the 3d array.

Raises:

Type Description
ValueError

If the input is not a numpy.ndarray or da.core.Array.

ValueError

If the slice_axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.

ValueError

If the file or array is not a volume with at least 3 dimensions.

ValueError

If the position keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".

ValueError

If the color_bar_style keyword argument is not one of the following strings: 'small' or 'large'.

Example

import qim3d

vol = qim3d.examples.shell_225x128x128
qim3d.viz.slices_grid(vol, num_slices=15)
Grid of slices

Source code in qim3d/viz/_data_exploration.py
def slices_grid(
    volume: np.ndarray,
    slice_axis: int = 0,
    slice_positions: Optional[Union[str, int, List[int]]] = None,
    num_slices: int = 15,
    max_columns: int = 5,
    color_map: str = 'magma',
    value_min: float = None,
    value_max: float = None,
    image_size: int = None,
    image_height: int = 2,
    image_width: int = 2,
    display_figure: bool = False,
    display_positions: bool = True,
    interpolation: Optional[str] = None,
    color_bar: bool = False,
    color_bar_style: str = 'small',
    **matplotlib_imshow_kwargs,
) -> matplotlib.figure.Figure:
    """
    Displays one or several slices from a 3d volume.

    By default if `slice_positions` is None, slices_grid plots `num_slices` linearly spaced slices.
    If `slice_positions` is given as a string or integer, slices_grid will plot an overview with `num_slices` figures around that position.
    If `slice_positions` is given as a list, `num_slices` will be ignored and the slices from `slice_positions` will be plotted.

    Args:
        volume (np.ndarray): The 3D volume to be sliced.
        slice_axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
        slice_positions (int or list[int] or str or None, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.
        num_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 15.
        max_columns (int, optional): The maximum number of columns to be plotted. Defaults to 5.
        color_map (str or matplotlib.colors.LinearSegmentedColormap, optional): Specifies the color map for the image. Defaults to "magma".
        value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        image_height (int, optional): Height of the figure.
        image_width (int, optional): Width of the figure.
        display_figure (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
        display_positions (bool, optional): If True, displays the position of the slices. Defaults to True.
        interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
        color_bar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.
        color_bar_style (str, optional): Determines the style of the colorbar. Option 'small' is height of one image row. Option 'large' spans full height of image grid. Defaults to 'small'.

    Returns:
        fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.

    Raises:
        ValueError: If the input is not a numpy.ndarray or da.core.Array.
        ValueError: If the slice_axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
        ValueError: If the file or array is not a volume with at least 3 dimensions.
        ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".
        ValueError: If the color_bar_style keyword argument is not one of the following strings: 'small' or 'large'.

    Example:
        ```python
        import qim3d

        vol = qim3d.examples.shell_225x128x128
        qim3d.viz.slices_grid(vol, num_slices=15)
        ```
        ![Grid of slices](../../assets/screenshots/viz-slices.png)

    """
    if image_size:
        image_height = image_size
        image_width = image_size

    # If we pass python None to the imshow function, it will set to
    # default value 'antialiased'
    if interpolation is None:
        interpolation = 'none'

    # Numpy array or Torch tensor input
    if not isinstance(volume, (np.ndarray, da.core.Array)):
        raise ValueError('Data type not supported')

    if volume.ndim < 3:
        raise ValueError(
            'The provided object is not a volume as it has less than 3 dimensions.'
        )

    color_bar_style_options = ['small', 'large']
    if color_bar_style not in color_bar_style_options:
        raise ValueError(
            f"Value '{color_bar_style}' is not valid for colorbar style. Please select from {color_bar_style_options}."
        )

    if isinstance(volume, da.core.Array):
        volume = volume.compute()

    # Ensure axis is a valid choice
    if not (0 <= slice_axis < volume.ndim):
        raise ValueError(
            f"Invalid value for 'slice_axis'. It should be an integer between 0 and {volume.ndim - 1}."
        )

    # Here we deal with the case that the user wants to use the objects colormap directly
    if (
        type(color_map) == matplotlib.colors.LinearSegmentedColormap
        or color_map == 'segmentation'
    ):
        num_labels = len(np.unique(volume))

        if color_map == 'segmentation':
            color_map = qim3d.viz.colormaps.segmentation(num_labels)
        # If value_min and value_max are not set like this, then in case the
        # number of objects changes on new slice, objects might change
        # colors. So when using a slider, the same object suddently
        # changes color (flickers), which is confusing and annoying.
        value_min = 0
        value_max = num_labels

    # Get total number of slices in the specified dimension
    n_total = volume.shape[slice_axis]

    # Position is not provided - will use linearly spaced slices
    if slice_positions is None:
        slice_idxs = np.linspace(0, n_total - 1, num_slices, dtype=int)
    # Position is a string
    elif isinstance(slice_positions, str) and slice_positions.lower() in [
        'start',
        'mid',
        'end',
    ]:
        if slice_positions.lower() == 'start':
            slice_idxs = _get_slice_range(0, num_slices, n_total)
        elif slice_positions.lower() == 'mid':
            slice_idxs = _get_slice_range(n_total // 2, num_slices, n_total)
        elif slice_positions.lower() == 'end':
            slice_idxs = _get_slice_range(n_total - 1, num_slices, n_total)
    #  Position is an integer
    elif isinstance(slice_positions, int):
        slice_idxs = _get_slice_range(slice_positions, num_slices, n_total)
    # Position is a list of integers
    elif isinstance(slice_positions, list) and all(
        isinstance(idx, int) for idx in slice_positions
    ):
        slice_idxs = slice_positions
    else:
        raise ValueError(
            'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
        )

    # Make grid
    nrows = math.ceil(num_slices / max_columns)
    ncols = min(num_slices, max_columns)

    # Generate figure
    fig, axs = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=(ncols * image_height, nrows * image_width),
        constrained_layout=True,
    )

    if nrows == 1:
        axs = [axs]  # Convert to a list for uniformity

    # Convert to NumPy array in order to use the numpy.take method
    if isinstance(volume, da.core.Array):
        volume = volume.compute()

    if color_bar:
        # In this case, we want the vrange to be constant across the
        # slices, which makes them all comparable to a single color_bar.
        new_value_min = value_min if value_min is not None else np.min(volume)
        new_value_max = value_max if value_max is not None else np.max(volume)

    # Run through each ax of the grid
    for i, ax_row in enumerate(axs):
        for j, ax in enumerate(np.atleast_1d(ax_row)):
            slice_idx = i * max_columns + j
            try:
                slice_img = volume.take(slice_idxs[slice_idx], axis=slice_axis)

                if not color_bar:
                    # If value_min is higher than the highest value in the
                    # image ValueError is raised. We don't want to
                    # override the values because next slices might be okay
                    new_value_min = (
                        None
                        if (
                            isinstance(value_min, (float, int))
                            and value_min > np.max(slice_img)
                        )
                        else value_min
                    )
                    new_value_max = (
                        None
                        if (
                            isinstance(value_max, (float, int))
                            and value_max < np.min(slice_img)
                        )
                        else value_max
                    )

                ax.imshow(
                    slice_img,
                    cmap=color_map,
                    interpolation=interpolation,
                    vmin=new_value_min,
                    vmax=new_value_max,
                    **matplotlib_imshow_kwargs,
                )

                if display_positions:
                    ax.text(
                        0.0,
                        1.0,
                        f'slice {slice_idxs[slice_idx]} ',
                        transform=ax.transAxes,
                        color='white',
                        fontsize=8,
                        va='top',
                        ha='left',
                        bbox=dict(facecolor='#303030', linewidth=0, pad=0),
                    )

                    ax.text(
                        1.0,
                        0.0,
                        f'axis {slice_axis} ',
                        transform=ax.transAxes,
                        color='white',
                        fontsize=8,
                        va='bottom',
                        ha='right',
                        bbox=dict(facecolor='#303030', linewidth=0, pad=0),
                    )

            except IndexError:
                # Not a problem, because we simply do not have a slice to show
                pass

            # Hide the axis, so that we have a nice grid
            ax.axis('off')

    if color_bar:
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', category=UserWarning)
            fig.tight_layout()

        norm = matplotlib.colors.Normalize(
            vmin=new_value_min, vmax=new_value_max, clip=True
        )
        mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=color_map)

        if color_bar_style == 'small':
            # Figure coordinates of top-right axis
            tr_pos = np.atleast_1d(axs[0])[-1].get_position()
            # The width is divided by ncols to make it the same relative size to the images
            color_bar_ax = fig.add_axes(
                [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
            )
            fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation='vertical')
        elif color_bar_style == 'large':
            # Figure coordinates of bottom- and top-right axis
            br_pos = np.atleast_1d(axs[-1])[-1].get_position()
            tr_pos = np.atleast_1d(axs[0])[-1].get_position()
            # The width is divided by ncols to make it the same relative size to the images
            color_bar_ax = fig.add_axes(
                [
                    br_pos.xmax + 0.05 / ncols,
                    br_pos.y0 + 0.0015,
                    0.05 / ncols,
                    (tr_pos.y1 - br_pos.y0) - 0.0015,
                ]
            )
            fig.colorbar(mappable=mappable, cax=color_bar_ax, orientation='vertical')

    if display_figure:
        plt.show()

    plt.close()

    return fig

qim3d.viz.slicer_orthogonal

slicer_orthogonal(volume, color_map='magma', value_min=None, value_max=None, image_height=3, image_width=3, display_positions=False, interpolation=None, image_size=None)

Interactive widget for visualizing orthogonal slices of a 3D volume.

Parameters:

Name Type Description Default
volume ndarray

The 3D volume to be sliced.

required
color_map str or LinearSegmentedColormap

Specifies the color map for the image. Defaults to "magma".

'magma'
value_min float

Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
value_max float

Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
image_height int

Height of the figure.

3
image_width int

Width of the figure.

3
display_positions bool

If True, displays the position of the slices. Defaults to False.

False
interpolation str

Specifies the interpolation method for the image. Defaults to None.

None

Returns:

Name Type Description
slicer_orthogonal_obj HBox

The interactive widget for visualizing orthogonal slices of a 3D volume.

Example

import qim3d

vol = qim3d.examples.fly_150x256x256
qim3d.viz.slicer_orthogonal(vol, color_map="magma")
viz slicer_orthogonal

Source code in qim3d/viz/_data_exploration.py
def slicer_orthogonal(
    volume: np.ndarray,
    color_map: str = 'magma',
    value_min: float = None,
    value_max: float = None,
    image_height: int = 3,
    image_width: int = 3,
    display_positions: bool = False,
    interpolation: Optional[str] = None,
    image_size: int = None,
) -> widgets.interactive:
    """
    Interactive widget for visualizing orthogonal slices of a 3D volume.

    Args:
        volume (np.ndarray): The 3D volume to be sliced.
        color_map (str or matplotlib.colors.LinearSegmentedColormap, optional): Specifies the color map for the image. Defaults to "magma".
        value_min (float, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        value_max (float, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        image_height (int, optional): Height of the figure.
        image_width (int, optional): Width of the figure.
        display_positions (bool, optional): If True, displays the position of the slices. Defaults to False.
        interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.

    Returns:
        slicer_orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume.

    Example:
        ```python
        import qim3d

        vol = qim3d.examples.fly_150x256x256
        qim3d.viz.slicer_orthogonal(vol, color_map="magma")
        ```
        ![viz slicer_orthogonal](../../assets/screenshots/viz-orthogonal.gif)

    """

    if image_size:
        image_height = image_size
        image_width = image_size

    get_slicer_for_axis = lambda slice_axis: slicer(
        volume,
        slice_axis=slice_axis,
        color_map=color_map,
        value_min=value_min,
        value_max=value_max,
        image_height=image_height,
        image_width=image_width,
        display_positions=display_positions,
        interpolation=interpolation,
    )

    z_slicer = get_slicer_for_axis(slice_axis=0)
    y_slicer = get_slicer_for_axis(slice_axis=1)
    x_slicer = get_slicer_for_axis(slice_axis=2)

    z_slicer.children[0].description = 'Z'
    y_slicer.children[0].description = 'Y'
    x_slicer.children[0].description = 'X'

    return widgets.HBox([z_slicer, y_slicer, x_slicer])

qim3d.viz.circles

circles(blobs, vol, alpha=0.5, color='#ff9900', **kwargs)

Plots the blobs found on a slice of the volume.

This function takes in a 3D volume and a list of blobs (detected features) and plots the blobs on a specified slice of the volume. If no slice is specified, it defaults to the middle slice of the volume.

Parameters:

Name Type Description Default
blobs ndarray

An array-like object of blobs, where each blob is represented as a 4-tuple (p, r, c, radius). Usually the result of qim3d.processing.blob_detection(vol)

required
vol ndarray

The 3D volume on which to plot the blobs.

required
alpha float

The transparency of the blobs. Defaults to 0.5.

0.5
color str

The color of the blobs. Defaults to "#ff9900".

'#ff9900'
**kwargs Any

Arbitrary keyword arguments for the slices function.

{}

Returns:

Name Type Description
slicer_obj interactive

An interactive widget for visualizing the blobs.

Example

import qim3d
import qim3d.detection

# Get data
vol = qim3d.examples.cement_128x128x128

# Detect blobs, and get binary mask
blobs, _ = qim3d.detection.blobs(
    vol,
    min_sigma=1,
    max_sigma=8,
    threshold=0.001,
    overlap=0.1,
    background="bright"
    )

# Visualize detected blobs with circles method
qim3d.viz.circles(blobs, vol, alpha=0.8, color='blue')
blob detection

Source code in qim3d/viz/_detection.py
def circles(
    blobs: tuple[float, float, float, float],
    vol: np.ndarray,
    alpha: float = 0.5,
    color: str = '#ff9900',
    **kwargs,
) -> widgets.interactive:
    """
    Plots the blobs found on a slice of the volume.

    This function takes in a 3D volume and a list of blobs (detected features)
    and plots the blobs on a specified slice of the volume. If no slice is specified,
    it defaults to the middle slice of the volume.

    Args:
        blobs (np.ndarray): An array-like object of blobs, where each blob is represented
            as a 4-tuple (p, r, c, radius). Usually the result of `qim3d.processing.blob_detection(vol)`
        vol (np.ndarray): The 3D volume on which to plot the blobs.
        alpha (float, optional): The transparency of the blobs. Defaults to 0.5.
        color (str, optional): The color of the blobs. Defaults to "#ff9900".
        **kwargs (Any): Arbitrary keyword arguments for the `slices` function.

    Returns:
        slicer_obj (ipywidgets.interactive): An interactive widget for visualizing the blobs.

    Example:
        ```python
        import qim3d
        import qim3d.detection

        # Get data
        vol = qim3d.examples.cement_128x128x128

        # Detect blobs, and get binary mask
        blobs, _ = qim3d.detection.blobs(
            vol,
            min_sigma=1,
            max_sigma=8,
            threshold=0.001,
            overlap=0.1,
            background="bright"
            )

        # Visualize detected blobs with circles method
        qim3d.viz.circles(blobs, vol, alpha=0.8, color='blue')
        ```
        ![blob detection](../../assets/screenshots/blob_detection.gif)

    """

    def _slicer(z_slice):
        clear_output(wait=True)
        fig = qim3d.viz.slices_grid(
            vol[z_slice : z_slice + 1],
            num_slices=1,
            color_map='gray',
            display_figure=False,
            display_positions=False,
            **kwargs,
        )
        # Add circles from deteced blobs
        for detected in blobs:
            z, y, x, s = detected
            if abs(z - z_slice) < s:  # The blob is in the slice
                # Adjust the radius based on the distance from the center of the sphere
                distance_from_center = abs(z - z_slice)
                angle = (
                    np.pi / 2 * (distance_from_center / s)
                )  # Angle varies from 0 at the center to pi/2 at the edge
                adjusted_radius = s * np.cos(angle)  # Radius follows a cosine curve

                if adjusted_radius > 0.5:
                    c = plt.Circle(
                        (x, y),
                        adjusted_radius,
                        color=color,
                        linewidth=0,
                        fill=True,
                        alpha=alpha,
                    )
                    fig.get_axes()[0].add_patch(c)

        display(fig)
        return fig

    position_slider = widgets.IntSlider(
        value=vol.shape[0] // 2,
        min=0,
        max=vol.shape[0] - 1,
        description='Slice',
        continuous_update=True,
    )
    slicer_obj = widgets.interactive(_slicer, z_slice=position_slider)
    slicer_obj.layout = widgets.Layout(align_items='flex-start')

    return slicer_obj

qim3d.viz.chunks

chunks(zarr_path, **kwargs)

Function to visualize chunks of a Zarr dataset using the specified visualization method.

Parameters:

Name Type Description Default
zarr_path str or PathLike

Path to the Zarr dataset.

required
**kwargs Any

Additional keyword arguments to pass to the visualization method.

{}
Example

import qim3d

# Download dataset
downloader = qim3d.io.Downloader()
data = downloader.Snail.Escargot(load_file=True)

# Export as OME-Zarr
qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2, replace=True)

# Explore chunks
qim3d.viz.chunks("Escargot.zarr")
chunks-visualization

Source code in qim3d/viz/_data_exploration.py
def chunks(zarr_path: str, **kwargs) -> widgets.interactive:
    """
    Function to visualize chunks of a Zarr dataset using the specified visualization method.

    Args:
        zarr_path (str or os.PathLike): Path to the Zarr dataset.
        **kwargs (Any): Additional keyword arguments to pass to the visualization method.

    Example:
        ```python
        import qim3d

        # Download dataset
        downloader = qim3d.io.Downloader()
        data = downloader.Snail.Escargot(load_file=True)

        # Export as OME-Zarr
        qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2, replace=True)

        # Explore chunks
        qim3d.viz.chunks("Escargot.zarr")
        ```
        ![chunks-visualization](../../assets/screenshots/chunks_visualization.gif)

    """
    import zarr

    # Load the Zarr dataset
    zarr_data = zarr.open(zarr_path, mode='r')

    # Save arguments for later use
    # visualization_method = visualization_method
    # preserved_kwargs = kwargs

    # Create label to display the chunk coordinates
    widget_title = widgets.HTML('<h2>Chunk Explorer</h2>')
    chunk_info_label = widgets.HTML(value='Chunk info will be displayed here')

    def load_and_visualize(
        scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
    ):
        # Get chunk shape for the selected scale
        chunk_shape = zarr_data[scale].chunks

        # Calculate slice indices for the selected chunk
        slices = (
            slice(
                z_coord * chunk_shape[0],
                min((z_coord + 1) * chunk_shape[0], zarr_data[scale].shape[0]),
            ),
            slice(
                y_coord * chunk_shape[1],
                min((y_coord + 1) * chunk_shape[1], zarr_data[scale].shape[1]),
            ),
            slice(
                x_coord * chunk_shape[2],
                min((x_coord + 1) * chunk_shape[2], zarr_data[scale].shape[2]),
            ),
        )

        # Extract start and stop values from each slice object
        z_start, z_stop = slices[0].start, slices[0].stop
        y_start, y_stop = slices[1].start, slices[1].stop
        x_start, x_stop = slices[2].start, slices[2].stop

        # Extract the chunk
        chunk = zarr_data[scale][slices]

        # Update the chunk info label with the chunk coordinates
        info_string = (
            f'<b>shape:</b> {chunk_shape}\n'
            + f'<b>coordinates:</b> ({z_coord}, {y_coord}, {x_coord})\n'
            + f'<b>ranges: </b>Z({z_start}-{z_stop})   Y({y_start}-{y_stop})   X({x_start}-{x_stop})\n'
            + f'<b>dtype:</b> {chunk.dtype}\n'
            + f'<b>min value:</b> {np.min(chunk)}\n'
            + f'<b>max value:</b> {np.max(chunk)}\n'
            + f'<b>mean value:</b> {np.mean(chunk)}\n'
        )

        chunk_info_label.value = f"""
            <div style="font-size: 14px; text-align: left; margin-left:32px">
                <h3 style="margin: 0px">Chunk Info</h3>
                    <div style="font-size: 14px; text-align: left;">
                    <pre>{info_string}</pre>
                    </div>
            </div>

            """

        # Prepare chunk visualization based on the selected method
        if visualization_method == 'slicer':  # return a widget
            viz_widget = qim3d.viz.slicer(chunk, **kwargs)
        elif visualization_method == 'slices':  # return a plt.Figure
            viz_widget = widgets.Output()
            with viz_widget:
                viz_widget.clear_output(wait=True)
                fig = qim3d.viz.slices_grid(chunk, **kwargs)
                display(fig)
        elif visualization_method == 'volume':
            viz_widget = widgets.Output()
            with viz_widget:
                viz_widget.clear_output(wait=True)
                out = qim3d.viz.volumetric(chunk, show=False, **kwargs)
                display(out)
        else:
            log.info(f'Invalid visualization method: {visualization_method}')

        return viz_widget

    # Function to calculate the number of chunks for each dimension, including partial chunks
    def get_num_chunks(shape, chunk_size):
        return [(s + chunk_size[i] - 1) // chunk_size[i] for i, s in enumerate(shape)]

    scale_options = {
        f'{i} {zarr_data[i].shape}': i for i in range(len(zarr_data))
    }  # len(zarr_data) gives number of scales

    description_width = '128px'
    # Create dropdown for scale
    scale_dropdown = widgets.Dropdown(
        options=scale_options,
        value=0,  # Default to first scale
        description='OME-Zarr scale',
        style={'description_width': description_width, 'text_align': 'left'},
    )

    # Initialize the options for x, y, and z based on the first scale by default
    multiscale_shape = zarr_data[0].shape
    chunk_shape = zarr_data[0].chunks
    num_chunks = get_num_chunks(multiscale_shape, chunk_shape)

    z_dropdown = widgets.Dropdown(
        options=list(range(num_chunks[0])),
        value=0,
        description='First dimension (Z)',
        style={'description_width': description_width, 'text_align': 'left'},
    )

    y_dropdown = widgets.Dropdown(
        options=list(range(num_chunks[1])),
        value=0,
        description='Second dimension (Y)',
        style={'description_width': description_width, 'text_align': 'left'},
    )

    x_dropdown = widgets.Dropdown(
        options=list(range(num_chunks[2])),
        value=0,
        description='Third dimension (X)',
        style={'description_width': description_width, 'text_align': 'left'},
    )

    method_dropdown = widgets.Dropdown(
        options=['slicer', 'slices', 'volume'],
        value='slicer',
        description='Visualization',
        style={'description_width': description_width, 'text_align': 'left'},
    )

    # Funtion to temporarily disable observers
    def disable_observers():
        x_dropdown.unobserve(update_visualization, names='value')
        y_dropdown.unobserve(update_visualization, names='value')
        z_dropdown.unobserve(update_visualization, names='value')
        method_dropdown.unobserve(update_visualization, names='value')

    # Funtion to enable observers
    def enable_observers():
        x_dropdown.observe(update_visualization, names='value')
        y_dropdown.observe(update_visualization, names='value')
        z_dropdown.observe(update_visualization, names='value')
        method_dropdown.observe(update_visualization, names='value')

    # Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0
    def update_coordinate_dropdowns(scale):
        disable_observers()  # to avoid multiple reload of the visualization when updating the dropdowns

        multiscale_shape = zarr_data[scale].shape
        chunk_shape = zarr_data[scale].chunks
        num_chunks = get_num_chunks(
            multiscale_shape, chunk_shape
        )  # Calculate  new chunk options

        # Reset X, Y, Z dropdowns to 0
        z_dropdown.options = list(range(num_chunks[0]))
        z_dropdown.value = 0  # Reset to 0
        z_dropdown.disabled = (
            len(z_dropdown.options) == 1
        )  # Disable if only one option (0) is available

        y_dropdown.options = list(range(num_chunks[1]))
        y_dropdown.value = 0  # Reset to 0
        y_dropdown.disabled = (
            len(y_dropdown.options) == 1
        )  # Disable if only one option (0) is available

        x_dropdown.options = list(range(num_chunks[2]))
        x_dropdown.value = 0  # Reset to 0
        x_dropdown.disabled = (
            len(x_dropdown.options) == 1
        )  # Disable if only one option (0) is available

        enable_observers()

        update_visualization()

    # Function to update the visualization when any dropdown value changes
    def update_visualization(*args):
        scale = scale_dropdown.value
        x_coord = x_dropdown.value
        y_coord = y_dropdown.value
        z_coord = z_dropdown.value
        visualization_method = method_dropdown.value

        # Clear and update the chunk visualization
        slicer_widget = load_and_visualize(
            scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
        )

        # Recreate the layout and display the new visualization
        final_layout.children = [widget_title, hbox_layout, slicer_widget]

    # Attach an observer to scale dropdown to update x, y, z dropdowns when the scale changes
    scale_dropdown.observe(
        lambda change: update_coordinate_dropdowns(scale_dropdown.value), names='value'
    )

    enable_observers()

    # Create first visualization
    slicer_widget = load_and_visualize(
        scale_dropdown.value,
        z_dropdown.value,
        y_dropdown.value,
        x_dropdown.value,
        method_dropdown.value,
        **kwargs,
    )

    # Create the layout
    vbox_dropbox = widgets.VBox(
        [scale_dropdown, z_dropdown, y_dropdown, x_dropdown, method_dropdown]
    )
    hbox_layout = widgets.HBox([vbox_dropbox, chunk_info_label])
    final_layout = widgets.VBox([widget_title, hbox_layout, slicer_widget])

    # Display the VBox
    display(final_layout)

qim3d.viz.itk_vtk

itk_vtk(filename=None, open_browser=True, file_server_port=8042, viewer_port=3000)

Command to run in cli/init.py. Tries to run the vizualization, if that fails, asks the user to install it. This function is needed here so we don't have to import NotInstalledError and Installer, which exposes these to user.

Source code in qim3d/viz/itk_vtk_viewer/run.py
def itk_vtk(
    filename: str = None,
    open_browser: bool = True,
    file_server_port: int = 8042,
    viewer_port: int = 3000,
):
    """
    Command to run in cli/__init__.py. Tries to run the vizualization,
    if that fails, asks the user to install it. This function is needed
    here so we don't have to import NotInstalledError and Installer,
    which exposes these to user.
    """

    try:
        try_opening_itk_vtk(
            filename,
            open_browser=open_browser,
            file_server_port=file_server_port,
            viewer_port=viewer_port,
        )

    except NotInstalledError:
        message = "Itk-vtk-viewer is not installed or qim3d can not find it.\nYou can either:\n\to  Use 'qim3d viz SOURCE -m k3d' to display data using different method\n\to  Install itk-vtk-viewer yourself following https://kitware.github.io/itk-vtk-viewer/docs/cli.html#Installation\n\to  Let qim3D install itk-vtk-viewer now (it will also install node.js in qim3d library)\nDo you want qim3D to install itk-vtk-viewer now?"
        print(message)
        answer = input('[Y/n]:')
        if answer in 'Yy':
            Installer().install()
            try_opening_itk_vtk(
                filename,
                open_browser=open_browser,
                file_server_port=file_server_port,
                viewer_port=viewer_port,
            )

qim3d.viz.volumetric

volumetric(img, aspectmode='data', show=True, save=False, grid_visible=False, color_map='magma', constant_opacity=False, vmin=None, vmax=None, samples='auto', max_voxels=512 ** 3, data_type='scaled_float16', **kwargs)

Visualizes a 3D volume using volumetric rendering.

Parameters:

Name Type Description Default
img ndarray

The input 3D image data. It should be a 3D numpy array.

required
aspectmode str

Determines the proportions of the scene's axes. Defaults to "data". If 'data', the axes are drawn in proportion with the axes' ranges. If 'cube', the axes are drawn as a cube, regardless of the axes' ranges.

'data'
show bool

If True, displays the visualization inline. Defaults to True.

True
save bool or str

If True, saves the visualization as an HTML file. If a string is provided, it's interpreted as the file path where the HTML file will be saved. Defaults to False.

False
grid_visible bool

If True, the grid is visible in the plot. Defaults to False.

False
color_map str or Colormap or list

The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to 'magma'.

'magma'
constant_opacity bool

Set to True if doing an object label visualization with a corresponding color_map; otherwise, the plot may appear poorly. Defaults to False.

False
vmin float or None

Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
vmax float or None

Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
samples int or auto

The number of samples to be used for the volume rendering in k3d. Input 'auto' for auto selection. Defaults to 'auto'. Lower values will render faster but with lower quality.

'auto'
max_voxels int

Defaults to 512^3.

512 ** 3
data_type str

Default to 'scaled_float16'.

'scaled_float16'
**kwargs Any

Additional keyword arguments to be passed to the k3d.plot function.

{}

Returns:

Name Type Description
plot plot

If show=False, returns the K3D plot object.

Raises:

Type Description
ValueError

If aspectmode is not 'data' or 'cube'.

Tip

The function can be used for object label visualization using a color_map created with qim3d.viz.colormaps.objects along with setting objects=True. The latter ensures appropriate rendering.

Example

Display a volume inline:

import qim3d

vol = qim3d.examples.bone_128x128x128
qim3d.viz.volumetric(vol)

Save a plot to an HTML file:

import qim3d
vol = qim3d.examples.bone_128x128x128
plot = qim3d.viz.volumetric(vol, show=False, save="plot.html")
Source code in qim3d/viz/_k3d.py
def volumetric(
    img: np.ndarray,
    aspectmode: str = 'data',
    show: bool = True,
    save: bool = False,
    grid_visible: bool = False,
    color_map: str = 'magma',
    constant_opacity: bool = False,
    vmin: float | None = None,
    vmax: float | None = None,
    samples: int | str = 'auto',
    max_voxels: int = 512**3,
    data_type: str = 'scaled_float16',
    **kwargs,
):
    """
    Visualizes a 3D volume using volumetric rendering.

    Args:
        img (numpy.ndarray): The input 3D image data. It should be a 3D numpy array.
        aspectmode (str, optional): Determines the proportions of the scene's axes. Defaults to "data".
            If `'data'`, the axes are drawn in proportion with the axes' ranges.
            If `'cube'`, the axes are drawn as a cube, regardless of the axes' ranges.
        show (bool, optional): If True, displays the visualization inline. Defaults to True.
        save (bool or str, optional): If True, saves the visualization as an HTML file.
            If a string is provided, it's interpreted as the file path where the HTML
            file will be saved. Defaults to False.
        grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False.
        color_map (str or matplotlib.colors.Colormap or list, optional): The color map to be used for the volume rendering. If a string is passed, it should be a matplotlib colormap name. Defaults to 'magma'.
        constant_opacity (bool): Set to True if doing an object label visualization with a corresponding color_map; otherwise, the plot may appear poorly. Defaults to False.
        vmin (float or None, optional): Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        vmax (float or None, optional): Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None
        samples (int or 'auto', optional): The number of samples to be used for the volume rendering in k3d. Input 'auto' for auto selection. Defaults to 'auto'.
            Lower values will render faster but with lower quality.
        max_voxels (int, optional): Defaults to 512^3.
        data_type (str, optional): Default to 'scaled_float16'.
        **kwargs (Any): Additional keyword arguments to be passed to the `k3d.plot` function.

    Returns:
        plot (k3d.plot): If `show=False`, returns the K3D plot object.

    Raises:
        ValueError: If `aspectmode` is not `'data'` or `'cube'`.

    Tip:
        The function can be used for object label visualization using a `color_map` created with `qim3d.viz.colormaps.objects` along with setting `objects=True`. The latter ensures appropriate rendering.

    Example:
        Display a volume inline:

        ```python
        import qim3d

        vol = qim3d.examples.bone_128x128x128
        qim3d.viz.volumetric(vol)
        ```
        <iframe src="https://platform.qim.dk/k3d/fima-bone_128x128x128-20240221113459.html" width="100%" height="500" frameborder="0"></iframe>

        Save a plot to an HTML file:

        ```python
        import qim3d
        vol = qim3d.examples.bone_128x128x128
        plot = qim3d.viz.volumetric(vol, show=False, save="plot.html")
        ```

    """

    pixel_count = img.shape[0] * img.shape[1] * img.shape[2]
    # target is 60fps on m1 macbook pro, using test volume: https://data.qim.dk/pages/foam.html
    if samples == 'auto':
        y1, x1 = 256, 16777216  # 256 samples at res 256*256*256=16.777.216
        y2, x2 = 32, 134217728  # 32 samples at res 512*512*512=134.217.728

        # we fit linear function to the two points
        a = (y1 - y2) / (x1 - x2)
        b = y1 - a * x1

        samples = int(min(max(a * pixel_count + b, 64), 512))
    else:
        samples = int(samples)  # make sure it's an integer

    if aspectmode.lower() not in ['data', 'cube']:
        raise ValueError("aspectmode should be either 'data' or 'cube'")
    # check if image should be downsampled for visualization
    original_shape = img.shape
    img = downscale_img(img, max_voxels=max_voxels)

    new_shape = img.shape

    if original_shape != new_shape:
        log.warning(
            f'Downsampled image for visualization, from {original_shape} to {new_shape}'
        )

    # Scale the image to float16 if needed
    if save:
        # When saving, we need float64
        img = img.astype(np.float64)
    else:
        if data_type == 'scaled_float16':
            img = scale_to_float16(img)
        else:
            img = img.astype(data_type)

    # Set color ranges
    color_range = [np.min(img), np.max(img)]
    if vmin:
        color_range[0] = vmin
    if vmax:
        color_range[1] = vmax

    # Handle the different formats that color_map can take
    if color_map:
        if isinstance(color_map, str):
            color_map = plt.get_cmap(color_map)  # Convert to Colormap object
        if isinstance(color_map, Colormap):
            # Convert to the format of color_map required by k3d.volume
            attr_vals = np.linspace(0.0, 1.0, num=color_map.N)
            RGB_vals = color_map(np.arange(0, color_map.N))[:, :3]
            color_map = np.column_stack((attr_vals, RGB_vals)).tolist()

    # Default k3d.volume settings
    opacity_function = []
    interpolation = True
    if constant_opacity:
        # without these settings, the plot will look bad when color_map is created with qim3d.viz.colormaps.objects
        opacity_function = [0.0, float(constant_opacity), 1.0, float(constant_opacity)]
        interpolation = False

    # Create the volume plot
    plt_volume = k3d.volume(
        img,
        bounds=(
            [0, img.shape[2], 0, img.shape[1], 0, img.shape[0]]
            if aspectmode.lower() == 'data'
            else None
        ),
        color_map=color_map,
        samples=samples,
        color_range=color_range,
        opacity_function=opacity_function,
        interpolation=interpolation,
    )
    plot = k3d.plot(grid_visible=grid_visible, **kwargs)
    plot += plt_volume
    if save:
        # Save html to disk
        with open(str(save), 'w', encoding='utf-8') as fp:
            fp.write(plot.get_snapshot())

    if show:
        plot.display()
    else:
        return plot

qim3d.viz.mesh

mesh(mesh, backend='pygel3d', wireframe=True, flat_shading=True, grid_visible=False, show=True, save=False, **kwargs)

Visualize a 3D mesh using pygel3d or k3d.

Parameters:

Name Type Description Default
mesh Manifold

The input mesh object.

required
backend str

The visualization backend to use. Choose between pygel3d (default) and k3d.

'pygel3d'
wireframe bool

If True, displays the mesh as a wireframe. Works both with pygel3d and k3d. Defaults to True.

True
flat_shading bool

If True, applies flat shading to the mesh. Works only with k3d. Defaults to True.

True
grid_visible bool

If True, shows a grid in the visualization. Works only with k3d. Defaults to False.

False
show bool

If True, displays the visualization inline. Works only with k3d. Defaults to True.

True
save bool or str

If True, saves the visualization as an HTML file. If a string is provided, it's interpreted as the file path where the HTML file will be saved. Works only with k3d. Defaults to False.

False
**kwargs Any

Additional keyword arguments specific to the chosen backend:

  • k3d.plot kwargs: Arguments that customize the k3d.plot visualization.

  • pygel3d.display kwargs: Arguments that customize the pygel3d.display visualization.

{}

Returns:

Type Description
Optional[Plot]

k3d.Plot or None:

  • If backend="k3d", returns a k3d.Plot object.
  • If backend="pygel3d", the function displays the mesh but does not return a plot object.

Raises:

Type Description
ValueError

If backend is not pygel3d or k3d.

Example
import qim3d
synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015)
mesh = qim3d.mesh.from_volume(synthetic_blob)
qim3d.viz.mesh(mesh, backend="pygel3d") # or qim3d.viz.mesh(mesh, backend="k3d")

pygel3d_visualization

Source code in qim3d/viz/_k3d.py
def mesh(
    mesh,
    backend: str = 'pygel3d',
    wireframe: bool = True,
    flat_shading: bool = True,
    grid_visible: bool = False,
    show: bool = True,
    save: bool = False,
    **kwargs,
) -> Optional[k3d.Plot]:
    """
    Visualize a 3D mesh using `pygel3d` or `k3d`.

    Args:
        mesh (pygel3d.hmesh.Manifold): The input mesh object.
        backend (str, optional): The visualization backend to use.
            Choose between `pygel3d` (default) and `k3d`.
        wireframe (bool, optional): If True, displays the mesh as a wireframe.
            Works both with `pygel3d` and `k3d`. Defaults to True.
        flat_shading (bool, optional): If True, applies flat shading to the mesh.
            Works only with `k3d`. Defaults to True.
        grid_visible (bool, optional): If True, shows a grid in the visualization.
            Works only with `k3d`. Defaults to False.
        show (bool, optional): If True, displays the visualization inline.
            Works only with `k3d`. Defaults to True.
        save (bool or str, optional): If True, saves the visualization as an HTML file.
            If a string is provided, it's interpreted as the file path where the HTML
            file will be saved. Works only with `k3d`. Defaults to False.
        **kwargs (Any): Additional keyword arguments specific to the chosen backend:

            - `k3d.plot` kwargs: Arguments that customize the [`k3d.plot`](https://k3d-jupyter.org/reference/factory.plot.html) visualization.

            - `pygel3d.display` kwargs: Arguments that customize the [`pygel3d.display`](https://www2.compute.dtu.dk/projects/GEL/PyGEL/pygel3d/jupyter_display.html#display) visualization.

    Returns:
        k3d.Plot or None:

            - If `backend="k3d"`, returns a `k3d.Plot` object.
            - If `backend="pygel3d"`, the function displays the mesh but does not return a plot object.

    Raises:
        ValueError: If `backend` is not `pygel3d` or `k3d`.

    Example:
        ```python
        import qim3d
        synthetic_blob = qim3d.generate.noise_object(noise_scale = 0.015)
        mesh = qim3d.mesh.from_volume(synthetic_blob)
        qim3d.viz.mesh(mesh, backend="pygel3d") # or qim3d.viz.mesh(mesh, backend="k3d")
        ```
    ![pygel3d_visualization](../../assets/screenshots/pygel3d_visualization.png)

    """

    if backend not in ['k3d', 'pygel3d']:
        raise ValueError("Invalid backend. Choose 'pygel3d' or 'k3d'.")

    # Extract vertex positions and face indices
    face_indices = list(mesh.faces())
    vertices_array = np.array(mesh.positions())

    # Extract face vertex indices
    face_vertices = [
        list(mesh.circulate_face(int(fid), mode='v'))[:3] for fid in face_indices
    ]
    face_vertices = np.array(face_vertices, dtype=np.uint32)

    # Validate the mesh structure
    if vertices_array.shape[1] != 3 or face_vertices.shape[1] != 3:
        raise ValueError('Vertices must have shape (N, 3) and faces (M, 3)')

    # Separate valid kwargs for each backend
    valid_k3d_kwargs = {k: v for k, v in kwargs.items() if k not in ['smooth', 'data']}
    valid_pygel_kwargs = {k: v for k, v in kwargs.items() if k in ['smooth', 'data']}

    if backend == 'k3d':
        vertices_array = np.ascontiguousarray(vertices_array.astype(np.float32))
        face_vertices = np.ascontiguousarray(face_vertices)

        mesh_plot = k3d.mesh(
            vertices=vertices_array,
            indices=face_vertices,
            wireframe=wireframe,
            flat_shading=flat_shading,
        )

        # Create plot
        plot = k3d.plot(grid_visible=grid_visible, **valid_k3d_kwargs)
        plot += mesh_plot

        if save:
            # Save html to disk
            with open(str(save), 'w', encoding='utf-8') as fp:
                fp.write(plot.get_snapshot())

        if show:
            plot.display()
        else:
            return plot

    elif backend == 'pygel3d':
        jd.set_export_mode(True)
        return jd.display(mesh, wireframe=wireframe, **valid_pygel_kwargs)

qim3d.viz.local_thickness

local_thickness(image, image_lt, max_projection=False, axis=0, slice_idx=None, show=False, figsize=(15, 5))

Visualizes the local thickness of a 2D or 3D image.

Parameters:

Name Type Description Default
image ndarray

2D or 3D NumPy array representing the image/volume.

required
image_lt ndarray

2D or 3D NumPy array representing the local thickness of the input image/volume.

required
max_projection bool

If True, displays the maximum projection of the local thickness. Only used for 3D images. Defaults to False.

False
axis int

The axis along which to visualize the local thickness. Unused for 2D images. Defaults to 0.

0
slice_idx int or float

The initial slice to be visualized. The slice index can afterwards be changed. If value is an integer, it will be the index of the slice to be visualized. If value is a float between 0 and 1, it will be multiplied by the number of slices and rounded to the nearest integer. If None, the middle slice will be used for 3D images. Unused for 2D images. Defaults to None.

None
show bool

If True, displays the plot (i.e. calls plt.show()). Defaults to False.

False
figsize tuple

The size of the figure. Defaults to (15, 5).

(15, 5)

Raises:

Type Description
ValueError

If the slice index is not an integer or a float between 0 and 1.

Returns:

Name Type Description
local_thickness interactive or Figure

If the input is 3D, returns an interactive widget. Otherwise, returns a matplotlib figure.

Example

import qim3d

fly = qim3d.examples.fly_150x256x256
lt_fly = qim3d.processing.local_thickness(fly)
qim3d.viz.local_thickness(fly, lt_fly, axis=0)
local thickness 3d

Source code in qim3d/viz/_local_thickness.py
def local_thickness(
    image: np.ndarray,
    image_lt: np.ndarray,
    max_projection: bool = False,
    axis: int = 0,
    slice_idx: Optional[Union[int, float]] = None,
    show: bool = False,
    figsize: Tuple[int, int] = (15, 5),
) -> Union[plt.Figure, widgets.interactive]:
    """
    Visualizes the local thickness of a 2D or 3D image.

    Args:
        image (np.ndarray): 2D or 3D NumPy array representing the image/volume.
        image_lt (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input
            image/volume.
        max_projection (bool, optional): If True, displays the maximum projection of the local
            thickness. Only used for 3D images. Defaults to False.
        axis (int, optional): The axis along which to visualize the local thickness.
            Unused for 2D images.
            Defaults to 0.
        slice_idx (int or float, optional): The initial slice to be visualized. The slice index
            can afterwards be changed. If value is an integer, it will be the index of the slice
            to be visualized. If value is a float between 0 and 1, it will be multiplied by the
            number of slices and rounded to the nearest integer. If None, the middle slice will
            be used for 3D images. Unused for 2D images. Defaults to None.
        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
        figsize (tuple, optional): The size of the figure. Defaults to (15, 5).

    Raises:
        ValueError: If the slice index is not an integer or a float between 0 and 1.

    Returns:
        local_thickness (widgets.interactive or plt.Figure): If the input is 3D, returns an interactive widget. Otherwise, returns a matplotlib figure.

    Example:
        ```python
        import qim3d

        fly = qim3d.examples.fly_150x256x256
        lt_fly = qim3d.processing.local_thickness(fly)
        qim3d.viz.local_thickness(fly, lt_fly, axis=0)
        ```
        ![local thickness 3d](../../assets/screenshots/local_thickness_3d.gif)


    """

    def _local_thickness(image, image_lt, show, figsize, axis=None, slice_idx=None):
        if slice_idx is not None:
            image = image.take(slice_idx, axis=axis)
            image_lt = image_lt.take(slice_idx, axis=axis)

        fig, axs = plt.subplots(1, 3, figsize=figsize, layout='constrained')

        axs[0].imshow(image, cmap='gray')
        axs[0].set_title('Original image')
        axs[0].axis('off')

        axs[1].imshow(image_lt, cmap='viridis')
        axs[1].set_title('Local thickness')
        axs[1].axis('off')

        plt.colorbar(
            axs[1].imshow(image_lt, cmap='viridis'), ax=axs[1], orientation='vertical'
        )

        axs[2].hist(image_lt[image_lt > 0].ravel(), bins=32, edgecolor='black')
        axs[2].set_title('Local thickness histogram')
        axs[2].set_xlabel('Local thickness')
        axs[2].set_ylabel('Count')

        if show:
            plt.show()

        plt.close()

        return fig

    # Get the middle slice if the input is 3D
    if len(image.shape) == 3:
        if max_projection:
            if slice_idx is not None:
                log.warning(
                    'slice_idx is not used for max_projection. It will be ignored.'
                )
            image = image.max(axis=axis)
            image_lt = image_lt.max(axis=axis)
            return _local_thickness(image, image_lt, show, figsize)
        else:
            if slice_idx is None:
                slice_idx = image.shape[axis] // 2
            elif isinstance(slice_idx, float):
                if slice_idx < 0 or slice_idx > 1:
                    raise ValueError(
                        'Values of slice_idx of float type must be between 0 and 1.'
                    )
                slice_idx = int(slice_idx * image.shape[0]) - 1
            slide_idx_slider = widgets.IntSlider(
                min=0,
                max=image.shape[axis] - 1,
                step=1,
                value=slice_idx,
                description='Slice index',
                layout=widgets.Layout(width='450px'),
            )
            widget_obj = widgets.interactive(
                _local_thickness,
                image=widgets.fixed(image),
                image_lt=widgets.fixed(image_lt),
                show=widgets.fixed(True),
                figsize=widgets.fixed(figsize),
                axis=widgets.fixed(axis),
                slice_idx=slide_idx_slider,
            )
            widget_obj.layout = widgets.Layout(align_items='center')
            if show:
                display(widget_obj)
            return widget_obj
    else:
        if max_projection:
            log.warning(
                'max_projection is only used for 3D images. It will be ignored.'
            )
        if slice_idx is not None:
            log.warning('slice_idx is only used for 3D images. It will be ignored.')
        return _local_thickness(image, image_lt, show, figsize)

qim3d.viz.vectors

vectors(volume, vec, axis=0, volume_cmap='grey', vmin=None, vmax=None, slice_idx=None, grid_size=10, interactive=True, figsize=(10, 5), show=False)

Visualizes the orientation of the structures in a 3D volume using the eigenvectors of the structure tensor.

Parameters:

Name Type Description Default
volume ndarray

The 3D volume to be sliced.

required
vec ndarray

The eigenvectors of the structure tensor.

required
axis int

The axis along which to visualize the orientation. Defaults to 0.

0
volume_cmap str

Defines colormap for display of the volume

'grey'
vmin float

Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
vmax float

Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
slice_idx int or float or None

The initial slice to be visualized. The slice index can afterwards be changed. If value is an integer, it will be the index of the slice to be visualized. If value is a float between 0 and 1, it will be multiplied by the number of slices and rounded to the nearest integer. If None, the middle slice will be used. Defaults to None.

None
grid_size int

The size of the grid. Defaults to 10.

10
interactive bool

If True, returns an interactive widget. Defaults to True.

True
figsize tuple

The size of the figure. Defaults to (15, 5).

(10, 5)
show bool

If True, displays the plot (i.e. calls plt.show()). Defaults to False.

False

Raises:

Type Description
ValueError

If the axis to slice along is not 0, 1, or 2.

ValueError

If the slice index is not an integer or a float between 0 and 1.

Returns:

Name Type Description
fig interactive or Figure

If interactive is True, returns an interactive widget. Otherwise, returns a matplotlib figure.

Note

The orientation of the vectors is visualized using an HSV color map, where the saturation corresponds to the vector component of the slicing direction (i.e. z-component when choosing visualization along axis = 0). Hence, if an orientation in the volume is orthogonal to the slicing direction, the corresponding color of the visualization will be gray.

Example

import qim3d

vol = qim3d.examples.NT_128x128x128
val, vec = qim3d.processing.structure_tensor(vol)

# Visualize the structure tensor
qim3d.viz.vectors(vol, vec, axis = 2, interactive = True)
structure tensor

Source code in qim3d/viz/_structure_tensor.py
def vectors(
    volume: np.ndarray,
    vec: np.ndarray,
    axis: int = 0,
    volume_cmap: str = 'grey',
    vmin: float | None = None,
    vmax: float | None = None,
    slice_idx: Union[int, float] | None = None,
    grid_size: int = 10,
    interactive: bool = True,
    figsize: Tuple[int, int] = (10, 5),
    show: bool = False,
) -> Union[plt.Figure, widgets.interactive]:
    """
    Visualizes the orientation of the structures in a 3D volume using the eigenvectors of the structure tensor.

    Args:
        volume (np.ndarray): The 3D volume to be sliced.
        vec (np.ndarray): The eigenvectors of the structure tensor.
        axis (int, optional): The axis along which to visualize the orientation. Defaults to 0.
        volume_cmap (str, optional): Defines colormap for display of the volume
        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        slice_idx (int or float or None, optional): The initial slice to be visualized. The slice index
            can afterwards be changed. If value is an integer, it will be the index of the slice
            to be visualized. If value is a float between 0 and 1, it will be multiplied by the
            number of slices and rounded to the nearest integer. If None, the middle slice will
            be used. Defaults to None.
        grid_size (int, optional): The size of the grid. Defaults to 10.
        interactive (bool, optional): If True, returns an interactive widget. Defaults to True.
        figsize (tuple, optional): The size of the figure. Defaults to (15, 5).
        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.

    Raises:
        ValueError: If the axis to slice along is not 0, 1, or 2.
        ValueError: If the slice index is not an integer or a float between 0 and 1.

    Returns:
        fig (widgets.interactive or plt.Figure): If `interactive` is True, returns an interactive widget. Otherwise, returns a matplotlib figure.

    Note:
        The orientation of the vectors is visualized using an HSV color map, where the saturation corresponds to the vector component
        of the slicing direction (i.e. z-component when choosing visualization along `axis = 0`). Hence, if an orientation in the volume
        is orthogonal to the slicing direction, the corresponding color of the visualization will be gray.

    Example:
        ```python
        import qim3d

        vol = qim3d.examples.NT_128x128x128
        val, vec = qim3d.processing.structure_tensor(vol)

        # Visualize the structure tensor
        qim3d.viz.vectors(vol, vec, axis = 2, interactive = True)
        ```
        ![structure tensor](../../assets/screenshots/structure_tensor_visualization.gif)

    """

    # Ensure volume is a float
    if volume.dtype != np.float32 and volume.dtype != np.float64:
        volume = volume.astype(np.float32)

    # Normalize the volume if needed (i.e. if values are in [0, 255])
    if volume.max() > 1.0:
        volume = volume / 255.0

    # Define grid size limits
    min_grid_size = max(1, volume.shape[axis] // 50)
    max_grid_size = max(1, volume.shape[axis] // 10)
    if max_grid_size <= min_grid_size:
        max_grid_size = min_grid_size * 5

    if not grid_size:
        grid_size = (min_grid_size + max_grid_size) // 2

    # Testing
    if grid_size < min_grid_size or grid_size > max_grid_size:
        # Adjust grid size as little as possible to be within the limits
        grid_size = min(max(min_grid_size, grid_size), max_grid_size)
        log.warning(f'Adjusting grid size to {grid_size} as it is out of bounds.')

    def _structure_tensor(volume, vec, axis, slice_idx, grid_size, figsize, show):
        # Choose the appropriate slice based on the specified dimension
        if axis == 0:
            data_slice = volume[slice_idx, :, :]
            vectors_slice_x = vec[0, slice_idx, :, :]
            vectors_slice_y = vec[1, slice_idx, :, :]
            vectors_slice_z = vec[2, slice_idx, :, :]

        elif axis == 1:
            data_slice = volume[:, slice_idx, :]
            vectors_slice_x = vec[0, :, slice_idx, :]
            vectors_slice_y = vec[2, :, slice_idx, :]
            vectors_slice_z = vec[1, :, slice_idx, :]

        elif axis == 2:
            data_slice = volume[:, :, slice_idx]
            vectors_slice_x = vec[1, :, :, slice_idx]
            vectors_slice_y = vec[2, :, :, slice_idx]
            vectors_slice_z = vec[0, :, :, slice_idx]

        else:
            raise ValueError('Invalid dimension. Use 0 for Z, 1 for Y, or 2 for X.')

        # Create three subplots
        fig, ax = plt.subplots(1, 3, figsize=figsize, layout='constrained')

        blend_hue_saturation = (
            lambda hue, sat: hue * (1 - sat) + 0.5 * sat
        )  # Function for blending hue and saturation
        blend_slice_colors = lambda slice, colors: 0.5 * (
            slice + colors
        )  # Function for blending image slice with orientation colors

        # ----- Subplot 1: Image slice with orientation vectors ----- #
        # Create meshgrid with the correct dimensions
        xmesh, ymesh = np.mgrid[0 : data_slice.shape[0], 0 : data_slice.shape[1]]

        # Create a slice object for selecting the grid points
        g = slice(grid_size // 2, None, grid_size)

        # Angles from 0 to pi
        angles_quiver = np.mod(
            np.arctan2(vectors_slice_y[g, g], vectors_slice_x[g, g]), np.pi
        )

        # Calculate z-component (saturation)
        saturation_quiver = (vectors_slice_z[g, g] ** 2)[:, :, np.newaxis]

        # Calculate hue
        hue_quiver = plt.cm.hsv(angles_quiver / np.pi)

        # Blend hue and saturation
        rgba_quiver = blend_hue_saturation(hue_quiver, saturation_quiver)
        rgba_quiver = np.clip(
            rgba_quiver, 0, 1
        )  # Ensure rgba values are values within [0, 1]
        rgba_quiver_flat = rgba_quiver.reshape(
            (rgba_quiver.shape[0] * rgba_quiver.shape[1], 4)
        )  # Flatten array for quiver plot

        # Plot vectors
        ax[0].quiver(
            ymesh[g, g],
            xmesh[g, g],
            vectors_slice_x[g, g],
            vectors_slice_y[g, g],
            color=rgba_quiver_flat,
            angles='xy',
        )
        ax[0].quiver(
            ymesh[g, g],
            xmesh[g, g],
            -vectors_slice_x[g, g],
            -vectors_slice_y[g, g],
            color=rgba_quiver_flat,
            angles='xy',
        )

        ax[0].imshow(data_slice, cmap=volume_cmap, vmin=vmin, vmax=vmax)
        ax[0].set_title(
            f'Orientation vectors (slice {slice_idx})'
            if not interactive
            else 'Orientation vectors'
        )
        ax[0].set_axis_off()

        # ----- Subplot 2: Orientation histogram ----- #
        nbins = 36

        # Angles from 0 to pi
        angles = np.mod(np.arctan2(vectors_slice_y, vectors_slice_x), np.pi)

        # Orientation histogram over angles
        distribution, bin_edges = np.histogram(angles, bins=nbins, range=(0.0, np.pi))

        # Half circle (180 deg)
        bin_centers = (np.arange(nbins) + 0.5) * np.pi / nbins

        # Calculate z-component (saturation) for each bin
        bins = np.digitize(angles.ravel(), bin_edges)
        saturation_bin = np.array(
            [
                (
                    np.mean((vectors_slice_z**2).ravel()[bins == i])
                    if np.sum(bins == i) > 0
                    else 0
                )
                for i in range(1, len(bin_edges))
            ]
        )

        # Calculate hue for each bin
        hue_bin = plt.cm.hsv(bin_centers / np.pi)

        # Blend hue and saturation
        rgba_bin = hue_bin.copy()
        rgba_bin[:, :3] = blend_hue_saturation(
            hue_bin[:, :3], saturation_bin[:, np.newaxis]
        )

        ax[1].bar(bin_centers, distribution, width=np.pi / nbins, color=rgba_bin)
        ax[1].set_xlabel('Angle [radians]')
        ax[1].set_xlim([0, np.pi])
        ax[1].set_aspect(np.pi / ax[1].get_ylim()[1])
        ax[1].set_xticks([0, np.pi / 2, np.pi])
        ax[1].set_xticklabels(['0', '$\\frac{\\pi}{2}$', '$\\pi$'])
        ax[1].set_yticks([])
        ax[1].set_ylabel('Frequency')
        ax[1].set_title('Histogram over orientation angles')

        # ----- Subplot 3: Image slice colored according to orientation ----- #
        # Calculate z-component (saturation)
        saturation = (vectors_slice_z**2)[:, :, np.newaxis]

        # Calculate hue
        hue = plt.cm.hsv(angles / np.pi)

        # Blend hue and saturation
        rgba = blend_hue_saturation(hue, saturation)

        # Grayscale image slice blended with orientation colors
        data_slice_orientation_colored = (
            blend_slice_colors(plt.cm.gray(data_slice), rgba) * 255
        ).astype('uint8')

        ax[2].imshow(data_slice_orientation_colored)
        ax[2].set_title(
            f'Colored orientations (slice {slice_idx})'
            if not interactive
            else 'Colored orientations'
        )
        ax[2].set_axis_off()

        if show:
            plt.show()

        plt.close()

        return fig

    if vec.ndim == 5:
        vec = vec[0, ...]
        log.warning(
            'Eigenvector array is full. Only the eigenvectors corresponding to the first eigenvalue will be used.'
        )

    if slice_idx is None:
        slice_idx = volume.shape[axis] // 2

    elif isinstance(slice_idx, float):
        if slice_idx < 0 or slice_idx > 1:
            raise ValueError(
                'Values of slice_idx of float type must be between 0 and 1.'
            )
        slice_idx = int(slice_idx * volume.shape[0]) - 1

    if interactive:
        slide_idx_slider = widgets.IntSlider(
            min=0,
            max=volume.shape[axis] - 1,
            step=1,
            value=slice_idx,
            description='Slice index',
            layout=widgets.Layout(width='450px'),
        )

        grid_size_slider = widgets.IntSlider(
            min=min_grid_size,
            max=max_grid_size,
            step=1,
            value=grid_size,
            description='Grid size',
            layout=widgets.Layout(width='450px'),
        )

        widget_obj = widgets.interactive(
            _structure_tensor,
            volume=widgets.fixed(volume),
            vec=widgets.fixed(vec),
            axis=widgets.fixed(axis),
            slice_idx=slide_idx_slider,
            grid_size=grid_size_slider,
            figsize=widgets.fixed(figsize),
            show=widgets.fixed(True),
        )
        # Arrange sliders horizontally
        sliders_box = widgets.HBox([slide_idx_slider, grid_size_slider])
        widget_obj = widgets.VBox([sliders_box, widget_obj.children[-1]])
        widget_obj.layout.align_items = 'center'

        if show:
            display(widget_obj)

        return widget_obj

    else:
        return _structure_tensor(volume, vec, axis, slice_idx, grid_size, figsize, show)

qim3d.viz.plot_cc

plot_cc(connected_components, component_indexs=None, max_cc_to_plot=32, overlay=None, crop=False, display_figure=True, color_map='viridis', value_min=None, value_max=None, **kwargs)

Plots the connected components from a qim3d.processing.cc.CC object. If an overlay image is provided, the connected component will be masked to the overlay image.

Parameters connected_components (CC): The connected components object. component_indexs (list or tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None. max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32. overlay (np.ndarray or None, optional): Overlay image. Defaults to None. crop (bool, optional): Whether to crop the image to the cc. Defaults to False. display_figure (bool, optional): Whether to show the figure. Defaults to True. color_map (str, optional): Specifies the color map for the image. Defaults to "viridis". value_min (float or None, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None. value_max (float or None, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None **kwargs (Any): Additional keyword arguments to pass to qim3d.viz.slices_grid.

Returns figs (list[plt.Figure]): List of figures, if display_figure=False.

Example

import qim3d
vol = qim3d.examples.cement_128x128x128[50:150]
vol_bin = vol<80
cc = qim3d.segmentation.get_3d_cc(vol_bin)
qim3d.viz.plot_cc(cc, crop=True, display_figure=True, overlay=None, num_slices=5, component_indexs=[4,6,7])
qim3d.viz.plot_cc(cc, crop=True, display_figure=True, overlay=vol, num_slices=5, component_indexs=[4,6,7])
plot_cc_no_overlay plot_cc_overlay

Source code in qim3d/viz/_cc.py
def plot_cc(
    connected_components: CC,
    component_indexs: list | tuple = None,
    max_cc_to_plot: int = 32,
    overlay: np.ndarray = None,
    crop: bool = False,
    display_figure: bool = True,
    color_map: str = 'viridis',
    value_min: float = None,
    value_max: float = None,
    **kwargs,
) -> list[plt.Figure]:
    """
    Plots the connected components from a `qim3d.processing.cc.CC` object. If an overlay image is provided, the connected component will be masked to the overlay image.

    Parameters
        connected_components (CC): The connected components object.
        component_indexs (list or tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None.
        max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32.
        overlay (np.ndarray or None, optional): Overlay image. Defaults to None.
        crop (bool, optional): Whether to crop the image to the cc. Defaults to False.
        display_figure (bool, optional): Whether to show the figure. Defaults to True.
        color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
        value_min (float or None, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        value_max (float or None, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        **kwargs (Any): Additional keyword arguments to pass to `qim3d.viz.slices_grid`.

    Returns
        figs (list[plt.Figure]): List of figures, if `display_figure=False`.

    Example:
        ```python
        import qim3d
        vol = qim3d.examples.cement_128x128x128[50:150]
        vol_bin = vol<80
        cc = qim3d.segmentation.get_3d_cc(vol_bin)
        qim3d.viz.plot_cc(cc, crop=True, display_figure=True, overlay=None, num_slices=5, component_indexs=[4,6,7])
        qim3d.viz.plot_cc(cc, crop=True, display_figure=True, overlay=vol, num_slices=5, component_indexs=[4,6,7])
        ```
        ![plot_cc_no_overlay](../../assets/screenshots/plot_cc_no_overlay.png)
        ![plot_cc_overlay](../../assets/screenshots/plot_cc_overlay.png)

    """
    # if no components are given, plot the first max_cc_to_plot=32 components
    if component_indexs is None:
        if len(connected_components) > max_cc_to_plot:
            log.warning(
                f'More than {max_cc_to_plot} connected components found. Only the first {max_cc_to_plot} will be plotted. Change max_cc_to_plot to plot more components.'
            )
        component_indexs = range(
            1, min(max_cc_to_plot + 1, len(connected_components) + 1)
        )

    figs = []
    for component in component_indexs:
        if overlay is not None:
            assert (
                overlay.shape == connected_components.shape
            ), f'Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}.'

            # plots overlay masked to connected component
            if crop:
                # Crop the overlay image based on the bounding box of the component
                bb = connected_components.get_bounding_box(component)[0]
                cc = connected_components.get_cc(component, crop=True)
                overlay_crop = overlay[bb]
                # use cc as mask for overlay_crop, where all values in cc set to 0 should be masked out, cc contains integers
                overlay_crop = np.where(cc == 0, 0, overlay_crop)
            else:
                cc = connected_components.get_cc(component, crop=False)
                overlay_crop = np.where(cc == 0, 0, overlay)
            fig = qim3d.viz.slices_grid(
                overlay_crop,
                display_figure=display_figure,
                color_map=color_map,
                value_min=value_min,
                value_max=value_max,
                **kwargs,
            )
        else:
            # assigns discrete color map to each connected component if not given
            if 'color_map' not in kwargs:
                kwargs['color_map'] = qim3d.viz.colormaps.segmentation(
                    len(component_indexs)
                )

            # Plot the connected component without overlay
            fig = qim3d.viz.slices_grid(
                connected_components.get_cc(component, crop=crop),
                display_figure=display_figure,
                **kwargs,
            )

        figs.append(fig)

    if not display_figure:
        return figs

    return

qim3d.viz.fade_mask

fade_mask(volume, axis=0, color_map='magma', value_min=None, value_max=None)

Interactive widget for visualizing the effect of edge fading on a 3D volume.

This can be used to select the best parameters before applying the mask.

Parameters:

Name Type Description Default
volume ndarray

The volume to apply edge fading to.

required
axis int

The axis along which to apply the fading. Defaults to 0.

0
color_map str

Specifies the color map for the image. Defaults to "viridis".

'magma'
value_min float or None

Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
value_max float or None

Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None

Returns:

Name Type Description
slicer_obj HBox

The interactive widget for visualizing fade mask on slices of a 3D volume.

Example

import qim3d
vol = qim3d.examples.cement_128x128x128
qim3d.viz.fade_mask(vol)
operations-edge_fade_before

Source code in qim3d/viz/_data_exploration.py
def fade_mask(
    volume: np.ndarray,
    axis: int = 0,
    color_map: str = 'magma',
    value_min: float = None,
    value_max: float = None,
) -> widgets.interactive:
    """
    Interactive widget for visualizing the effect of edge fading on a 3D volume.

    This can be used to select the best parameters before applying the mask.

    Args:
        volume (np.ndarray): The volume to apply edge fading to.
        axis (int, optional): The axis along which to apply the fading. Defaults to 0.
        color_map (str, optional): Specifies the color map for the image. Defaults to "viridis".
        value_min (float or None, optional): Together with value_max define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        value_max (float or None, optional): Together with value_min define the data range the colormap covers. By default colormap covers the full range. Defaults to None

    Returns:
        slicer_obj (widgets.HBox): The interactive widget for visualizing fade mask on slices of a 3D volume.

    Example:
        ```python
        import qim3d
        vol = qim3d.examples.cement_128x128x128
        qim3d.viz.fade_mask(vol)
        ```
        ![operations-edge_fade_before](../../assets/screenshots/viz-fade_mask.gif)

    """

    # Create the interactive widget
    def _slicer(position, decay_rate, ratio, geometry, invert):
        fig, axes = plt.subplots(1, 3, figsize=(9, 3))

        slice_img = volume[position, :, :]
        # If value_min is higher than the highest value in the image ValueError is raised
        # We don't want to override the values because next slices might be okay
        new_value_min = (
            None
            if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img))
            else value_min
        )
        new_value_max = (
            None
            if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img))
            else value_max
        )

        axes[0].imshow(
            slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max
        )
        axes[0].set_title('Original')
        axes[0].axis('off')

        mask = qim3d.operations.fade_mask(
            np.ones_like(volume),
            decay_rate=decay_rate,
            ratio=ratio,
            geometry=geometry,
            axis=axis,
            invert=invert,
        )
        axes[1].imshow(mask[position, :, :], cmap=color_map)
        axes[1].set_title('Mask')
        axes[1].axis('off')

        masked_volume = qim3d.operations.fade_mask(
            volume,
            decay_rate=decay_rate,
            ratio=ratio,
            geometry=geometry,
            axis=axis,
            invert=invert,
        )
        # If value_min is higher than the highest value in the image ValueError is raised
        # We don't want to override the values because next slices might be okay
        slice_img = masked_volume[position, :, :]
        new_value_min = (
            None
            if (isinstance(value_min, (float, int)) and value_min > np.max(slice_img))
            else value_min
        )
        new_value_max = (
            None
            if (isinstance(value_max, (float, int)) and value_max < np.min(slice_img))
            else value_max
        )
        axes[2].imshow(
            slice_img, cmap=color_map, vmin=new_value_min, vmax=new_value_max
        )
        axes[2].set_title('Masked')
        axes[2].axis('off')

        return fig

    shape_dropdown = widgets.Dropdown(
        options=['spherical', 'cylindrical'],
        value='spherical',  # default value
        description='Geometry',
    )

    position_slider = widgets.IntSlider(
        value=volume.shape[0] // 2,
        min=0,
        max=volume.shape[0] - 1,
        description='Slice',
        continuous_update=False,
    )
    decay_rate_slider = widgets.FloatSlider(
        value=10,
        min=1,
        max=50,
        step=1.0,
        description='Decay Rate',
        continuous_update=False,
    )
    ratio_slider = widgets.FloatSlider(
        value=0.5,
        min=0.1,
        max=1,
        step=0.01,
        description='Ratio',
        continuous_update=False,
    )

    # Create the Checkbox widget
    invert_checkbox = widgets.Checkbox(
        value=False,
        description='Invert',  # default value
    )

    slicer_obj = widgets.interactive(
        _slicer,
        position=position_slider,
        decay_rate=decay_rate_slider,
        ratio=ratio_slider,
        geometry=shape_dropdown,
        invert=invert_checkbox,
    )
    slicer_obj.layout = widgets.Layout(align_items='flex-start')

    return slicer_obj

qim3d.viz.line_profile

line_profile(volume, slice_axis=0, slice_index='middle', vertical_position='middle', horizontal_position='middle', angle=0, fraction_range=(0.0, 1.0))

Returns an interactive widget for visualizing the intensity profiles of lines on slices.

Parameters:

Name Type Description Default
volume ndarray

The 3D volume of interest.

required
slice_axis int

Specifies the initial axis along which to slice.

0
slice_index int or str

Specifies the initial slice index along slice_axis.

'middle'
vertical_position int or str

Specifies the initial vertical position of the line's pivot point.

'middle'
horizontal_position int or str

Specifies the initial horizontal position of the line's pivot point.

'middle'
angle int or float

Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo.

0
fraction_range tuple or list

Specifies the fraction of the line segment to use from border to border. Both the start and the end should be in the range [0.0, 1.0].

(0.0, 1.0)

Returns:

Name Type Description
widget VBox

The interactive widget.

Example

import qim3d

vol = qim3d.examples.bone_128x128x128
qim3d.viz.line_profile(vol)
viz histogram

Source code in qim3d/viz/_data_exploration.py
def line_profile(
    volume: np.ndarray,
    slice_axis: int = 0,
    slice_index: int | str = 'middle',
    vertical_position: int | str = 'middle',
    horizontal_position: int | str = 'middle',
    angle: int = 0,
    fraction_range: Tuple[float, float] = (0.00, 1.00),
) -> widgets.interactive:
    """
    Returns an interactive widget for visualizing the intensity profiles of lines on slices.

    Args:
        volume (np.ndarray): The 3D volume of interest.
        slice_axis (int, optional): Specifies the initial axis along which to slice.
        slice_index (int or str, optional): Specifies the initial slice index along slice_axis.
        vertical_position (int or str, optional): Specifies the initial vertical position of the line's pivot point.
        horizontal_position (int or str, optional): Specifies the initial horizontal position of the line's pivot point.
        angle (int or float, optional): Specifies the initial angle (°) of the line around the pivot point. A float will be converted to an int. A value outside the range will be wrapped modulo.
        fraction_range (tuple or list, optional): Specifies the fraction of the line segment to use from border to border. Both the start and the end should be in the range [0.0, 1.0].

    Returns:
        widget (widgets.widget_box.VBox): The interactive widget.


    Example:
        ```python
        import qim3d

        vol = qim3d.examples.bone_128x128x128
        qim3d.viz.line_profile(vol)
        ```
        ![viz histogram](../../assets/screenshots/viz-line_profile.gif)

    """

    def parse_position(pos, pos_range, name):
        if isinstance(pos, int):
            if not pos_range[0] <= pos < pos_range[1]:
                raise ValueError(
                    f'Value for {name} must be inside [{pos_range[0]}, {pos_range[1]}]'
                )
            return pos
        elif isinstance(pos, str):
            pos = pos.lower()
            if pos == 'start':
                return pos_range[0]
            elif pos == 'middle':
                return pos_range[0] + (pos_range[1] - pos_range[0]) // 2
            elif pos == 'end':
                return pos_range[1]
            else:
                raise ValueError(
                    f"Invalid string '{pos}' for {name}. "
                    "Must be 'start', 'middle', or 'end'."
                )
        else:
            raise TypeError('Axis position must be of type int or str.')

    if not isinstance(volume, (np.ndarray, da.core.Array)):
        raise ValueError('Data type for volume not supported.')
    if volume.ndim != 3:
        raise ValueError('Volume must be 3D.')

    dims = volume.shape
    slice_index = parse_position(slice_index, (0, dims[slice_axis] - 1), 'slice_index')
    # the omission of the ends for the pivot point is due to border issues.
    vertical_position = parse_position(
        vertical_position, (1, np.delete(dims, slice_axis)[0] - 2), 'vertical_position'
    )
    horizontal_position = parse_position(
        horizontal_position,
        (1, np.delete(dims, slice_axis)[1] - 2),
        'horizontal_position',
    )

    if not isinstance(angle, int | float):
        raise ValueError('Invalid type for angle.')
    angle = round(angle) % 360

    if not (
        0.0 <= fraction_range[0] <= 1.0
        and 0.0 <= fraction_range[1] <= 1.0
        and fraction_range[0] <= fraction_range[1]
    ):
        raise ValueError('Invalid values for fraction_range.')

    lp = _LineProfile(
        volume,
        slice_axis,
        slice_index,
        vertical_position,
        horizontal_position,
        angle,
        fraction_range,
    )
    return lp.build_interactive()

qim3d.viz.threshold

threshold(volume, cmap_image='magma', vmin=None, vmax=None)

This function provides an interactive interface to explore thresholding on a 3D volume slice-by-slice. Users can either manually set the threshold value using a slider or select an automatic thresholding method from skimage.

The visualization includes the original image slice, a binary mask showing regions above the threshold and an overlay combining the binary mask and the original image.

Parameters:

Name Type Description Default
volume ndarray

3D volume to threshold.

required
cmap_image str

Colormap for the original image. Defaults to 'viridis'.

'magma'
cmap_threshold str

Colormap for the binary image. Defaults to 'gray'.

required
vmin float

Minimum value for the colormap. Defaults to None.

None
vmax float

Maximum value for the colormap. Defaults to None.

None

Returns:

Name Type Description
slicer_obj VBox

The interactive widget for thresholding a 3D volume.

Interactivity
  • Manual Thresholding: Select 'Manual' from the dropdown menu to manually adjust the threshold using the slider.
  • Automatic Thresholding: Choose a method from the dropdown menu to apply an automatic thresholding algorithm. Available methods include:

    • Otsu
    • Isodata
    • Li
    • Mean
    • Minimum
    • Triangle
    • Yen

    The threshold slider will display the computed value and will be disabled in this mode.

import qim3d

# Load a sample volume
vol = qim3d.examples.bone_128x128x128

# Visualize interactive thresholding
qim3d.viz.threshold(vol)
interactive threshold

Source code in qim3d/viz/_data_exploration.py
def threshold(
    volume: np.ndarray,
    cmap_image: str = 'magma',
    vmin: float = None,
    vmax: float = None,
) -> widgets.VBox:
    """
    This function provides an interactive interface to explore thresholding on a
    3D volume slice-by-slice. Users can either manually set the threshold value
    using a slider or select an automatic thresholding method from `skimage`.

    The visualization includes the original image slice, a binary mask showing regions above the
    threshold and an overlay combining the binary mask and the original image.

    Args:
        volume (np.ndarray): 3D volume to threshold.
        cmap_image (str, optional): Colormap for the original image. Defaults to 'viridis'.
        cmap_threshold (str, optional): Colormap for the binary image. Defaults to 'gray'.
        vmin (float, optional): Minimum value for the colormap. Defaults to None.
        vmax (float, optional): Maximum value for the colormap. Defaults to None.

    Returns:
        slicer_obj (widgets.VBox): The interactive widget for thresholding a 3D volume.

    Interactivity:
        - **Manual Thresholding**:
            Select 'Manual' from the dropdown menu to manually adjust the threshold
            using the slider.
        - **Automatic Thresholding**:
            Choose a method from the dropdown menu to apply an automatic thresholding
            algorithm. Available methods include:
            - Otsu
            - Isodata
            - Li
            - Mean
            - Minimum
            - Triangle
            - Yen

            The threshold slider will display the computed value and will be disabled
            in this mode.


        ```python
        import qim3d

        # Load a sample volume
        vol = qim3d.examples.bone_128x128x128

        # Visualize interactive thresholding
        qim3d.viz.threshold(vol)
        ```
        ![interactive threshold](../../assets/screenshots/interactive_thresholding.gif)

    """

    # Centralized state dictionary to track current parameters
    state = {
        'position': volume.shape[0] // 2,
        'threshold': int((volume.min() + volume.max()) / 2),
        'method': 'Manual',
    }

    threshold_methods = {
        'Otsu': threshold_otsu,
        'Isodata': threshold_isodata,
        'Li': threshold_li,
        'Mean': threshold_mean,
        'Minimum': threshold_minimum,
        'Triangle': threshold_triangle,
        'Yen': threshold_yen,
    }

    # Create an output widget to display the plot
    output = widgets.Output()

    # Function to update the state and trigger visualization
    def update_state(change):
        # Update state based on widget values
        state['position'] = position_slider.value
        state['method'] = method_dropdown.value

        if state['method'] == 'Manual':
            state['threshold'] = threshold_slider.value
            threshold_slider.disabled = False
        else:
            threshold_func = threshold_methods.get(state['method'])
            if threshold_func:
                slice_img = volume[state['position'], :, :]
                computed_threshold = threshold_func(slice_img)
                state['threshold'] = computed_threshold

                # Programmatically update the slider without triggering callbacks
                threshold_slider.unobserve_all()
                threshold_slider.value = computed_threshold
                threshold_slider.disabled = True
                threshold_slider.observe(update_state, names='value')
            else:
                raise ValueError(f"Unsupported thresholding method: {state['method']}")

        # Trigger visualization
        update_visualization()

    # Visualization function
    def update_visualization():
        slice_img = volume[state['position'], :, :]
        with output:
            output.clear_output(wait=True)  # Clear previous plot
            fig, axes = plt.subplots(1, 4, figsize=(25, 5))

            # Original image
            new_vmin = (
                None
                if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
                else vmin
            )
            new_vmax = (
                None
                if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
                else vmax
            )
            axes[0].imshow(slice_img, cmap=cmap_image, vmin=new_vmin, vmax=new_vmax)
            axes[0].set_title('Original')
            axes[0].axis('off')

            # Histogram
            histogram(
                volume=volume,
                bins=32,
                slice_idx=state['position'],
                vertical_line=state['threshold'],
                axis=1,
                kde=False,
                ax=axes[1],
                show=False,
            )
            axes[1].set_title(f"Histogram with Threshold = {int(state['threshold'])}")

            # Binary mask
            mask = slice_img > state['threshold']
            axes[2].imshow(mask, cmap='gray')
            axes[2].set_title('Binary mask')
            axes[2].axis('off')

            # Overlay
            mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
            mask_rgb[:, :, 0] = mask
            masked_volume = qim3d.operations.overlay_rgb_images(
                background=slice_img,
                foreground=mask_rgb,
            )
            axes[3].imshow(masked_volume, vmin=new_vmin, vmax=new_vmax)
            axes[3].set_title('Overlay')
            axes[3].axis('off')

            plt.show()

    # Widgets
    position_slider = widgets.IntSlider(
        value=state['position'],
        min=0,
        max=volume.shape[0] - 1,
        description='Slice',
    )

    threshold_slider = widgets.IntSlider(
        value=state['threshold'],
        min=volume.min(),
        max=volume.max(),
        description='Threshold',
    )

    method_dropdown = widgets.Dropdown(
        options=[
            'Manual',
            'Otsu',
            'Isodata',
            'Li',
            'Mean',
            'Minimum',
            'Triangle',
            'Yen',
        ],
        value=state['method'],
        description='Method',
    )

    # Attach the state update function to widgets
    position_slider.observe(update_state, names='value')
    threshold_slider.observe(update_state, names='value')
    method_dropdown.observe(update_state, names='value')

    # Layout
    controls_left = widgets.VBox([position_slider, threshold_slider])
    controls_right = widgets.VBox([method_dropdown])
    controls_layout = widgets.HBox(
        [controls_left, controls_right],
        layout=widgets.Layout(justify_content='flex-start'),
    )
    interactive_ui = widgets.VBox([controls_layout, output])
    update_visualization()

    return interactive_ui

qim3d.viz.colormaps

qim3d.viz.colormaps.qim module-attribute

qim = from_list('qim', [(0.6, 0.0, 0.0), (1.0, 0.6, 0.0)])

Defines colormap in QIM logo colors. Can be accessed as module attribute or easily by cmap = 'qim'

Example

import qim3d

display(qim3d.viz.colormaps.qim)
colormap objects

qim3d.viz.colormaps.segmentation

segmentation(num_labels, style='bright', first_color_background=True, last_color_background=False, background_color=(0.0, 0.0, 0.0), min_dist=0.5, seed=19)

Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks

Parameters:

Name Type Description Default
num_labels int

Number of labels (size of colormap).

required
style str

'bright' for strong colors, 'soft' for pastel colors, 'earth' for yellow/green/blue colors, 'ocean' for blue/purple/pink colors. Defaults to 'bright'.

'bright'
first_color_background bool

If True, the first color is used as background. Defaults to True.

True
last_color_background bool

If True, the last color is used as background. Defaults to False.

False
background_color tuple or str

RGB tuple or string for background color. Can be "black" or "white". Defaults to (0.0, 0.0, 0.0).

(0.0, 0.0, 0.0)
min_dist int

Minimum distance between neighboring colors. Defaults to 0.5.

0.5
seed int

Seed for random number generator. Defaults to 19.

19

Returns:

Name Type Description
color_map LinearSegmentedColormap

Colormap for matplotlib

Example

import qim3d

cmap_bright = qim3d.viz.colormaps.segmentation(num_labels=100, style = 'bright', first_color_background=True, background_color="black", min_dist=0.7)
cmap_soft = qim3d.viz.colormaps.segmentation(num_labels=100, style = 'soft', first_color_background=True, background_color="black", min_dist=0.2)
cmap_earth = qim3d.viz.colormaps.segmentation(num_labels=100, style = 'earth', first_color_background=True, background_color="black", min_dist=0.8)
cmap_ocean = qim3d.viz.colormaps.segmentation(num_labels=100, style = 'ocean', first_color_background=True, background_color="black", min_dist=0.9)

display(cmap_bright)
display(cmap_soft)
display(cmap_earth)
display(cmap_ocean)
colormap objects

import qim3d

vol = qim3d.examples.cement_128x128x128
binary = qim3d.filters.gaussian(vol, sigma = 2) < 60
labeled_volume, num_labels = qim3d.segmentation.watershed(binary)

color_map = qim3d.viz.colormaps.segmentation(num_labels, style = 'bright')
qim3d.viz.slicer(labeled_volume, slice_axis = 1, color_map=color_map)
colormap objects

Tip

It can be easily used when calling visualization functions as

qim3d.viz.slices_grid(segmented_volume, color_map = 'objects')
which automatically detects number of unique classes and creates the colormap object with defualt arguments.

Tip

The min_dist parameter can be used to control the distance between neighboring colors. colormap objects mind_dist

Source code in qim3d/viz/colormaps/_segmentation.py
def segmentation(
    num_labels: int,
    style: str = 'bright',
    first_color_background: bool = True,
    last_color_background: bool = False,
    background_color: Union[Tuple[float, float, float], str] = (0.0, 0.0, 0.0),
    min_dist: int = 0.5,
    seed: int = 19,
) -> LinearSegmentedColormap:
    """
    Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks

    Args:
        num_labels (int): Number of labels (size of colormap).
        style (str, optional): 'bright' for strong colors, 'soft' for pastel colors, 'earth' for yellow/green/blue colors, 'ocean' for blue/purple/pink colors. Defaults to 'bright'.
        first_color_background (bool, optional): If True, the first color is used as background. Defaults to True.
        last_color_background (bool, optional): If True, the last color is used as background. Defaults to False.
        background_color (tuple or str, optional): RGB tuple or string for background color. Can be "black" or "white". Defaults to (0.0, 0.0, 0.0).
        min_dist (int, optional): Minimum distance between neighboring colors. Defaults to 0.5.
        seed (int, optional): Seed for random number generator. Defaults to 19.

    Returns:
        color_map (matplotlib.colors.LinearSegmentedColormap): Colormap for matplotlib


    Example:
        ```python
        import qim3d

        cmap_bright = qim3d.viz.colormaps.segmentation(num_labels=100, style = 'bright', first_color_background=True, background_color="black", min_dist=0.7)
        cmap_soft = qim3d.viz.colormaps.segmentation(num_labels=100, style = 'soft', first_color_background=True, background_color="black", min_dist=0.2)
        cmap_earth = qim3d.viz.colormaps.segmentation(num_labels=100, style = 'earth', first_color_background=True, background_color="black", min_dist=0.8)
        cmap_ocean = qim3d.viz.colormaps.segmentation(num_labels=100, style = 'ocean', first_color_background=True, background_color="black", min_dist=0.9)

        display(cmap_bright)
        display(cmap_soft)
        display(cmap_earth)
        display(cmap_ocean)
        ```
        ![colormap objects](../../assets/screenshots/viz-colormaps-objects-all.png)

        ```python
        import qim3d

        vol = qim3d.examples.cement_128x128x128
        binary = qim3d.filters.gaussian(vol, sigma = 2) < 60
        labeled_volume, num_labels = qim3d.segmentation.watershed(binary)

        color_map = qim3d.viz.colormaps.segmentation(num_labels, style = 'bright')
        qim3d.viz.slicer(labeled_volume, slice_axis = 1, color_map=color_map)
        ```
        ![colormap objects](../../assets/screenshots/viz-colormaps-objects.gif)

    Tip:
        It can be easily used when calling visualization functions as
        ```python
        qim3d.viz.slices_grid(segmented_volume, color_map = 'objects')
        ```
        which automatically detects number of unique classes
        and creates the colormap object with defualt arguments.

    Tip:
        The `min_dist` parameter can be used to control the distance between neighboring colors.
        ![colormap objects mind_dist](../../assets/screenshots/viz-colormaps-min_dist.gif)

    """
    from skimage import color

    # Check style
    if style not in ('bright', 'soft', 'earth', 'ocean'):
        raise ValueError(
            f'Please choose "bright", "soft", "earth" or "ocean" for style in qim3dCmap not "{style}"'
        )

    # Translate strings to background color
    color_dict = {'black': (0.0, 0.0, 0.0), 'white': (1.0, 1.0, 1.0)}
    if not isinstance(background_color, tuple):
        try:
            background_color = color_dict[background_color]
        except KeyError:
            raise ValueError(
                f'Invalid color name "{background_color}". Please choose from {list(color_dict.keys())}.'
            )

    # Add one to num_labels to include the background color
    num_labels += 1

    # Create a new random generator, to locally set seed
    rng = np.random.default_rng(seed)

    # Generate color map for bright colors, based on hsv
    if style == 'bright':
        randHSVcolors = [
            (
                rng.uniform(low=0.0, high=1),
                rng.uniform(low=0.4, high=1),
                rng.uniform(low=0.9, high=1),
            )
            for i in range(num_labels)
        ]

        # Convert HSV list to RGB
        randRGBcolors = []
        for HSVcolor in randHSVcolors:
            randRGBcolors.append(
                colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])
            )

    # Generate soft pastel colors, by limiting the RGB spectrum
    if style == 'soft':
        low = 0.6
        high = 0.95
        randRGBcolors = [
            (
                rng.uniform(low=low, high=high),
                rng.uniform(low=low, high=high),
                rng.uniform(low=low, high=high),
            )
            for i in range(num_labels)
        ]

    # Generate color map for earthy colors, based on LAB
    if style == 'earth':
        randLABColors = [
            (
                rng.uniform(low=25, high=110),
                rng.uniform(low=-120, high=70),
                rng.uniform(low=-70, high=70),
            )
            for i in range(num_labels)
        ]

        # Convert LAB list to RGB
        randRGBcolors = []
        for LabColor in randLABColors:
            randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist())

    # Generate color map for ocean colors, based on LAB
    if style == 'ocean':
        randLABColors = [
            (
                rng.uniform(low=0, high=110),
                rng.uniform(low=-128, high=160),
                rng.uniform(low=-128, high=0),
            )
            for i in range(num_labels)
        ]

        # Convert LAB list to RGB
        randRGBcolors = []
        for LabColor in randLABColors:
            randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist())

    # Re-arrange colors to have a minimum distance between neighboring colors
    randRGBcolors = rearrange_colors(randRGBcolors, min_dist)

    # Set first and last color to background
    if first_color_background:
        randRGBcolors[0] = background_color

    if last_color_background:
        randRGBcolors[-1] = background_color

    # Create colormap
    objects = LinearSegmentedColormap.from_list('objects', randRGBcolors, N=num_labels)

    return objects