class BlockLinearOperator(LinearOperator):
"""
A linear operator defined by blocks. Each block must be a linear operator.
`blocks` should be a list of lists describing the blocks row-wise.
If there is only one block row, it should be specified as
`[[b1, b2, ..., bn]]`, not as `[b1, b2, ..., bn]`.
If the overall linear operator is symmetric, only its upper triangle
need be specified, e.g., `[[A,B,C], [D,E], [F]]`, and the blocks on the
diagonal must be square and symmetric.
Parameters
----------
blocks : List[LinearOperator]
_description_
symmetric : bool, optional
_description_, by default False
**kwargs: Any
_description_
"""
def __init__(
self,
blocks: List[LinearOperator],
symmetric: bool = False,
**kwargs: Any,
):
# If building a symmetric operator, fill in the blanks.
# They're just references to existing objects.
try:
for block_row in blocks:
for block_col in block_row:
__ = block_col.shape
except (TypeError, AttributeError):
raise ValueError("blocks should be a nested list of operators")
if symmetric:
nrow = len(blocks)
ncol = len(blocks[0])
if nrow != ncol:
raise ShapeError("Inconsistent shape.")
for block_row in blocks:
if not block_row[0].symmetric:
raise ValueError("Blocks on diagonal must be symmetric.")
self._blocks = blocks[:]
for i in range(1, nrow):
for j in range(i - 1, -1, -1):
self._blocks[i].insert(0, self._blocks[j][i].T)
else:
self._blocks = blocks
log = kwargs.get("logger", null_log)
log.debug("Building new BlockLinearOperator")
nargins = [[blk.shape[-1] for blk in row] for row in self._blocks]
log.debug("nargins = " + repr(nargins))
nargins_by_row = [nargin[0] for nargin in nargins]
if min(nargins_by_row) != max(nargins_by_row):
raise ShapeError("Inconsistent block shapes")
nargouts = [[blk.shape[0] for blk in row] for row in self._blocks]
log.debug("nargouts = " + repr(nargouts))
for row in nargouts:
if min(row) != max(row):
raise ShapeError("Inconsistent block shapes")
nargin = sum(nargins[0])
nargout = sum([out[0] for out in nargouts])
# Create blocks of transpose operator.
blocksT = list(map(lambda *row: [blk.T for blk in row], *self._blocks))
def blk_matvec(x, blks):
nargins = [[blk.shape[-1] for blk in blkrow] for blkrow in blks]
nargouts = [[blk.shape[0] for blk in blkrow] for blkrow in blks]
nargin = sum(nargins[0])
nargout = sum([out[0] for out in nargouts])
nx = len(x)
self.logger.debug("Multiplying with a vector of size %d" % nx)
self.logger.debug("nargin=%d, nargout=%d" % (nargin, nargout))
if nx != nargin:
raise ShapeError("Multiplying with vector of wrong shape.")
result_type = np.result_type(self.dtype, x.dtype)
y = np.zeros(nargout, dtype=result_type)
nblk_row = len(blks)
nblk_col = len(blks[0])
row_start = col_start = 0
for row in range(nblk_row):
row_end = row_start + nargouts[row][0]
yout = y[row_start:row_end]
for col in range(nblk_col):
col_end = col_start + nargins[0][col]
xin = x[col_start:col_end]
B = blks[row][col]
yout[:] += B * xin
col_start = col_end
row_start = row_end
col_start = 0
return y
flat_blocks = list(itertools.chain(*blocks))
blk_dtypes = [blk.dtype for blk in flat_blocks]
op_dtype = np.result_type(*blk_dtypes)
super(BlockLinearOperator, self).__init__(
nargin,
nargout,
symmetric=symmetric,
matvec=lambda x: blk_matvec(x, self._blocks),
rmatvec=lambda x: blk_matvec(x, blocksT),
dtype=op_dtype,
**kwargs,
)
self.H._blocks = blocksT
@property
def blocks(self):
"""The list of blocks defining the block operator."""
return self._blocks
def __getitem__(self, indices):
blks = np.matrix(self._blocks, dtype=object)[indices]
# If indexing narrowed it down to a single block, return it.
if isinstance(blks, BaseLinearOperator):
return blks
# Otherwise, we have a matrix of blocks.
return BlockLinearOperator(blks.tolist(), symmetric=False)
def __contains__(self, op):
flat_blocks = list(itertools.chain(*self.blocks))
return op in flat_blocks
def __iter__(self):
for block in self._blocks:
yield block