Skip to content

brahmap.math.parallel_norm

Computes the 2-norm of a vector distributed among multiple MPI processes as a replacement for np.linalg.norm().

Parameters:

Name Type Description Default
x NDArray[number]

The input array to compute the norm for

required

Returns:

Type Description
float

The final computed 2-norm of the vector \(x\)

Source code in brahmap/math/linalg.py
def parallel_norm(x: npt.NDArray[np.number]) -> float:
    """Computes the 2-norm of a vector distributed among multiple MPI
    processes as a replacement for
    [`np.linalg.norm()`](https://numpy.org/doc/stable/reference/generated/numpy.linalg.norm.html).

    Parameters
    ----------
    x : npt.NDArray[np.number]
        The input array to compute the norm for

    Returns
    -------
    float
        The final computed 2-norm of the vector $x$
    """
    sqnorm = x.dot(x)
    sqnorm = MPI_UTILS.comm.allreduce(sqnorm)
    ret = np.sqrt(sqnorm)
    return ret