nanopyx.methods.drift_alignment.estimator

  1import numpy as np
  2from math import sqrt
  3from scipy.interpolate import interp1d
  4
  5from .estimator_table import DriftEstimatorTable
  6from .corrector import DriftCorrector
  7from ...core.analysis.estimate_shift import GetMaxOptimizer
  8from ...core.utils.timeit import timeit
  9from ...core.analysis.ccm import calculate_ccm
 10from ...core.analysis.rcc import rcc
 11from ...core.analysis._le_drift_calculator import (
 12    DriftEstimator as leDriftEstimator,
 13)
 14
 15
 16class DriftEstimator(object):
 17    """
 18    Drift estimator class for estimating and correcting drift in image stacks.
 19
 20    This class provides methods for estimating and correcting drift in image stacks using cross-correlation.
 21
 22    Args:
 23        None
 24
 25    Attributes:
 26        estimator_table (DriftEstimatorTable): A table of parameters for drift estimation and correction.
 27        cross_correlation_map (numpy.ndarray): The cross-correlation map calculated during drift estimation.
 28        drift_xy (numpy.ndarray): The drift magnitude at each time point.
 29        drift_x (numpy.ndarray): The drift in the X direction at each time point.
 30        drift_y (numpy.ndarray): The drift in the Y direction at each time point.
 31
 32    Methods:
 33        __init__(): Initialize the `DriftEstimator` object.
 34
 35        estimate(image_array, **kwargs): Estimate and correct drift in an image stack.
 36
 37        compute_temporal_averaging(image_arr): Compute temporal averaging of image frames.
 38
 39        get_shift_from_ccm_slice(slice_index): Get the drift shift from a slice of the cross-correlation map.
 40
 41        get_shifts_from_ccm(): Get the drift shifts from the entire cross-correlation map.
 42
 43        create_drift_table(): Create a table of drift values.
 44
 45        save_drift_table(save_as_npy=True, path=None): Save the drift table to a file.
 46
 47        set_estimator_params(**kwargs): Set parameters for drift estimation and correction.
 48
 49    Example:
 50        estimator = DriftEstimator()
 51        drift_params = {
 52            "time_averaging": 2,
 53            "max_expected_drift": 5,
 54            "shift_calc_method": "rcc",
 55            "ref_option": 0,
 56            "apply": True,
 57        }
 58        drift_corrected_image = estimator.estimate(image_stack, **drift_params)
 59
 60    Note:
 61        The `DriftEstimator` class is used for estimating and correcting drift in image stacks.
 62        It provides methods for estimating drift using cross-correlation and applying drift correction to an image stack.
 63    """
 64
 65    def __init__(self, verbose=True):
 66        """
 67        Initialize the `DriftEstimator` object.
 68
 69        Args:
 70            None
 71
 72        Returns:
 73            None
 74
 75        Example:
 76            estimator = DriftEstimator()
 77        """
 78        self.verbose = verbose
 79        self.estimator_table = DriftEstimatorTable()
 80        self.cross_correlation_map = None
 81        self.drift_xy = None
 82        self.drift_x = None
 83        self.drift_y = None
 84
 85    # @timeit
 86    def estimate(self, image_array, **kwargs):
 87        """
 88        Estimate and correct drift in an image stack.
 89
 90        Args:
 91            image_array (numpy.ndarray): The input image stack with shape [n_slices, height, width].
 92            **kwargs: Keyword arguments for setting drift estimation parameters.
 93
 94        Returns:
 95            numpy.ndarray or None: The drift-corrected image stack if `apply` is True, else None.
 96
 97        Example:
 98            drift_params = {
 99                "time_averaging": 2,
100                "max_expected_drift": 5,
101                "ref_option": 0,
102                "apply": True,
103            }
104            drift_corrected_image = estimator.estimate(image_stack, **drift_params)
105
106        Note:
107            This method estimates and corrects drift in an image stack using specified parameters.
108        """
109        self.set_estimator_params(**kwargs)
110
111        n_slices = image_array.shape[0]
112
113        # x0, y0, x1, y1 correspond to the exact coordinates of the roi to be used or full image dims and should be a tuple
114        if (
115            self.estimator_table.params["use_roi"]
116            and self.estimator_table.params["roi"] is not None
117        ):  # crops image to roi
118            print(
119                self.estimator_table.params["use_roi"],
120                self.estimator_table.params["roi"],
121            )
122            x0, y0, x1, y1 = tuple(self.estimator_table.params["roi"])
123            image_arr = image_array[:, y0 : y1 + 1, x0 : x1 + 1]
124        else:
125            image_arr = image_array
126
127        estimator = leDriftEstimator(verbose=self.verbose)
128        self.estimator_table.drift_table = estimator.run(
129            np.asarray(image_arr, dtype=np.float32),
130            time_averaging=self.estimator_table.params["time_averaging"],
131            max_drift=self.estimator_table.params["max_expected_drift"],
132            ref_option=self.estimator_table.params["ref_option"],
133        )
134
135        if self.estimator_table.params["apply"]:
136            drift_corrector = DriftCorrector()
137            drift_corrector.estimator_table = self.estimator_table
138            tmp = drift_corrector.apply_correction(image_array)
139            return tmp
140        else:
141            return None
142
143    def save_drift_table(self, save_as_npy=True, path=None):
144        """
145        Save the drift table to a file.
146
147        Args:
148            save_as_npy (bool, optional): Whether to save the table as a NumPy binary file. Default is True.
149            path (str, optional): The file path to save the table. If not provided, a user input prompt will be used.
150
151        Returns:
152            None
153
154        Example:
155            self.save_drift_table(save_as_npy=True, path="drift_table.npy")
156
157        Note:
158            This method allows saving the drift table to a file in either NumPy binary or CSV format.
159        """
160        if save_as_npy:
161            self.estimator_table.export_npy(path=path)
162        else:
163            self.estimator_table.export_csv(path=path)
164
165    def set_estimator_params(self, **kwargs):
166        """
167        Set parameters for drift estimation and correction.
168
169        Args:
170            **kwargs: Keyword arguments for setting drift estimation parameters.
171
172        Returns:
173            None
174
175        Example:
176            params = {
177                "time_averaging": 2,
178                "max_expected_drift": 5,
179                "shift_calc_method": "rcc",
180                "ref_option": 0,
181                "apply": True,
182            }
183            self.set_estimator_params(**params)
184
185        Note:
186            This method allows setting parameters for drift estimation and correction.
187        """
188        self.estimator_table.set_params(**kwargs)
class DriftEstimator:
 17class DriftEstimator(object):
 18    """
 19    Drift estimator class for estimating and correcting drift in image stacks.
 20
 21    This class provides methods for estimating and correcting drift in image stacks using cross-correlation.
 22
 23    Args:
 24        None
 25
 26    Attributes:
 27        estimator_table (DriftEstimatorTable): A table of parameters for drift estimation and correction.
 28        cross_correlation_map (numpy.ndarray): The cross-correlation map calculated during drift estimation.
 29        drift_xy (numpy.ndarray): The drift magnitude at each time point.
 30        drift_x (numpy.ndarray): The drift in the X direction at each time point.
 31        drift_y (numpy.ndarray): The drift in the Y direction at each time point.
 32
 33    Methods:
 34        __init__(): Initialize the `DriftEstimator` object.
 35
 36        estimate(image_array, **kwargs): Estimate and correct drift in an image stack.
 37
 38        compute_temporal_averaging(image_arr): Compute temporal averaging of image frames.
 39
 40        get_shift_from_ccm_slice(slice_index): Get the drift shift from a slice of the cross-correlation map.
 41
 42        get_shifts_from_ccm(): Get the drift shifts from the entire cross-correlation map.
 43
 44        create_drift_table(): Create a table of drift values.
 45
 46        save_drift_table(save_as_npy=True, path=None): Save the drift table to a file.
 47
 48        set_estimator_params(**kwargs): Set parameters for drift estimation and correction.
 49
 50    Example:
 51        estimator = DriftEstimator()
 52        drift_params = {
 53            "time_averaging": 2,
 54            "max_expected_drift": 5,
 55            "shift_calc_method": "rcc",
 56            "ref_option": 0,
 57            "apply": True,
 58        }
 59        drift_corrected_image = estimator.estimate(image_stack, **drift_params)
 60
 61    Note:
 62        The `DriftEstimator` class is used for estimating and correcting drift in image stacks.
 63        It provides methods for estimating drift using cross-correlation and applying drift correction to an image stack.
 64    """
 65
 66    def __init__(self, verbose=True):
 67        """
 68        Initialize the `DriftEstimator` object.
 69
 70        Args:
 71            None
 72
 73        Returns:
 74            None
 75
 76        Example:
 77            estimator = DriftEstimator()
 78        """
 79        self.verbose = verbose
 80        self.estimator_table = DriftEstimatorTable()
 81        self.cross_correlation_map = None
 82        self.drift_xy = None
 83        self.drift_x = None
 84        self.drift_y = None
 85
 86    # @timeit
 87    def estimate(self, image_array, **kwargs):
 88        """
 89        Estimate and correct drift in an image stack.
 90
 91        Args:
 92            image_array (numpy.ndarray): The input image stack with shape [n_slices, height, width].
 93            **kwargs: Keyword arguments for setting drift estimation parameters.
 94
 95        Returns:
 96            numpy.ndarray or None: The drift-corrected image stack if `apply` is True, else None.
 97
 98        Example:
 99            drift_params = {
100                "time_averaging": 2,
101                "max_expected_drift": 5,
102                "ref_option": 0,
103                "apply": True,
104            }
105            drift_corrected_image = estimator.estimate(image_stack, **drift_params)
106
107        Note:
108            This method estimates and corrects drift in an image stack using specified parameters.
109        """
110        self.set_estimator_params(**kwargs)
111
112        n_slices = image_array.shape[0]
113
114        # x0, y0, x1, y1 correspond to the exact coordinates of the roi to be used or full image dims and should be a tuple
115        if (
116            self.estimator_table.params["use_roi"]
117            and self.estimator_table.params["roi"] is not None
118        ):  # crops image to roi
119            print(
120                self.estimator_table.params["use_roi"],
121                self.estimator_table.params["roi"],
122            )
123            x0, y0, x1, y1 = tuple(self.estimator_table.params["roi"])
124            image_arr = image_array[:, y0 : y1 + 1, x0 : x1 + 1]
125        else:
126            image_arr = image_array
127
128        estimator = leDriftEstimator(verbose=self.verbose)
129        self.estimator_table.drift_table = estimator.run(
130            np.asarray(image_arr, dtype=np.float32),
131            time_averaging=self.estimator_table.params["time_averaging"],
132            max_drift=self.estimator_table.params["max_expected_drift"],
133            ref_option=self.estimator_table.params["ref_option"],
134        )
135
136        if self.estimator_table.params["apply"]:
137            drift_corrector = DriftCorrector()
138            drift_corrector.estimator_table = self.estimator_table
139            tmp = drift_corrector.apply_correction(image_array)
140            return tmp
141        else:
142            return None
143
144    def save_drift_table(self, save_as_npy=True, path=None):
145        """
146        Save the drift table to a file.
147
148        Args:
149            save_as_npy (bool, optional): Whether to save the table as a NumPy binary file. Default is True.
150            path (str, optional): The file path to save the table. If not provided, a user input prompt will be used.
151
152        Returns:
153            None
154
155        Example:
156            self.save_drift_table(save_as_npy=True, path="drift_table.npy")
157
158        Note:
159            This method allows saving the drift table to a file in either NumPy binary or CSV format.
160        """
161        if save_as_npy:
162            self.estimator_table.export_npy(path=path)
163        else:
164            self.estimator_table.export_csv(path=path)
165
166    def set_estimator_params(self, **kwargs):
167        """
168        Set parameters for drift estimation and correction.
169
170        Args:
171            **kwargs: Keyword arguments for setting drift estimation parameters.
172
173        Returns:
174            None
175
176        Example:
177            params = {
178                "time_averaging": 2,
179                "max_expected_drift": 5,
180                "shift_calc_method": "rcc",
181                "ref_option": 0,
182                "apply": True,
183            }
184            self.set_estimator_params(**params)
185
186        Note:
187            This method allows setting parameters for drift estimation and correction.
188        """
189        self.estimator_table.set_params(**kwargs)

