class BlockDiagonalLinearOperator(LinearOperator):
"""Base class for a block-diagonal linear operator
Parameters
----------
block_list : List[LinearOperator]
_description_
**kwargs: Any
_description_
"""
def __init__(
self,
block_list: List[LinearOperator],
**kwargs: Any,
):
try:
for block in block_list:
__, __ = block.shape
except (TypeError, AttributeError):
MPI_RAISE_EXCEPTION(
condition=True,
exception=ValueError,
message="The `block_list` must be a flat list of linearoperators",
)
self.__row_size = np.asarray(
[block.shape[0] for block in block_list], dtype=int
)
self.__col_size = np.asarray(
[block.shape[-1] for block in block_list], dtype=int
)
nargin = sum(self.__col_size)
nargout = sum(self.__row_size)
symmetric = reduce(
lambda x, y: x and y, [block.symmetric for block in block_list]
)
dtype = np.result_type(*[block.dtype for block in block_list])
self.__block_list = block_list
# transpose operator
blocks_list_transposed = [block.T for block in block_list]
matvec = partial(
self._mult,
block_list=self.block_list,
dtype=dtype,
)
rmatvec = partial(
self._mult,
block_list=blocks_list_transposed,
dtype=dtype,
)
super(BlockDiagonalLinearOperator, self).__init__(
nargin=nargin,
nargout=nargout,
symmetric=symmetric,
matvec=matvec,
rmatvec=rmatvec,
dtype=dtype,
**kwargs,
)
@property
def block_list(self) -> List:
return self.__block_list
@property
def num_blocks(self) -> int:
return len(self.block_list)
@property
def row_size(self) -> np.ndarray:
return self.__row_size
@property
def col_size(self) -> np.ndarray:
return self.__col_size
def __getitem__(self, idx):
block_range = self.block_list[idx]
if isinstance(idx, slice):
return BlockDiagonalLinearOperator(
block_list=block_range,
)
else:
return block_range
def _mult(self, vec: np.ndarray, block_list: List, dtype) -> np.ndarray:
nrows = sum([block.shape[0] for block in block_list])
ncols = sum([block.shape[1] for block in block_list])
MPI_RAISE_EXCEPTION(
condition=(len(vec) != ncols),
exception=ValueError,
message=f"Dimensions of `vec` is not compatible with the dimensions of this `BlockDiagonalLinearOperator` instance.\nShape of `BlockDiagonalLinearOperator` instance: ({nrows, ncols})\nShape of `vec`: {vec.shape}",
)
if vec.dtype != dtype:
if MPI_UTILS.rank == 0:
warnings.warn(
f"dtype of `vec` will be changed to {dtype}",
TypeChangeWarning,
)
vec = vec.astype(dtype=dtype, copy=False)
prod = np.zeros(nrows, dtype=dtype)
start_row_idx = 0
start_col_idx = 0
for idx, block in enumerate(block_list):
end_row_idx = start_row_idx + block.shape[0]
end_col_idx = start_col_idx + block.shape[1]
prod[start_row_idx:end_row_idx] = block * vec[start_col_idx:end_col_idx]
start_row_idx = end_row_idx
start_col_idx = end_col_idx
return prod