nanopyx.methods.drift_alignment.estimator_3d

  1import numpy as np
  2from math import sqrt
  3from scipy.interpolate import interp1d
  4
  5from .estimator_table import DriftEstimatorTable
  6from .corrector import DriftCorrector
  7from .estimator import DriftEstimator
  8from ...core.analysis.estimate_shift import GetMaxOptimizer
  9from ...core.utils.timeit import timeit
 10from ...core.analysis.ccm import calculate_ccm
 11
 12
 13class Estimator3D(object):
 14    """
 15    Main class implementing 3D drift correction.
 16
 17    This class provides methods for performing 3D drift correction on an image array with shape (t, z, y, x).
 18    It corrects both XY drift and Z drift using projection methods.
 19
 20    Args:
 21        None
 22
 23    Attributes:
 24        image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x).
 25        xy_estimator (DriftEstimator): A drift estimator for XY drift correction.
 26        z_estimator (DriftEstimator): A drift estimator for Z drift correction.
 27
 28    Methods:
 29        __init__(): Initialize the `Estimator3D` object.
 30
 31        correct_xy_drift(projection_mode="Mean", **kwargs): Correct XY drift in the image stack using projection.
 32
 33        correct_z_drift(axis_mode="top", projection_mode="Mean", **kwargs): Correct Z drift in the image stack using projection.
 34
 35        correct_3d_drift(image_array, axis_mode="top", projection_mode="Mean", **kwargs): Correct both XY and Z drift in the 3D image stack.
 36
 37    Example:
 38        estimator = Estimator3D()
 39        corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params)
 40
 41    Note:
 42        The `Estimator3D` class is used for performing 3D drift correction on image stacks.
 43        It includes methods for correcting XY drift and Z drift separately or together.
 44    """
 45
 46    def __init__(self):
 47        """
 48        Initialize the `Estimator3D` object.
 49
 50        Args:
 51            None
 52
 53        Returns:
 54            None
 55
 56        Example:
 57            estimator = Estimator3D()
 58        """
 59        self.image_array = None
 60        self.xy_estimator = None
 61        self.z_estimator = None
 62
 63    def correct_xy_drift(self, projection_mode="Mean", **kwargs):
 64        """
 65        Correct XY drift in the image stack using projection.
 66
 67        Args:
 68            projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean".
 69            **kwargs: Keyword arguments for drift estimation parameters.
 70
 71        Returns:
 72            None
 73
 74        Example:
 75            estimator.correct_xy_drift(projection_mode="Mean", **drift_params)
 76
 77        Note:
 78            This method corrects XY drift in the 3D image stack using projection-based drift estimation and correction.
 79        """
 80        self.xy_estimator = DriftEstimator()
 81
 82        if projection_mode == "Mean":
 83            projection = np.mean(self.image_array, axis=1)
 84        elif projection_mode == "Max":
 85            projection = np.max(self.image_array, axis=1)
 86        else:
 87            print("Not a valid projection mode")
 88            return None
 89
 90        self.xy_estimator.estimate(projection, apply=False, **kwargs)
 91
 92        corrector = DriftCorrector()
 93        corrector.estimator_table = self.xy_estimator.estimator_table
 94        for i in range(self.image_array.shape[1]):
 95            self.image_array[:, i, :, :] = corrector.apply_correction(self.image_array[:, i, :, :])
 96
 97    def correct_z_drift(self, axis_mode="top", projection_mode="Mean", **kwargs):
 98        """
 99        Correct Z drift in the image stack using projection.
100
101        Args:
102            axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top".
103            projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean".
104            **kwargs: Keyword arguments for drift estimation parameters.
105
106        Returns:
107            None
108
109        Example:
110            estimator.correct_z_drift(axis_mode="top", projection_mode="Mean", **drift_params)
111
112        Note:
113            This method corrects Z drift in the 3D image stack using projection-based drift estimation and correction.
114        """
115        if axis_mode == "top":
116            axis_idx = 2
117        elif axis_mode == "left":
118            axis_idx = 3
119        else:
120            print("Not a valid axis mode")
121            return None
122
123        self.z_estimator = DriftEstimator()
124        if projection_mode == "Mean":
125            projection = np.mean(self.image_array, axis=axis_idx)
126        elif projection_mode == "Max":
127            projection = np.max(self.image_array, axis=axis_idx)
128        else:
129            print("Not a valid projection mode")
130            return None
131
132        self.z_estimator.estimate(projection, apply=False, **kwargs)
133
134        corrector = DriftCorrector()
135        print(self.image_array.shape, projection.shape)
136        corrector.estimator_table = self.z_estimator.estimator_table
137        if axis_mode == "top":
138            corrector.estimator_table.drift_table[:, 1] = 0
139            for i in range(self.image_array.shape[axis_idx]):
140                self.image_array[:, :, i, :] = corrector.apply_correction(self.image_array[:, :, i, :])
141        elif axis_mode == "left":
142            corrector.estimator_table.drift_table[:, 2] = 0
143            for i in range(self.image_array.shape[axis_idx]):
144                self.image_array[:, :, :, i] = corrector.apply_correction(self.image_array[:, :, :, i])
145
146    def correct_3d_drift(self, image_array, axis_mode="top", projection_mode="Mean", **kwargs):
147        """
148        Correct both XY and Z drift in the 3D image stack.
149
150        Args:
151            image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x).
152            axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top".
153            projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean".
154            **kwargs: Keyword arguments for drift estimation parameters.
155
156        Returns:
157            numpy.ndarray: The drift-corrected 3D image stack.
158
159        Example:
160            corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params)
161
162        Note:
163            This method performs both XY and Z drift correction on the 3D image stack and returns the corrected stack.
164        """
165        self.image_array = image_array
166        self.correct_xy_drift(projection_mode=projection_mode, **kwargs)
167        self.correct_z_drift(axis_mode=axis_mode, projection_mode=projection_mode, **kwargs)
168        return self.image_array
class Estimator3D:
 14class Estimator3D(object):
 15    """
 16    Main class implementing 3D drift correction.
 17
 18    This class provides methods for performing 3D drift correction on an image array with shape (t, z, y, x).
 19    It corrects both XY drift and Z drift using projection methods.
 20
 21    Args:
 22        None
 23
 24    Attributes:
 25        image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x).
 26        xy_estimator (DriftEstimator): A drift estimator for XY drift correction.
 27        z_estimator (DriftEstimator): A drift estimator for Z drift correction.
 28
 29    Methods:
 30        __init__(): Initialize the `Estimator3D` object.
 31
 32        correct_xy_drift(projection_mode="Mean", **kwargs): Correct XY drift in the image stack using projection.
 33
 34        correct_z_drift(axis_mode="top", projection_mode="Mean", **kwargs): Correct Z drift in the image stack using projection.
 35
 36        correct_3d_drift(image_array, axis_mode="top", projection_mode="Mean", **kwargs): Correct both XY and Z drift in the 3D image stack.
 37
 38    Example:
 39        estimator = Estimator3D()
 40        corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params)
 41
 42    Note:
 43        The `Estimator3D` class is used for performing 3D drift correction on image stacks.
 44        It includes methods for correcting XY drift and Z drift separately or together.
 45    """
 46
 47    def __init__(self):
 48        """
 49        Initialize the `Estimator3D` object.
 50
 51        Args:
 52            None
 53
 54        Returns:
 55            None
 56
 57        Example:
 58            estimator = Estimator3D()
 59        """
 60        self.image_array = None
 61        self.xy_estimator = None
 62        self.z_estimator = None
 63
 64    def correct_xy_drift(self, projection_mode="Mean", **kwargs):
 65        """
 66        Correct XY drift in the image stack using projection.
 67
 68        Args:
 69            projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean".
 70            **kwargs: Keyword arguments for drift estimation parameters.
 71
 72        Returns:
 73            None
 74
 75        Example:
 76            estimator.correct_xy_drift(projection_mode="Mean", **drift_params)
 77
 78        Note:
 79            This method corrects XY drift in the 3D image stack using projection-based drift estimation and correction.
 80        """
 81        self.xy_estimator = DriftEstimator()
 82
 83        if projection_mode == "Mean":
 84            projection = np.mean(self.image_array, axis=1)
 85        elif projection_mode == "Max":
 86            projection = np.max(self.image_array, axis=1)
 87        else:
 88            print("Not a valid projection mode")
 89            return None
 90
 91        self.xy_estimator.estimate(projection, apply=False, **kwargs)
 92
 93        corrector = DriftCorrector()
 94        corrector.estimator_table = self.xy_estimator.estimator_table
 95        for i in range(self.image_array.shape[1]):
 96            self.image_array[:, i, :, :] = corrector.apply_correction(self.image_array[:, i, :, :])
 97
 98    def correct_z_drift(self, axis_mode="top", projection_mode="Mean", **kwargs):
 99        """
100        Correct Z drift in the image stack using projection.
101
102        Args:
103            axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top".
104            projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean".
105            **kwargs: Keyword arguments for drift estimation parameters.
106
107        Returns:
108            None
109
110        Example:
111            estimator.correct_z_drift(axis_mode="top", projection_mode="Mean", **drift_params)
112
113        Note:
114            This method corrects Z drift in the 3D image stack using projection-based drift estimation and correction.
115        """
116        if axis_mode == "top":
117            axis_idx = 2
118        elif axis_mode == "left":
119            axis_idx = 3
120        else:
121            print("Not a valid axis mode")
122            return None
123
124        self.z_estimator = DriftEstimator()
125        if projection_mode == "Mean":
126            projection = np.mean(self.image_array, axis=axis_idx)
127        elif projection_mode == "Max":
128            projection = np.max(self.image_array, axis=axis_idx)
129        else:
130            print("Not a valid projection mode")
131            return None
132
133        self.z_estimator.estimate(projection, apply=False, **kwargs)
134
135        corrector = DriftCorrector()
136        print(self.image_array.shape, projection.shape)
137        corrector.estimator_table = self.z_estimator.estimator_table
138        if axis_mode == "top":
139            corrector.estimator_table.drift_table[:, 1] = 0
140            for i in range(self.image_array.shape[axis_idx]):
141                self.image_array[:, :, i, :] = corrector.apply_correction(self.image_array[:, :, i, :])
142        elif axis_mode == "left":
143            corrector.estimator_table.drift_table[:, 2] = 0
144            for i in range(self.image_array.shape[axis_idx]):
145                self.image_array[:, :, :, i] = corrector.apply_correction(self.image_array[:, :, :, i])
146
147    def correct_3d_drift(self, image_array, axis_mode="top", projection_mode="Mean", **kwargs):
148        """
149        Correct both XY and Z drift in the 3D image stack.
150
151        Args:
152            image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x).
153            axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top".
154            projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean".
155            **kwargs: Keyword arguments for drift estimation parameters.
156
157        Returns:
158            numpy.ndarray: The drift-corrected 3D image stack.
159
160        Example:
161            corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params)
162
163        Note:
164            This method performs both XY and Z drift correction on the 3D image stack and returns the corrected stack.
165        """
166        self.image_array = image_array
167        self.correct_xy_drift(projection_mode=projection_mode, **kwargs)
168        self.correct_z_drift(axis_mode=axis_mode, projection_mode=projection_mode, **kwargs)
169        return self.image_array

