Skip to content

brahmap.core.compute_GLS_maps_from_PTS

Source code in brahmap/core/GLS.py
def compute_GLS_maps_from_PTS(
    processed_samples: ProcessTimeSamples,
    time_ordered_data: np.ndarray,
    inv_noise_cov_operator: Union[DTypeNoiseCov, None] = None,
    gls_parameters: GLSParameters = GLSParameters(),
) -> GLSResult:
    MPI_RAISE_EXCEPTION(
        condition=(processed_samples.nsamples != len(time_ordered_data)),
        exception=ValueError,
        message=f"Size of `pointings` must be equal to the size of `time_ordered_data` array:\nlen(pointings) = {processed_samples.nsamples}\nlen(time_ordered_data) = {len(time_ordered_data)}",
    )

    try:
        time_ordered_data = time_ordered_data.astype(
            dtype=processed_samples.dtype_float, casting="safe", copy=False
        )
    except TypeError:
        raise TypeError(
            f"The `time_ordered_data` array has higher dtype than `processed_samples.dtype_float={processed_samples.dtype_float}`. Please compute `processed_samples` again with `dtype_float={time_ordered_data.dtype}`"
        )

    if inv_noise_cov_operator is None:
        inv_noise_cov_operator = InvNoiseCovLO_Diagonal(
            size=processed_samples.nsamples, dtype=processed_samples.dtype_float
        )
    else:
        MPI_RAISE_EXCEPTION(
            condition=(inv_noise_cov_operator.shape[0] != processed_samples.nsamples),
            exception=ValueError,
            message=f"The shape of `inv_noise_cov_operator` must be same as `(len(time_ordered_data), len(time_ordered_data))`:\nlen(time_ordered_data) = {len(time_ordered_data)}\ninv_noise_cov_operator.shape = ({inv_noise_cov_operator.shape}, {inv_noise_cov_operator.shape})",
        )

    pointing_operator = PointingLO(
        processed_samples=processed_samples, solver_type=gls_parameters.solver_type
    )

    blockdiagprecond_operator = BlockDiagonalPreconditionerLO(
        processed_samples=processed_samples, solver_type=gls_parameters.solver_type
    )

    b = pointing_operator.T * inv_noise_cov_operator * time_ordered_data

    num_iterations = 0
    if gls_parameters.use_iterative_solver:

        def callback_function(x, r, norm_residual):
            nonlocal num_iterations
            num_iterations += 1
            if gls_parameters.callback_function is not None:
                gls_parameters.callback_function(x, r, norm_residual)

        A = pointing_operator.T * inv_noise_cov_operator * pointing_operator

        map_vector, pcg_status = cg(
            A=A,
            b=b,
            atol=gls_parameters.isolver_threshold,
            maxiter=gls_parameters.isolver_max_iterations,
            M=blockdiagprecond_operator,
            callback=callback_function,
        )
    else:
        pcg_status = 0
        map_vector = blockdiagprecond_operator * b

    output_maps = separate_map_vectors(
        map_vector=map_vector, processed_samples=processed_samples
    )

    if gls_parameters.return_hit_map:
        hit_map = processed_samples.get_hit_counts()
    else:
        hit_map = None

    if pcg_status != 0:
        convergence_status = False
    else:
        convergence_status = True

    gls_result = GLSResult(
        solver_type=processed_samples.solver_type,
        npix=processed_samples.npix,
        new_npix=processed_samples.new_npix,
        GLS_maps=output_maps,
        hit_map=hit_map,
        convergence_status=convergence_status,
        num_iterations=num_iterations,
        GLSParameters=gls_parameters,
    )

    return gls_result