Skip to content

brahmap.math.parallel_norm

A replacement of np.linalg.norm to compute 2-norm of a vector distributed among multiple MPI processes

Parameters:

Name Type Description Default
x ndarray

Input array

required

Returns:

Type Description
float

The norm of vector x

Source code in brahmap/math/linalg.py
def parallel_norm(x: np.ndarray) -> float:
    """A replacement of `np.linalg.norm` to compute 2-norm of a vector
    distributed among multiple MPI processes

    Parameters
    ----------
    x : np.ndarray
        Input array

    Returns
    -------
    float
        The norm of vector `x`
    """
    sqnorm = x.dot(x)
    sqnorm = MPI_UTILS.comm.allreduce(sqnorm)
    ret = np.sqrt(sqnorm)
    return ret