Source code for timagetk.tasks.fusion

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------------
#  Copyright (c) 2022 Univ. Lyon, ENS de Lyon, UCB Lyon 1, CNRS, INRAe, Inria
#  All rights reserved.
#  This file is part of the TimageTK library, and is released under the "GPLv3"
#  license. Please see the LICENSE.md file that should have been included as
#  part of this package.
# ------------------------------------------------------------------------------

import time

import numpy as np

from timagetk.algorithms.averaging import images_averaging
from timagetk.algorithms.blockmatching import blockmatching
from timagetk.algorithms.resample import resample
from timagetk.algorithms.template import iso_template
from timagetk.algorithms.trsf import apply_trsf
from timagetk.algorithms.trsf import compose_trsf
from timagetk.algorithms.trsf import create_trsf
from timagetk.algorithms.trsf import inv_trsf
from timagetk.algorithms.trsf import trsfs_averaging
from timagetk.bin.logger import get_logger
from timagetk.components.image import get_image_attributes
from timagetk.components.spatial_image import SpatialImage
from timagetk.util import elapsed_time

log = get_logger(__name__)


[docs] def fusion_on_first(images, method="mean", init_trsfs=None, **kwargs): """Fuse a list of image by registering them on the first of the list. Parameters ---------- images : list of SpatialImage List of image to fuse together method : str, optional Image fusion method to use, 'mean' by default, see ``AVERAGING_METHODS`` init_trsfs : list of Trsf or None, optional List of transformations to use to initialize registration, must be of type *rigid*. Other Parameters ---------------- vectorfield : bool If ``True`` (default is ``False``), also compute a non-linear deformation to match the images. fuse_reference : bool If ``True`` (default), add the reference image to the list of images to fuse. final_vxs : float If specified, the template image will be the isometric resampling of th reference image to this voxel-size float value super_resolution : bool If ``True`` (default), compute a reference frame with minimum voxel-size. global_averaging : bool If ``True`` (default is ``False``), perform global averaging of each voxel by the total number of images. Else performs local averaging based on the number of defined imaged per voxels after registrations. Returns ------- timagetk.SpatialImage The averaged images after registration. list of Trsf The list of initial transformation files. See Also -------- timagetk.third_party.vt_parser.AVERAGING_METHODS Notes ----- Prior to averaging, masks are generated to average each voxel by the number of locally defined images and not the total number of images. Examples -------- >>> import numpy as np >>> from timagetk.tasks.fusion import fusion_on_first >>> from timagetk.algorithms.pointmatching import pointmatching >>> from timagetk.algorithms.resample import resample >>> from timagetk.algorithms.trsf import inv_trsf >>> from timagetk.io import imread >>> from timagetk.io.util import shared_data >>> from timagetk.visu.stack import orthogonal_view >>> # Using shared data as example, images are multi-angle view of the same floral meristem: >>> fname = 'p58-t0-a{}.lsm' >>> list_files = [shared_data(fname.format(a), 'p58') for a in range(3)] >>> list_img = [imread(fname) for fname in list_files] >>> # Load shared multi-angle landmarks for the first time point (t0) of 'p58' shared time-series >>> ref_pts_01 = np.loadtxt(shared_data('p58_t0_reference_ldmk-01.txt', 'p58')) >>> flo_pts_01 = np.loadtxt(shared_data('p58_t0_floating_ldmk-01.txt', 'p58')) >>> ref_pts_02 = np.loadtxt(shared_data('p58_t0_reference_ldmk-02.txt', 'p58')) >>> flo_pts_02 = np.loadtxt(shared_data('p58_t0_floating_ldmk-02.txt', 'p58')) >>> # Creates manual initialization transformations with `pointmatching` algorithm: >>> trsf_01 = pointmatching(flo_pts_01, ref_pts_01, template_img=list_img[0], method='rigid') >>> trsf_02 = pointmatching(flo_pts_02, ref_pts_02, template_img=list_img[0], method='rigid') >>> # Example 1 - Fuse the image, in super-resolution, after affine registration with manual initialization >>> fused_img, _ = fusion_on_first(list_img, init_trsfs=[trsf_01, trsf_02], super_resolution=True) >>> orthogonal_view(fused_img,suptitle="Fusion: centered local average image after affine registration with manual initialization") >>> # Example 2 - Fuse the image, in super-resolution, after affine registration with manual initialization >>> fused_img, _ = fusion_on_first(list_img, init_trsfs=[trsf_01, trsf_02], super_resolution=True, global_averaging=True) >>> orthogonal_view(fused_img,suptitle="Fusion: centered global average image after affine registration with manual initialization") >>> # Example 3 - Fuse the image, in super-resolution, after affine & non-linear registration with manual initialization >>> fused_img, _ = fusion_on_first(list_img, init_trsfs=[trsf_01, trsf_02], super_resolution=True, vectorfield=True) >>> orthogonal_view(fused_img,suptitle="Fusion: centered average image after affine & non-linear registration with manual initialization") """ init_trsfs = _check_init_rigid_trsf(init_trsfs, images) return _fusion_on_first(images, method=method, init_trsfs=init_trsfs, **kwargs)
[docs] def iterative_fusion(images, method="mean", init_trsfs=None, n_iter=3, **kwargs): """Iterative fusion of a list of multi-angle images. Parameters ---------- images : list of SpatialImage List of image to fuse together method : str, optional Image fusion method to use, 'mean' by default, see ``AVERAGING_METHODS`` init_trsfs : list of Trsf, optional List of transformations to use to initialize registration, must be of type *rigid*. n_iter : int, optional Number of iterations to perform, including the initial registration on first image. Other Parameters ---------------- fuse_reference : bool If `True` (default), add the reference image to the list of images to fuse. global_averaging : bool If ``True`` (default is ``False``), perform global averaging of each voxel by the total number of images. Else performs local averaging based on the number of defined image per voxels after registrations. super_resolution : bool Defines the voxelsize of the returned fused image. If ``True`` (default is ``False``), use the smallest voxelsize of the reference image as isometric voxelsize. Else, use the largest voxelsize of the reference image as isometric voxelsize. vectorfield_at_last : bool If ``True`` (default is ``False``), perform non-linear registration at last iteration step. Else, do not performs non-linear registration, only rigid and affine. final_vxs : float The voxel-size of the final isometric image, override the voxelsize computed by `super_resolution`. interpolation : {'linear', 'cspline'} The type of interpolation to performs when resampling and applying transformation to intensity images. Defaults to 'linear'. Returns ------- timagetk.SpatialImage The fused image after iterative registration. See Also -------- timagetk.third_party.vt_parser.AVERAGING_METHODS Notes ----- Only the last iteration will use the maximum resolution. The previous steps are performed on 2x down-sampled images. Examples -------- >>> import numpy as np >>> from timagetk.tasks.fusion import iterative_fusion >>> from timagetk.algorithms.pointmatching import pointmatching >>> from timagetk.algorithms.resample import resample >>> from timagetk.algorithms.trsf import inv_trsf >>> from timagetk.io import imread >>> from timagetk.io.util import shared_data >>> from timagetk.visu.stack import stack_browser >>> from timagetk.visu.stack import orthogonal_view >>> # Using shared data as example, images are multi-angle view of the same floral meristem: >>> fname = 'p58-t0-a{}.lsm' >>> list_files = [shared_data(fname.format(a), 'p58') for a in range(3)] >>> list_img = [imread(fname) for fname in list_files] >>> ref_pts_01 = np.loadtxt(shared_data('p58_t0_reference_ldmk-01.txt', 'p58')) >>> flo_pts_01 = np.loadtxt(shared_data('p58_t0_floating_ldmk-01.txt', 'p58')) >>> ref_pts_02 = np.loadtxt(shared_data('p58_t0_reference_ldmk-02.txt', 'p58')) >>> flo_pts_02 = np.loadtxt(shared_data('p58_t0_floating_ldmk-02.txt', 'p58')) >>> # Creates manual initialization transformations with `pointmatching` algorithm: >>> trsf_01 = pointmatching(flo_pts_01, ref_pts_01, template_img=list_img[0], method='rigid') >>> trsf_02 = pointmatching(flo_pts_02, ref_pts_02, template_img=list_img[0], method='rigid') >>> # Example #1 - Fast iterative fusion: non-linear registration only at last iteration >>> vf_fused_img = iterative_fusion(list_img, init_trsfs=[trsf_01, trsf_02], vectorfield_at_last=True) >>> orthogonal_view(vf_fused_img,suptitle="p58-t0 multi-angle iterative fusion in super-resolution from AFFINE registration with manual initialization") >>> # Example #2 - Slower iterative fusion: non-linear registration only at every iteration >>> fused_img = iterative_fusion(list_img, init_trsfs=[trsf_01, trsf_02]) >>> orthogonal_view(fused_img,suptitle="p58-t0 multi-angle iterative fusion in super-resolution from AFFINE registration with manual initialization") >>> from timagetk.components.multi_channel import combine_channels >>> from timagetk.visu.mplt import grayscale_imshow >>> cuts = [1 / 6., 1 / 4., 1 / 2., 2 / 3., 3 / 4.] >>> z_sh = fused_img.get_shape('z') >>> z_slices = [int(round(c * z_sh, 0)) for c in cuts] >>> blend = combine_channels([fused_img, vf_fused_img], colors=["red", "green"]) >>> thumbs = [blend.get_slice(zsl) for zsl in z_slices] >>> titles = [f"z-slice {zsl}/{z_sh}" for zsl in z_slices] >>> grayscale_imshow(thumbs, titles=titles) """ try: assert n_iter >= 2 except AssertionError: raise ValueError("You need a minimum of 2 iterations, otherwise use the `fusion` method!") ref_img, _ = get_ref_im_float_list(images, kwargs.pop('ref_id_img', 0)) # Defines the voxel-size of the isometric templates using the reference image: final_vxs = kwargs.pop('final_vxs', None) super_res = kwargs.pop('super_resolution', False) if final_vxs is None: if super_res: super_rez_vxs = min(ref_img.get_voxelsize()) final_vxs = super_rez_vxs else: final_vxs = max(ref_img.get_voxelsize()) # Initialize the rigid transformations log.info("Initial rigid registration on the first image of the list...") init_trsfs = _check_init_rigid_trsf(init_trsfs, images) template_vxs = final_vxs * 2. # Defines the voxel-size of the isometric template to use for the initial fusion # Initial fusion to isometric template including the reference image it = 0 # Initialize fusion iteration counter to 0 log.info(f"Fusion - Initial registration (iteration {it})...") fused_img, _ = _fusion_on_first(images, method=method, init_trsfs=init_trsfs, fuse_reference=True, super_resolution=False, final_vxs=template_vxs, **kwargs) # As we will now register the reference image on the previously fused image # we have to add the identity trsf in first position of `init_trsfs` init_trsfs = [create_trsf('identity', trsf_type='rigid', template_img=ref_img)] + init_trsfs for it in range(1, n_iter): if it == n_iter - 1: # maximum resolution only at last iteration to save processing time & memory template_vxs = final_vxs # if required, perform vectorfield registration at last iteration if kwargs.get('vectorfield_at_last', False): kwargs["vectorfield"] = True log.info(f"Fusion - Registration iteration {it}...") fused_img, _ = _fusion_on_first([fused_img] + images, method=method, init_trsfs=init_trsfs, fuse_reference=False, super_resolution=False, final_vxs=template_vxs, **kwargs) return fused_img
[docs] def get_ref_im_float_list(images, ref_img_id): """Return the reference image and the list of floating images from given index. Parameters ---------- images : list(SpatialImages) The list of SpatialImage to consider. ref_img_id : int The index of the reference image in the initial images list. Returns ------- timagetk.SpatialImage The reference SpatialImage. list(SpatialImages) The list of floating SpatialImages. Examples -------- >>> from timagetk.io import imread >>> from timagetk.io.util import shared_data >>> from timagetk.tasks.fusion import get_ref_im_float_list >>> # Using shared data as example, images are multi-angle view of the same floral meristem: >>> fname = 'p58-t0-a{}.lsm' >>> list_files = [shared_data(fname.format(a), 'p58') for a in range(3)] >>> list_img = [imread(fname) for fname in list_files] >>> # Example 1 - Use the first image of the list as reference: >>> ref_img, float_imgs = get_ref_im_float_list(list_img, 0) # index starts at 0! >>> print(f"Reference image: {ref_img.filename}") >>> print(f"Floating images: {', '.join([flo_img.filename for flo_img in float_imgs])}") Reference image: p58-t0-a0.lsm Floating images: p58-t0-a1.lsm, p58-t0-a2.lsm >>> # Example 2 - Use the third image of the list as reference: >>> ref_img, float_imgs = get_ref_im_float_list(list_img, 2) # index starts at 0! >>> print(f"Reference image: {ref_img.filename}") >>> print(f"Floating images: {', '.join([flo_img.filename for flo_img in float_imgs])}") Reference image: p58-t0-a2.lsm Floating images: p58-t0-a0.lsm, p58-t0-a1.lsm """ n_imgs = len(images) try: assert ref_img_id <= n_imgs except AssertionError: raise ValueError(f"Index `{ref_img_id}` is not accessible in a list of {n_imgs} images!") flo_img_ids = list(range(n_imgs)) flo_img_ids.remove(ref_img_id) ref_img = images[ref_img_id] # All images are registered on the chosen reference float_imgs = [images[idx] for idx in flo_img_ids] # Floating image to register return ref_img, float_imgs
def _check_init_rigid_trsf(init_trsfs, images, **kwargs): """Hidden procedure for initial RIGID transformations. If no initial transformation is provided (`init_trsfs is None`), compute it. Else should be a list of same size as the images list. Also, make sure they are linear transformations or `None`. Finally, improve them with a rigid registration by `blockmatching`. Parameters ---------- init_trsfs : list of Trsf, optional List of transformations to use to initialize registration, must be of type *rigid*. images : list of SpatialImage List of image to fuse together Other Parameters ---------------- ref_id_img : int The reference image index in the ``images`` list, ``0`` by default. Returns ------- init_trsfs List of transformations to use to initialize registration Raises ------ TypeError If the given list of transformations is not linear or `None`. ValueError If the given list of transformations is of different size than the list of images. """ ref_img, float_imgs = get_ref_im_float_list(images, kwargs.pop('ref_id_img', 0)) # Check given init_trsf are LINEAR trsf: if init_trsfs is not None: try: assert all(trsf is None or trsf.is_linear() for trsf in init_trsfs) except AssertionError: msg = "Input list `init_trsf` should contain only linear transformations!" raise TypeError(msg) try: assert len(init_trsfs) == len(images) - 1 except AssertionError: msg = "There should be one more images than transformations!" raise ValueError(msg) # If initial transformation is given, improve it with a quick round of rigid registration: for flo_idx, flo_img in enumerate(float_imgs): init_trsf = blockmatching(flo_img, ref_img, method='rigid', left_trsf=init_trsfs[flo_idx], pyramid_lowest_level=2, quiet=True) init_trsfs[flo_idx] = compose_trsf([init_trsfs[flo_idx], init_trsf]) else: init_trsfs = [] # If no initial transformation(s), for each floating image... for flo_img in float_imgs: # ... compute an initial RIGID registration onto the reference image init_trsf = blockmatching(flo_img, ref_img, method='rigid', pyramid_lowest_level=2, quiet=True) init_trsfs.append(init_trsf) return init_trsfs def _fusion_on_first(images, method, init_trsfs, **kwargs): """Fuse a list of image by registering them on the first of the list. The first image of the list is the reference image, the rest are the floating images. All floating image are registered twice, first the affine then the vectorfield. These transformations (and the initial one, if any), are composed to be applied to their respective floating image. The registered images are then 'averaged' with the selected `method`. The reference image is included in this list by default. Finally, to "center" the transformation, we: 1. compose the estimated affine and vectorfield transformation of each floating image, 2. invert them, 3. average them with the same image averaging method 4. apply this *averaged transformation* to the *averaged image* Parameters ---------- images : list of SpatialImage List of image to fuse together. method : str Image fusion method to use, 'mean' by default, see ``AVERAGING_METHODS`` init_trsfs : list of Trsf List of rigid transformations used to initialize registration, must be of type *rigid*. See the "Notes" section for a detailled explanations. Other Parameters ---------------- vectorfield : bool If ``True`` (default is ``False``), also compute a non-linear deformation to match the images. fuse_reference : bool If ``True`` (default), add the reference image to the list of images to fuse. final_vxs : float If specified, the template image will be the isometric resampling of the reference image to this voxel-size float value super_resolution : bool If ``True`` (default), compute a reference frame with minimum voxelsize. global_averaging : bool If ``True`` (default is ``False``), perform global averaging of each voxel by the total number of images. Else performs local averaging based on the number of defined imaged per voxels after registrations. interpolation : {'linear', 'cspline'} The type of interpolation to performs when resampling and applying transformation to intensity images. Defaults to 'linear'. Returns ------- timagetk.SpatialImage The fused image after registration, averaging and frame centering. list of Trsf The list of transformations to averge. See Also -------- timagetk.third_party.vt_parser.AVERAGING_METHODS Notes ----- The rigid part of the registration should be excluded from the averaged transformations, so the initial transformations should be processed by ``_check_init_rigid_trsf`` prior to using this method! Prior to averaging, masks are generated to average each voxel by the number of locally defined images and not the total number of images. Examples -------- >>> import numpy as np >>> from timagetk.tasks.fusion import _check_init_rigid_trsf >>> from timagetk.tasks.fusion import _fusion_on_first >>> from timagetk.tasks.fusion import fusion_on_first >>> from timagetk.algorithms.pointmatching import pointmatching >>> from timagetk.algorithms.resample import resample >>> from timagetk.algorithms.trsf import apply_trsf >>> from timagetk.algorithms.trsf import inv_trsf >>> from timagetk.io import imread >>> from timagetk.io.util import shared_data >>> from timagetk.visu.stack import orthogonal_view >>> from timagetk.visu.stack import stack_browser >>> from timagetk.visu.registration import registration_snapshot >>> from timagetk.visu.mplt import grayscale_imshow >>> # Using shared data as example, images are multi-angle view of the same floral meristem: >>> fname = 'p58-t0-a{}.lsm' >>> list_files = [shared_data(fname.format(a), 'p58') for a in range(3)] >>> list_img = [imread(fname) for fname in list_files] >>> img_titles = [img.filename for img in list_img] >>> grayscale_imshow(list_img, title=img_titles, val_range=[0, 255], threshold=40) >>> # Load shared multi-angle landmarks for the first time point (t0) of 'p58' shared time-series >>> ref_pts_01 = np.loadtxt(shared_data('p58_t0_reference_ldmk-01.txt', 'p58')) >>> flo_pts_01 = np.loadtxt(shared_data('p58_t0_floating_ldmk-01.txt', 'p58')) >>> ref_pts_02 = np.loadtxt(shared_data('p58_t0_reference_ldmk-02.txt', 'p58')) >>> flo_pts_02 = np.loadtxt(shared_data('p58_t0_floating_ldmk-02.txt', 'p58')) >>> # Creates manual initialization transformations with `pointmatching` algorithm: >>> trsf_01 = pointmatching(flo_pts_01, ref_pts_01, method='rigid', real=True) >>> trsf_02 = pointmatching(flo_pts_02, ref_pts_02, method='rigid', real=True) >>> init_trsfs = _check_init_rigid_trsf([trsf_01, trsf_02], list_img) >>> registered_images = [apply_trsf(img, trsf, template_img=list_img[0]) for img, trsf in zip(list_img[1:], init_trsfs)] >>> grayscale_imshow([list_img[0]]+registered_images, suptitle="Landmark registration on 'a0'.", title=img_titles, val_range=[0, 255], threshold=40) >>> registration_snapshot(list_img, init_trsfs, "p58-t0 manual RIGID registration") >>> fused_img, trsfs = _fusion_on_first(list_img, "mean", init_trsfs, vectorfield=False) >>> registration_snapshot(list_img, trsfs[1:], "p58-t0 multi-angle AFFINE registration with manual initialization") >>> orthogonal_view(list_img[0],suptitle="p58-t0 reference image for multi-angle fusion") >>> orthogonal_view(fused_img,suptitle="p58-t0 multi-angle fusion from AFFINE registration with manual initialization") >>> stack_browser(fused_img, title="p58-t0 multi-angle fusion from AFFINE registration with manual initialization") >>> fused_img, _ = _fusion_on_first(list_img, "mean", init_trsfs, final_vxs=0.3) >>> print(fused_img.get_shape()) >>> print(fused_img.get_voxelsize()) >>> orthogonal_view(fused_img,suptitle="p58-t0 multi-angle fusion from AFFINE registration with manual initialization") >>> stack_browser(fused_img, title="p58-t0 multi-angle fusion from AFFINE registration with manual initialization") >>> fused_img, init_trsfs = _fusion_on_first(list_img, "mean", init_trsfs, final_vxs=0.3, vectorfield=True) >>> registration_snapshot(list_img, init_trsfs, "p58-t0 multi-angle VECTORFIELD registration with manual initialization") >>> print(fused_img.get_shape()) >>> print(fused_img.get_voxelsize()) >>> orthogonal_view(fused_img,suptitle="p58-t0 multi-angle fusion from VECTORFIELD registration with manual initialization") >>> stack_browser(fused_img, title="p58-t0 multi-angle fusion from VECTORFIELD registration with manual initialization") >>> fused_img, init_trsfs = _fusion_on_first(list_img, 'mean', init_trsfs, super_resolution=False) >>> print(fused_img.get_shape()) >>> print(fused_img.get_voxelsize()) >>> orthogonal_view(fused_img,suptitle="p58-t0 multi-angle fusion in super-resolution from AFFINE registration with manual initialization") >>> stack_browser(fused_img) >>> fused_img, init_trsfs = _fusion_on_first(list_img, 'mean', init_trsfs, super_resolution=True) >>> print(fused_img.get_shape()) >>> print(fused_img.get_voxelsize()) >>> orthogonal_view(fused_img,suptitle="p58-t0 multi-angle fusion in super-resolution from AFFINE registration with manual initialization") >>> stack_browser(fused_img) """ # All images in 'list_images', are registered on the first of the list: ref_img, float_imgs = get_ref_im_float_list(images, 0) log.info(f"Reference image shape: {ref_img.get_shape()}") log.info(f"Reference image voxel-sizes: {ref_img.get_voxelsize()}") for flo_idx, flo_img in enumerate(float_imgs): log.info(f"Floating image #{flo_idx} shape: {flo_img.get_shape()}") log.info(f"Floating image #{flo_idx} voxel-sizes: {flo_img.get_voxelsize()}") # Defines the template image used for the final resampling: if 'final_vxs' in kwargs: # Resample the reference image in the template image: template = iso_template(ref_img, method=kwargs.pop('final_vxs')) elif kwargs.pop('super_resolution', True): # If Super-resolution is requested, the template image is isometric to the smallest voxel-size: template = iso_template(ref_img, method="min") else: template = ref_img log.info(f"Template image shape: {template.get_shape()}") log.info(f"Template image voxel-sizes: {template.get_voxelsize()}") int_interp = kwargs.pop('interpolation', 'linear') # Timer: fusion_time = time.time() # --- Block-matching registration of each floating images on the reference: lin_trsfs = [] # linear part of the transformations vf_trsfs = [] # non-linear part of the transformation for flo_idx, flo_img in enumerate(float_imgs): # Estimate the AFFINE transformation, from floating image to reference image: lin_trsf = blockmatching(flo_img, ref_img, method='affine', left_trsf=init_trsfs[flo_idx], **kwargs) lin_trsfs.append(lin_trsf) # If required, also estimate the VECTORFIELD transformation, from floating image to reference image: if kwargs.get('vectorfield', False): init_vf = compose_trsf([init_trsfs[flo_idx], lin_trsf]) vf_trsf = blockmatching(flo_img, resample(ref_img, voxelsize=template.get_voxelsize()), method='vectorfield', left_trsf=init_vf, **kwargs) vf_trsfs.append(vf_trsf) # --- Registered images and mask creation: res_imgs = [] # images list (to average) res_masks = [] # masks list (to average) if kwargs.get('fuse_reference', True): attr = get_image_attributes(ref_img, exclude=["dtype"]) ref_mask = SpatialImage(np.ones_like(ref_img, np.uint8), **attr) ref_id_aff_trsf = create_trsf('identity', template_img=template, trsf_type='affine') # Add the reference image to the list of image to be "fused" res_imgs.append(apply_trsf(ref_img, ref_id_aff_trsf, template_img=template, interpolation=int_interp)) res_masks.append(apply_trsf(ref_mask, ref_id_aff_trsf, template_img=template, interpolation='nearest')) for i, flo_img in enumerate(float_imgs): attr = get_image_attributes(ref_img, exclude=["dtype"]) flo_mask = SpatialImage(np.ones_like(flo_img, np.uint8), **attr) if kwargs.get('vectorfield', False): # Compose the linear and non-linear transformations (initial + affine + vectorfield) comp_trsf = compose_trsf([init_trsfs[i], lin_trsfs[i], vf_trsfs[i]]) # Apply the composed transformation to the floating image & create associated mask res_imgs.append(apply_trsf(flo_img, comp_trsf, interpolation=int_interp)) res_masks.append(apply_trsf(flo_mask, comp_trsf, interpolation='nearest')) else: # Compose the linear transformations (initial + affine) comp_trsf = compose_trsf([init_trsfs[i], lin_trsfs[i]]) # Apply the composed transformation to the floating image & create associated mask res_imgs.append(apply_trsf(flo_img, comp_trsf, template_img=template, interpolation=int_interp)) res_masks.append(apply_trsf(flo_mask, comp_trsf, template_img=template, interpolation='nearest')) # --- Average registered images with specified fusion method: if kwargs.get('global_averaging', False): # Global averaging => no masks ! mean_image = images_averaging(res_imgs, masks=None, method=method) else: # Local averaging mean_image = images_averaging(res_imgs, masks=res_masks, method=method) log.info(f"Average image shape: {mean_image.get_shape()}") log.info(f"Average image voxel-sizes: {mean_image.get_voxelsize()}") # --- Centering the mean image in the reference frame is achieved by: # 1. averaging the computed transformation (excluding potential manual rigid initialization) # 2. inverting it # 3. apply the inverted mean transformation to the averaged image log.debug("# --- Centering the mean image in the reference frame:") # 0. Create the list of (composed) transformation to average trsfs2average = [] # This list needs one trsf per image to fuse # So if we want to add the reference to the fusion, we need to create an identity transformation and add it to the list log.debug("0. Create the list of (composed) transformation to average") if kwargs.get('fuse_reference', True): if kwargs.get('vectorfield', False): # Add an identity vectorfield transfo if the reference image is "fused" trsfs2average.append(create_trsf('identity', template_img=template, trsf_type='vectorfield')) else: # Add an identity affine transfo if the reference image is "fused" trsfs2average.append(create_trsf('identity', template_img=template, trsf_type='affine')) # Now add the estimated AFFINE transformations (composed with VECTORFIELD if required) for i, _ in enumerate(float_imgs): if kwargs.get('vectorfield', False): trsfs2average.append(compose_trsf([lin_trsfs[i], vf_trsfs[i]], template_img=template)) else: trsfs2average.append(lin_trsfs[i]) # 1. average the (composed) transformations if kwargs.get('vectorfield', False): log.debug("1. average the composed linear transformations") mean_trsf = trsfs_averaging(trsfs2average, method='mean', trsf_type="vectorfield") else: log.debug("1. average the composed non-linear transformations") mean_trsf = trsfs_averaging(trsfs2average, method='mean', trsf_type="affine") # 2. invert the average transformation if kwargs.get('vectorfield', False): log.debug("2. invert the average transformation") inv_mean_trsf = inv_trsf(mean_trsf, template_img=template) else: log.debug("2. invert the average transformation") inv_mean_trsf = inv_trsf(mean_trsf) # 3. finally, apply the inverted average transformation to the average image if kwargs.get('vectorfield', False): log.debug("3. finally, apply the non-linear inverted average transformation to the average image") fused_img = apply_trsf(mean_image, inv_mean_trsf, interpolation=int_interp) else: log.debug("3. finally, apply the linear inverted average transformation to the average image") fused_img = apply_trsf(mean_image, inv_mean_trsf, template_img=template, interpolation=int_interp) log.info(f"Fused image shape: {fused_img.get_shape()}") log.info(f"Fused image voxel-sizes: {fused_img.get_voxelsize()}") # Timer: log.info(f"Multi-angle fusion {elapsed_time(fusion_time)}") return fused_img, trsfs2average