Skip to content

Data visualization

The qim3d library aims to provide easy ways to explore and get insights from volumetric data.

qim3d.viz

qim3d.viz.slices

slices(vol, axis=0, position=None, n_slices=5, max_cols=5, cmap='viridis', vmin=None, vmax=None, img_height=2, img_width=2, show=False, show_position=True, interpolation='none', img_size=None, cbar=False, **imshow_kwargs)

Displays one or several slices from a 3d volume.

By default if position is None, slices plots n_slices linearly spaced slices. If position is given as a string or integer, slices will plot an overview with n_slices figures around that position. If position is given as a list, n_slices will be ignored and the slices from position will be plotted.

Parameters:

Name Type Description Default
vol ndarray

The 3D volume to be sliced.

required
axis int

Specifies the axis, or dimension, along which to slice. Defaults to 0.

0
position (str, int, list)

One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.

None
n_slices int

Defines how many slices the user wants to be displayed. Defaults to 5.

5
max_cols int

The maximum number of columns to be plotted. Defaults to 5.

5
cmap str

Specifies the color map for the image. Defaults to "viridis".

'viridis'
vmin float

Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
vmax float

Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
img_height int

Height of the figure.

2
img_width int

Width of the figure.

2
show bool

If True, displays the plot (i.e. calls plt.show()). Defaults to False.

False
show_position bool

If True, displays the position of the slices. Defaults to True.

True
interpolation str

Specifies the interpolation method for the image. Defaults to None.

'none'
cbar bool

Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.

False

Returns:

Name Type Description
fig Figure

The figure with the slices from the 3d array.

Raises:

Type Description
ValueError

If the input is not a numpy.ndarray or da.core.Array.

ValueError

If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.

ValueError

If the file or array is not a volume with at least 3 dimensions.

ValueError

If the position keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".

Example

import qim3d

vol = qim3d.examples.shell_225x128x128
qim3d.viz.slices(vol, n_slices=15)
Grid of slices

Source code in qim3d/viz/explore.py
def slices(
    vol: np.ndarray,
    axis: int = 0,
    position: Optional[Union[str, int, List[int]]] = None,
    n_slices: int = 5,
    max_cols: int = 5,
    cmap: str = "viridis",
    vmin: float = None,
    vmax: float = None,
    img_height: int = 2,
    img_width: int = 2,
    show: bool = False,
    show_position: bool = True,
    interpolation: Optional[str] = "none",
    img_size=None,
    cbar: bool = False,
    **imshow_kwargs,
) -> plt.Figure:
    """Displays one or several slices from a 3d volume.

    By default if `position` is None, slices plots `n_slices` linearly spaced slices.
    If `position` is given as a string or integer, slices will plot an overview with `n_slices` figures around that position.
    If `position` is given as a list, `n_slices` will be ignored and the slices from `position` will be plotted.

    Args:
        vol np.ndarray: The 3D volume to be sliced.
        axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
        position (str, int, list, optional): One or several slicing levels. If None, linearly spaced slices will be displayed. Defaults to None.
        n_slices (int, optional): Defines how many slices the user wants to be displayed. Defaults to 5.
        max_cols (int, optional): The maximum number of columns to be plotted. Defaults to 5.
        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        img_height (int, optional): Height of the figure.
        img_width (int, optional): Width of the figure.
        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
        show_position (bool, optional): If True, displays the position of the slices. Defaults to True.
        interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
        cbar (bool, optional): Adds a colorbar positioned in the top-right for the corresponding colormap and data range. Defaults to False.

    Returns:
        fig (matplotlib.figure.Figure): The figure with the slices from the 3d array.

    Raises:
        ValueError: If the input is not a numpy.ndarray or da.core.Array.
        ValueError: If the axis to slice along is not a valid choice, i.e. not an integer between 0 and the number of dimensions of the volume minus 1.
        ValueError: If the file or array is not a volume with at least 3 dimensions.
        ValueError: If the `position` keyword argument is not a integer, list of integers or one of the following strings: "start", "mid" or "end".

    Example:
        ```python
        import qim3d

        vol = qim3d.examples.shell_225x128x128
        qim3d.viz.slices(vol, n_slices=15)
        ```
        ![Grid of slices](assets/screenshots/viz-slices.png)
    """
    if img_size:
        img_height = img_size
        img_width = img_size

    # Numpy array or Torch tensor input
    if not isinstance(vol, (np.ndarray, da.core.Array)):
        raise ValueError("Data type not supported")

    if vol.ndim < 3:
        raise ValueError(
            "The provided object is not a volume as it has less than 3 dimensions."
        )

    if isinstance(vol, da.core.Array):
        vol = vol.compute()

    # Ensure axis is a valid choice
    if not (0 <= axis < vol.ndim):
        raise ValueError(
            f"Invalid value for 'axis'. It should be an integer between 0 and {vol.ndim - 1}."
        )

    # Get total number of slices in the specified dimension
    n_total = vol.shape[axis]

    # Position is not provided - will use linearly spaced slices
    if position is None:
        slice_idxs = np.linspace(0, n_total - 1, n_slices, dtype=int)
    # Position is a string
    elif isinstance(position, str) and position.lower() in ["start", "mid", "end"]:
        if position.lower() == "start":
            slice_idxs = _get_slice_range(0, n_slices, n_total)
        elif position.lower() == "mid":
            slice_idxs = _get_slice_range(n_total // 2, n_slices, n_total)
        elif position.lower() == "end":
            slice_idxs = _get_slice_range(n_total - 1, n_slices, n_total)
    #  Position is an integer
    elif isinstance(position, int):
        slice_idxs = _get_slice_range(position, n_slices, n_total)
    # Position is a list of integers
    elif isinstance(position, list) and all(isinstance(idx, int) for idx in position):
        slice_idxs = position
    else:
        raise ValueError(
            'Position not recognized. Choose an integer, list of integers or one of the following strings: "start", "mid" or "end".'
        )

    # Make grid
    nrows = math.ceil(n_slices / max_cols)
    ncols = min(n_slices, max_cols)

    # Generate figure
    fig, axs = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=(ncols * img_height, nrows * img_width),
        constrained_layout=True,
    )

    if nrows == 1:
        axs = [axs]  # Convert to a list for uniformity

    # Convert to NumPy array in order to use the numpy.take method
    if isinstance(vol, da.core.Array):
        vol = vol.compute()

    if cbar:
        # In this case, we want the vrange to be constant across the slices, which makes them all comparable to a single cbar.
        new_vmin = vmin if vmin else np.min(vol)
        new_vmax = vmax if vmax else np.max(vol)

    # Run through each ax of the grid
    for i, ax_row in enumerate(axs):
        for j, ax in enumerate(np.atleast_1d(ax_row)):
            slice_idx = i * max_cols + j
            try:
                slice_img = vol.take(slice_idxs[slice_idx], axis=axis)

                if not cbar:
                    # If vmin is higher than the highest value in the image ValueError is raised
                    # We don't want to override the values because next slices might be okay
                    new_vmin = (
                        None
                        if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
                        else vmin
                    )
                    new_vmax = (
                        None
                        if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
                        else vmax
                    )

                ax.imshow(
                    slice_img,
                    cmap=cmap,
                    interpolation=interpolation,
                    vmin=new_vmin,
                    vmax=new_vmax,
                    **imshow_kwargs,
                )

                if show_position:
                    ax.text(
                        0.0,
                        1.0,
                        f"slice {slice_idxs[slice_idx]} ",
                        transform=ax.transAxes,
                        color="white",
                        fontsize=8,
                        va="top",
                        ha="left",
                        bbox=dict(facecolor="#303030", linewidth=0, pad=0),
                    )

                    ax.text(
                        1.0,
                        0.0,
                        f"axis {axis} ",
                        transform=ax.transAxes,
                        color="white",
                        fontsize=8,
                        va="bottom",
                        ha="right",
                        bbox=dict(facecolor="#303030", linewidth=0, pad=0),
                    )

            except IndexError:
                # Not a problem, because we simply do not have a slice to show
                pass

            # Hide the axis, so that we have a nice grid
            ax.axis("off")

    if cbar:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=UserWarning)
            fig.tight_layout()
        norm = matplotlib.colors.Normalize(vmin=new_vmin, vmax=new_vmax, clip=True)
        mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)

        # Figure coordinates of top-right axis
        tr_pos = np.atleast_1d(axs[0])[-1].get_position()
        # The width is divided by ncols to make it the same relative size to the images
        cbar_ax = fig.add_axes(
            [tr_pos.x1 + 0.05 / ncols, tr_pos.y0, 0.05 / ncols, tr_pos.height]
        )
        fig.colorbar(mappable=mappable, cax=cbar_ax, orientation="vertical")

    if show:
        plt.show()

    plt.close()

    return fig

qim3d.viz.slicer

slicer(vol, axis=0, cmap='viridis', vmin=None, vmax=None, img_height=3, img_width=3, show_position=False, interpolation='none', img_size=None, cbar=False, **imshow_kwargs)

Interactive widget for visualizing slices of a 3D volume.

Parameters:

Name Type Description Default
vol ndarray

The 3D volume to be sliced.

required
axis int

Specifies the axis, or dimension, along which to slice. Defaults to 0.

0
cmap str

Specifies the color map for the image. Defaults to "viridis".

'viridis'
vmin float

Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
vmax float

Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
img_height int

Height of the figure. Defaults to 3.

3
img_width int

Width of the figure. Defaults to 3.

3
show_position bool

If True, displays the position of the slices. Defaults to False.

False
interpolation str

Specifies the interpolation method for the image. Defaults to None.

'none'
cbar bool

Adds a colorbar for the corresponding colormap and data range. Defaults to False.

False

Returns:

Name Type Description
slicer_obj interactive

The interactive widget for visualizing slices of a 3D volume.

Example

import qim3d

