Skip to content

brahmap.base.BlockDiagonalLinearOperator

Bases: LinearOperator

Base class for a block-diagonal linear operator

Parameters:

Name Type Description Default
block_list List[LinearOperator]

description

required
**kwargs Any

description

{}
Source code in brahmap/base/blkop.py
class BlockDiagonalLinearOperator(LinearOperator):
    """Base class for a block-diagonal linear operator

    Parameters
    ----------
    block_list : List[LinearOperator]
        _description_
    **kwargs: Any
        _description_
    """

    def __init__(
        self,
        block_list: List[LinearOperator],
        **kwargs: Any,
    ):
        try:
            for block in block_list:
                __, __ = block.shape
        except (TypeError, AttributeError):
            MPI_RAISE_EXCEPTION(
                condition=True,
                exception=ValueError,
                message="The `block_list` must be a flat list of linearoperators",
            )

        self.__row_size = np.asarray(
            [block.shape[0] for block in block_list], dtype=int
        )
        self.__col_size = np.asarray(
            [block.shape[-1] for block in block_list], dtype=int
        )

        nargin = sum(self.__col_size)
        nargout = sum(self.__row_size)
        symmetric = reduce(
            lambda x, y: x and y, [block.symmetric for block in block_list]
        )
        dtype = np.result_type(*[block.dtype for block in block_list])

        self.__block_list = block_list

        # transpose operator
        blocks_list_transposed = [block.T for block in block_list]

        matvec = partial(
            self._mult,
            block_list=self.block_list,
            dtype=dtype,
        )
        rmatvec = partial(
            self._mult,
            block_list=blocks_list_transposed,
            dtype=dtype,
        )

        super(BlockDiagonalLinearOperator, self).__init__(
            nargin=nargin,
            nargout=nargout,
            symmetric=symmetric,
            matvec=matvec,
            rmatvec=rmatvec,
            dtype=dtype,
            **kwargs,
        )

    @property
    def block_list(self) -> List:
        return self.__block_list

    @property
    def num_blocks(self) -> int:
        return len(self.block_list)

    @property
    def row_size(self) -> np.ndarray:
        return self.__row_size

    @property
    def col_size(self) -> np.ndarray:
        return self.__col_size

    def __getitem__(self, idx):
        block_range = self.block_list[idx]
        if isinstance(idx, slice):
            return BlockDiagonalLinearOperator(
                block_list=block_range,
            )
        else:
            return block_range

    def _mult(self, vec: np.ndarray, block_list: List, dtype) -> np.ndarray:
        nrows = sum([block.shape[0] for block in block_list])
        ncols = sum([block.shape[1] for block in block_list])
        MPI_RAISE_EXCEPTION(
            condition=(len(vec) != ncols),
            exception=ValueError,
            message=f"Dimensions of `vec` is not compatible with the dimensions of this `BlockDiagonalLinearOperator` instance.\nShape of `BlockDiagonalLinearOperator` instance: ({nrows, ncols})\nShape of `vec`: {vec.shape}",
        )

        if vec.dtype != dtype:
            if MPI_UTILS.rank == 0:
                warnings.warn(
                    f"dtype of `vec` will be changed to {dtype}",
                    TypeChangeWarning,
                )
            vec = vec.astype(dtype=dtype, copy=False)

        prod = np.zeros(nrows, dtype=dtype)

        start_row_idx = 0
        start_col_idx = 0
        for idx, block in enumerate(block_list):
            end_row_idx = start_row_idx + block.shape[0]
            end_col_idx = start_col_idx + block.shape[1]

            prod[start_row_idx:end_row_idx] = block * vec[start_col_idx:end_col_idx]

            start_row_idx = end_row_idx
            start_col_idx = end_col_idx

        return prod