"""
core.py
A dimension-agnostic engine for numerical integration, statistical analysis,
Bayesian inference, and generic visualisation.
"""
from typing import Callable, List, Tuple, Union, Optional, Any, Dict
from abc import ABC, abstractmethod
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp, trapezoid
from scipy.interpolate import RegularGridInterpolator
from scipy.stats import multivariate_normal
from IPython.display import HTML
# --- Math Utilities ---
[docs]
def wrap_angle(theta: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
"""
Wraps an angle or array of angles to the interval [-pi, pi].
Args:
theta: Input angle(s) in radians.
Returns:
The wrapped angle(s) in [-pi, pi].
"""
return (theta + np.pi) % (2 * np.pi) - np.pi
# --- Numerical Solvers (Dimension Agnostic) ---
[docs]
def solve_trajectory(
eom_func: Callable[[float, np.ndarray, Any], np.ndarray],
y0: np.ndarray,
t_points: np.ndarray,
args: Tuple = (),
rtol: float = 1e-9,
atol: float = 1e-12,
method: str = "RK45",
) -> np.ndarray:
"""
Integrates a single ODE trajectory over time.
Args:
eom_func: The Equation of Motion function f(t, y, *args) -> dy/dt.
y0: Initial state vector of shape (n_dim,).
t_points: Array of time points to evaluate at.
args: Tuple of extra arguments to pass to eom_func.
rtol: Relative tolerance for solver.
atol: Absolute tolerance for solver.
method: Integration method (e.g., 'RK45', 'DOP853').
Returns:
Solution array of shape (n_dim, n_times).
"""
t_span = (t_points[0], t_points[-1])
sol = solve_ivp(
eom_func,
t_span,
y0,
t_eval=t_points,
method=method,
rtol=rtol,
atol=atol,
args=args,
)
return sol.y
[docs]
def solve_ensemble(
eom_func: Callable[[float, np.ndarray, Any], np.ndarray],
initial_conditions: np.ndarray,
t_points: np.ndarray,
args: Tuple = (),
**solver_kwargs,
) -> np.ndarray:
"""
Propagates an ensemble of particles forward in time.
Args:
eom_func: The Equation of Motion function.
initial_conditions: Array of shape (n_samples, n_dim).
t_points: Array of time points.
args: Physics arguments passed to eom_func.
**solver_kwargs: Extra kwargs for solve_trajectory (rtol, atol, etc).
Returns:
Trajectories array of shape (n_samples, n_dim, n_times).
"""
n_samples, n_dim = initial_conditions.shape
n_times = len(t_points)
trajectories = np.zeros((n_samples, n_dim, n_times))
print(f"Propagating {n_samples} particles ({n_dim}D system)...")
for i in range(n_samples):
trajectories[i] = solve_trajectory(
eom_func, initial_conditions[i], t_points, args=args, **solver_kwargs
)
return trajectories
# --- Distribution Factories ---
[docs]
def get_gaussian_pdf(mean: np.ndarray, cov: np.ndarray) -> Callable[..., np.ndarray]:
"""
Returns a callable PDF function for an N-dimensional multivariate Gaussian.
The returned function is compatible with ProbabilityGrid factories.
Args:
mean: Mean vector of shape (N,).
cov: Covariance matrix of shape (N, N).
Returns:
A function pdf(*coordinates) -> density_array.
"""
# Create the frozen distribution object
dist = multivariate_normal(mean=mean, cov=cov)
def pdf_func(*coords):
# coords is a tuple of meshgrids (X, Y, Z, ...), each shape (Ni, Nj, Nk...)
# We stack them into shape (..., N_dim) for scipy
pos = np.stack(coords, axis=-1)
return dist.pdf(pos)
return pdf_func
[docs]
def get_independent_gaussian_pdf(
means: Union[List[float], np.ndarray], stds: Union[List[float], np.ndarray]
) -> Callable[..., np.ndarray]:
"""
Returns a callable PDF for N independent Gaussians (diagonal covariance).
Faster and simpler than the full multivariate version.
Args:
means: List or array of means for each dimension.
stds: List or array of standard deviations for each dimension.
Returns:
A function pdf(*coordinates) -> density_array.
"""
means = np.asarray(means)
stds = np.asarray(stds)
def pdf_func(*coords):
if len(coords) != len(means):
raise ValueError(
f"Number of coordinates ({len(coords)}) does not match dimension of means ({len(means)})."
)
# Calculate product of 1D Gaussians
result = 1.0
for i, val in enumerate(coords):
mu = means[i]
sigma = stds[i]
# Normalisation factor for this dimension
norm_const = 1.0 / (sigma * np.sqrt(2 * np.pi))
exponent = -0.5 * ((val - mu) / sigma) ** 2
result *= norm_const * np.exp(exponent)
return result
return pdf_func
# --- The Probability Grid Class ---
[docs]
class ProbabilityGrid:
"""
Encapsulates an N-dimensional probability density function discretised
on a rectilinear grid. Handles normalisation, marginalisation, and
Bayesian updates.
"""
def __init__(self, axes: List[np.ndarray], values: np.ndarray):
"""
Args:
axes: List of 1D arrays defining the grid coordinates for each dimension.
values: N-dimensional array of density values matching the shape of axes.
"""
self.axes = axes
self.values = values
self.ndim = len(axes)
self.shape = values.shape
# Validation
expected_shape = tuple(len(ax) for ax in self.axes)
if self.values.shape != expected_shape:
raise ValueError(
f"Grid shape {self.values.shape} mismatch with axes {expected_shape}"
)
[docs]
@classmethod
def from_bounds(
cls,
bounds: List[Tuple[float, float]],
resolution: Union[int, List[int]],
pdf_func: Optional[Callable[..., np.ndarray]] = None,
) -> "ProbabilityGrid":
"""
Factory: Creates a grid from bounds [(min, max), ...] and resolution.
Optionally evaluates a pdf_func on that grid immediately.
Args:
bounds: List of (min, max) tuples for each dimension.
resolution: Number of points per dimension (int or list of ints).
pdf_func: Optional callable f(x1, x2...) -> density to evaluate on initialisation.
Returns:
A new ProbabilityGrid instance.
"""
ndim = len(bounds)
# Handle scalar vs list resolution
if np.isscalar(resolution):
res_list = [int(resolution)] * ndim # type: ignore
else:
res_list = list(resolution) # type: ignore
# Create axes
axes = [np.linspace(b[0], b[1], r) for b, r in zip(bounds, res_list)]
if pdf_func:
# Create mesh and evaluate (indexing='ij' for matrix order)
mesh = np.meshgrid(*axes, indexing="ij")
values = pdf_func(*mesh)
else:
# Create empty grid
shape = tuple(len(a) for a in axes)
values = np.zeros(shape)
return cls(axes, values)
@property
def total_mass(self) -> float:
"""Computes the total integral (volume) of the grid using the trapezoidal rule."""
integral = self.values
# Integrate over each dimension
for i, axis_vals in enumerate(reversed(self.axes)):
current_dim = self.ndim - 1 - i
integral = trapezoid(integral, x=axis_vals, axis=current_dim)
return float(integral)
@property
def mean(self) -> np.ndarray:
"""
Computes the expected value vector E[x] of the distribution.
"""
# Ensure we are working with a normalised distribution for the expectation
grid_norm = self.normalise()
means = []
# Calculate mean for each dimension by marginalising to 1D
for i in range(self.ndim):
marginal = grid_norm.marginalise(keep_indices=(i,))
# E[x] = integral(x * p(x) dx)
expected_val = trapezoid(
marginal.axes[0] * marginal.values, marginal.axes[0]
)
means.append(expected_val)
return np.array(means)
[docs]
def normalise(self) -> "ProbabilityGrid":
"""
Returns a NEW ProbabilityGrid that sums to 1.0.
If mass is zero, returns the original grid to avoid division errors.
"""
mass = self.total_mass
if mass == 0:
return ProbabilityGrid(self.axes, self.values)
return ProbabilityGrid(self.axes, self.values / mass)
[docs]
def marginalise(self, keep_indices: Tuple[int, ...]) -> "ProbabilityGrid":
"""
Integrates out all axes NOT in keep_indices.
Args:
keep_indices: Tuple of dimension indices to retain (e.g., (0, 1)).
Returns:
A lower-dimensional ProbabilityGrid.
"""
keep_set = set(keep_indices)
all_indices = set(range(self.ndim))
integrate_indices = sorted(list(all_indices - keep_set), reverse=True)
current_values = self.values.copy()
for i in integrate_indices:
current_values = trapezoid(current_values, x=self.axes[i], axis=i)
new_axes = [self.axes[i] for i in sorted(keep_indices)]
return ProbabilityGrid(new_axes, current_values)
[docs]
def to_interpolator(self, fill_value: float = 0.0) -> Callable[..., np.ndarray]:
"""
Returns a callable function f(x1, x2, ...) backed by this grid.
Uses linear interpolation.
"""
interp = RegularGridInterpolator(
self.axes,
self.values,
bounds_error=False,
fill_value=fill_value,
method="linear",
)
def pdf_func(*coords):
# Wrapper to handle broadcasting for meshgrids
if len(coords) != self.ndim:
raise ValueError(f"PDF expects {self.ndim} coordinates.")
broadcasted = np.broadcast_arrays(*coords)
result_shape = broadcasted[0].shape
points = np.column_stack([b.ravel() for b in broadcasted])
values = interp(points)
return values.reshape(result_shape)
return pdf_func
[docs]
def sample(self, n_samples: int = 1000) -> np.ndarray:
"""
Efficiently samples from the grid using PMF approximation + jitter.
Args:
n_samples: Number of samples to draw.
Returns:
Array of shape (n_samples, n_dim) containing the samples.
"""
flat_pdf = self.values.ravel()
flat_pdf = np.maximum(flat_pdf, 0) # Clip negative noise
total_mass = flat_pdf.sum()
if total_mass == 0:
raise ValueError("PDF grid has zero total probability mass.")
pmf = flat_pdf / total_mass
# Discrete Sampling
rng = np.random.default_rng()
flat_indices = rng.choice(len(pmf), size=n_samples, p=pmf)
multi_indices = np.unravel_index(flat_indices, self.shape)
# Dithering (Jitter) to make continuous
samples = np.zeros((n_samples, self.ndim))
for d in range(self.ndim):
axis = self.axes[d]
indices = multi_indices[d]
if len(axis) > 1:
dx = axis[1] - axis[0]
else:
dx = 0.0
coords = axis[indices]
jitter = rng.uniform(-dx / 2, dx / 2, size=n_samples)
samples[:, d] = coords + jitter
return samples
[docs]
def push_forward(
self,
eom_func: Callable,
t_final: float,
eom_args: Tuple = (),
) -> "ProbabilityGrid":
"""
Performs Liouville Advection (Method of Characteristics).
Propagates the probability density forward in time by 't_final'.
Args:
eom_func: The differential equation f(t, y, *args).
t_final: The time duration to propagate forward.
eom_args: Physics arguments for the eom_func.
Returns:
A NEW ProbabilityGrid at t=t_final (normalised).
"""
# 1. Create interpolator for current state
initial_pdf_func = self.to_interpolator()
# 2. Setup target mesh (Eulerian: we use the same grid definition)
grids = np.meshgrid(*self.axes, indexing="ij")
# 3. Back-propagate Grid Points to t=0
flat_state = np.stack([g.ravel() for g in grids])
y0_vectorized = flat_state.reshape(-1)
def vectorized_eom(t, y_flat):
y_reshaped = y_flat.reshape(self.ndim, -1)
dydt = eom_func(t, y_reshaped, *eom_args)
return np.concatenate(dydt).reshape(-1)
# Integrate BACKWARDS: t_final -> 0
sol = solve_trajectory(
vectorized_eom,
y0_vectorized,
np.array([t_final, 0.0]),
rtol=1e-5,
atol=1e-5,
)
origins_flat = sol[:, -1]
origins = origins_flat.reshape(self.ndim, *self.shape)
# 4. Evaluate initial PDF at the back-propagated origins
advected_values = initial_pdf_func(*origins)
# 5. Return new normalised object
new_grid = ProbabilityGrid(self.axes, advected_values)
return new_grid.normalise()
def __mul__(
self, other: Union["ProbabilityGrid", np.ndarray, float]
) -> "ProbabilityGrid":
"""
Element-wise multiplication.
Supports: Grid * Grid, Grid * Scalar, Grid * Array.
"""
if isinstance(other, ProbabilityGrid):
if self.shape != other.shape:
raise ValueError("Grids must have same shape for multiplication.")
new_values = self.values * other.values
else:
# Assume numpy array or scalar
new_values = self.values * other
return ProbabilityGrid(self.axes, new_values)
def __truediv__(self, scalar: float) -> "ProbabilityGrid":
"""
Devide the grid values by a scalar
"""
return self * (1 / scalar)
[docs]
def bayes_update(
self, likelihood: Union["ProbabilityGrid", np.ndarray, float]
) -> Tuple["ProbabilityGrid", float]:
"""
Calculates Posterior = (Prior * Likelihood) / Evidence.
Args:
likelihood: Can be a ProbabilityGrid, a numpy array (same shape),
or a scalar representing P(y|x).
Returns:
posterior (ProbabilityGrid): Normalised posterior density.
evidence (float): The normalisation constant (integral of Prior * Likelihood).
"""
# 1. Compute Unnormalised Posterior (Prior * Likelihood)
unnormalised = self * likelihood
# 2. Compute Evidence (Total Probability)
evidence = unnormalised.total_mass
# 3. Handle Zero Probability (Numerical collapse)
if evidence == 0:
print("Warning: Evidence is zero. Posterior is undefined.")
return unnormalised, 0.0
# 4. Normalise
# We manually create normalised grid to avoid re-integrating inside normalise()
posterior = ProbabilityGrid(self.axes, unnormalised.values / evidence)
return posterior, evidence
# --- Bayesian Observation Models ---
[docs]
class GaussianLikelihood:
"""
Handles Gaussian likelihoods for generic non-linear observation operators.
Likelihood L(x) = P(y_obs | x) ~ N(y_obs; H(x), R)
"""
def __init__(
self,
observation_value: Optional[np.ndarray],
observation_covariance: np.ndarray,
obs_operator_func: Optional[Callable[[np.ndarray], np.ndarray]] = None,
):
"""
Args:
observation_value: The observed data vector 'y_obs' (ObsDim,).
Can be None if generating synthetic data.
observation_covariance: The observation error covariance matrix 'R' (ObsDim, ObsDim).
obs_operator_func: Callable H(x) -> y.
Must accept shape (..., StateDim) and return (..., ObsDim).
If None, assumes Direct Observation H(x)=x.
"""
if observation_value is not None:
self.y_obs = np.atleast_1d(observation_value)
else:
self.y_obs = None
self.R = np.atleast_2d(observation_covariance)
# Precompute precision matrix (inverse covariance) for speed
try:
self.R_inv = np.linalg.inv(self.R)
# Normalisation factor for the Gaussian
# Use covariance dimension for normalization constant
k = self.R.shape[0]
det_R = np.linalg.det(self.R)
self.norm_factor = 1.0 / np.sqrt(((2 * np.pi) ** k) * det_R)
except np.linalg.LinAlgError:
raise ValueError("Observation covariance matrix R must be invertible.")
if obs_operator_func is None:
self.H = lambda x: x
else:
self.H = obs_operator_func
[docs]
def evaluate(self, prob_grid: ProbabilityGrid) -> ProbabilityGrid:
"""
Evaluates the likelihood on the given ProbabilityGrid.
Returns a new ProbabilityGrid instance containing the Likelihood values.
"""
if self.y_obs is None:
raise ValueError(
"Cannot evaluate likelihood: Observation value is not set."
)
# 1. Generate State Vectors
mesh = np.meshgrid(*prob_grid.axes, indexing="ij")
state_vectors = np.stack(mesh, axis=-1)
# 2. Map State Space -> Observation Space: y_pred = H(x)
y_pred = self.H(state_vectors)
# 3. Compute Residual: d = y_obs - H(x)
residual = self.y_obs - y_pred
# 4. Mahalanobis Distance: d^T * R^-1 * d
mahalanobis = np.einsum("...i, ij, ...j -> ...", residual, self.R_inv, residual)
# 5. Compute Gaussian
likelihood_values = self.norm_factor * np.exp(-0.5 * mahalanobis)
return ProbabilityGrid(prob_grid.axes, likelihood_values)
[docs]
def sample(self, true_state: np.ndarray) -> np.ndarray:
"""
Generates a noisy observation y given a true state x.
y = H(x) + N(0, R)
Args:
true_state: The true state vector x.
Returns:
The noisy observation vector y.
"""
# 1. Apply Observation Operator H(x)
# Handle shape: input (N,) -> output (M,)
y_clean = self.H(true_state)
y_clean = np.atleast_1d(y_clean)
# 2. Add Noise ~ N(0, R)
noise = np.random.multivariate_normal(np.zeros(len(y_clean)), self.R)
return y_clean + noise
[docs]
class LinearGaussianLikelihood(GaussianLikelihood):
"""
Optimised subclass for Linear Observation Operators: y = Hx
Avoids generic function calls in favour of matrix multiplication.
"""
def __init__(
self,
observation_value: Optional[np.ndarray],
observation_covariance: np.ndarray,
observation_matrix: np.ndarray,
):
"""
Args:
observation_value: The observed data vector 'y_obs' (ObsDim,).
observation_covariance: The observation error covariance matrix 'R'.
observation_matrix: The observation matrix H (ObsDim, StateDim).
"""
self.H_mat = np.atleast_2d(observation_matrix)
def linear_op(x):
return x @ self.H_mat.T
super().__init__(observation_value, observation_covariance, linear_op)
[docs]
def evaluate(self, prob_grid: ProbabilityGrid) -> ProbabilityGrid:
"""Optimised evaluation for linear H."""
if self.y_obs is None:
raise ValueError(
"Cannot evaluate likelihood: Observation value is not set."
)
# 1. Generate State Vectors
mesh = np.meshgrid(*prob_grid.axes, indexing="ij")
state_vectors = np.stack(mesh, axis=-1)
# 2. Linear Algebra Prediction: (..., N) @ (N, M) -> (..., M)
y_pred = state_vectors @ self.H_mat.T
# 3. Standard Gaussian Logic
residual = self.y_obs - y_pred
mahalanobis = np.einsum("...i, ij, ...j -> ...", residual, self.R_inv, residual)
likelihood_values = self.norm_factor * np.exp(-0.5 * mahalanobis)
return ProbabilityGrid(prob_grid.axes, likelihood_values)
# --- Abstract Base Class for Assimilation ---
[docs]
class AssimilationEngine(ABC):
"""
Abstract Base Class defining the interface for all assimilation methods
(Grid, KF, EnKF). Enforces a consistent 'setup -> observe -> run' workflow.
"""
def __init__(self):
self.observations: List[Tuple[float, Any]] = []
[docs]
def add_observation(
self,
time: float,
covariance: np.ndarray,
value: Optional[np.ndarray] = None,
operator: Optional[Union[np.ndarray, Callable]] = None,
):
"""
Registers an observation to be assimilated during the run.
"""
# Default implementation for Gaussian/Linear obs (can be overridden)
if operator is None or callable(operator):
model = GaussianLikelihood(value, covariance, operator)
else:
model = LinearGaussianLikelihood(value, covariance, operator)
self.observations.append((time, model))
self.observations.sort(key=lambda x: x[0])
[docs]
@abstractmethod
def run(
self, initial_state: Any, t_final: Optional[float] = None
) -> List[Dict[str, Any]]:
"""
Executes the assimilation cycle.
Must return a history list of dicts containing 'time', 'forecast', 'analysis'.
"""
pass
[docs]
@abstractmethod
def reanalyse_initial_condition(self, history: List[Dict[str, Any]]) -> Any:
"""
Performs Reanalysis (Smoothing) to estimate the state at t=0
given all observations up to t_final.
"""
pass
# --- Bayesian Assimilation Manager ---
[docs]
class BayesianAssimilationProblem(AssimilationEngine):
"""
Manages the definition and execution of a Grid-based Bayesian assimilation cycle.
Inherits observation management from AssimilationEngine.
"""
def __init__(
self,
eom_func: Callable[[float, np.ndarray, Any], np.ndarray],
eom_args: Tuple = (),
):
"""
Args:
eom_func: The system dynamics ODE function.
eom_args: Arguments (constants/parameters) for the ODE function.
"""
super().__init__() # Initialises self.observations = []
self.eom_func = eom_func
self.eom_args = eom_args
[docs]
def generate_synthetic_data(
self,
true_initial_condition: np.ndarray,
dt_render: float = 0.05,
seed: Optional[int] = None,
) -> Dict[str, Any]:
"""
Runs the physics from t=0, samples noisy observations at registered times,
updates the internal Likelihood models with these values, and returns ground truth.
"""
if seed is not None:
np.random.seed(seed)
# 1. Identify observation times
obs_times = np.array([t for t, _ in self.observations])
if len(obs_times) == 0:
raise ValueError("No observations registered to generate data for.")
# 2. Run High-Res Simulation (Ground Truth)
t_max = obs_times[-1]
t_render = np.arange(0, t_max + dt_render, dt_render)
if t_render[-1] < t_max:
t_render = np.append(t_render, t_max)
all_times = np.unique(np.concatenate([t_render, obs_times]))
sol_all = solve_trajectory(
self.eom_func, true_initial_condition, all_times, args=self.eom_args
)
# 3. Iterate through registered observations and update them
for i, (t_obs, model) in enumerate(self.observations):
# Find the state at this exact time
idx = np.where(np.isclose(all_times, t_obs))[0][0]
true_state_at_t = sol_all[:, idx]
# Sample noisy observation
noisy_val = model.sample(true_state_at_t)
# UPDATE the model stored in the problem list with the sampled value
model.y_obs = noisy_val
return {
"t_ground_truth": all_times,
"state_ground_truth": sol_all,
"initial_truth": true_initial_condition,
}
[docs]
def run(
self, initial_state: ProbabilityGrid, t_final: Optional[float] = None
) -> List[Dict[str, Any]]:
"""
Executes the assimilation cycle starting from t=0.
Args:
initial_state: The ProbabilityGrid at t=0 (Prior).
t_final: Optional end time.
Returns:
A list of dictionaries containing {'time', 'forecast', 'analysis', 'evidence'}.
"""
history = []
current_grid = initial_state
t_current = 0.0
for t_obs, lik_model in self.observations:
dt = t_obs - t_current
if dt < 0:
raise ValueError(
f"Observation at t={t_obs} is in the past relative to t={t_current}."
)
# 1. Forecast Step (Advection)
if dt > 0:
forecast_grid = current_grid.push_forward(
self.eom_func, dt, self.eom_args
)
else:
forecast_grid = current_grid
# 2. Analysis Step (Bayesian Update)
lik_grid = lik_model.evaluate(forecast_grid)
analysis_grid, evidence = forecast_grid.bayes_update(lik_grid)
history.append(
{
"time": t_obs,
"forecast": forecast_grid,
"analysis": analysis_grid,
"evidence": evidence,
}
)
current_grid = analysis_grid
t_current = t_obs
# Optional: Final forecast
if t_final is not None and t_final > t_current:
dt = t_final - t_current
final_grid = current_grid.push_forward(self.eom_func, dt, self.eom_args)
history.append(
{
"time": t_final,
"forecast": final_grid,
"analysis": final_grid,
"evidence": 1.0,
}
)
return history
[docs]
def reanalyse_initial_condition(
self, history: List[Dict[str, Any]]
) -> ProbabilityGrid:
"""
Pulls the final posterior back to t=0 using inverse advection.
Returns the smoothed ProbabilityGrid at t=0.
"""
final_posterior = history[-1]["analysis"]
t_final = history[-1]["time"]
# Inverse advection (negative time)
smoothed_grid = final_posterior.push_forward(
self.eom_func, -t_final, self.eom_args
)
return smoothed_grid
# ---- Linear KF class ----
[docs]
class LinearKalmanFilter(AssimilationEngine):
"""
Deterministic Linear Kalman Filter.
Assumes perfect physics (Process Noise Q = 0).
System Model:
x_k = F_k * x_{k-1} (Deterministic)
y_k = H * x_k + v_k, v_k ~ N(0, R)
Because the physics are deterministic, we can exactly reconstruct the
smoothed initial condition by inverting the final state.
"""
def __init__(
self,
transition_matrix_func: Callable[[float], np.ndarray],
):
"""
Args:
transition_matrix_func: Function f(dt) -> F matrix (StateDim, StateDim).
Returns the propagator for a time step of dt.
"""
super().__init__()
self.get_F = transition_matrix_func
def _predict(
self, mean: np.ndarray, cov: np.ndarray, dt: float
) -> Tuple[np.ndarray, np.ndarray]:
"""
Prediction Step: x = Fx, P = FPF'
(No process noise Q added)
"""
F = self.get_F(dt)
# 1. Predict Mean
mean_pred = F @ mean
# 2. Predict Covariance
cov_pred = F @ cov @ F.T
return mean_pred, cov_pred
def _update(
self, mean: np.ndarray, cov: np.ndarray, likelihood: LinearGaussianLikelihood
) -> Tuple[np.ndarray, np.ndarray]:
"""
Analysis Step: Standard Kalman Update.
"""
H = likelihood.H_mat
y_obs = likelihood.y_obs
R = likelihood.R
if y_obs is None:
raise ValueError("Observation value cannot be None during KF update.")
# 1. Innovation
y_pred = H @ mean
y_residual = y_obs - y_pred
# 2. Innovation Covariance
S = H @ cov @ H.T + R
# 3. Kalman Gain
try:
# K = P H' S^-1
K_transposed = np.linalg.solve(S, H @ cov)
K = K_transposed.T
except np.linalg.LinAlgError:
print("Warning: Singular innovation covariance. Using pseudo-inverse.")
K = cov @ H.T @ np.linalg.pinv(S)
# 4. Update
mean_new = mean + K @ y_residual
# 5. Covariance Update
Id = np.eye(len(mean))
cov_new = (Id - K @ H) @ cov
return mean_new, cov_new
[docs]
def run(
self,
initial_state: Tuple[np.ndarray, np.ndarray],
t_final: Optional[float] = None,
) -> List[Dict[str, Any]]:
"""
Executes the Deterministic KF.
Args:
initial_state: Tuple (mean_0, cov_0)
"""
history = []
current_mean, current_cov = initial_state
t_current = 0.0
for t_obs, lik_model in self.observations:
if not isinstance(lik_model, LinearGaussianLikelihood):
raise TypeError(
"LinearKalmanFilter only supports LinearGaussianLikelihood."
)
dt = t_obs - t_current
# 1. Forecast
if dt > 0:
f_mean, f_cov = self._predict(current_mean, current_cov, dt)
else:
f_mean, f_cov = current_mean, current_cov
# 2. Analysis
a_mean, a_cov = self._update(f_mean, f_cov, lik_model)
history.append(
{
"time": t_obs,
"forecast": (f_mean, f_cov),
"analysis": (a_mean, a_cov),
}
)
current_mean, current_cov = a_mean, a_cov
t_current = t_obs
# Final Forecast
if t_final is not None and t_final > t_current:
dt = t_final - t_current
f_mean, f_cov = self._predict(current_mean, current_cov, dt)
history.append(
{
"time": t_final,
"forecast": (f_mean, f_cov),
"analysis": (f_mean, f_cov),
}
)
return history
[docs]
def reanalyse_initial_condition(
self, history: List[Dict[str, Any]]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Back-propagates the final estimate to t=0 by inverting the physics.
Exact for deterministic systems.
"""
final_mean, final_cov = history[-1]["analysis"]
t_final = history[-1]["time"]
# Compute propagator for the full duration: x_T = F(T) * x_0
F_total = self.get_F(t_final)
# Invert State: x_0 = F(T)^-1 * x_T
smoothed_mean = np.linalg.solve(F_total, final_mean)
# Invert Covariance: P_0 = F^-1 * P_T * F^-T
F_inv = np.linalg.inv(F_total)
smoothed_cov = F_inv @ final_cov @ F_inv.T
return smoothed_mean, smoothed_cov
# ---- Ensemble Kalman filter class ----
[docs]
class EnsembleKalmanFilter(AssimilationEngine):
"""
Stochastic Ensemble Kalman Filter (EnKF).
Suitable for non-linear dynamics. Uses the "Perturbed Observation" method.
System Model:
x_k = f(x_{k-1}) (Deterministic or Stochastic)
y_k = H(x_k) + v_k
"""
def __init__(
self,
eom_func: Callable[[float, np.ndarray, Any], np.ndarray],
eom_args: Tuple = (),
n_ensemble: int = 50,
):
"""
Args:
eom_func: The non-linear physics ODE function.
eom_args: Arguments for the physics.
n_ensemble: Number of particles to simulate.
"""
super().__init__()
self.eom_func = eom_func
self.eom_args = eom_args
self.n_ensemble = n_ensemble
def _predict(self, ensemble: np.ndarray, dt: float) -> np.ndarray:
"""
Propagates the full ensemble forward by dt using the non-linear physics.
"""
# solve_ensemble returns shape (N_samples, N_dim, N_time)
# We simulate from 0 to dt
traj = solve_ensemble(
self.eom_func, ensemble, np.array([0, dt]), args=self.eom_args
)
# Return the state at the end time (index -1)
return traj[:, :, -1]
def _update(
self, ensemble: np.ndarray, likelihood: GaussianLikelihood
) -> np.ndarray:
"""
Analysis Step: EnKF Update with perturbed observations.
"""
n_ens, n_dim = ensemble.shape
y_obs = likelihood.y_obs
R = likelihood.R
if y_obs is None:
raise ValueError("Observation value cannot be None during EnKF update.")
# 1. Perturb Observations
# We generate a cloud of noisy observations centered on the actual data
# v ~ N(0, R)
noise = np.random.multivariate_normal(np.zeros(len(y_obs)), R, size=n_ens)
Y_perturbed = y_obs + noise # Shape (N_ens, ObsDim)
# 2. Forecast Observations (H(x))
# Map every particle to observation space (handles non-linear H)
H_x = likelihood.H(ensemble) # Shape (N_ens, ObsDim)
# 3. Calculate Sample Covariances (via Anomalies)
# Mean centering
x_mean = np.mean(ensemble, axis=0)
Hx_mean = np.mean(H_x, axis=0)
# Anomaly matrices (StateDim x N_ens) and (ObsDim x N_ens)
A = (ensemble - x_mean).T
Y = (H_x - Hx_mean).T
# 4. Kalman Gain Construction
# P_xy = 1/(N-1) * A @ Y.T
# P_yy = 1/(N-1) * Y @ Y.T
# K = P_xy * (P_yy + R)^-1
# Note: The denominator in EnKF is typically (P_yy + R).
# Since we use perturbed observations, R is implicitly handled if we just use
# the covariance of the innovations, but the standard robust form adds R explicitly.
factor = 1.0 / (n_ens - 1)
P_yy = factor * (Y @ Y.T)
S = P_yy + R
P_xy = factor * (A @ Y.T)
# Solve K = P_xy @ S^-1
try:
K_transposed = np.linalg.solve(S, P_xy.T)
K = K_transposed.T
except np.linalg.LinAlgError:
K = P_xy @ np.linalg.pinv(S)
# 5. Update State
# Innovation for each particle: y_perturbed_i - H(x_i)
innovations = (Y_perturbed - H_x).T # (ObsDim, N_ens)
# Update: x_new = x_old + K * innovation
update = K @ innovations # (StateDim, N_ens)
return ensemble + update.T
[docs]
def run(
self,
initial_state: Union[Tuple[np.ndarray, np.ndarray], np.ndarray],
t_final: Optional[float] = None,
) -> List[Dict[str, Any]]:
"""
Executes the EnKF.
Args:
initial_state: Either a tuple (mean, cov) to sample from,
OR an explicit ensemble array (N_ens, N_dim).
"""
# 1. Initialize Ensemble
if isinstance(initial_state, tuple):
mean, cov = initial_state
current_ensemble = np.random.multivariate_normal(
mean, cov, size=self.n_ensemble
)
else:
current_ensemble = initial_state
history = []
t_current = 0.0
for t_obs, lik_model in self.observations:
dt = t_obs - t_current
# 1. Forecast
if dt > 0:
f_ens = self._predict(current_ensemble, dt)
else:
f_ens = current_ensemble
# 2. Analysis
a_ens = self._update(f_ens, lik_model)
# Store Statistics (Mean, Cov) for history, rather than full ensemble
# (Saves memory and keeps interface consistent with KF)
f_stats = (np.mean(f_ens, axis=0), np.cov(f_ens, rowvar=False))
a_stats = (np.mean(a_ens, axis=0), np.cov(a_ens, rowvar=False))
history.append(
{
"time": t_obs,
"forecast": f_stats,
"analysis": a_stats,
"ensemble_analysis": a_ens, # Keep full final ensemble if needed
}
)
current_ensemble = a_ens
t_current = t_obs
# Final Forecast
if t_final is not None and t_final > t_current:
dt = t_final - t_current
f_ens = self._predict(current_ensemble, dt)
f_stats = (np.mean(f_ens, axis=0), np.cov(f_ens, rowvar=False))
history.append(
{
"time": t_final,
"forecast": f_stats,
"analysis": f_stats,
"ensemble_analysis": f_ens,
}
)
return history
[docs]
def reanalyse_initial_condition(
self, history: List[Dict[str, Any]]
) -> Tuple[np.ndarray, np.ndarray]:
"""
Reanalysis for deterministic physics.
Takes the MEAN of the final analysis ensemble and integrates it BACKWARDS.
"""
# Get final ensemble mean
final_mean, final_cov = history[-1]["analysis"]
t_final = history[-1]["time"]
# Define Reverse Physics
def reverse_eom(t, y, *args):
return -1.0 * np.array(self.eom_func(t, y, *args))
# Integrate backwards: x_T -> x_0
sol = solve_trajectory(
self.eom_func, # We use normal physics but negative time in solve_trajectory logic
final_mean,
np.array([t_final, 0.0]), # Backwards time span
args=self.eom_args,
)
smoothed_mean = sol[:, -1]
# Note: Back-propagating covariance for EnKF is complex.
# For this simple course, we simply return the propagated mean
# and a placeholder covariance (or the initial covariance if available).
# We return Zeros for covariance to indicate it wasn't computed.
smoothed_cov = np.zeros_like(final_cov)
return smoothed_mean, smoothed_cov
# --- Generic Visualisation Tools ---
[docs]
def plot_grid_marginal(
prob_grid: ProbabilityGrid,
dims: Tuple[int, int] = (0, 1),
ax: Optional[plt.Axes] = None,
filled: bool = True,
**kwargs,
) -> Tuple[plt.Axes, Any]:
"""
Marginalises an N-dimensional ProbabilityGrid down to 2 dimensions
and plots the result as a contour.
Args:
prob_grid: core.ProbabilityGrid instance.
dims: Tuple of (x_dim_index, y_dim_index) to keep.
ax: Matplotlib Axes object. If None, creates a new figure.
filled: Boolean, whether to use contourf (filled) or contour (lines).
**kwargs: Passed directly to ax.contourf/ax.contour (e.g., levels, cmap).
Returns:
ax: The matplotlib Axis.
contour: The contour plot object.
"""
# 1. Marginalise
marginal_grid = prob_grid.marginalise(keep_indices=dims)
# 2. Extract data
x_axis = marginal_grid.axes[0]
y_axis = marginal_grid.axes[1]
Z = marginal_grid.values
# 3. Setup Plot
if ax is None:
fig, ax = plt.subplots()
# 4. Create Meshgrid for plotting
X, Y = np.meshgrid(x_axis, y_axis, indexing="ij")
# 5. Plot
if "levels" not in kwargs:
kwargs["levels"] = 30
if filled:
contour = ax.contourf(X, Y, Z, **kwargs)
else:
contour = ax.contour(X, Y, Z, **kwargs)
ax.grid(True, alpha=0.3, linestyle="--")
return ax, contour
[docs]
def plot_ensemble_scatter(
trajectories: np.ndarray,
dim_indices: Tuple[int, int] = (0, 1),
time_idx: int = -1,
ax: Optional[plt.Axes] = None,
**kwargs,
) -> plt.Axes:
"""
Generic scatter plot of ensemble particles at a specific time snapshot.
Args:
trajectories: Array of shape (n_samples, n_dim, n_times).
dim_indices: Tuple of (x_dim, y_dim) indices to plot.
time_idx: Integer index of the time step to plot.
ax: Matplotlib Axes object.
**kwargs: Passed to ax.scatter (e.g., c, s, alpha, label).
Returns:
The matplotlib Axis.
"""
# 1. Extract Snapshot
snap = trajectories[:, :, time_idx]
# 2. Select Dimensions
x_data = snap[:, dim_indices[0]]
y_data = snap[:, dim_indices[1]]
# 3. Setup Plot
if ax is None:
fig, ax = plt.subplots()
if "alpha" not in kwargs:
kwargs["alpha"] = 0.5
if "s" not in kwargs:
kwargs["s"] = 10
# 4. Plot
ax.scatter(x_data, y_data, **kwargs)
ax.grid(True, alpha=0.3)
return ax
[docs]
def plot_1d_slice(
prob_grid: ProbabilityGrid, dim: int = 0, ax: Optional[plt.Axes] = None, **kwargs
) -> plt.Axes:
"""
Marginalises down to a single dimension and plots a line graph (PDF).
Args:
prob_grid: core.ProbabilityGrid instance.
dim: Index of the dimension to keep.
ax: Matplotlib Axes object.
**kwargs: Passed to ax.plot (e.g., color, linewidth).
Returns:
The matplotlib Axis.
"""
marginal_grid = prob_grid.marginalise(keep_indices=(dim,))
x_axis = marginal_grid.axes[0]
y_values = marginal_grid.values
if ax is None:
fig, ax = plt.subplots()
ax.plot(x_axis, y_values, **kwargs)
ax.grid(True, alpha=0.3)
return ax
[docs]
def display_animation_html(anim: Any) -> HTML:
"""
Standard helper to render animations in notebooks.
Args:
anim: The Matplotlib FuncAnimation object.
Returns:
IPython HTML object for display.
"""
print("Rendering animation...")
return HTML(anim.to_jshtml())
# --- Visualisation (Kalman Filters) ---
[docs]
def plot_gaussian_ellipsoid(
mean: np.ndarray,
cov: np.ndarray,
ax: plt.Axes,
n_std: float = 2.0,
dims: Tuple[int, int] = (0, 1),
**kwargs,
):
"""
Plots a 2D covariance ellipse representing the Gaussian distribution.
Args:
mean: Mean vector (N,).
cov: Covariance matrix (N, N).
ax: Matplotlib axes.
n_std: Number of standard deviations for the ellipse radius.
dims: Tuple of (x_dim, y_dim) indices.
**kwargs: Passed to ax.plot (color, linestyle, etc).
"""
# Extract 2D slice
idx = list(dims)
mean_2d = mean[idx]
cov_2d = cov[np.ix_(idx, idx)]
# Calculate Eigenvalues and Eigenvectors
vals, vecs = np.linalg.eigh(cov_2d)
# Order them by magnitude (largest first)
order = vals.argsort()[::-1]
vals, vecs = vals[order], vecs[:, order]
# Calculate Angle of Rotation
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
# Width and Height (eigenvalues are variance, so sqrt for std)
width, height = 2 * n_std * np.sqrt(vals)
# Create Ellipse Patch
from matplotlib.patches import Ellipse
ell = Ellipse(
xy=mean_2d, width=width, height=height, angle=theta, fill=False, **kwargs
)
ax.add_patch(ell)
# Plot center mean
ax.plot(mean_2d[0], mean_2d[1], marker="+", markersize=10, **kwargs)
[docs]
def plot_kf_step(
forecast_stats: Tuple[np.ndarray, np.ndarray],
analysis_stats: Tuple[np.ndarray, np.ndarray],
dims: Tuple[int, int] = (0, 1),
ax: Optional[plt.Axes] = None,
title: Optional[str] = None,
):
"""
Visualises a single Kalman Filter step (Forecast vs Analysis).
Plots both Gaussian ellipsoids (2-sigma).
Args:
forecast_stats: (mean, cov) tuple for the prior.
analysis_stats: (mean, cov) tuple for the posterior.
dims: Dimensions to plot.
"""
if ax is None:
fig, ax = plt.subplots()
# Unpack
f_mean, f_cov = forecast_stats
a_mean, a_cov = analysis_stats
# Plot Forecast (Blue, Dashed)
plot_gaussian_ellipsoid(
f_mean,
f_cov,
ax,
n_std=2.0,
dims=dims,
color="blue",
linestyle="--",
label="Forecast (Prior)",
)
# Plot Analysis (Red, Solid)
plot_gaussian_ellipsoid(
a_mean,
a_cov,
ax,
n_std=2.0,
dims=dims,
color="red",
linestyle="-",
label="Analysis (Posterior)",
)
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xlabel(f"State Dimension {dims[0]}")
ax.set_ylabel(f"State Dimension {dims[1]}")
if title:
ax.set_title(title)
return ax
[docs]
def plot_tracker_1d(
history: List[Dict[str, Any]],
dim: int = 0,
ground_truth: Optional[Dict[str, Any]] = None,
ax: Optional[plt.Axes] = None,
):
"""
Plots the time-evolution of a single state variable with uncertainty bounds.
Args:
history: The output list from kf.run().
dim: Index of the state dimension to plot.
ground_truth: Optional dict from generate_synthetic_data to compare against.
"""
if ax is None:
fig, ax = plt.subplots(figsize=(10, 5))
# Extract Time Series
times = np.array([step["time"] for step in history])
# Extract Analysis Means and StdDevs
means = np.array([step["analysis"][0][dim] for step in history])
covs = [step["analysis"][1] for step in history]
stds = np.array([np.sqrt(c[dim, dim]) for c in covs])
# Plot Estimate
ax.plot(times, means, "k-", label="KF Estimate")
# Plot Uncertainty (2-sigma shaded region)
ax.fill_between(
times,
means - 2 * stds,
means + 2 * stds,
color="gray",
alpha=0.3,
label=r"$\pm 2\sigma$ Uncertainty",
)
# Plot Ground Truth if available
if ground_truth:
t_true = ground_truth["t_ground_truth"]
x_true = ground_truth["state_ground_truth"][dim]
ax.plot(t_true, x_true, "g--", lw=1, label="Ground Truth")
ax.set_xlabel("Time")
ax.set_ylabel(f"State Dimension {dim}")
ax.grid(True, alpha=0.3)
ax.legend()
return ax