vol = qim3d.examples.bone_128x128x128
qim3d.viz.slicer(vol)
viz slicer

Source code in qim3d/viz/explore.py
def slicer(
    vol: np.ndarray,
    axis: int = 0,
    cmap: str = "viridis",
    vmin: float = None,
    vmax: float = None,
    img_height: int = 3,
    img_width: int = 3,
    show_position: bool = False,
    interpolation: Optional[str] = "none",
    img_size=None,
    cbar: bool = False,
    **imshow_kwargs,
) -> widgets.interactive:
    """Interactive widget for visualizing slices of a 3D volume.

    Args:
        vol (np.ndarray): The 3D volume to be sliced.
        axis (int, optional): Specifies the axis, or dimension, along which to slice. Defaults to 0.
        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        img_height (int, optional): Height of the figure. Defaults to 3.
        img_width (int, optional): Width of the figure. Defaults to 3.
        show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
        interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.
        cbar (bool, optional): Adds a colorbar for the corresponding colormap and data range. Defaults to False.

    Returns:
        slicer_obj (widgets.interactive): The interactive widget for visualizing slices of a 3D volume.

    Example:
        ```python
        import qim3d

        vol = qim3d.examples.bone_128x128x128
        qim3d.viz.slicer(vol)
        ```
        ![viz slicer](assets/screenshots/viz-slicer.gif)
    """

    if img_size:
        img_height = img_size
        img_width = img_size

    # Create the interactive widget
    def _slicer(position):
        fig = slices(
            vol,
            axis=axis,
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
            img_height=img_height,
            img_width=img_width,
            show_position=show_position,
            interpolation=interpolation,
            position=position,
            n_slices=1,
            show=True,
            cbar=cbar,
            **imshow_kwargs,
        )
        return fig

    position_slider = widgets.IntSlider(
        value=vol.shape[axis] // 2,
        min=0,
        max=vol.shape[axis] - 1,
        description="Slice",
        continuous_update=True,
    )
    slicer_obj = widgets.interactive(_slicer, position=position_slider)
    slicer_obj.layout = widgets.Layout(align_items="flex-start")

    return slicer_obj

qim3d.viz.orthogonal

orthogonal(vol, cmap='viridis', vmin=None, vmax=None, img_height=3, img_width=3, show_position=False, interpolation=None, img_size=None)

Interactive widget for visualizing orthogonal slices of a 3D volume.

Parameters:

Name Type Description Default
vol ndarray

The 3D volume to be sliced.

required
cmap str

Specifies the color map for the image. Defaults to "viridis".

'viridis'
vmin float

Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
vmax float

Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
img_height int

Height of the figure.

3
img_width int

Width of the figure.

3
show_position bool

If True, displays the position of the slices. Defaults to False.

False
interpolation str

Specifies the interpolation method for the image. Defaults to None.

None

Returns:

Name Type Description
orthogonal_obj HBox

The interactive widget for visualizing orthogonal slices of a 3D volume.

Example

import qim3d

vol = qim3d.examples.fly_150x256x256
qim3d.viz.orthogonal(vol, cmap="magma")
viz orthogonal

Source code in qim3d/viz/explore.py
def orthogonal(
    vol: np.ndarray,
    cmap: str = "viridis",
    vmin: float = None,
    vmax: float = None,
    img_height: int = 3,
    img_width: int = 3,
    show_position: bool = False,
    interpolation: Optional[str] = None,
    img_size=None,
):
    """Interactive widget for visualizing orthogonal slices of a 3D volume.

    Args:
        vol (np.ndarray): The 3D volume to be sliced.
        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        img_height (int, optional): Height of the figure.
        img_width (int, optional): Width of the figure.
        show_position (bool, optional): If True, displays the position of the slices. Defaults to False.
        interpolation (str, optional): Specifies the interpolation method for the image. Defaults to None.

    Returns:
        orthogonal_obj (widgets.HBox): The interactive widget for visualizing orthogonal slices of a 3D volume.

    Example:
        ```python
        import qim3d

        vol = qim3d.examples.fly_150x256x256
        qim3d.viz.orthogonal(vol, cmap="magma")
        ```
        ![viz orthogonal](assets/screenshots/viz-orthogonal.gif)
    """

    if img_size:
        img_height = img_size
        img_width = img_size

    get_slicer_for_axis = lambda axis: slicer(
        vol,
        axis=axis,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        img_height=img_height,
        img_width=img_width,
        show_position=show_position,
        interpolation=interpolation,
    )

    z_slicer = get_slicer_for_axis(axis=0)
    y_slicer = get_slicer_for_axis(axis=1)
    x_slicer = get_slicer_for_axis(axis=2)

    z_slicer.children[0].description = "Z"
    y_slicer.children[0].description = "Y"
    x_slicer.children[0].description = "X"

    return widgets.HBox([z_slicer, y_slicer, x_slicer])

qim3d.viz.vol

vol(img, aspectmode='data', show=True, save=False, grid_visible=False, cmap=None, vmin=None, vmax=None, samples='auto', max_voxels=512 ** 3, data_type='scaled_float16', **kwargs)

Visualizes a 3D volume using volumetric rendering.

Parameters:

Name Type Description Default
img ndarray

The input 3D image data. It should be a 3D numpy array.

required
aspectmode str

Determines the proportions of the scene's axes. Defaults to "data".

If 'data', the axes are drawn in proportion with the axes' ranges. If 'cube', the axes are drawn as a cube, regardless of the axes' ranges.

'data'
show bool

If True, displays the visualization inline. Defaults to True.

True
save bool or str

If True, saves the visualization as an HTML file. If a string is provided, it's interpreted as the file path where the HTML file will be saved. Defaults to False.

False
grid_visible bool

If True, the grid is visible in the plot. Defaults to False.

False
cmap list

The color map to be used for the volume rendering. Defaults to None.

None
vmin float

Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
vmax float

Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
samples int

The number of samples to be used for the volume rendering in k3d. Defaults to 512. Lower values will render faster but with lower quality.

'auto'
max_voxels int

Defaults to 512^3.

512 ** 3
data_type str

Default to 'scaled_float16'.

'scaled_float16'
**kwargs

Additional keyword arguments to be passed to the k3d.plot function.

{}

Returns:

Name Type Description
plot plot

If show=False, returns the K3D plot object.

Raises:

Type Description
ValueError

If aspectmode is not 'data' or 'cube'.

Example

Display a volume inline:

import qim3d

vol = qim3d.examples.bone_128x128x128
qim3d.viz.vol(vol)

Save a plot to an HTML file:

import qim3d
vol = qim3d.examples.bone_128x128x128
plot = qim3d.viz.vol(vol, show=False, save="plot.html")
Source code in qim3d/viz/k3d.py
def vol(
    img,
    aspectmode="data",
    show=True,
    save=False,
    grid_visible=False,
    cmap=None,
    vmin=None,
    vmax=None,
    samples="auto",
    max_voxels=512**3,
    data_type="scaled_float16",
    **kwargs,
):
    """
    Visualizes a 3D volume using volumetric rendering.

    Args:
        img (numpy.ndarray): The input 3D image data. It should be a 3D numpy array.
        aspectmode (str, optional): Determines the proportions of the scene's axes. Defaults to "data".

            If `'data'`, the axes are drawn in proportion with the axes' ranges.
            If `'cube'`, the axes are drawn as a cube, regardless of the axes' ranges.
        show (bool, optional): If True, displays the visualization inline. Defaults to True.
        save (bool or str, optional): If True, saves the visualization as an HTML file.
            If a string is provided, it's interpreted as the file path where the HTML
            file will be saved. Defaults to False.
        grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False.
        cmap (list, optional): The color map to be used for the volume rendering. Defaults to None.
        vmin (float, optional): Together with vmax defines the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        vmax (float, optional): Together with vmin defines the data range the colormap covers. By default colormap covers the full range. Defaults to None
        samples (int, optional): The number of samples to be used for the volume rendering in k3d. Defaults to 512.
            Lower values will render faster but with lower quality.
        max_voxels (int, optional): Defaults to 512^3.
        data_type (str, optional): Default to 'scaled_float16'.
        **kwargs: Additional keyword arguments to be passed to the `k3d.plot` function.

    Returns:
        plot (k3d.plot): If `show=False`, returns the K3D plot object.

    Raises:
        ValueError: If `aspectmode` is not `'data'` or `'cube'`.

    Example:
        Display a volume inline:

        ```python
        import qim3d

        vol = qim3d.examples.bone_128x128x128
        qim3d.viz.vol(vol)
        ```
        <iframe src="https://platform.qim.dk/k3d/fima-bone_128x128x128-20240221113459.html" width="100%" height="500" frameborder="0"></iframe>

        Save a plot to an HTML file:

        ```python
        import qim3d
        vol = qim3d.examples.bone_128x128x128
        plot = qim3d.viz.vol(vol, show=False, save="plot.html")
        ```

    """
    import k3d

    pixel_count = img.shape[0] * img.shape[1] * img.shape[2]
    # target is 60fps on m1 macbook pro, using test volume: https://data.qim.dk/pages/foam.html
    if samples == "auto":
        y1, x1 = 256, 16777216  # 256 samples at res 256*256*256=16.777.216
        y2, x2 = 32, 134217728  # 32 samples at res 512*512*512=134.217.728

        # we fit linear function to the two points
        a = (y1 - y2) / (x1 - x2)
        b = y1 - a * x1

        samples = int(min(max(a * pixel_count + b, 64), 512))
    else:
        samples = int(samples)  # make sure it's an integer

    if aspectmode.lower() not in ["data", "cube"]:
        raise ValueError("aspectmode should be either 'data' or 'cube'")
    # check if image should be downsampled for visualization
    original_shape = img.shape
    img = downscale_img(img, max_voxels=max_voxels)

    new_shape = img.shape

    if original_shape != new_shape:
        log.warning(
            f"Downsampled image for visualization, from {original_shape} to {new_shape}"
        )

    # Scale the image to float16 if needed
    if save:
        # When saving, we need float64
        img = img.astype(np.float64)
    else:

        if data_type == "scaled_float16":
            img = scale_to_float16(img)
        else:
            img = img.astype(data_type)

    # Set color ranges
    color_range = [np.min(img), np.max(img)]
    if vmin:
        color_range[0] = vmin
    if vmax:
        color_range[1] = vmax

    # Create the volume plot
    plt_volume = k3d.volume(
        img,
        bounds=(
            [0, img.shape[2], 0, img.shape[1], 0, img.shape[0]]
            if aspectmode.lower() == "data"
            else None
        ),
        color_map=cmap,
        samples=samples,
        color_range=color_range,
    )
    plot = k3d.plot(grid_visible=grid_visible, **kwargs)
    plot += plt_volume
    if save:
        # Save html to disk
        with open(str(save), "w", encoding="utf-8") as fp:
            fp.write(plot.get_snapshot())

    if show:
        plot.display()
    else:
        return plot  

