nanopyx.methods.drift_alignment.estimator_3d
1import numpy as np 2from math import sqrt 3from scipy.interpolate import interp1d 4 5from .estimator_table import DriftEstimatorTable 6from .corrector import DriftCorrector 7from .estimator import DriftEstimator 8from ...core.analysis.estimate_shift import GetMaxOptimizer 9from ...core.utils.timeit import timeit 10from ...core.analysis.ccm import calculate_ccm 11 12 13class Estimator3D(object): 14 """ 15 Main class implementing 3D drift correction. 16 17 This class provides methods for performing 3D drift correction on an image array with shape (t, z, y, x). 18 It corrects both XY drift and Z drift using projection methods. 19 20 Args: 21 None 22 23 Attributes: 24 image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x). 25 xy_estimator (DriftEstimator): A drift estimator for XY drift correction. 26 z_estimator (DriftEstimator): A drift estimator for Z drift correction. 27 28 Methods: 29 __init__(): Initialize the `Estimator3D` object. 30 31 correct_xy_drift(projection_mode="Mean", **kwargs): Correct XY drift in the image stack using projection. 32 33 correct_z_drift(axis_mode="top", projection_mode="Mean", **kwargs): Correct Z drift in the image stack using projection. 34 35 correct_3d_drift(image_array, axis_mode="top", projection_mode="Mean", **kwargs): Correct both XY and Z drift in the 3D image stack. 36 37 Example: 38 estimator = Estimator3D() 39 corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params) 40 41 Note: 42 The `Estimator3D` class is used for performing 3D drift correction on image stacks. 43 It includes methods for correcting XY drift and Z drift separately or together. 44 """ 45 46 def __init__(self): 47 """ 48 Initialize the `Estimator3D` object. 49 50 Args: 51 None 52 53 Returns: 54 None 55 56 Example: 57 estimator = Estimator3D() 58 """ 59 self.image_array = None 60 self.xy_estimator = None 61 self.z_estimator = None 62 63 def correct_xy_drift(self, projection_mode="Mean", **kwargs): 64 """ 65 Correct XY drift in the image stack using projection. 66 67 Args: 68 projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". 69 **kwargs: Keyword arguments for drift estimation parameters. 70 71 Returns: 72 None 73 74 Example: 75 estimator.correct_xy_drift(projection_mode="Mean", **drift_params) 76 77 Note: 78 This method corrects XY drift in the 3D image stack using projection-based drift estimation and correction. 79 """ 80 self.xy_estimator = DriftEstimator() 81 82 if projection_mode == "Mean": 83 projection = np.mean(self.image_array, axis=1) 84 elif projection_mode == "Max": 85 projection = np.max(self.image_array, axis=1) 86 else: 87 print("Not a valid projection mode") 88 return None 89 90 self.xy_estimator.estimate(projection, apply=False, **kwargs) 91 92 corrector = DriftCorrector() 93 corrector.estimator_table = self.xy_estimator.estimator_table 94 for i in range(self.image_array.shape[1]): 95 self.image_array[:, i, :, :] = corrector.apply_correction(self.image_array[:, i, :, :]) 96 97 def correct_z_drift(self, axis_mode="top", projection_mode="Mean", **kwargs): 98 """ 99 Correct Z drift in the image stack using projection. 100 101 Args: 102 axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top". 103 projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". 104 **kwargs: Keyword arguments for drift estimation parameters. 105 106 Returns: 107 None 108 109 Example: 110 estimator.correct_z_drift(axis_mode="top", projection_mode="Mean", **drift_params) 111 112 Note: 113 This method corrects Z drift in the 3D image stack using projection-based drift estimation and correction. 114 """ 115 if axis_mode == "top": 116 axis_idx = 2 117 elif axis_mode == "left": 118 axis_idx = 3 119 else: 120 print("Not a valid axis mode") 121 return None 122 123 self.z_estimator = DriftEstimator() 124 if projection_mode == "Mean": 125 projection = np.mean(self.image_array, axis=axis_idx) 126 elif projection_mode == "Max": 127 projection = np.max(self.image_array, axis=axis_idx) 128 else: 129 print("Not a valid projection mode") 130 return None 131 132 self.z_estimator.estimate(projection, apply=False, **kwargs) 133 134 corrector = DriftCorrector() 135 print(self.image_array.shape, projection.shape) 136 corrector.estimator_table = self.z_estimator.estimator_table 137 if axis_mode == "top": 138 corrector.estimator_table.drift_table[:, 1] = 0 139 for i in range(self.image_array.shape[axis_idx]): 140 self.image_array[:, :, i, :] = corrector.apply_correction(self.image_array[:, :, i, :]) 141 elif axis_mode == "left": 142 corrector.estimator_table.drift_table[:, 2] = 0 143 for i in range(self.image_array.shape[axis_idx]): 144 self.image_array[:, :, :, i] = corrector.apply_correction(self.image_array[:, :, :, i]) 145 146 def correct_3d_drift(self, image_array, axis_mode="top", projection_mode="Mean", **kwargs): 147 """ 148 Correct both XY and Z drift in the 3D image stack. 149 150 Args: 151 image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x). 152 axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top". 153 projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". 154 **kwargs: Keyword arguments for drift estimation parameters. 155 156 Returns: 157 numpy.ndarray: The drift-corrected 3D image stack. 158 159 Example: 160 corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params) 161 162 Note: 163 This method performs both XY and Z drift correction on the 3D image stack and returns the corrected stack. 164 """ 165 self.image_array = image_array 166 self.correct_xy_drift(projection_mode=projection_mode, **kwargs) 167 self.correct_z_drift(axis_mode=axis_mode, projection_mode=projection_mode, **kwargs) 168 return self.image_array
14class Estimator3D(object): 15 """ 16 Main class implementing 3D drift correction. 17 18 This class provides methods for performing 3D drift correction on an image array with shape (t, z, y, x). 19 It corrects both XY drift and Z drift using projection methods. 20 21 Args: 22 None 23 24 Attributes: 25 image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x). 26 xy_estimator (DriftEstimator): A drift estimator for XY drift correction. 27 z_estimator (DriftEstimator): A drift estimator for Z drift correction. 28 29 Methods: 30 __init__(): Initialize the `Estimator3D` object. 31 32 correct_xy_drift(projection_mode="Mean", **kwargs): Correct XY drift in the image stack using projection. 33 34 correct_z_drift(axis_mode="top", projection_mode="Mean", **kwargs): Correct Z drift in the image stack using projection. 35 36 correct_3d_drift(image_array, axis_mode="top", projection_mode="Mean", **kwargs): Correct both XY and Z drift in the 3D image stack. 37 38 Example: 39 estimator = Estimator3D() 40 corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params) 41 42 Note: 43 The `Estimator3D` class is used for performing 3D drift correction on image stacks. 44 It includes methods for correcting XY drift and Z drift separately or together. 45 """ 46 47 def __init__(self): 48 """ 49 Initialize the `Estimator3D` object. 50 51 Args: 52 None 53 54 Returns: 55 None 56 57 Example: 58 estimator = Estimator3D() 59 """ 60 self.image_array = None 61 self.xy_estimator = None 62 self.z_estimator = None 63 64 def correct_xy_drift(self, projection_mode="Mean", **kwargs): 65 """ 66 Correct XY drift in the image stack using projection. 67 68 Args: 69 projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". 70 **kwargs: Keyword arguments for drift estimation parameters. 71 72 Returns: 73 None 74 75 Example: 76 estimator.correct_xy_drift(projection_mode="Mean", **drift_params) 77 78 Note: 79 This method corrects XY drift in the 3D image stack using projection-based drift estimation and correction. 80 """ 81 self.xy_estimator = DriftEstimator() 82 83 if projection_mode == "Mean": 84 projection = np.mean(self.image_array, axis=1) 85 elif projection_mode == "Max": 86 projection = np.max(self.image_array, axis=1) 87 else: 88 print("Not a valid projection mode") 89 return None 90 91 self.xy_estimator.estimate(projection, apply=False, **kwargs) 92 93 corrector = DriftCorrector() 94 corrector.estimator_table = self.xy_estimator.estimator_table 95 for i in range(self.image_array.shape[1]): 96 self.image_array[:, i, :, :] = corrector.apply_correction(self.image_array[:, i, :, :]) 97 98 def correct_z_drift(self, axis_mode="top", projection_mode="Mean", **kwargs): 99 """ 100 Correct Z drift in the image stack using projection. 101 102 Args: 103 axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top". 104 projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". 105 **kwargs: Keyword arguments for drift estimation parameters. 106 107 Returns: 108 None 109 110 Example: 111 estimator.correct_z_drift(axis_mode="top", projection_mode="Mean", **drift_params) 112 113 Note: 114 This method corrects Z drift in the 3D image stack using projection-based drift estimation and correction. 115 """ 116 if axis_mode == "top": 117 axis_idx = 2 118 elif axis_mode == "left": 119 axis_idx = 3 120 else: 121 print("Not a valid axis mode") 122 return None 123 124 self.z_estimator = DriftEstimator() 125 if projection_mode == "Mean": 126 projection = np.mean(self.image_array, axis=axis_idx) 127 elif projection_mode == "Max": 128 projection = np.max(self.image_array, axis=axis_idx) 129 else: 130 print("Not a valid projection mode") 131 return None 132 133 self.z_estimator.estimate(projection, apply=False, **kwargs) 134 135 corrector = DriftCorrector() 136 print(self.image_array.shape, projection.shape) 137 corrector.estimator_table = self.z_estimator.estimator_table 138 if axis_mode == "top": 139 corrector.estimator_table.drift_table[:, 1] = 0 140 for i in range(self.image_array.shape[axis_idx]): 141 self.image_array[:, :, i, :] = corrector.apply_correction(self.image_array[:, :, i, :]) 142 elif axis_mode == "left": 143 corrector.estimator_table.drift_table[:, 2] = 0 144 for i in range(self.image_array.shape[axis_idx]): 145 self.image_array[:, :, :, i] = corrector.apply_correction(self.image_array[:, :, :, i]) 146 147 def correct_3d_drift(self, image_array, axis_mode="top", projection_mode="Mean", **kwargs): 148 """ 149 Correct both XY and Z drift in the 3D image stack. 150 151 Args: 152 image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x). 153 axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top". 154 projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". 155 **kwargs: Keyword arguments for drift estimation parameters. 156 157 Returns: 158 numpy.ndarray: The drift-corrected 3D image stack. 159 160 Example: 161 corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params) 162 163 Note: 164 This method performs both XY and Z drift correction on the 3D image stack and returns the corrected stack. 165 """ 166 self.image_array = image_array 167 self.correct_xy_drift(projection_mode=projection_mode, **kwargs) 168 self.correct_z_drift(axis_mode=axis_mode, projection_mode=projection_mode, **kwargs) 169 return self.image_array
Main class implementing 3D drift correction.
This class provides methods for performing 3D drift correction on an image array with shape (t, z, y, x). It corrects both XY drift and Z drift using projection methods.
Args: None
Attributes: image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x). xy_estimator (DriftEstimator): A drift estimator for XY drift correction. z_estimator (DriftEstimator): A drift estimator for Z drift correction.
Methods:
__init__(): Initialize the Estimator3D
object.
correct_xy_drift(projection_mode="Mean", **kwargs): Correct XY drift in the image stack using projection.
correct_z_drift(axis_mode="top", projection_mode="Mean", **kwargs): Correct Z drift in the image stack using projection.
correct_3d_drift(image_array, axis_mode="top", projection_mode="Mean", **kwargs): Correct both XY and Z drift in the 3D image stack.
Example: estimator = Estimator3D() corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params)
Note:
The Estimator3D
class is used for performing 3D drift correction on image stacks.
It includes methods for correcting XY drift and Z drift separately or together.
64 def correct_xy_drift(self, projection_mode="Mean", **kwargs): 65 """ 66 Correct XY drift in the image stack using projection. 67 68 Args: 69 projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". 70 **kwargs: Keyword arguments for drift estimation parameters. 71 72 Returns: 73 None 74 75 Example: 76 estimator.correct_xy_drift(projection_mode="Mean", **drift_params) 77 78 Note: 79 This method corrects XY drift in the 3D image stack using projection-based drift estimation and correction. 80 """ 81 self.xy_estimator = DriftEstimator() 82 83 if projection_mode == "Mean": 84 projection = np.mean(self.image_array, axis=1) 85 elif projection_mode == "Max": 86 projection = np.max(self.image_array, axis=1) 87 else: 88 print("Not a valid projection mode") 89 return None 90 91 self.xy_estimator.estimate(projection, apply=False, **kwargs) 92 93 corrector = DriftCorrector() 94 corrector.estimator_table = self.xy_estimator.estimator_table 95 for i in range(self.image_array.shape[1]): 96 self.image_array[:, i, :, :] = corrector.apply_correction(self.image_array[:, i, :, :])
Correct XY drift in the image stack using projection.
Args: projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". **kwargs: Keyword arguments for drift estimation parameters.
Returns: None
Example: estimator.correct_xy_drift(projection_mode="Mean", **drift_params)
Note: This method corrects XY drift in the 3D image stack using projection-based drift estimation and correction.
98 def correct_z_drift(self, axis_mode="top", projection_mode="Mean", **kwargs): 99 """ 100 Correct Z drift in the image stack using projection. 101 102 Args: 103 axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top". 104 projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". 105 **kwargs: Keyword arguments for drift estimation parameters. 106 107 Returns: 108 None 109 110 Example: 111 estimator.correct_z_drift(axis_mode="top", projection_mode="Mean", **drift_params) 112 113 Note: 114 This method corrects Z drift in the 3D image stack using projection-based drift estimation and correction. 115 """ 116 if axis_mode == "top": 117 axis_idx = 2 118 elif axis_mode == "left": 119 axis_idx = 3 120 else: 121 print("Not a valid axis mode") 122 return None 123 124 self.z_estimator = DriftEstimator() 125 if projection_mode == "Mean": 126 projection = np.mean(self.image_array, axis=axis_idx) 127 elif projection_mode == "Max": 128 projection = np.max(self.image_array, axis=axis_idx) 129 else: 130 print("Not a valid projection mode") 131 return None 132 133 self.z_estimator.estimate(projection, apply=False, **kwargs) 134 135 corrector = DriftCorrector() 136 print(self.image_array.shape, projection.shape) 137 corrector.estimator_table = self.z_estimator.estimator_table 138 if axis_mode == "top": 139 corrector.estimator_table.drift_table[:, 1] = 0 140 for i in range(self.image_array.shape[axis_idx]): 141 self.image_array[:, :, i, :] = corrector.apply_correction(self.image_array[:, :, i, :]) 142 elif axis_mode == "left": 143 corrector.estimator_table.drift_table[:, 2] = 0 144 for i in range(self.image_array.shape[axis_idx]): 145 self.image_array[:, :, :, i] = corrector.apply_correction(self.image_array[:, :, :, i])
Correct Z drift in the image stack using projection.
Args: axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top". projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". **kwargs: Keyword arguments for drift estimation parameters.
Returns: None
Example: estimator.correct_z_drift(axis_mode="top", projection_mode="Mean", **drift_params)
Note: This method corrects Z drift in the 3D image stack using projection-based drift estimation and correction.
147 def correct_3d_drift(self, image_array, axis_mode="top", projection_mode="Mean", **kwargs): 148 """ 149 Correct both XY and Z drift in the 3D image stack. 150 151 Args: 152 image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x). 153 axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top". 154 projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". 155 **kwargs: Keyword arguments for drift estimation parameters. 156 157 Returns: 158 numpy.ndarray: The drift-corrected 3D image stack. 159 160 Example: 161 corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params) 162 163 Note: 164 This method performs both XY and Z drift correction on the 3D image stack and returns the corrected stack. 165 """ 166 self.image_array = image_array 167 self.correct_xy_drift(projection_mode=projection_mode, **kwargs) 168 self.correct_z_drift(axis_mode=axis_mode, projection_mode=projection_mode, **kwargs) 169 return self.image_array
Correct both XY and Z drift in the 3D image stack.
Args: image_array (numpy.ndarray): The input 3D image stack with shape (t, z, y, x). axis_mode (str, optional): The axis mode for Z drift correction. "top" or "left" can be used. Default is "top". projection_mode (str, optional): The projection mode for drift correction. "Mean" or "Max" can be used. Default is "Mean". **kwargs: Keyword arguments for drift estimation parameters.
Returns: numpy.ndarray: The drift-corrected 3D image stack.
Example: corrected_image_stack = estimator.correct_3d_drift(image_stack, axis_mode="top", projection_mode="Mean", **drift_params)
Note: This method performs both XY and Z drift correction on the 3D image stack and returns the corrected stack.