[INPUT_SEPARATOR]
First, create a Cython helper module for distance calculations.
```python
edit
file: solver.py
lines: 0-0
---
import numpy as np
from numba import njit

@njit(cache=True)
def _solve_blocks(A, B, C, y, x0, tau):
    N = y.shape[0]
    n = A.shape[1]
    # Compute precision of process noise
    Q = B.dot(B.T)
    Qi = np.linalg.inv(Q)
    # Precompute matrix products
    A_T = A.T
    A_T_Qi = A_T.dot(Qi)
    Qi_A = Qi.dot(A)
    meas = tau * C.T.dot(C)
    # Diagonal block for steps 1..N-1
    D1 = Qi + A_T_Qi.dot(A) + meas
    # Diagonal block for last step
    D_last = Qi
    # Off-diagonal blocks
    L = -Qi_A
    U = -A_T_Qi
    # Right-hand side f
    f = np.zeros((N, n))
    if N > 1:
        f[0] = Qi.dot(A.dot(x0)) + tau * C.T.dot(y[1])
        for i in range(1, N-1):
            f[i] = tau * C.T.dot(y[i+1])
        # f[N-1] remains zero
    else:
        f[0] = Qi.dot(A.dot(x0))
    # Thomas algorithm: forward elimination
    d_prime = np.zeros((N, n))
    c_prime = np.zeros((N-1, n, n))
    if N == 1:
        # Single step
        d_prime[0] = np.linalg.solve(D_last, f[0])
    else:
        # Prepare stacked RHS for d_prime and c_prime solves
        stack = np.zeros((n, n + 1))
        stack[:,1:] = U
        # Step 0: solve for d_prime[0] and c_prime[0]
        stack[:,0] = f[0]
        sol = np.linalg.solve(D1, stack)
        d_prime[0] = sol[:,0]
        c_prime[0] = sol[:,1:]
        # Forward elimination for i = 1..N-2
        for i in range(1, N-1):
            denom = D1 - L.dot(c_prime[i-1])
            rhs = f[i] - L.dot(d_prime[i-1])
            stack[:,0] = rhs
            sol = np.linalg.solve(denom, stack)
            d_prime[i] = sol[:,0]
            c_prime[i] = sol[:,1:]
        # Last step i = N-1
        denom_last = D_last - L.dot(c_prime[N-2])
        rhs_last = f[N-1] - L.dot(d_prime[N-2])
        d_prime[N-1] = np.linalg.solve(denom_last, rhs_last)
    x_vars = np.zeros((N, n))
    x_vars[N-1] = d_prime[N-1]
    for i in range(N-2, -1, -1):
        x_vars[i] = d_prime[i] - c_prime[i].dot(x_vars[i+1])
    # Assemble state trajectory
    x_hat = np.zeros((N+1, n))
    x_hat[0] = x0
    for i in range(N):
        x_hat[i+1] = x_vars[i]
    # Measurement residuals
    m = C.shape[0]
    v_hat = np.zeros((N, m))
    for t in range(N):
        v_hat[t] = y[t] - C.dot(x_hat[t])
    # Process noise estimates
    M_mat = B.T.dot(Qi)
    w_hat = np.zeros((N, B.shape[1]))
    for t in range(N):
        delta = x_hat[t+1] - A.dot(x_hat[t])
        w_hat[t] = M_mat.dot(delta)
    return x_hat, w_hat, v_hat

# Precompile for a trivial signature to avoid JIT in solve()
# Precompile for a trivial signature to avoid JIT in solve()
_dummyA = np.zeros((1, 1))
_dummyB = np.eye(1)
_dummyC = np.zeros((1, 1))
_dummyy = np.zeros((1, 1))
_dummyx0 = np.zeros(1)
_s = _solve_blocks(_dummyA, _dummyB, _dummyC, _dummyy, _dummyx0, 1.0)

class Solver:
    def solve(self, problem):
        A = np.array(problem["A"], dtype=np.float64)
        B = np.array(problem["B"], dtype=np.float64)
        C = np.array(problem["C"], dtype=np.float64)
        y = np.array(problem["y"], dtype=np.float64)
        x0 = np.array(problem["x_initial"], dtype=np.float64)
        tau = float(problem["tau"])
        # Empty timeline
        if y.shape[0] == 0:
            return {"x_hat": [x0.tolist()], "w_hat": [], "v_hat": []}
        x_hat, w_hat, v_hat = _solve_blocks(A, B, C, y, x0, tau)
        return {
            "x_hat": x_hat.tolist(),
            "w_hat": w_hat.tolist(),
            "v_hat": v_hat.tolist(),
        }
---
```