Drift estimator class for estimating and correcting drift in image stacks.

This class provides methods for estimating and correcting drift in image stacks using cross-correlation.

Args: None

Attributes: estimator_table (DriftEstimatorTable): A table of parameters for drift estimation and correction. cross_correlation_map (numpy.ndarray): The cross-correlation map calculated during drift estimation. drift_xy (numpy.ndarray): The drift magnitude at each time point. drift_x (numpy.ndarray): The drift in the X direction at each time point. drift_y (numpy.ndarray): The drift in the Y direction at each time point.

Methods: __init__(): Initialize the DriftEstimator object.

estimate(image_array, **kwargs): Estimate and correct drift in an image stack.

compute_temporal_averaging(image_arr): Compute temporal averaging of image frames.

get_shift_from_ccm_slice(slice_index): Get the drift shift from a slice of the cross-correlation map.

get_shifts_from_ccm(): Get the drift shifts from the entire cross-correlation map.

create_drift_table(): Create a table of drift values.

save_drift_table(save_as_npy=True, path=None): Save the drift table to a file.

set_estimator_params(**kwargs): Set parameters for drift estimation and correction.

Example: estimator = DriftEstimator() drift_params = { "time_averaging": 2, "max_expected_drift": 5, "shift_calc_method": "rcc", "ref_option": 0, "apply": True, } drift_corrected_image = estimator.estimate(image_stack, **drift_params)

