import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle

def base_plot(ax):
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-1.2, 1.2)
    ax.grid(True, ls=":", alpha=.55)
    # ax.axhline(0, color="black", lw=.8)
    # ax.axvline(0, color="black", lw=.8)
    # no ticks on axes
    for spine in ax.spines.values():
        spine.set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])

    # unit‐circle background
    ax.add_patch(Circle((0,0), 1, fc="#ECEFF1", ec="#BFC9D2", zorder=0))

    # key points
    ax.scatter([0,0,0],[1,-1,0], c="k", s=60, zorder=4)
    ax.annotate(r"$M$", (0,1),  xytext=( 0, 5), textcoords="offset points",
                ha="center", va="bottom", fontsize=15)
    ax.annotate(r"$W$", (0,-1), xytext=( 0,-5), textcoords="offset points",
                ha="center", va="top",   fontsize=15)
    ax.annotate(r"$\Pi$", (0,0), xytext=( 8, 0), textcoords="offset points",
                ha="left",   va="center",fontsize=15)

def add_si_sd(ax, a=0.60, fc_si="#85C1E9", fc_sd="#F5B7B1", alpha=.45):
    """
    Draw SI and SD as the *intersection* of two circles that both
    pass through Π(0,0) & M(0,1) (top) or Π & W(0,-1) (bottom).

             centres :  (±a,  +½)   and   (±a,  -½)
             radius  :  R = sqrt(a² + 0.25)

    Increasing `a`  →  centres move outward  →  slimmer lens.

    Parameters
    ----------
    a : float   horizontal offset of the circle centres (0 ≤ a < 1)
    fc_si / fc_sd : str   face colours for SI and SD.
    alpha : float          fill transparency.
    """
    # circle radius (same for top & bottom pairs)
    R2 = a*a + 0.25           # R²  (save a sqrt each iteration)
    y_grid = np.linspace(0, 1, 600)

    # --- helper to shade one cap ----------------------------------------
    def shade_cap(sign, facecol):
        """
        sign = +1 → top (SI),  sign = -1 → bottom (SD)
        """
        k =  0.5 * sign                        # vertical centre  (±½)
        ys = y_grid * sign                     # reuse grid for bottom
        ys = ys if sign > 0 else -y_grid       #  (monotone for fill_between)
        xs_left, xs_right = [], []

        for y in ys:
            s = R2 - (y - k)**2                # = R² - (y-k)²
            if s < a*a:
                xs_left.append(np.nan)         # nothing to draw here
                xs_right.append(np.nan)
                continue

            root = np.sqrt(s)
            xl =  a - root                     # left bound of intersection
            xr = -a + root                     # right bound
            xs_left.append(xl)
            xs_right.append(xr)

        # fill the lens strip
        ax.fill_betweenx(ys, xs_left, xs_right,
                        facecolor=facecol, alpha=alpha, zorder=1)
    shade_cap(+1, fc_si)       # SI  (ρ > 0)
    shade_cap(-1, fc_sd)       # SD  (ρ < 0)

    # ----------------------------- labels -------------------------------
    ax.text(0, 0.70, "SI  ($\\rho\\geq 0$)",
            ha="center", va="center", fontsize=12,
            color="#1B4F72", weight="bold")
    ax.text(0, -0.70, "SD  ($\\rho\\leq 0$)",
            ha="center", va="center", fontsize=12,
            color="#78281F", weight="bold")



# ------------------------------------------------------------------- #
# 3.  Plot A  – horizontal ρ-lines                                   #
# ------------------------------------------------------------------- #
figA, axA = plt.subplots(figsize=(5.3,6))
base_plot(axA)
add_si_sd(axA)   # change slope for narrower/wider lens

rho_lines = np.linspace(-0.8,0.8,9)
text_color = "black"  # purple
for rho in rho_lines:
    xmax = np.sqrt(max(0,1-rho*rho))
    axA.plot([-xmax,xmax],[rho,rho], c=text_color, lw=1.3, ls="-")
    axA.text(-1, rho, fr"$\rho={rho:+.1f}$",
             ha="left", va="center", fontsize=11, color=text_color,
             bbox=dict(fc="white", alpha=.9, ec="none", pad=1))

# axA.set_xlabel(r"$x$ (orthogonal component)", fontsize=12)
# axA.set_ylabel(r"$y$", fontsize=12)
# axA.set_title("Plot A – horizontal lines of constant $\\rho$", fontsize=14)
figA.tight_layout()

# ------------------------------------------------------------------- #
# 4.  Plot B – ξ-circles                                             #
# ------------------------------------------------------------------- #
figB, axB = plt.subplots(figsize=(5.3,6))
base_plot(axB)
add_si_sd(axB)

xi_vals = np.linspace(0.2,1.0,5)
theta   = np.linspace(0,2*np.pi,400)
phi_lab = np.deg2rad(45)
rad_off = 0.05

xi_text_color = "black"  # teal
for xi in xi_vals:
    axB.plot(xi*np.cos(theta), xi*np.sin(theta), c=xi_text_color, lw=1.3, ls="-")
    # place the label exactly on the circle
    xL = xi * np.cos(phi_lab)
    yL = xi * np.sin(phi_lab)
    axB.text(xL, yL, fr"$\xi={xi:.1f}$",
             ha="center", va="center", fontsize=11, color=xi_text_color,
             bbox=dict(facecolor="white", alpha=0.7, ec="none", pad=0.8))


# axB.set_xlabel(r"$x$ (orthogonal component)", fontsize=12)
# axB.set_ylabel(r"$y$", fontsize=12)
# axB.set_title("Plot B – concentric circles of constant $\\xi$", fontsize=14)
figB.tight_layout()

plt.show()
