Source code for pygeoinf.affine_operators

"""
Provides the `AffineOperator` class for affine mappings between Hilbert spaces.
"""

from __future__ import annotations
from typing import Any

from .checks.affine_operators import AffineOperatorAxiomChecks
from .linear_operators import LinearOperator
from .nonlinear_operators import NonLinearOperator


[docs] class AffineOperator(AffineOperatorAxiomChecks, NonLinearOperator): """ Represents an affine transformation between two Hilbert spaces. An affine operator is a mapping F(x) = A(x) + b, where 'A' is a bounded linear operator and 'b' is a fixed translation vector in the codomain. """ def __init__( self, linear_part: LinearOperator, translation: Any, ) -> None: """ Initializes the AffineOperator. Args: linear_part: The underlying linear operator 'A'. translation: The translation vector 'b' in the operator's codomain. If the translation is the zero vector, this operator is equivalent to the linear part. Raises: TypeError: If the translation vector is not an element of the linear operator's codomain. """ if not linear_part.codomain.is_element(translation): raise TypeError( "The translation vector must be an element of the linear " "operator's codomain." ) self._linear_part = linear_part self._translation = translation domain = linear_part.domain codomain = linear_part.codomain def mapping(x: Any) -> Any: # F(x) = A(x) + b return codomain.add(self._linear_part(x), self._translation) def derivative(x: Any) -> LinearOperator: # The Fréchet derivative of A(x) + b is just A return self._linear_part super().__init__(domain, codomain, mapping, derivative=derivative) @property def linear_part(self) -> LinearOperator: """The underlying linear mapping 'A'.""" return self._linear_part @property def translation_part(self) -> Any: """The translation vector 'b'.""" return self._translation # ------------------------------------------------------------------- # # Algebraic Overloads # # ------------------------------------------------------------------- # def __add__(self, other: Any) -> Any: if isinstance(other, AffineOperator): if self.domain != other.domain or self.codomain != other.codomain: raise ValueError("Domains and codomains must match for addition.") new_linear = self.linear_part + other.linear_part new_translation = self.codomain.add( self.translation_part, other.translation_part ) return AffineOperator(new_linear, new_translation) elif isinstance(other, LinearOperator): if self.domain != other.domain or self.codomain != other.codomain: raise ValueError("Domains and codomains must match for addition.") # F(x) + L(x) = (A + L)x + a new_linear = self.linear_part + other return AffineOperator(new_linear, self.translation_part) return super().__add__(other) def __radd__(self, other: Any) -> Any: if isinstance(other, LinearOperator): # Addition is commutative: L(x) + F(x) = F(x) + L(x) return self.__add__(other) return super().__radd__(other) def __sub__(self, other: Any) -> Any: if isinstance(other, AffineOperator): if self.domain != other.domain or self.codomain != other.codomain: raise ValueError("Domains and codomains must match for subtraction.") new_linear = self.linear_part - other.linear_part new_translation = self.codomain.subtract( self.translation_part, other.translation_part ) return AffineOperator(new_linear, new_translation) elif isinstance(other, LinearOperator): if self.domain != other.domain or self.codomain != other.codomain: raise ValueError("Domains and codomains must match for subtraction.") # F(x) - L(x) = (A - L)x + a new_linear = self.linear_part - other return AffineOperator(new_linear, self.translation_part) return super().__sub__(other) def __rsub__(self, other: Any) -> Any: if isinstance(other, LinearOperator): if self.domain != other.domain or self.codomain != other.codomain: raise ValueError("Domains and codomains must match for subtraction.") # L(x) - F(x) = (L - A)x - a new_linear = other - self.linear_part new_translation = self.codomain.negative(self.translation_part) return AffineOperator(new_linear, new_translation) return super().__rsub__(other) def __mul__(self, alpha: float) -> "AffineOperator": if not isinstance(alpha, (int, float)): return NotImplemented new_linear = self.linear_part * alpha new_translation = self.codomain.multiply(alpha, self.translation_part) return AffineOperator(new_linear, new_translation) def __rmul__(self, alpha: float) -> "AffineOperator": return self.__mul__(alpha) def __matmul__(self, other: Any) -> Any: if isinstance(other, AffineOperator): if self.domain != other.codomain: raise ValueError( "Codomain of right operator must match domain of left operator." ) # F(G(x)) = (A @ B)x + (A(b) + a) new_linear = self.linear_part @ other.linear_part A_b = self.linear_part(other.translation_part) new_translation = self.codomain.add(A_b, self.translation_part) return AffineOperator(new_linear, new_translation) elif isinstance(other, LinearOperator): if self.domain != other.codomain: raise ValueError( "Codomain of right operator must match domain of left operator." ) # F(L(x)) = (A @ L)x + a new_linear = self.linear_part @ other return AffineOperator(new_linear, self.translation_part) return super().__matmul__(other) def __rmatmul__(self, other: Any) -> Any: if isinstance(other, LinearOperator): if other.domain != self.codomain: raise ValueError( "Codomain of right operator must match domain of left operator." ) # L(F(x)) = (L @ A)x + L(a) new_linear = other @ self.linear_part new_translation = other(self.translation_part) return AffineOperator(new_linear, new_translation) # We don't override for non-linear operators, let the base class handle it return super().__rmatmul__(other)