Note: The DriftEstimator class is used for estimating and correcting drift in image stacks. It provides methods for estimating drift using cross-correlation and applying drift correction to an image stack.

DriftEstimator(verbose=True)
66    def __init__(self, verbose=True):
67        """
68        Initialize the `DriftEstimator` object.
69
70        Args:
71            None
72
73        Returns:
74            None
75
76        Example:
77            estimator = DriftEstimator()
78        """
79        self.verbose = verbose
80        self.estimator_table = DriftEstimatorTable()
81        self.cross_correlation_map = None
82        self.drift_xy = None
83        self.drift_x = None
84        self.drift_y = None

Initialize the DriftEstimator object.

Args: None

Returns: None

Example: estimator = DriftEstimator()

verbose
estimator_table
cross_correlation_map
drift_xy
drift_x
drift_y
def estimate(self, image_array, **kwargs):
 87    def estimate(self, image_array, **kwargs):
 88        """
 89        Estimate and correct drift in an image stack.
 90
 91        Args:
 92            image_array (numpy.ndarray): The input image stack with shape [n_slices, height, width].
 93            **kwargs: Keyword arguments for setting drift estimation parameters.
 94
 95        Returns:
 96            numpy.ndarray or None: The drift-corrected image stack if `apply` is True, else None.
 97
 98        Example:
 99            drift_params = {
100                "time_averaging": 2,
101                "max_expected_drift": 5,
102                "ref_option": 0,
103                "apply": True,
104            }
105            drift_corrected_image = estimator.estimate(image_stack, **drift_params)
106
107        Note:
108            This method estimates and corrects drift in an image stack using specified parameters.
109        """
110        self.set_estimator_params(**kwargs)
111
112        n_slices = image_array.shape[0]
113
114        # x0, y0, x1, y1 correspond to the exact coordinates of the roi to be used or full image dims and should be a tuple
115        if (
116            self.estimator_table.params["use_roi"]
117            and self.estimator_table.params["roi"] is not None
118        ):  # crops image to roi
119            print(
120                self.estimator_table.params["use_roi"],
121                self.estimator_table.params["roi"],
122            )
123            x0, y0, x1, y1 = tuple(self.estimator_table.params["roi"])
124            image_arr = image_array[:, y0 : y1 + 1, x0 : x1 + 1]
125        else:
126            image_arr = image_array
127
128        estimator = leDriftEstimator(verbose=self.verbose)
129        self.estimator_table.drift_table = estimator.run(
130            np.asarray(image_arr, dtype=np.float32),
131            time_averaging=self.estimator_table.params["time_averaging"],
132            max_drift=self.estimator_table.params["max_expected_drift"],
133            ref_option=self.estimator_table.params["ref_option"],
134        )
135
136        if self.estimator_table.params["apply"]:
137            drift_corrector = DriftCorrector()
138            drift_corrector.estimator_table = self.estimator_table
139            tmp = drift_corrector.apply_correction(image_array)
140            return tmp
141        else:
142            return None

