import numpy as np
import math

def d_star(q, c):
    """Compute optimal jump height d* given q=-a and c."""
    W0 = 0.5 - q/8 + q**2/30
    W1 = 0.5 - q/2 + 4*q**2/15
    target = (c + 2)/6
    if target >= W0:
        return 0.0
    if target <= W1:
        return 1.0
    coeffs = [4*q**2, -15*q, 0, 20*(1-c)]
    roots = np.roots(coeffs)
    for r in roots:
        if abs(r.imag) < 1e-8 and 1.0 < r.real < 2.0:
            return r.real - 1.0
    raise RuntimeError("No valid root for d*")

def s_v(v, q, d):
    """Piece‑wise shift function."""
    v1 = q*(1+d)/2.0
    v2 = 1 - v1
    if v <= v1:
        return math.sqrt(2*q*(1+d)*v)
    elif v <= v2:
        return v + q*(1+d)/2.0
    else:
        return 1 + q*(1+d) - math.sqrt(2*q*(1+d)*(1-v))

def C_closed(u, v, q, d):
    """Closed‑form copula."""
    b = 1.0/q
    sv = s_v(v, q, d)
    t1 = sv - q*(1+d)
    t0 = sv - q*d
    to = sv - q
    te = sv

    def g1(t): return b*(sv*t - 0.5*t*t)
    def g2(t): return g1(t) - d*t

    term1 = max(0.0, min(u, v, t1))
    lo_in, hi_in = max(0.0, t1), min(u, v, t0)
    term2 = g2(hi_in) - g2(lo_in) if hi_in > lo_in else 0.0
    term3 = max(0.0, min(u, to) - v)
    lo_out, hi_out = max(v, to), min(u, te)
    term4 = g1(hi_out) - g1(lo_out) if hi_out > lo_out else 0.0

    return term1 + term2 + term3 + term4

def h_star(t, v, q, d):
    """Pointwise optimal density."""
    b = 1.0/q
    sv = s_v(v, q, d)
    core = b*(sv - t) - d*(t <= v)
    return min(1.0, max(0.0, core))

def compare_C(q, c, grid=201):
    d = d_star(q, c)
    xs = np.linspace(0.0, 1.0, grid)
    max_err = 0.0
    for v in xs:
        for u in xs:
            ts = xs[xs <= u]
            h_vals = [h_star(t, v, q, d) for t in ts]
            C_num = np.trapz(h_vals, ts)
            C_cl = C_closed(u, v, q, d)
            max_err = max(max_err, abs(C_cl - C_num))
    return d, max_err

if __name__ == "__main__":
    a = -0.02
    c = 0.3
    q = -a
    d, err = compare_C(q, c, grid=201)
    print(f"a={a}, c={c}, d*={d:.6f}, max C error={err:.3e}")