qim3d.viz.chunks

chunks(zarr_path, **kwargs)

Function to visualize chunks of a Zarr dataset using the specified visualization method.

Parameters:

Name Type Description Default
zarr_path str

Path to the Zarr dataset.

required
**kwargs

Additional keyword arguments to pass to the visualization method.

{}
Example

import qim3d

# Download dataset
downloader = qim3d.io.Downloader()
data = downloader.Snail.Escargot(load_file=True)

# Export as OME-Zarr
qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2, replace=True)

# Explore chunks
qim3d.viz.chunks("Escargot.zarr")
chunks-visualization

Source code in qim3d/viz/explore.py
def chunks(zarr_path: str, **kwargs):
    """
    Function to visualize chunks of a Zarr dataset using the specified visualization method.

    Args:
        zarr_path (str): Path to the Zarr dataset.
        **kwargs: Additional keyword arguments to pass to the visualization method.

    Example:
        ```python
        import qim3d

        # Download dataset
        downloader = qim3d.io.Downloader()
        data = downloader.Snail.Escargot(load_file=True)

        # Export as OME-Zarr
        qim3d.io.export_ome_zarr("Escargot.zarr", data, chunk_size=100, downsample_rate=2, replace=True)

        # Explore chunks
        qim3d.viz.chunks("Escargot.zarr")
        ```
        ![chunks-visualization](assets/screenshots/chunks_visualization.gif)
    """

    # Load the Zarr dataset
    zarr_data = zarr.open(zarr_path, mode="r")

    # Save arguments for later use
    # visualization_method = visualization_method
    # preserved_kwargs = kwargs

    # Create label to display the chunk coordinates
    widget_title = widgets.HTML("<h2>Chunk Explorer</h2>")
    chunk_info_label = widgets.HTML(value="Chunk info will be displayed here")

    def load_and_visualize(
        scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
    ):
        # Get chunk shape for the selected scale
        chunk_shape = zarr_data[scale].chunks

        # Calculate slice indices for the selected chunk
        slices = (
            slice(
                z_coord * chunk_shape[0],
                min((z_coord + 1) * chunk_shape[0], zarr_data[scale].shape[0]),
            ),
            slice(
                y_coord * chunk_shape[1],
                min((y_coord + 1) * chunk_shape[1], zarr_data[scale].shape[1]),
            ),
            slice(
                x_coord * chunk_shape[2],
                min((x_coord + 1) * chunk_shape[2], zarr_data[scale].shape[2]),
            ),
        )

        # Extract start and stop values from each slice object
        z_start, z_stop = slices[0].start, slices[0].stop
        y_start, y_stop = slices[1].start, slices[1].stop
        x_start, x_stop = slices[2].start, slices[2].stop

        # Extract the chunk
        chunk = zarr_data[scale][slices]

        # Update the chunk info label with the chunk coordinates
        info_string = (
            f"<b>shape:</b> {chunk_shape}\n"
            + f"<b>coordinates:</b> ({z_coord}, {y_coord}, {x_coord})\n"
            + f"<b>ranges: </b>Z({z_start}-{z_stop})   Y({y_start}-{y_stop})   X({x_start}-{x_stop})\n"
            + f"<b>dtype:</b> {chunk.dtype}\n"
            + f"<b>min value:</b> {np.min(chunk)}\n"
            + f"<b>max value:</b> {np.max(chunk)}\n"
            + f"<b>mean value:</b> {np.mean(chunk)}\n"
        )

        chunk_info_label.value = f"""
            <div style="font-size: 14px; text-align: left; margin-left:32px">
                <h3 style="margin: 0px">Chunk Info</h3>
                    <div style="font-size: 14px; text-align: left;">
                    <pre>{info_string}</pre>
                    </div>
            </div>

            """

        # Prepare chunk visualization based on the selected method
        if visualization_method == "slicer":  # return a widget
            viz_widget = qim3d.viz.slicer(chunk, **kwargs)
        elif visualization_method == "slices":  # return a plt.Figure
            viz_widget = widgets.Output()
            with viz_widget:
                viz_widget.clear_output(wait=True)
                fig = qim3d.viz.slices(chunk, **kwargs)
                display(fig)
        elif visualization_method == "vol":
            viz_widget = widgets.Output()
            with viz_widget:
                viz_widget.clear_output(wait=True)
                out = qim3d.viz.vol(chunk, show=False, **kwargs)
                display(out)
        else:
            log.info(f"Invalid visualization method: {visualization_method}")

        return viz_widget

    # Function to calculate the number of chunks for each dimension, including partial chunks
    def get_num_chunks(shape, chunk_size):
        return [(s + chunk_size[i] - 1) // chunk_size[i] for i, s in enumerate(shape)]

    scale_options = {
        f"{i} {zarr_data[i].shape}": i for i in range(len(zarr_data))
    }  # len(zarr_data) gives number of scales

    description_width = "128px"
    # Create dropdown for scale
    scale_dropdown = widgets.Dropdown(
        options=scale_options,
        value=0,  # Default to first scale
        description="OME-Zarr scale",
        style={"description_width": description_width, "text_align": "left"},
    )

    # Initialize the options for x, y, and z based on the first scale by default
    multiscale_shape = zarr_data[0].shape
    chunk_shape = zarr_data[0].chunks
    num_chunks = get_num_chunks(multiscale_shape, chunk_shape)

    z_dropdown = widgets.Dropdown(
        options=list(range(num_chunks[0])),
        value=0,
        description="First dimension (Z)",
        style={"description_width": description_width, "text_align": "left"},
    )

    y_dropdown = widgets.Dropdown(
        options=list(range(num_chunks[1])),
        value=0,
        description="Second dimension (Y)",
        style={"description_width": description_width, "text_align": "left"},
    )

    x_dropdown = widgets.Dropdown(
        options=list(range(num_chunks[2])),
        value=0,
        description="Third dimension (X)",
        style={"description_width": description_width, "text_align": "left"},
    )

    method_dropdown = widgets.Dropdown(
        options=["slicer", "slices", "vol"],
        value="slicer",
        description="Visualization",
        style={"description_width": description_width, "text_align": "left"},
    )

    # Funtion to temporarily disable observers
    def disable_observers():
        x_dropdown.unobserve(update_visualization, names="value")
        y_dropdown.unobserve(update_visualization, names="value")
        z_dropdown.unobserve(update_visualization, names="value")
        method_dropdown.unobserve(update_visualization, names="value")

    # Funtion to enable observers
    def enable_observers():
        x_dropdown.observe(update_visualization, names="value")
        y_dropdown.observe(update_visualization, names="value")
        z_dropdown.observe(update_visualization, names="value")
        method_dropdown.observe(update_visualization, names="value")

    # Function to update the x, y, z dropdowns when the scale changes and reset the coordinates to 0
    def update_coordinate_dropdowns(scale):

        disable_observers()  # to avoid multiple reload of the visualization when updating the dropdowns

        multiscale_shape = zarr_data[scale].shape
        chunk_shape = zarr_data[scale].chunks
        num_chunks = get_num_chunks(
            multiscale_shape, chunk_shape
        )  # Calculate  new chunk options

        # Reset X, Y, Z dropdowns to 0
        z_dropdown.options = list(range(num_chunks[0]))
        z_dropdown.value = 0  # Reset to 0
        z_dropdown.disabled = (
            len(z_dropdown.options) == 1
        )  # Disable if only one option (0) is available

        y_dropdown.options = list(range(num_chunks[1]))
        y_dropdown.value = 0  # Reset to 0
        y_dropdown.disabled = (
            len(y_dropdown.options) == 1
        )  # Disable if only one option (0) is available

        x_dropdown.options = list(range(num_chunks[2]))
        x_dropdown.value = 0  # Reset to 0
        x_dropdown.disabled = (
            len(x_dropdown.options) == 1
        )  # Disable if only one option (0) is available

        enable_observers()

        update_visualization()

    # Function to update the visualization when any dropdown value changes
    def update_visualization(*args):
        scale = scale_dropdown.value
        x_coord = x_dropdown.value
        y_coord = y_dropdown.value
        z_coord = z_dropdown.value
        visualization_method = method_dropdown.value

        # Clear and update the chunk visualization
        slicer_widget = load_and_visualize(
            scale, z_coord, y_coord, x_coord, visualization_method, **kwargs
        )

        # Recreate the layout and display the new visualization
        final_layout.children = [widget_title, hbox_layout, slicer_widget]

    # Attach an observer to scale dropdown to update x, y, z dropdowns when the scale changes
    scale_dropdown.observe(
        lambda change: update_coordinate_dropdowns(scale_dropdown.value), names="value"
    )

    enable_observers()

    # Create first visualization
    slicer_widget = load_and_visualize(
        scale_dropdown.value,
        z_dropdown.value,
        y_dropdown.value,
        x_dropdown.value,
        method_dropdown.value,
        **kwargs,
    )

    # Create the layout
    vbox_dropbox = widgets.VBox(
        [scale_dropdown, z_dropdown, y_dropdown, x_dropdown, method_dropdown]
    )
    hbox_layout = widgets.HBox([vbox_dropbox, chunk_info_label])
    final_layout = widgets.VBox([widget_title, hbox_layout, slicer_widget])

    # Display the VBox
    display(final_layout)

qim3d.viz.itk_vtk

itk_vtk(filename=None, open_browser=True, file_server_port=8042, viewer_port=3000)

Opens a visualization window using the itk-vtk-viewer. Works both for common file types (Tiff, Nifti, etc.) and for OME-Zarr stores.

This function starts the itk-vtk-viewer, either using a global installation or a local installation within the QIM package. It also starts an HTTP server to serve the file to the viewer. Optionally, it can automatically open a browser window to display the viewer. If the viewer is not installed, it raises a NotInstalledError.

Parameters:

Name Type Description Default
filename str

Path to the file or OME-Zarr store to be visualized. Trailing slashes in the path are normalized. Defaults to None.

None
open_browser bool

If True, opens the visualization in a new browser tab. Defaults to True.

True
file_server_port int

The port number for the local file server that hosts the store. Defaults to 8042.

8042
viewer_port int

The port number for the itk-vtk-viewer server. Defaults to 3000.

3000

Raises:

Type Description
NotInstalledError

Raised if the itk-vtk-viewer is not installed in the expected location.

Example
import qim3d

# Download data
downloader = qim3d.io.Downloader()
data = downloader.Okinawa_Forams.Okinawa_Foram_1(load_file=True, virtual_stack=True)

# Export to OME-Zarr
qim3d.io.export_ome_zarr("Okinawa_Foram_1.zarr", data)

# Start visualization
qim3d.viz.itk_vtk("Okinawa_Foram_1.zarr")
Downloading Okinawa_Foram_1.tif
https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository/Okinawa_Forams/Okinawa_Foram_1.tif
1.85GB [00:17, 111MB/s]                                                         

Loading Okinawa_Foram_1.tif
Loading: 100%
 1.85GB/1.85GB  [00:02<00:00, 762MB/s]
Loaded shape: (995, 1014, 984)
Using virtual stack
Exporting data to OME-Zarr format at Okinawa_Foram_1.zarr
Number of scales: 5
Creating a multi-scale pyramid
- Scale 0: (995, 1014, 984)
- Scale 1: (498, 507, 492)
- Scale 2: (249, 254, 246)
- Scale 3: (124, 127, 123)
- Scale 4: (62, 63, 62)
Writing data to disk
All done!

itk-vtk-viewer
=> Serving /home/fima/Notebooks/Qim3d on port 3000

    enp0s31f6 => http://10.52.0.158:3000/
    wlp0s20f3 => http://10.197.104.229:3000/

Serving directory '/home/fima/Notebooks/Qim3d'
http://localhost:8042/

Visualization url:
http://localhost:3000/?rotate=false&fileToLoad=http://localhost:8042/Okinawa_Foram_1.zarr

itk-vtk-viewer

Source code in qim3d/viz/itk_vtk_viewer/run.py
def itk_vtk(
    filename: str = None,
    open_browser: bool = True,
    file_server_port: int = 8042,
    viewer_port: int = 3000,
):
    """
    Opens a visualization window using the itk-vtk-viewer. Works both for common file types (Tiff, Nifti, etc.) and for **OME-Zarr stores**.

    This function starts the itk-vtk-viewer, either using a global
    installation or a local installation within the QIM package. It also starts
    an HTTP server to serve the file to the viewer. Optionally, it can
    automatically open a browser window to display the viewer. If the viewer
    is not installed, it raises a NotInstalledError.

    Args:
        filename (str, optional): Path to the file or OME-Zarr store to be visualized. Trailing slashes in
            the path are normalized. Defaults to None.
        open_browser (bool, optional): If True, opens the visualization in a new browser tab.
            Defaults to True.
        file_server_port (int, optional): The port number for the local file server that hosts
            the store. Defaults to 8042.
        viewer_port (int, optional): The port number for the itk-vtk-viewer server. Defaults to 3000.

    Raises:
        NotInstalledError: Raised if the itk-vtk-viewer is not installed in the expected location.

    Example:
        ```python
        import qim3d

        # Download data
        downloader = qim3d.io.Downloader()
        data = downloader.Okinawa_Forams.Okinawa_Foram_1(load_file=True, virtual_stack=True)

        # Export to OME-Zarr
        qim3d.io.export_ome_zarr("Okinawa_Foram_1.zarr", data)

        # Start visualization
        qim3d.viz.itk_vtk("Okinawa_Foram_1.zarr")
        ```
        <pre style="margin-left: 12px; margin-right: 12px; color:#454545">
        Downloading Okinawa_Foram_1.tif
        https://archive.compute.dtu.dk/download/public/projects/viscomp_data_repository/Okinawa_Forams/Okinawa_Foram_1.tif
        1.85GB [00:17, 111MB/s]                                                         

        Loading Okinawa_Foram_1.tif
        Loading: 100%
         1.85GB/1.85GB  [00:02<00:00, 762MB/s]
        Loaded shape: (995, 1014, 984)
        Using virtual stack
        Exporting data to OME-Zarr format at Okinawa_Foram_1.zarr
        Number of scales: 5
        Creating a multi-scale pyramid
        - Scale 0: (995, 1014, 984)
        - Scale 1: (498, 507, 492)
        - Scale 2: (249, 254, 246)
        - Scale 3: (124, 127, 123)
        - Scale 4: (62, 63, 62)
        Writing data to disk
        All done!

        itk-vtk-viewer
        => Serving /home/fima/Notebooks/Qim3d on port 3000

            enp0s31f6 => http://10.52.0.158:3000/
            wlp0s20f3 => http://10.197.104.229:3000/

        Serving directory '/home/fima/Notebooks/Qim3d'
        http://localhost:8042/

        Visualization url:
        http://localhost:3000/?rotate=false&fileToLoad=http://localhost:8042/Okinawa_Foram_1.zarr

        </pre>

        ![itk-vtk-viewer](assets/screenshots/itk-vtk-viewer.gif)

    """

    global is_installed
    # This might seem redundant but is here in case we have to go through the installation first
    # If we have to install first this variable is set to False and doesn't disappear
    # So when we want to run the newly installed viewer it would still be false and webbrowser wouldnt open
    c.acquire()
    is_installed = True
    c.release()

    # We do a delay open for the browser, just so that the itk-vtk-viewer has time to start.
    # Timing is not critical, this is just so that the user does not see the "server cannot be reached" page
    def delayed_open():
        time.sleep(3)
        global is_installed
        c.acquire()
        if is_installed:

            # Normalize the filename. This is necessary for trailing slashes by the end of the path
            filename_norm = os.path.normpath(os.path.abspath(filename))

            # Start the http server
            qim3d.utils.start_http_server(
                os.path.dirname(filename_norm), port=file_server_port
            )

            viz_url = f"http://localhost:{viewer_port}/?rotate=false&fileToLoad=http://localhost:{file_server_port}/{os.path.basename(filename_norm)}"

            if open_browser:
                webbrowser.open_new_tab(viz_url)

            log.info(f"\nVisualization url:\n{viz_url}\n")
        c.release()

    # Start the delayed open in a separate thread
    delayed_window = threading.Thread(target=delayed_open)
    delayed_window.start()

    # First try if the user doesn't have it globally
    run_global(port=viewer_port)

    # Then try to also find node.js installed in qim package
    run_within_qim_dir(port=viewer_port)

    # If we got to this part, it means that the viewer is not installed and we don't want to
    # open browser with non-working window
    # We sat the flag is_installed to False which will be read in the other thread to let it know not to open the browser
    c.acquire()
    is_installed = False
    c.release()

    delayed_window.join()

    # If we still get an error, it is not installed in location we expect it to be installed and have to raise an error
    # which will be caught in the command line and it will ask for installation
    raise NotInstalledError

qim3d.viz.mesh

mesh(verts, faces, wireframe=True, flat_shading=True, grid_visible=False, show=True, save=False, **kwargs)

Visualizes a 3D mesh using K3D.

Parameters:

Name Type Description Default
verts ndarray

A 2D array (Nx3) containing the vertices of the mesh.

required
faces ndarray

A 2D array (Mx3) containing the indices of the mesh faces.

required
wireframe bool

If True, the mesh is rendered as a wireframe. Defaults to True.

True
flat_shading bool

If True, flat shading is applied to the mesh. Defaults to True.

True
grid_visible bool

If True, the grid is visible in the plot. Defaults to False.

False
show bool

If True, displays the visualization inline. Defaults to True.

True
save bool or str

If True, saves the visualization as an HTML file. If a string is provided, it's interpreted as the file path where the HTML file will be saved. Defaults to False.

False
**kwargs

Additional keyword arguments to be passed to the k3d.plot function.

{}

Returns:

Name Type Description
plot plot

If show=False, returns the K3D plot object.

Example
import qim3d

vol = qim3d.generate.blob(base_shape=(128,128,128),
                          final_shape=(128,128,128),
                          noise_scale=0.03,
                          order=1,
                          gamma=1,
                          max_value=255,
                          threshold=0.5,
                          dtype='uint8'
                          )
mesh = qim3d.processing.create_mesh(vol, step_size=3)
qim3d.viz.mesh(mesh.vertices, mesh.faces)
Source code in qim3d/viz/k3d.py
def mesh(
    verts,
    faces,
    wireframe=True,
    flat_shading=True,
    grid_visible=False,
    show=True,
    save=False,
    **kwargs,
):
    """
    Visualizes a 3D mesh using K3D.

    Args:
        verts (numpy.ndarray): A 2D array (Nx3) containing the vertices of the mesh.
        faces (numpy.ndarray): A 2D array (Mx3) containing the indices of the mesh faces.
        wireframe (bool, optional): If True, the mesh is rendered as a wireframe. Defaults to True.
        flat_shading (bool, optional): If True, flat shading is applied to the mesh. Defaults to True.
        grid_visible (bool, optional): If True, the grid is visible in the plot. Defaults to False.
        show (bool, optional): If True, displays the visualization inline. Defaults to True.
        save (bool or str, optional): If True, saves the visualization as an HTML file.
            If a string is provided, it's interpreted as the file path where the HTML
            file will be saved. Defaults to False.
        **kwargs: Additional keyword arguments to be passed to the `k3d.plot` function.

    Returns:
        plot (k3d.plot): If `show=False`, returns the K3D plot object.

    Example:
        ```python
        import qim3d

        vol = qim3d.generate.blob(base_shape=(128,128,128),
                                  final_shape=(128,128,128),
                                  noise_scale=0.03,
                                  order=1,
                                  gamma=1,
                                  max_value=255,
                                  threshold=0.5,
                                  dtype='uint8'
                                  )
        mesh = qim3d.processing.create_mesh(vol, step_size=3)
        qim3d.viz.mesh(mesh.vertices, mesh.faces)
        ```
        <iframe src="https://platform.qim.dk/k3d/mesh_visualization.html" width="100%" height="500" frameborder="0"></iframe>
    """
    import k3d

    # Validate the inputs
    if verts.shape[1] != 3:
        raise ValueError("Vertices array must have shape (N, 3)")
    if faces.shape[1] != 3:
        raise ValueError("Faces array must have shape (M, 3)")

    # Ensure the correct data types and memory layout
    verts = np.ascontiguousarray(verts.astype(np.float32))  # Cast and ensure C-contiguous layout
    faces = np.ascontiguousarray(faces.astype(np.uint32))    # Cast and ensure C-contiguous layout


    # Create the mesh plot
    plt_mesh = k3d.mesh(
        vertices=verts,
        indices=faces,
        wireframe=wireframe,
        flat_shading=flat_shading,
    )

    # Create plot
    plot = k3d.plot(grid_visible=grid_visible, **kwargs)
    plot += plt_mesh

    if save:
        # Save html to disk
        with open(str(save), "w", encoding="utf-8") as fp:
            fp.write(plot.get_snapshot())

    if show:
        plot.display()
    else:
        return plot

qim3d.viz.local_thickness

local_thickness(image, image_lt, max_projection=False, axis=0, slice_idx=None, show=False, figsize=(15, 5))

Visualizes the local thickness of a 2D or 3D image.

Parameters:

Name Type Description Default
image ndarray

2D or 3D NumPy array representing the image/volume.

required
image_lt ndarray

2D or 3D NumPy array representing the local thickness of the input image/volume.

required
max_projection bool

If True, displays the maximum projection of the local thickness. Only used for 3D images. Defaults to False.

False
axis int

The axis along which to visualize the local thickness. Unused for 2D images. Defaults to 0.

0
slice_idx int or float

The initial slice to be visualized. The slice index can afterwards be changed. If value is an integer, it will be the index of the slice to be visualized. If value is a float between 0 and 1, it will be multiplied by the number of slices and rounded to the nearest integer. If None, the middle slice will be used for 3D images. Unused for 2D images. Defaults to None.

None
show bool

If True, displays the plot (i.e. calls plt.show()). Defaults to False.

False
figsize Tuple[int, int]

The size of the figure. Defaults to (15, 5).

(15, 5)

Raises:

Type Description
ValueError

If the slice index is not an integer or a float between 0 and 1.

Returns:

Type Description
Union[Figure, interactive]

If the input is 3D, returns an interactive widget. Otherwise, returns a matplotlib figure.

Example

import qim3d

fly = qim3d.examples.fly_150x256x256 # 3D volume
lt_fly = qim3d.processing.local_thickness(fly)
qim3d.viz.local_thickness(fly, lt_fly, axis=0)
local thickness 3d

Source code in qim3d/viz/local_thickness_.py
def local_thickness(
    image: np.ndarray,
    image_lt: np.ndarray,
    max_projection: bool = False,
    axis: int = 0,
    slice_idx: Optional[Union[int, float]] = None,
    show: bool = False,
    figsize: Tuple[int, int] = (15, 5),
) -> Union[plt.Figure, widgets.interactive]:
    """Visualizes the local thickness of a 2D or 3D image.

    Args:
        image (np.ndarray): 2D or 3D NumPy array representing the image/volume.
        image_lt (np.ndarray): 2D or 3D NumPy array representing the local thickness of the input
            image/volume.
        max_projection (bool, optional): If True, displays the maximum projection of the local
            thickness. Only used for 3D images. Defaults to False.
        axis (int, optional): The axis along which to visualize the local thickness.
            Unused for 2D images.
            Defaults to 0.
        slice_idx (int or float, optional): The initial slice to be visualized. The slice index
            can afterwards be changed. If value is an integer, it will be the index of the slice
            to be visualized. If value is a float between 0 and 1, it will be multiplied by the
            number of slices and rounded to the nearest integer. If None, the middle slice will
            be used for 3D images. Unused for 2D images. Defaults to None.
        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.
        figsize (Tuple[int, int], optional): The size of the figure. Defaults to (15, 5).

    Raises:
        ValueError: If the slice index is not an integer or a float between 0 and 1.

    Returns:
        If the input is 3D, returns an interactive widget. Otherwise, returns a matplotlib figure.

    Example:
        ```python
        import qim3d

        fly = qim3d.examples.fly_150x256x256 # 3D volume
        lt_fly = qim3d.processing.local_thickness(fly)
        qim3d.viz.local_thickness(fly, lt_fly, axis=0)
        ```
        ![local thickness 3d](assets/screenshots/local_thickness_3d.gif)


    """

    def _local_thickness(image, image_lt, show, figsize, axis=None, slice_idx=None):
        if slice_idx is not None:
            image = image.take(slice_idx, axis=axis)
            image_lt = image_lt.take(slice_idx, axis=axis)

        fig, axs = plt.subplots(1, 3, figsize=figsize, layout="constrained")

        axs[0].imshow(image, cmap="gray")
        axs[0].set_title("Original image")
        axs[0].axis("off")

        axs[1].imshow(image_lt, cmap="viridis")
        axs[1].set_title("Local thickness")
        axs[1].axis("off")

        plt.colorbar(
            axs[1].imshow(image_lt, cmap="viridis"), ax=axs[1], orientation="vertical"
        )

        axs[2].hist(image_lt[image_lt > 0].ravel(), bins=32, edgecolor="black")
        axs[2].set_title("Local thickness histogram")
        axs[2].set_xlabel("Local thickness")
        axs[2].set_ylabel("Count")

        if show:
            plt.show()

        plt.close()

        return fig

    # Get the middle slice if the input is 3D
    if len(image.shape) == 3:
        if max_projection:
            if slice_idx is not None:
                log.warning(
                    "slice_idx is not used for max_projection. It will be ignored."
                )
            image = image.max(axis=axis)
            image_lt = image_lt.max(axis=axis)
            return _local_thickness(image, image_lt, show, figsize)
        else:
            if slice_idx is None:
                slice_idx = image.shape[axis] // 2
            elif isinstance(slice_idx, float):
                if slice_idx < 0 or slice_idx > 1:
                    raise ValueError(
                        "Values of slice_idx of float type must be between 0 and 1."
                    )
                slice_idx = int(slice_idx * image.shape[0]) - 1
            slide_idx_slider = widgets.IntSlider(
                min=0,
                max=image.shape[axis] - 1,
                step=1,
                value=slice_idx,
                description="Slice index",
                layout=widgets.Layout(width="450px"),
            )
            widget_obj = widgets.interactive(
                _local_thickness,
                image=widgets.fixed(image),
                image_lt=widgets.fixed(image_lt),
                show=widgets.fixed(True),
                figsize=widgets.fixed(figsize),
                axis=widgets.fixed(axis),
                slice_idx=slide_idx_slider,
            )
            widget_obj.layout = widgets.Layout(align_items="center")
            if show:
                display(widget_obj)
            return widget_obj
    else:
        if max_projection:
            log.warning(
                "max_projection is only used for 3D images. It will be ignored."
            )
        if slice_idx is not None:
            log.warning("slice_idx is only used for 3D images. It will be ignored.")
        return _local_thickness(image, image_lt, show, figsize)

qim3d.viz.vectors

vectors(volume, vec, axis=0, volume_cmap='grey', vmin=None, vmax=None, slice_idx=None, grid_size=10, interactive=True, figsize=(10, 5), show=False)

Visualizes the orientation of the structures in a 3D volume using the eigenvectors of the structure tensor.

Parameters:

Name Type Description Default
volume ndarray

The 3D volume to be sliced.

required
vec ndarray

The eigenvectors of the structure tensor.

required
axis int

The axis along which to visualize the orientation. Defaults to 0.

0
volume_cmap str

Defines colormap for display of the volume

'grey'
vmin float

Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
vmax float

Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
slice_idx int or float

The initial slice to be visualized. The slice index can afterwards be changed. If value is an integer, it will be the index of the slice to be visualized. If value is a float between 0 and 1, it will be multiplied by the number of slices and rounded to the nearest integer. If None, the middle slice will be used. Defaults to None.

None
grid_size int

The size of the grid. Defaults to 10.

10
interactive bool

If True, returns an interactive widget. Defaults to True.

True
figsize Tuple[int, int]

The size of the figure. Defaults to (15, 5).

(10, 5)
show bool

If True, displays the plot (i.e. calls plt.show()). Defaults to False.

False

Raises:

Type Description
ValueError

If the axis to slice along is not 0, 1, or 2.

ValueError

If the slice index is not an integer or a float between 0 and 1.

Returns:

Name Type Description
fig Union[Figure, interactive]

If interactive is True, returns an interactive widget. Otherwise, returns a matplotlib figure.

Note

The orientation of the vectors is visualized using an HSV color map, where the saturation corresponds to the vector component of the slicing direction (i.e. z-component when choosing visualization along axis = 0). Hence, if an orientation in the volume is orthogonal to the slicing direction, the corresponding color of the visualization will be gray.

Example

import qim3d

vol = qim3d.examples.NT_128x128x128
val, vec = qim3d.processing.structure_tensor(vol)

# Visualize the structure tensor
qim3d.viz.vectors(vol, vec, axis = 2, interactive = True)
structure tensor

Source code in qim3d/viz/structure_tensor.py
def vectors(
    volume: np.ndarray,
    vec: np.ndarray,
    axis: int = 0,
    volume_cmap:str = 'grey',
    vmin:float = None,
    vmax:float = None,
    slice_idx: Optional[Union[int, float]] = None,
    grid_size: int = 10,
    interactive: bool = True,
    figsize: Tuple[int, int] = (10, 5),
    show: bool = False,
) -> Union[plt.Figure, widgets.interactive]:
    """
    Visualizes the orientation of the structures in a 3D volume using the eigenvectors of the structure tensor.

    Args:
        volume (np.ndarray): The 3D volume to be sliced.
        vec (np.ndarray): The eigenvectors of the structure tensor.
        axis (int, optional): The axis along which to visualize the orientation. Defaults to 0.
        volume_cmap (str, optional): Defines colormap for display of the volume
        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        slice_idx (int or float, optional): The initial slice to be visualized. The slice index
            can afterwards be changed. If value is an integer, it will be the index of the slice
            to be visualized. If value is a float between 0 and 1, it will be multiplied by the
            number of slices and rounded to the nearest integer. If None, the middle slice will
            be used. Defaults to None.
        grid_size (int, optional): The size of the grid. Defaults to 10.
        interactive (bool, optional): If True, returns an interactive widget. Defaults to True.
        figsize (Tuple[int, int], optional): The size of the figure. Defaults to (15, 5).
        show (bool, optional): If True, displays the plot (i.e. calls plt.show()). Defaults to False.

    Raises:
        ValueError: If the axis to slice along is not 0, 1, or 2.
        ValueError: If the slice index is not an integer or a float between 0 and 1.

    Returns:
        fig (Union[plt.Figure, widgets.interactive]): If `interactive` is True, returns an interactive widget. Otherwise, returns a matplotlib figure.

    Note:
        The orientation of the vectors is visualized using an HSV color map, where the saturation corresponds to the vector component
        of the slicing direction (i.e. z-component when choosing visualization along `axis = 0`). Hence, if an orientation in the volume
        is orthogonal to the slicing direction, the corresponding color of the visualization will be gray.

    Example:
        ```python
        import qim3d

        vol = qim3d.examples.NT_128x128x128
        val, vec = qim3d.processing.structure_tensor(vol)

        # Visualize the structure tensor
        qim3d.viz.vectors(vol, vec, axis = 2, interactive = True)
        ```
        ![structure tensor](assets/screenshots/structure_tensor_visualization.gif)

    """

    # Ensure volume is a float
    if volume.dtype != np.float32 and volume.dtype != np.float64:
        volume = volume.astype(np.float32)

    # Normalize the volume if needed (i.e. if values are in [0, 255])
    if volume.max() > 1.0:
        volume = volume / 255.0

    # Define grid size limits
    min_grid_size = max(1, volume.shape[axis] // 50)
    max_grid_size = max(1, volume.shape[axis] // 10)
    if max_grid_size <= min_grid_size:
        max_grid_size = min_grid_size * 5

    if not grid_size:
        grid_size = (min_grid_size + max_grid_size) // 2

    # Testing
    if grid_size < min_grid_size or grid_size > max_grid_size:
        # Adjust grid size as little as possible to be within the limits
        grid_size = min(max(min_grid_size, grid_size), max_grid_size)
        log.warning(f"Adjusting grid size to {grid_size} as it is out of bounds.")

    def _structure_tensor(volume, vec, axis, slice_idx, grid_size, figsize, show):

        # Choose the appropriate slice based on the specified dimension
        if axis == 0:
            data_slice = volume[slice_idx, :, :]
            vectors_slice_x = vec[0, slice_idx, :, :]
            vectors_slice_y = vec[1, slice_idx, :, :]
            vectors_slice_z = vec[2, slice_idx, :, :]

        elif axis == 1:
            data_slice = volume[:, slice_idx, :]
            vectors_slice_x = vec[0, :, slice_idx, :]
            vectors_slice_y = vec[2, :, slice_idx, :]
            vectors_slice_z = vec[1, :, slice_idx, :]

        elif axis == 2:
            data_slice = volume[:, :, slice_idx]
            vectors_slice_x = vec[1, :, :, slice_idx]
            vectors_slice_y = vec[2, :, :, slice_idx]
            vectors_slice_z = vec[0, :, :, slice_idx]

        else:
            raise ValueError("Invalid dimension. Use 0 for Z, 1 for Y, or 2 for X.")

        # Create three subplots
        fig, ax = plt.subplots(1, 3, figsize=figsize, layout="constrained")

        blend_hue_saturation = (
            lambda hue, sat: hue * (1 - sat) + 0.5 * sat
        )  # Function for blending hue and saturation
        blend_slice_colors = lambda slice, colors: 0.5 * (
            slice + colors
        )  # Function for blending image slice with orientation colors

        # ----- Subplot 1: Image slice with orientation vectors ----- #
        # Create meshgrid with the correct dimensions
        xmesh, ymesh = np.mgrid[0 : data_slice.shape[0], 0 : data_slice.shape[1]]

        # Create a slice object for selecting the grid points
        g = slice(grid_size // 2, None, grid_size)

        # Angles from 0 to pi
        angles_quiver = np.mod(
            np.arctan2(vectors_slice_y[g, g], vectors_slice_x[g, g]), np.pi
        )

        # Calculate z-component (saturation)
        saturation_quiver = (vectors_slice_z[g, g] ** 2)[:, :, np.newaxis]

        # Calculate hue
        hue_quiver = plt.cm.hsv(angles_quiver / np.pi)

        # Blend hue and saturation
        rgba_quiver = blend_hue_saturation(hue_quiver, saturation_quiver)
        rgba_quiver = np.clip(
            rgba_quiver, 0, 1
        )  # Ensure rgba values are values within [0, 1]
        rgba_quiver_flat = rgba_quiver.reshape(
            (rgba_quiver.shape[0] * rgba_quiver.shape[1], 4)
        )  # Flatten array for quiver plot

        # Plot vectors
        ax[0].quiver(
            ymesh[g, g],
            xmesh[g, g],
            vectors_slice_x[g, g],
            vectors_slice_y[g, g],
            color=rgba_quiver_flat,
            angles="xy",
        )
        ax[0].quiver(
            ymesh[g, g],
            xmesh[g, g],
            -vectors_slice_x[g, g],
            -vectors_slice_y[g, g],
            color=rgba_quiver_flat,
            angles="xy",
        )

        ax[0].imshow(data_slice, cmap = volume_cmap, vmin = vmin, vmax = vmax)
        ax[0].set_title(
            f"Orientation vectors (slice {slice_idx})"
            if not interactive
            else "Orientation vectors"
        )
        ax[0].set_axis_off()

        # ----- Subplot 2: Orientation histogram ----- #
        nbins = 36

        # Angles from 0 to pi
        angles = np.mod(np.arctan2(vectors_slice_y, vectors_slice_x), np.pi)

        # Orientation histogram over angles
        distribution, bin_edges = np.histogram(angles, bins=nbins, range=(0.0, np.pi))

        # Half circle (180 deg)
        bin_centers = (np.arange(nbins) + 0.5) * np.pi / nbins

        # Calculate z-component (saturation) for each bin
        bins = np.digitize(angles.ravel(), bin_edges)
        saturation_bin = np.array(
            [
                (
                    np.mean((vectors_slice_z**2).ravel()[bins == i])
                    if np.sum(bins == i) > 0
                    else 0
                )
                for i in range(1, len(bin_edges))
            ]
        )

        # Calculate hue for each bin
        hue_bin = plt.cm.hsv(bin_centers / np.pi)

        # Blend hue and saturation
        rgba_bin = hue_bin.copy()
        rgba_bin[:, :3] = blend_hue_saturation(
            hue_bin[:, :3], saturation_bin[:, np.newaxis]
        )

        ax[1].bar(bin_centers, distribution, width=np.pi / nbins, color=rgba_bin)
        ax[1].set_xlabel("Angle [radians]")
        ax[1].set_xlim([0, np.pi])
        ax[1].set_aspect(np.pi / ax[1].get_ylim()[1])
        ax[1].set_xticks([0, np.pi / 2, np.pi])
        ax[1].set_xticklabels(["0", "$\\frac{\\pi}{2}$", "$\\pi$"])
        ax[1].set_yticks([])
        ax[1].set_ylabel("Frequency")
        ax[1].set_title(f"Histogram over orientation angles")

        # ----- Subplot 3: Image slice colored according to orientation ----- #
        # Calculate z-component (saturation)
        saturation = (vectors_slice_z**2)[:, :, np.newaxis]

        # Calculate hue
        hue = plt.cm.hsv(angles / np.pi)

        # Blend hue and saturation
        rgba = blend_hue_saturation(hue, saturation)

        # Grayscale image slice blended with orientation colors
        data_slice_orientation_colored = (
            blend_slice_colors(plt.cm.gray(data_slice), rgba) * 255
        ).astype("uint8")

        ax[2].imshow(data_slice_orientation_colored)
        ax[2].set_title(
            f"Colored orientations (slice {slice_idx})"
            if not interactive
            else "Colored orientations"
        )
        ax[2].set_axis_off()

        if show:
            plt.show()

        plt.close()

        return fig

    if vec.ndim == 5:
        vec = vec[0, ...]
        log.warning(
            "Eigenvector array is full. Only the eigenvectors corresponding to the first eigenvalue will be used."
        )

    if slice_idx is None:
        slice_idx = volume.shape[axis] // 2

    elif isinstance(slice_idx, float):
        if slice_idx < 0 or slice_idx > 1:
            raise ValueError(
                "Values of slice_idx of float type must be between 0 and 1."
            )
        slice_idx = int(slice_idx * volume.shape[0]) - 1

    if interactive:
        slide_idx_slider = widgets.IntSlider(
            min=0,
            max=volume.shape[axis] - 1,
            step=1,
            value=slice_idx,
            description="Slice index",
            layout=widgets.Layout(width="450px"),
        )

        grid_size_slider = widgets.IntSlider(
            min=min_grid_size,
            max=max_grid_size,
            step=1,
            value=grid_size,
            description="Grid size",
            layout=widgets.Layout(width="450px"),
        )

        widget_obj = widgets.interactive(
            _structure_tensor,
            volume=widgets.fixed(volume),
            vec=widgets.fixed(vec),
            axis=widgets.fixed(axis),
            slice_idx=slide_idx_slider,
            grid_size=grid_size_slider,
            figsize=widgets.fixed(figsize),
            show=widgets.fixed(True),
        )
        # Arrange sliders horizontally
        sliders_box = widgets.HBox([slide_idx_slider, grid_size_slider])
        widget_obj = widgets.VBox([sliders_box, widget_obj.children[-1]])
        widget_obj.layout.align_items = "center"

        if show:
            display(widget_obj)

        return widget_obj

    else:
        return _structure_tensor(volume, vec, axis, slice_idx, grid_size, figsize, show)

qim3d.viz.plot_cc

plot_cc(connected_components, component_indexs=None, max_cc_to_plot=32, overlay=None, crop=False, show=True, cmap='viridis', vmin=None, vmax=None, **kwargs)

Plots the connected components from a qim3d.processing.cc.CC object. If an overlay image is provided, the connected component will be masked to the overlay image.

Parameters:

Name Type Description Default
connected_components CC

The connected components object.

required
component_indexs list | tuple

The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None.

None
max_cc_to_plot int

The maximum number of connected components to plot. Defaults to 32.

32
overlay optional

Overlay image. Defaults to None.

None
crop bool

Whether to crop the image to the cc. Defaults to False.

False
show bool

Whether to show the figure. Defaults to True.

True
cmap str

Specifies the color map for the image. Defaults to "viridis".

'viridis'
vmin float

Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
vmax float

Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
**kwargs

Additional keyword arguments to pass to qim3d.viz.slices.

{}

Returns:

Name Type Description
figs list[Figure]

List of figures, if show=False.

Example

import qim3d
vol = qim3d.examples.cement_128x128x128[50:150]
vol_bin = vol<80
cc = qim3d.processing.get_3d_cc(vol_bin)
qim3d.viz.plot_cc(cc, crop=True, show=True, overlay=None, n_slices=5, component_indexs=[4,6,7])
qim3d.viz.plot_cc(cc, crop=True, show=True, overlay=vol, n_slices=5, component_indexs=[4,6,7])
plot_cc_no_overlay plot_cc_overlay

Source code in qim3d/viz/cc.py
def plot_cc(
    connected_components,
    component_indexs: list | tuple = None,
    max_cc_to_plot=32,
    overlay=None,
    crop=False,
    show=True,
    cmap:str = 'viridis',
    vmin:float = None,
    vmax:float = None,
    **kwargs,
) -> list[plt.Figure]:
    """
    Plots the connected components from a `qim3d.processing.cc.CC` object. If an overlay image is provided, the connected component will be masked to the overlay image.

    Parameters:
        connected_components (CC): The connected components object.
        component_indexs (list | tuple, optional): The components to plot. If None the first max_cc_to_plot=32 components will be plotted. Defaults to None.
        max_cc_to_plot (int, optional): The maximum number of connected components to plot. Defaults to 32.
        overlay (optional): Overlay image. Defaults to None.
        crop (bool, optional): Whether to crop the image to the cc. Defaults to False.
        show (bool, optional): Whether to show the figure. Defaults to True.
        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None
        **kwargs: Additional keyword arguments to pass to `qim3d.viz.slices`.

    Returns:
        figs (list[plt.Figure]): List of figures, if `show=False`.

    Example:
        ```python
        import qim3d
        vol = qim3d.examples.cement_128x128x128[50:150]
        vol_bin = vol<80
        cc = qim3d.processing.get_3d_cc(vol_bin)
        qim3d.viz.plot_cc(cc, crop=True, show=True, overlay=None, n_slices=5, component_indexs=[4,6,7])
        qim3d.viz.plot_cc(cc, crop=True, show=True, overlay=vol, n_slices=5, component_indexs=[4,6,7])
        ```
        ![plot_cc_no_overlay](assets/screenshots/plot_cc_no_overlay.png)
        ![plot_cc_overlay](assets/screenshots/plot_cc_overlay.png)
    """
    # if no components are given, plot the first max_cc_to_plot=32 components
    if component_indexs is None:
        if len(connected_components) > max_cc_to_plot:
            log.warning(
                f"More than {max_cc_to_plot} connected components found. Only the first {max_cc_to_plot} will be plotted. Change max_cc_to_plot to plot more components."
            )
        component_indexs = range(
            1, min(max_cc_to_plot + 1, len(connected_components) + 1)
        )

    figs = []
    for component in component_indexs:
        if overlay is not None:
            assert (
                overlay.shape == connected_components.shape
            ), f"Overlay image must have the same shape as the connected components. overlay.shape=={overlay.shape} != connected_components.shape={connected_components.shape}."

            # plots overlay masked to connected component
            if crop:
                # Crop the overlay image based on the bounding box of the component
                bb = connected_components.get_bounding_box(component)[0]
                cc = connected_components.get_cc(component, crop=True)
                overlay_crop = overlay[bb]
                # use cc as mask for overlay_crop, where all values in cc set to 0 should be masked out, cc contains integers
                overlay_crop = np.where(cc == 0, 0, overlay_crop)
            else:
                cc = connected_components.get_cc(component, crop=False)
                overlay_crop = np.where(cc == 0, 0, overlay)
            fig = qim3d.viz.slices(overlay_crop, show=show, cmap = cmap, vmin = vmin, vmax = vmax, **kwargs)
        else:
            # assigns discrete color map to each connected component if not given
            if "cmap" not in kwargs:
                kwargs["cmap"] = qim3d.viz.colormaps.objects(len(component_indexs))

            # Plot the connected component without overlay
            fig = qim3d.viz.slices(
                connected_components.get_cc(component, crop=crop), show=show, **kwargs
            )

        figs.append(fig)

    if not show:
        return figs

    return

qim3d.viz.interactive_fade_mask

interactive_fade_mask(vol, axis=0, cmap='viridis', vmin=None, vmax=None)

Interactive widget for visualizing the effect of edge fading on a 3D volume.

This can be used to select the best parameters before applying the mask.

Parameters:

Name Type Description Default
vol ndarray

The volume to apply edge fading to.

required
axis int

The axis along which to apply the fading. Defaults to 0.

0
cmap str

Specifies the color map for the image. Defaults to "viridis".

'viridis'
vmin float

Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.

None
vmax float

Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None

None
Example

import qim3d
vol = qim3d.examples.cement_128x128x128
qim3d.viz.interactive_fade_mask(vol)
operations-edge_fade_before

Source code in qim3d/viz/explore.py
def interactive_fade_mask(
    vol: np.ndarray,
    axis: int = 0,
    cmap: str = "viridis",
    vmin: float = None,
    vmax: float = None,
):
    """Interactive widget for visualizing the effect of edge fading on a 3D volume.

    This can be used to select the best parameters before applying the mask.

    Args:
        vol (np.ndarray): The volume to apply edge fading to.
        axis (int, optional): The axis along which to apply the fading. Defaults to 0.
        cmap (str, optional): Specifies the color map for the image. Defaults to "viridis".
        vmin (float, optional): Together with vmax define the data range the colormap covers. By default colormap covers the full range. Defaults to None.
        vmax (float, optional): Together with vmin define the data range the colormap covers. By default colormap covers the full range. Defaults to None

    Example:
        ```python
        import qim3d
        vol = qim3d.examples.cement_128x128x128
        qim3d.viz.interactive_fade_mask(vol)
        ```
        ![operations-edge_fade_before](assets/screenshots/viz-fade_mask.gif)

    """

    # Create the interactive widget
    def _slicer(position, decay_rate, ratio, geometry, invert):
        fig, axes = plt.subplots(1, 3, figsize=(9, 3))

        slice_img = vol[position, :, :]
        # If vmin is higher than the highest value in the image ValueError is raised
        # We don't want to override the values because next slices might be okay
        new_vmin = (
            None
            if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
            else vmin
        )
        new_vmax = (
            None
            if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
            else vmax
        )

        axes[0].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
        axes[0].set_title("Original")
        axes[0].axis("off")

        mask = qim3d.processing.operations.fade_mask(
            np.ones_like(vol),
            decay_rate=decay_rate,
            ratio=ratio,
            geometry=geometry,
            axis=axis,
            invert=invert,
        )
        axes[1].imshow(mask[position, :, :], cmap=cmap)
        axes[1].set_title("Mask")
        axes[1].axis("off")

        masked_vol = qim3d.processing.operations.fade_mask(
            vol,
            decay_rate=decay_rate,
            ratio=ratio,
            geometry=geometry,
            axis=axis,
            invert=invert,
        )
        # If vmin is higher than the highest value in the image ValueError is raised
        # We don't want to override the values because next slices might be okay
        slice_img = masked_vol[position, :, :]
        new_vmin = (
            None
            if (isinstance(vmin, (float, int)) and vmin > np.max(slice_img))
            else vmin
        )
        new_vmax = (
            None
            if (isinstance(vmax, (float, int)) and vmax < np.min(slice_img))
            else vmax
        )
        axes[2].imshow(slice_img, cmap=cmap, vmin=new_vmin, vmax=new_vmax)
        axes[2].set_title("Masked")
        axes[2].axis("off")

        return fig

    shape_dropdown = widgets.Dropdown(
        options=["spherical", "cylindrical"],
        value="spherical",  # default value
        description="Geometry",
    )

    position_slider = widgets.IntSlider(
        value=vol.shape[0] // 2,
        min=0,
        max=vol.shape[0] - 1,
        description="Slice",
        continuous_update=False,
    )
    decay_rate_slider = widgets.FloatSlider(
        value=10,
        min=1,
        max=50,
        step=1.0,
        description="Decay Rate",
        continuous_update=False,
    )
    ratio_slider = widgets.FloatSlider(
        value=0.5,
        min=0.1,
        max=1,
        step=0.01,
        description="Ratio",
        continuous_update=False,
    )

    # Create the Checkbox widget
    invert_checkbox = widgets.Checkbox(
        value=False, description="Invert"  # default value
    )

    slicer_obj = widgets.interactive(
        _slicer,
        position=position_slider,
        decay_rate=decay_rate_slider,
        ratio=ratio_slider,
        geometry=shape_dropdown,
        invert=invert_checkbox,
    )
    slicer_obj.layout = widgets.Layout(align_items="flex-start")

    return slicer_obj

qim3d.viz.colormaps

This module provides a collection of colormaps useful for 3D visualization.

qim3d.viz.colormaps.qim module-attribute

qim = from_list('qim', [(0.6, 0.0, 0.0), (1.0, 0.6, 0.0)])

Defines colormap in QIM logo colors. Can be accessed as module attribute or easily by cmap = 'qim'

Example

import qim3d

display(qim3d.viz.colormaps.qim)
colormap objects

qim3d.viz.colormaps.objects

objects(nlabels, style='bright', first_color_background=True, last_color_background=False, background_color=(0.0, 0.0, 0.0), min_dist=0.5, seed=19)

Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks

Parameters:

Name Type Description Default
nlabels int

Number of labels (size of colormap).

required
style str

'bright' for strong colors, 'soft' for pastel colors, 'earth' for yellow/green/blue colors, 'ocean' for blue/purple/pink colors. Defaults to 'bright'.

'bright'
first_color_background bool

If True, the first color is used as background. Defaults to True.

True
last_color_background bool

If True, the last color is used as background. Defaults to False.

False
background_color tuple or str

RGB tuple or string for background color. Can be "black" or "white". Defaults to (0.0, 0.0, 0.0).

(0.0, 0.0, 0.0)
min_dist int

Minimum distance between neighboring colors. Defaults to 0.5.

0.5
seed int

Seed for random number generator. Defaults to 19.

19

Returns:

Name Type Description
cmap LinearSegmentedColormap

Colormap for matplotlib

Example

import qim3d

cmap_bright = qim3d.viz.colormaps.objects(nlabels=100, style = 'bright', first_color_background=True, background_color="black", min_dist=0.7)
cmap_soft = qim3d.viz.colormaps.objects(nlabels=100, style = 'soft', first_color_background=True, background_color="black", min_dist=0.2)
cmap_earth = qim3d.viz.colormaps.objects(nlabels=100, style = 'earth', first_color_background=True, background_color="black", min_dist=0.8)
cmap_ocean = qim3d.viz.colormaps.objects(nlabels=100, style = 'ocean', first_color_background=True, background_color="black", min_dist=0.9)

display(cmap_bright)
display(cmap_soft)
display(cmap_earth)
display(cmap_ocean)
colormap objects

import qim3d

vol = qim3d.examples.cement_128x128x128
binary = qim3d.processing.filters.gaussian(vol, sigma = 2) < 60
labeled_volume, num_labels = qim3d.processing.operations.watershed(binary)

cmap = qim3d.viz.colormaps.objects(num_labels, style = 'bright')
qim3d.viz.slicer(labeled_volume, axis = 1, cmap=cmap)
colormap objects

Tip

The min_dist parameter can be used to control the distance between neighboring colors. colormap objects mind_dist

Source code in qim3d/viz/colormaps.py
def objects(
    nlabels: int,
    style: str = "bright",
    first_color_background: bool = True,
    last_color_background: bool = False,
    background_color: Union[Tuple[float, float, float], str] = (0.0, 0.0, 0.0),
    min_dist: int = 0.5,
    seed: int = 19,
) -> LinearSegmentedColormap:
    """
    Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks

    Args:
        nlabels (int): Number of labels (size of colormap).
        style (str, optional): 'bright' for strong colors, 'soft' for pastel colors, 'earth' for yellow/green/blue colors, 'ocean' for blue/purple/pink colors. Defaults to 'bright'.
        first_color_background (bool, optional): If True, the first color is used as background. Defaults to True.
        last_color_background (bool, optional): If True, the last color is used as background. Defaults to False.
        background_color (tuple or str, optional): RGB tuple or string for background color. Can be "black" or "white". Defaults to (0.0, 0.0, 0.0).
        min_dist (int, optional): Minimum distance between neighboring colors. Defaults to 0.5.
        seed (int, optional): Seed for random number generator. Defaults to 19.

    Returns:
        cmap (matplotlib.colors.LinearSegmentedColormap): Colormap for matplotlib


    Example:
        ```python
        import qim3d

        cmap_bright = qim3d.viz.colormaps.objects(nlabels=100, style = 'bright', first_color_background=True, background_color="black", min_dist=0.7)
        cmap_soft = qim3d.viz.colormaps.objects(nlabels=100, style = 'soft', first_color_background=True, background_color="black", min_dist=0.2)
        cmap_earth = qim3d.viz.colormaps.objects(nlabels=100, style = 'earth', first_color_background=True, background_color="black", min_dist=0.8)
        cmap_ocean = qim3d.viz.colormaps.objects(nlabels=100, style = 'ocean', first_color_background=True, background_color="black", min_dist=0.9)

        display(cmap_bright)
        display(cmap_soft)
        display(cmap_earth)
        display(cmap_ocean)
        ```
        ![colormap objects](assets/screenshots/viz-colormaps-objects-all.png)

        ```python
        import qim3d

        vol = qim3d.examples.cement_128x128x128
        binary = qim3d.processing.filters.gaussian(vol, sigma = 2) < 60
        labeled_volume, num_labels = qim3d.processing.operations.watershed(binary)

        cmap = qim3d.viz.colormaps.objects(num_labels, style = 'bright')
        qim3d.viz.slicer(labeled_volume, axis = 1, cmap=cmap)
        ```
        ![colormap objects](assets/screenshots/viz-colormaps-objects.gif)

    Tip:
        The `min_dist` parameter can be used to control the distance between neighboring colors.
        ![colormap objects mind_dist](assets/screenshots/viz-colormaps-min_dist.gif)


    """
    from skimage import color

    # Check style
    if style not in ("bright", "soft", "earth", "ocean"):
        raise ValueError(
            f'Please choose "bright", "soft", "earth" or "ocean" for style in qim3dCmap not "{style}"'
        )

    # Translate strings to background color
    color_dict = {"black": (0.0, 0.0, 0.0), "white": (1.0, 1.0, 1.0)}
    if not isinstance(background_color, tuple):
        try:
            background_color = color_dict[background_color]
        except KeyError:
            raise ValueError(
                f'Invalid color name "{background_color}". Please choose from {list(color_dict.keys())}.'
            )

    # Add one to nlabels to include the background color
    nlabels += 1

    # Create a new random generator, to locally set seed
    rng = np.random.default_rng(seed)

    # Generate color map for bright colors, based on hsv
    if style == "bright":
        randHSVcolors = [
            (
                rng.uniform(low=0.0, high=1),
                rng.uniform(low=0.4, high=1),
                rng.uniform(low=0.9, high=1),
            )
            for i in range(nlabels)
        ]

        # Convert HSV list to RGB
        randRGBcolors = []
        for HSVcolor in randHSVcolors:
            randRGBcolors.append(
                colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])
            )

    # Generate soft pastel colors, by limiting the RGB spectrum
    if style == "soft":
        low = 0.6
        high = 0.95
        randRGBcolors = [
            (
                rng.uniform(low=low, high=high),
                rng.uniform(low=low, high=high),
                rng.uniform(low=low, high=high),
            )
            for i in range(nlabels)
        ]

    # Generate color map for earthy colors, based on LAB
    if style == "earth":
        randLABColors = [
            (
                rng.uniform(low=25, high=110),
                rng.uniform(low=-120, high=70),
                rng.uniform(low=-70, high=70),
            )
            for i in range(nlabels)
        ]

        # Convert LAB list to RGB
        randRGBcolors = []
        for LabColor in randLABColors:
            randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist())

    # Generate color map for ocean colors, based on LAB
    if style == "ocean":
        randLABColors = [
            (
                rng.uniform(low=0, high=110),
                rng.uniform(low=-128, high=160),
                rng.uniform(low=-128, high=0),
            )
            for i in range(nlabels)
        ]

        # Convert LAB list to RGB
        randRGBcolors = []
        for LabColor in randLABColors:
            randRGBcolors.append(color.lab2rgb([[LabColor]])[0][0].tolist())

    # Re-arrange colors to have a minimum distance between neighboring colors
    randRGBcolors = rearrange_colors(randRGBcolors, min_dist)

    # Set first and last color to background
    if first_color_background:
        randRGBcolors[0] = background_color

    if last_color_background:
        randRGBcolors[-1] = background_color

    # Create colormap
    objects = LinearSegmentedColormap.from_list("objects", randRGBcolors, N=nlabels)

    return objects