Estimate and correct drift in an image stack.

Args: image_array (numpy.ndarray): The input image stack with shape [n_slices, height, width]. **kwargs: Keyword arguments for setting drift estimation parameters.

Returns: numpy.ndarray or None: The drift-corrected image stack if apply is True, else None.

Example: drift_params = { "time_averaging": 2, "max_expected_drift": 5, "ref_option": 0, "apply": True, } drift_corrected_image = estimator.estimate(image_stack, **drift_params)

Note: This method estimates and corrects drift in an image stack using specified parameters.

def save_drift_table(self, save_as_npy=True, path=None):
144    def save_drift_table(self, save_as_npy=True, path=None):
145        """
146        Save the drift table to a file.
147
148        Args:
149            save_as_npy (bool, optional): Whether to save the table as a NumPy binary file. Default is True.
150            path (str, optional): The file path to save the table. If not provided, a user input prompt will be used.
151
152        Returns:
153            None
154
155        Example:
156            self.save_drift_table(save_as_npy=True, path="drift_table.npy")
157
158        Note:
159            This method allows saving the drift table to a file in either NumPy binary or CSV format.
160        """
161        if save_as_npy:
162            self.estimator_table.export_npy(path=path)
163        else:
164            self.estimator_table.export_csv(path=path)

Save the drift table to a file.