Main class implementing 3D drift correction.

This class provides methods for performing 3D drift correction on an image array with shape (t, z, y, x). It corrects both XY drift and Z drift using projection methods.

Args: None

Attributes: image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x). xy_estimator (DriftEstimator): A drift estimator for XY drift correction. z_estimator (DriftEstimator): A drift estimator for Z drift correction.

Methods: __init__(): Initialize the Estimator3D object.

correct_xy_drift(projection_mode="Mean", **kwargs): Correct XY drift in the image stack using projection.

correct_z_drift(axis_mode="top", projection_mode="Mean", **kwargs): Correct Z drift in the image stack using projection.

correct_3d_drift(image_array, axis_mode="top", projection_mode="Mean", **kwargs): Correct both XY and Z drift in the 3D image stack.

Example: estimator = Estimator3D() corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params)

Note: The Estimator3D class is used for performing 3D drift correction on image stacks. It includes methods for correcting XY drift and Z drift separately or together.

Estimator3D()
47    def __init__(self):
48        """
49        Initialize the `Estimator3D` object.
50
51        Args:
52            None
53
54        Returns:
55            None
56
57        Example:
58            estimator = Estimator3D()
59        """
60        self.image_array = None
61        self.xy_estimator = None
62        self.z_estimator = None

Initialize the Estimator3D object.

