Skip to content

brahmap.core.BlockDiagonalPreconditionerLO

Bases: LinearOperator

Standard preconditioner defined as:

\[M_{BD}=( P^T diag(N^{-1}) P)^{-1}\]

where \(P\) is the pointing matrix (see PointingLO). Such inverse operator could be easily computed given the structure of the matrix \(P\).

Parameters:

Name Type Description Default
processed_samples ProcessTimeSamples

description

required
solver_type Union[None, SolverType]

description, by default None

None
Source code in brahmap/core/linearoperators.py
class BlockDiagonalPreconditionerLO(LinearOperator):
    r"""
    Standard preconditioner defined as:

    $$M_{BD}=( P^T diag(N^{-1}) P)^{-1}$$

    where $P$ is the *pointing matrix* (see `PointingLO`).
    Such inverse operator  could be easily computed given the structure of the
    matrix $P$.

    Parameters
    ----------
    processed_samples : ProcessTimeSamples
        _description_
    solver_type : Union[None, SolverType], optional
        _description_, by default None
    """

    def __init__(
        self,
        processed_samples: ProcessTimeSamples,
        solver_type: Union[None, SolverType] = None,
    ):
        if solver_type is None:
            self.__solver_type = processed_samples.solver_type
        else:
            MPI_RAISE_EXCEPTION(
                condition=(int(processed_samples.solver_type) < int(solver_type)),
                exception=ValueError,
                message="`solver_type` must be lower than or equal to the"
                "`solver_type` of `processed_samples` object",
            )
            self.__solver_type = solver_type

        self.new_npix = processed_samples.new_npix
        self.size = processed_samples.new_npix * self.solver_type

        if self.solver_type == 1:
            self.weighted_counts = processed_samples.weighted_counts
        else:
            self.weighted_sin_sq = processed_samples.weighted_sin_sq
            self.weighted_cos_sq = processed_samples.weighted_cos_sq
            self.weighted_sincos = processed_samples.weighted_sincos
            self.one_over_determinant = processed_samples.one_over_determinant
            if self.solver_type == 3:
                self.weighted_counts = processed_samples.weighted_counts
                self.weighted_sin = processed_samples.weighted_sin
                self.weighted_cos = processed_samples.weighted_cos

        if self.solver_type == 1:
            super(BlockDiagonalPreconditionerLO, self).__init__(
                nargin=self.size,
                nargout=self.size,
                symmetric=True,
                matvec=self._mult_I,
                dtype=processed_samples.dtype_float,
            )
        elif self.solver_type == 2:
            super(BlockDiagonalPreconditionerLO, self).__init__(
                nargin=self.size,
                nargout=self.size,
                symmetric=True,
                matvec=self._mult_QU,
                dtype=processed_samples.dtype_float,
            )
        else:
            super(BlockDiagonalPreconditionerLO, self).__init__(
                nargin=self.size,
                nargout=self.size,
                symmetric=True,
                matvec=self._mult_IQU,
                dtype=processed_samples.dtype_float,
            )

    def _mult_I(self, vec: np.ndarray):
        r"""
        Action of :math:`y=( A  diag(N^{-1}) A^T)^{-1} x`,
        where :math:`x` is   an :math:`n_{pix}` array.
        """

        MPI_RAISE_EXCEPTION(
            condition=(len(vec) != self.size),
            exception=ValueError,
            message=f"Dimenstions of `vec` is not compatible with the dimension of this `BlockDiagonalPreconditionerLO` instance.\nShape of `BlockDiagonalPreconditionerLO` instance: {self.shape}\nShape of `vec`: {vec.shape}",
        )

        if vec.dtype != self.dtype:
            if MPI_UTILS.rank == 0:
                warnings.warn(
                    f"dtype of `vec` will be changed to {self.dtype}",
                    TypeChangeWarning,
                )
            vec = vec.astype(dtype=self.dtype, copy=False)

        prod = vec / self.weighted_counts

        return prod

    def _mult_QU(self, vec: np.ndarray):
        r"""
        Action of :math:`y=( A  diag(N^{-1}) A^T)^{-1} x`,
        where :math:`x` is   an :math:`n_{pix}` array.
        """

        MPI_RAISE_EXCEPTION(
            condition=(len(vec) != self.size),
            exception=ValueError,
            message=f"Dimenstions of `vec` is not compatible with the dimension of this `BlockDiagonalPreconditionerLO` instance.\nShape of `BlockDiagonalPreconditionerLO` instance: {self.shape}\nShape of `vec`: {vec.shape}",
        )

        if vec.dtype != self.dtype:
            if MPI_UTILS.rank == 0:
                warnings.warn(
                    f"dtype of `vec` will be changed to {self.dtype}",
                    TypeChangeWarning,
                )
            vec = vec.astype(dtype=self.dtype, copy=False)

        prod = np.zeros(self.size, dtype=self.dtype)

        BlkDiagPrecondLO_tools.BDPLO_mult_QU(
            new_npix=self.new_npix,
            weighted_sin_sq=self.weighted_sin_sq,
            weighted_cos_sq=self.weighted_cos_sq,
            weighted_sincos=self.weighted_sincos,
            one_over_determinant=self.one_over_determinant,
            vec=vec,
            prod=prod,
        )

        return prod

    def _mult_IQU(self, vec: np.ndarray):
        r"""
        Action of :math:`y=( A  diag(N^{-1}) A^T)^{-1} x`,
        where :math:`x` is   an :math:`n_{pix}` array.
        """

        MPI_RAISE_EXCEPTION(
            condition=(len(vec) != self.size),
            exception=ValueError,
            message=f"Dimenstions of `vec` is not compatible with the dimension of this `BlockDiagonalPreconditionerLO` instance.\nShape of `BlockDiagonalPreconditionerLO` instance: {self.shape}\nShape of `vec`: {vec.shape}",
        )

        if vec.dtype != self.dtype:
            if MPI_UTILS.rank == 0:
                warnings.warn(
                    f"dtype of `vec` will be changed to {self.dtype}",
                    TypeChangeWarning,
                )
            vec = vec.astype(dtype=self.dtype, copy=False)

        prod = np.zeros(self.size, dtype=self.dtype)

        BlkDiagPrecondLO_tools.BDPLO_mult_IQU(
            new_npix=self.new_npix,
            weighted_counts=self.weighted_counts,
            weighted_sin_sq=self.weighted_sin_sq,
            weighted_cos_sq=self.weighted_cos_sq,
            weighted_sincos=self.weighted_sincos,
            weighted_sin=self.weighted_sin,
            weighted_cos=self.weighted_cos,
            one_over_determinant=self.one_over_determinant,
            vec=vec,
            prod=prod,
        )

        return prod

    @property
    def solver_type(self):
        return self.__solver_type