# test_rank_adaptive_bm_sen.py
import unittest
import numpy as np

from smlr.solvers import (
    rank_adaptive_bm_sen_quadratic,
    RankAdaptiveOptions,
)


def sen_prox(B: np.ndarray, lam1: float, lam2: float, step: float = 1.0) -> np.ndarray:
    """SEN prox，用于测试中的 gold 标准解.

    prox_{step * [ lam1 ||·||_* + (lam2/2)||·||^2 ]}(B)
    """
    if step <= 0.0 or (lam1 == 0.0 and lam2 == 0.0):
        return B

    # 先吸收 lam2 到平滑项，再对缩放后的矩阵做核范数 soft-thresholding
    if lam2 > 0.0:
        scale = 1.0 / (1.0 + step * lam2)
        B_eff = B * scale
        step_eff = step * scale
    else:
        B_eff = B
        step_eff = step

    U, s, Vt = np.linalg.svd(B_eff, full_matrices=False)
    s_thr = np.maximum(s - step_eff * lam1, 0.0)
    if np.all(s_thr == 0.0):
        return np.zeros_like(B)
    return (U * s_thr) @ Vt


def objective_F(B: np.ndarray, B_true: np.ndarray, lam1: float, lam2: float) -> float:
    """F(B) = 0.5||B - B_true||^2 + (lam2/2)||B||^2 + lam1||B||_*."""
    R = B - B_true
    smooth = 0.5 * float(np.sum(R * R)) + 0.5 * lam2 * float(np.sum(B * B))
    s = np.linalg.svd(B, compute_uv=False)
    nuc = float(np.sum(s))
    return smooth + lam1 * nuc


class TestRankAdaptiveBMSEN(unittest.TestCase):
    def setUp(self) -> None:
        self.rng = np.random.default_rng(0)

    def test_rank_adaptive_quadratic_sen(self) -> None:
        rng = self.rng
        m, n = 30, 20
        rank_true = 3

        # 构造一个低秩且整体尺度较小的 B_true
        U0 = rng.standard_normal(size=(m, rank_true))
        V0 = rng.standard_normal(size=(n, rank_true))
        B_true = 0.1 * (U0 @ V0.T)   # 注意这里的 0.1 缩放

        # 让核范数惩罚相对强一些，便于满足谱范数 KKT 条件
        lam1 = 0.5
        lam2 = 0.1

        # 1) B 空间的“真解” via SEN prox（只在测试中用 SVD）
        B_star = sen_prox(B_true, lam1=lam1, lam2=lam2, step=1.0)
        F_star = objective_F(B_star, B_true, lam1, lam2)

        # 2) Rank-adaptive BM-SEN 求解
        opts = RankAdaptiveOptions(
            max_rank=6,        # 略大于 rank_true
            eps_cert=1e-2,     # 证书松弛
            alpha_init=1e-2,
        )
        # 内层 L-BFGS 设置：容差稍微放宽一点
        opts.lbfgs_options.tol_grad = 1e-5
        opts.lbfgs_options.max_iter = 300
        opts.lbfgs_options.verbose = True  # 调试时可以先设 True 看内层情况

        B_hat, info = rank_adaptive_bm_sen_quadratic(B_true, lam1, lam2, options=opts)

        F_hat = objective_F(B_hat, B_true, lam1, lam2)
        rel_obj = abs(F_hat - F_star) / max(1.0, abs(F_star))
        rel_err = np.linalg.norm(B_hat - B_star) / max(1.0, np.linalg.norm(B_star))

        # 3) 证书条件：sigma1(G) <= lam1(1+eps) 应该成立
        self.assertTrue(info["converged"], msg=f"Not converged, info={info}")
        sigma_cert = info["certificate_sigma"]
        self.assertLessEqual(
            sigma_cert,
            lam1 * (1.0 + opts.eps_cert) + 1e-6,
            msg=f"sigma1(G)={sigma_cert}, lam1={lam1}",
        )

        # 4) 目标值应与凸问题解足够接近
        # 阈值可以根据实际数值情况微调（这里先取 1e-2 量级）
        self.assertLess(
            rel_obj,
            1e-2,
            msg=f"Relative objective gap too large: {rel_obj:.3e}",
        )

        # 5) 解本身也应接近（给得稍微宽松一点）
        self.assertLess(
            rel_err,
            5e-2,
            msg=f"Relative solution error too large: {rel_err:.3e}",
        )

        # 6) rank 行为：不超过 max_rank，且至少增大到 >1
        self.assertLessEqual(info["rank"], opts.max_rank)
        self.assertGreaterEqual(info["rank"], 1)


if __name__ == "__main__":
    unittest.main()
