Source code for pygeoinf.parallel

"""
A collection of helper functions for parallel computation using Joblib.

These functions are designed to be top-level to ensure they can be
"pickled" (serialized) and sent to worker processes by libraries like
multiprocessing or its wrapper, Joblib.
"""

from __future__ import annotations
from typing import TYPE_CHECKING, Union
import numpy as np
from joblib import Parallel, delayed

if TYPE_CHECKING:
    from scipy.sparse.linalg import LinearOperator as ScipyLinOp

    MatrixLike = Union[np.ndarray, ScipyLinOp]


[docs] def parallel_mat_mat(A: MatrixLike, B: np.ndarray, n_jobs: int = -1) -> np.ndarray: """ Computes the matrix product A @ B in parallel by applying A to each column of B. This is particularly useful when A is a LinearOperator whose action is computationally expensive. Args: A: The matrix or LinearOperator to apply. B: The matrix whose columns will be operated on. n_jobs: The number of CPU cores to use. -1 means all available. Returns: The result of the matrix product A @ B as a dense NumPy array. """ columns = Parallel(n_jobs=n_jobs)( delayed(A.__matmul__)(B[:, i]) for i in range(B.shape[1]) ) return np.column_stack(columns)
[docs] def parallel_compute_dense_matrix_from_scipy_op( scipy_op: "ScipyLinOp", n_jobs: int = -1 ) -> np.ndarray: """ Computes the dense matrix representation of a scipy.LinearOperator in parallel. It builds the matrix column by column by applying the operator to each basis vector. Args: scipy_op: The SciPy LinearOperator wrapper for the matrix action. n_jobs: The number of CPU cores to use. -1 means all available. Returns: The dense matrix as a NumPy array. """ codomain_dim, domain_dim = scipy_op.shape columns = Parallel(n_jobs=n_jobs)( delayed(_worker_compute_scipy_op_col)(scipy_op, j, domain_dim) for j in range(domain_dim) ) return np.column_stack(columns)
def _worker_compute_scipy_op_col( scipy_op: "ScipyLinOp", j: int, domain_dim: int ) -> np.ndarray: """ (Internal worker) Computes a single column of a SciPy LinearOperator's matrix. """ cx = np.zeros(domain_dim) cx[j] = 1.0 return scipy_op @ cx