# test_smlr.py
import unittest
import numpy as np

# test_smlr.py
import unittest
import numpy as np

# 根据你的工程结构调整这行 import
# 假设 SMLR 定义在包 smlr 下的 smlr.py 中：
from smlr.smlr import SMLR


class TestSMLRNoPenalty(unittest.TestCase):
    """
    测试 1：无 penalty + 只有截距 + 3 类各一个样本。

    理论结果（我们前面手推过）：
        - X = [[1], [1], [1]]（通过 fit_intercept 实现）
        - y = [0, 1, 2]
        - 最优解 β = 0
        - 每一行预测概率 = (1/3, 1/3, 1/3)
        - NLL = -log(1/3)
    """

    def test_intercept_only_three_classes_uniform(self) -> None:
        # 构造 X: 传入空特征矩阵，fit_intercept=True 会自动加一列 1
        n = 3
        X = np.empty((n, 0), dtype=float)
        y = np.array([0, 1, 2], dtype=int)

        model = SMLR(
            penalty=None,
            lam=0.0,
            solver="fista",
            fit_intercept=True,
        )

        model.fit(
            X,
            y,
            lr=0.1,           # 步长选得稍微大一点
            tol=1e-10,
            max_iter=5000,
            verbose=False,
        )

        # 预测概率（用同一个 X）
        probs = model.predict_proba(X)

        # 1) 形状检查
        self.assertEqual(probs.shape, (n, 3))  # 3 类

        # 2) 每一行概率之和接近 1
        row_sums = probs.sum(axis=1)
        self.assertTrue(np.allclose(row_sums, 1.0, atol=1e-6))

        # 3) 每一行应接近 (1/3, 1/3, 1/3)
        target = np.full_like(probs, 1.0 / 3.0)
        self.assertTrue(np.allclose(probs, target, atol=1e-3))

        # 4) NLL 应接近 -log(1/3)
        nll = -np.mean(np.log(probs[np.arange(n), y]))
        self.assertAlmostEqual(nll, -np.log(1.0 / 3.0), places=3)

        # 5) β 的范数应当不大（应收敛到 0 附近）
        beta = model.beta
        self.assertIsNotNone(beta)
        beta_norm = float(np.linalg.norm(beta))
        self.assertLess(
            beta_norm,
            1e-1,
            msg=f"beta norm too large: {beta_norm}",
        )


class TestSMLRWithFISTA(unittest.TestCase):
    """
    一些基础的 FISTA + penalty 的 smoke test，检查 shape 和数值合法性。
    """

    def test_random_data_l2_penalty_shapes_and_probs(self) -> None:
        rng = np.random.default_rng(0)
        n, d, k = 100, 5, 3

        X = rng.standard_normal(size=(n, d))
        y = rng.integers(low=0, high=k, size=n)

        model = SMLR(
            penalty="l2",
            lam=0.1,
            solver="fista",
            fit_intercept=True,
        )

        model.fit(
            X,
            y,
            lr=0.05,
            tol=1e-6,
            max_iter=2000,
            verbose=False,
        )

        # 预测
        X_test = rng.standard_normal(size=(20, d))
        probs = model.predict_proba(X_test)
        y_pred = model.predict(X_test)

        # 1) 形状检查
        self.assertEqual(probs.shape, (20, k))
        self.assertEqual(y_pred.shape, (20,))

        # 2) 每一行概率之和 ≈ 1，且非负
        row_sums = probs.sum(axis=1)
        self.assertTrue(np.allclose(row_sums, 1.0, atol=1e-6))
        self.assertTrue(np.all(probs >= -1e-8))  # 容忍一点数值误差

        # 3) 预测类别在 0..k-1
        self.assertTrue(np.all((y_pred >= 0) & (y_pred < k)))


class TestSMLRWithSENBM(unittest.TestCase):
    """
    测试 SEN + BM 因子化 backend 的基本集成情况：
        - 能正常跑通
        - 参数尺寸正确
        - 预测概率合法
    """

    def test_sen_bm_runs_and_shapes(self) -> None:
        rng = np.random.default_rng(1)
        n, d, k = 50, 4, 3

        X = rng.standard_normal(size=(n, d))
        y = rng.integers(low=0, high=k, size=n)

        model = SMLR(
            penalty="sen",
            lam=0.5,          # λ1 (核范数部分)
            sen_l2=0.1,       # λ2 (Frobenius 部分)
            solver="bm",
            fit_intercept=True,
            rank=2,           # 因子秩
        )

        model.fit(
            X,
            y,
            tol=1e-5,
            max_iter=3000,
            verbose=False,
        )

        # 1) 基本结构检查
        self.assertTrue(model._is_fit)
        self.assertIsNotNone(model.beta)
        self.assertIsNotNone(model.W)
        self.assertIsNotNone(model.k)

        beta = model.beta
        k = model.k
        self.assertIsInstance(beta, np.ndarray)
        # β 的第二维应该是 k-1
        self.assertEqual(beta.shape[1], k - 1)

        # 2) 预测概率检查
        X_test = rng.standard_normal(size=(10, d))
        probs = model.predict_proba(X_test)
        y_pred = model.predict(X_test)

        self.assertEqual(probs.shape, (10, k))
        self.assertEqual(y_pred.shape, (10,))

        row_sums = probs.sum(axis=1)
        self.assertTrue(np.allclose(row_sums, 1.0, atol=1e-6))
        self.assertTrue(np.all((y_pred >= 0) & (y_pred < k)))


if __name__ == "__main__":
    unittest.main()