Args: save_as_npy (bool, optional): Whether to save the table as a NumPy binary file. Default is True. path (str, optional): The file path to save the table. If not provided, a user input prompt will be used.

Returns: None

Example: self.save_drift_table(save_as_npy=True, path="drift_table.npy")

Note: This method allows saving the drift table to a file in either NumPy binary or CSV format.

def set_estimator_params(self, **kwargs):
166    def set_estimator_params(self, **kwargs):
167        """
168        Set parameters for drift estimation and correction.
169
170        Args:
171            **kwargs: Keyword arguments for setting drift estimation parameters.
172
173        Returns:
174            None
175
176        Example:
177            params = {
178                "time_averaging": 2,
179                "max_expected_drift": 5,
180                "shift_calc_method": "rcc",
181                "ref_option": 0,
182                "apply": True,
183            }
184            self.set_estimator_params(**params)
185
186        Note:
187            This method allows setting parameters for drift estimation and correction.
188        """
189        self.estimator_table.set_params(**kwargs)

Set parameters for drift estimation and correction.

Args: **kwargs: Keyword arguments for setting drift estimation parameters.

Returns: None

Example: params = { "time_averaging": 2, "max_expected_drift": 5, "shift_calc_method": "rcc", "ref_option": 0, "apply": True, } self.set_estimator_params(**params)

Note: This method allows setting parameters for drift estimation and correction.