"""
Provides a self-checking mechanism for LinearOperator implementations.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
# Import the base checks from the sibling module
from .nonlinear_operators import NonLinearOperatorAxiomChecks
if TYPE_CHECKING:
from ..hilbert_space import Vector
from ..linear_forms import LinearForm
from ..gaussian_measure import GaussianMeasure
[docs]
class LinearOperatorAxiomChecks(NonLinearOperatorAxiomChecks):
"""
A mixin for checking the properties of a LinearOperator.
Inherits the derivative check from NonLinearOperatorAxiomChecks and adds
checks for linearity and the adjoint identity.
"""
def _check_linearity(
self,
x: Vector,
y: Vector,
a: float,
b: float,
check_rtol: float = 1e-5,
check_atol: float = 1e-8,
):
"""Verifies the linearity property: L(ax + by) = a*L(x) + b*L(y)"""
ax_plus_by = self.domain.add(
self.domain.multiply(a, x), self.domain.multiply(b, y)
)
lhs = self(ax_plus_by)
aLx = self.codomain.multiply(a, self(x))
bLy = self.codomain.multiply(b, self(y))
rhs = self.codomain.add(aLx, bLy)
# Compare the results in the codomain
diff_norm = self.codomain.norm(self.codomain.subtract(lhs, rhs))
rhs_norm = self.codomain.norm(rhs)
relative_error = diff_norm / (rhs_norm + 1e-12)
if relative_error > check_rtol and diff_norm > check_atol:
raise AssertionError(
f"Linearity check failed: L(ax+by) != aL(x)+bL(y). "
f"Relative error: {relative_error:.2e} (Tol: {check_rtol:.2e}), "
f"Absolute error: {diff_norm:.2e} (Tol: {check_atol:.2e})"
)
def _check_adjoint_definition(
self,
x: Vector,
y: Vector,
check_rtol: float = 1e-5,
check_atol: float = 1e-8,
):
"""Verifies the adjoint identity: <L(x), y> = <x, L*(y)>"""
lhs = self.codomain.inner_product(self(x), y)
rhs = self.domain.inner_product(x, self.adjoint(y))
if not np.isclose(lhs, rhs, rtol=check_rtol, atol=check_atol):
raise AssertionError(
f"Adjoint definition failed: <L(x),y> = {lhs:.4e}, "
f"but <x,L*(y)> = {rhs:.4e} (RelTol: {check_rtol:.2e}, AbsTol: {check_atol:.2e})"
)
def _check_algebraic_identities(
self,
op1,
op2,
x,
y,
a,
check_rtol: float = 1e-5,
check_atol: float = 1e-8,
):
"""
Verifies the algebraic properties of the adjoint and dual operators.
Requires a second compatible operator (op2).
"""
def _check_norm_based(res1, res2, space, axiom_name):
"""Helper to perform norm-based comparison."""
diff_norm = space.norm(space.subtract(res1, res2))
norm_res2 = space.norm(res2)
if diff_norm > check_atol and diff_norm > check_rtol * (norm_res2 + 1e-12):
raise AssertionError(
f"Axiom failed: {axiom_name}. "
f"Absolute error: {diff_norm:.2e}, Relative error: {diff_norm / (norm_res2 + 1e-12):.2e}"
)
# --- Adjoint Identities ---
# (A+B)* = A* + B*
res1 = (op1 + op2).adjoint(y)
res2 = (op1.adjoint + op2.adjoint)(y)
_check_norm_based(res1, res2, op1.domain, "(A+B)* != A* + B*")
# (a*A)* = a*A*
res1 = (a * op1).adjoint(y)
res2 = (a * op1.adjoint)(y)
_check_norm_based(res1, res2, op1.domain, "(a*A)* != a*A*")
# (A*)* = A
res1 = op1.adjoint.adjoint(x)
res2 = op1(x)
_check_norm_based(res1, res2, op1.codomain, "(A*)* != A")
# (A@B)* = B*@A*
if op1.domain == op2.codomain:
res1 = (op1 @ op2).adjoint(y)
res2 = (op2.adjoint @ op1.adjoint)(y)
_check_norm_based(res1, res2, op2.domain, "(A@B)* != B*@A*")
# --- Dual Identities ---
# (A+B)' = A' + B'
op_sum_dual = (op1 + op2).dual
dual_sum = op1.dual + op2.dual
y_dual = op1.codomain.to_dual(y)
# The result of applying a dual operator is a LinearForm
res1_form: LinearForm = op_sum_dual(y_dual)
res2_form: LinearForm = dual_sum(y_dual)
# CORRECTED: Use LinearForm subtraction and dual space norm
# (This assumes LinearForm overloads __sub__)
try:
diff_form = res1_form - res2_form
diff_norm = op1.domain.dual.norm(diff_form)
norm_res2 = op1.domain.dual.norm(res2_form)
if diff_norm > check_atol and diff_norm > check_rtol * (norm_res2 + 1e-12):
raise AssertionError(
f"Axiom failed: (A+B)' != A' + B'. "
f"Absolute error: {diff_norm:.2e}, Relative error: {diff_norm / (norm_res2 + 1e-12):.2e}"
)
except (AttributeError, TypeError):
# Fallback if LinearForm doesn't support subtraction or norm
if not np.allclose(
res1_form.components,
res2_form.components,
rtol=check_rtol,
atol=check_atol,
):
raise AssertionError(
"Axiom failed: (A+B)' != A' + B' (component check)."
)
[docs]
def check(
self,
/,
*,
n_checks: int = 5,
op2=None,
check_rtol: float = 1e-5,
check_atol: float = 1e-8,
domain_measure: GaussianMeasure = None,
codomain_measure: GaussianMeasure = None,
) -> None:
"""
Runs all checks for the LinearOperator, including non-linear checks
and algebraic identities.
Args:
n_checks: The number of randomized trials to perform.
op2: An optional second operator for testing algebraic rules.
check_rtol: The relative tolerance for numerical checks.
check_atol: The absolute tolerance for numerical checks.
measure: A GaussianMeasure on the space from which
random samples can be generated. Defaults to None,
in which case the classes .random() method is used.
"""
# First, run the parent (non-linear) checks from the base class
super().check(
n_checks=n_checks,
op2=op2,
check_rtol=check_rtol,
check_atol=check_atol,
measure=domain_measure,
)
if domain_measure is None:
domain_sampler = self.domain.random
else:
if domain_measure.domain != self.domain:
raise ValueError(
"Provided measure must be defined on the operator's domain"
)
domain_sampler = domain_measure.sample
if codomain_measure is None:
codomain_sampler = self.codomain.random
else:
if codomain_measure.domain != self.codomain:
raise ValueError(
"Provided measure must be defined on the operator's codomain"
)
codomain_sampler = codomain_measure.sample
# Now, run the linear-specific checks
print(
f"Running {n_checks} additional randomized checks for linearity and adjoints..."
)
for _ in range(n_checks):
x1 = domain_sampler()
x2 = domain_sampler()
y = codomain_sampler()
a, b = np.random.randn(), np.random.randn()
self._check_linearity(
x1, x2, a, b, check_rtol=check_rtol, check_atol=check_atol
)
self._check_adjoint_definition(
x1, y, check_rtol=check_rtol, check_atol=check_atol
)
if op2:
self._check_algebraic_identities(
self, op2, x1, y, a, check_rtol=check_rtol, check_atol=check_atol
)
print(f"[✓] All {n_checks} linear operator checks passed successfully.")