Args: None

Returns: None

Example: estimator = Estimator3D()

image_array
xy_estimator
z_estimator
def correct_xy_drift(self, projection_mode='Mean', **kwargs):
64    def correct_xy_drift(self, projection_mode="Mean", **kwargs):
65        """
66        Correct XY drift in the image stack using projection.
67
68        Args:
69            projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean".
70            **kwargs: Keyword arguments for drift estimation parameters.
71
72        Returns:
73            None
74
75        Example:
76            estimator.correct_xy_drift(projection_mode="Mean", **drift_params)
77
78        Note:
79            This method corrects XY drift in the 3D image stack using projection-based drift estimation and correction.
80        """
81        self.xy_estimator = DriftEstimator()
82
83        if projection_mode == "Mean":
84            projection = np.mean(self.image_array, axis=1)
85        elif projection_mode == "Max":
86            projection = np.max(self.image_array, axis=1)
87        else:
88            print("Not a valid projection mode")
89            return None
90
91        self.xy_estimator.estimate(projection, apply=False, **kwargs)
92
93        corrector = DriftCorrector()
94        corrector.estimator_table = self.xy_estimator.estimator_table
95        for i in range(self.image_array.shape[1]):
96            self.image_array[:, i, :, :] = corrector.apply_correction(self.image_array[:, i, :, :])

Correct XY drift in the image stack using projection.

Args: projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". **kwargs: Keyword arguments for drift estimation parameters.

Returns: None

Example: estimator.correct_xy_drift(projection_mode="Mean", **drift_params)

Note: This method corrects XY drift in the 3D image stack using projection-based drift estimation and correction.

def correct_z_drift(self, axis_mode='top', projection_mode='Mean', **kwargs):
 98    def correct_z_drift(self, axis_mode="top", projection_mode="Mean", **kwargs):
 99        """
100        Correct Z drift in the image stack using projection.
101
102        Args:
103            axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top".
104            projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean".
105            **kwargs: Keyword arguments for drift estimation parameters.
106
107        Returns:
108            None
109
110        Example:
111            estimator.correct_z_drift(axis_mode="top", projection_mode="Mean", **drift_params)
112
113        Note:
114            This method corrects Z drift in the 3D image stack using projection-based drift estimation and correction.
115        """
116        if axis_mode == "top":
117            axis_idx = 2
118        elif axis_mode == "left":
119            axis_idx = 3
120        else:
121            print("Not a valid axis mode")
122            return None
123
124        self.z_estimator = DriftEstimator()
125        if projection_mode == "Mean":
126            projection = np.mean(self.image_array, axis=axis_idx)
127        elif projection_mode == "Max":
128            projection = np.max(self.image_array, axis=axis_idx)
129        else:
130            print("Not a valid projection mode")
131            return None
132
133        self.z_estimator.estimate(projection, apply=False, **kwargs)
134
135        corrector = DriftCorrector()
136        print(self.image_array.shape, projection.shape)
137        corrector.estimator_table = self.z_estimator.estimator_table
138        if axis_mode == "top":
139            corrector.estimator_table.drift_table[:, 1] = 0
140            for i in range(self.image_array.shape[axis_idx]):
141                self.image_array[:, :, i, :] = corrector.apply_correction(self.image_array[:, :, i, :])
142        elif axis_mode == "left":
143            corrector.estimator_table.drift_table[:, 2] = 0
144            for i in range(self.image_array.shape[axis_idx]):
145                self.image_array[:, :, :, i] = corrector.apply_correction(self.image_array[:, :, :, i])

Correct Z drift in the image stack using projection.

Args: axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top". projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". **kwargs: Keyword arguments for drift estimation parameters.

Returns: None

Example: estimator.correct_z_drift(axis_mode="top", projection_mode="Mean", **drift_params)

Note: This method corrects Z drift in the 3D image stack using projection-based drift estimation and correction.

def correct_3d_drift(self, image_array, axis_mode='top', projection_mode='Mean', **kwargs):
147    def correct_3d_drift(self, image_array, axis_mode="top", projection_mode="Mean", **kwargs):
148        """
149        Correct both XY and Z drift in the 3D image stack.
150
151        Args:
152            image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x).
153            axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top".
154            projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean".
155            **kwargs: Keyword arguments for drift estimation parameters.
156
157        Returns:
158            numpy.ndarray: The drift-corrected 3D image stack.
159
160        Example:
161            corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params)
162
163        Note:
164            This method performs both XY and Z drift correction on the 3D image stack and returns the corrected stack.
165        """
166        self.image_array = image_array
167        self.correct_xy_drift(projection_mode=projection_mode, **kwargs)
168        self.correct_z_drift(axis_mode=axis_mode, projection_mode=projection_mode, **kwargs)
169        return self.image_array

Correct both XY and Z drift in the 3D image stack.

Args: image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x). axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top". projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". **kwargs: Keyword arguments for drift estimation parameters.

Returns: numpy.ndarray: The drift-corrected 3D image stack.

Example: corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params)

Note: This method performs both XY and Z drift correction on the 3D image stack and returns the corrected stack.