"""
简单的固定输出测试框架

使用方法：
1. 修改 document 变量，写入你的 MarkdownFlow 文档
2. 修改 block_index，指定要测试的块索引
3. 修改 variables，设置变量值（如果需要）
4. 运行测试，查看输出

测试重点：
- 检查 XML 标记 <preserve_or_translate> 是否正确使用
- 检查 system 消息中是否包含约束提示词
- 检查 user 消息中是否不包含约束提示词
- 检查 LLM 输出是否不包含 XML 标记
"""

import os
import sys


# 添加项目路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)

from llm import create_llm_provider  # noqa: E402

from markdown_flow import MarkdownFlow, ProcessMode  # noqa: E402
from markdown_flow.llm import LLMResult  # noqa: E402


def test_preserved_output():
    """测试固定输出功能"""
    print("\n" + "=" * 60)
    print("🔖 固定输出测试")
    print("=" * 60)

    # ========== 配置区域 - 修改这里 ==========
    # 你的 MarkdownFlow 文档
    document = """
===# 💖七夕约会全阶段攻略 ===

=== 选择你的 MBTI 类型 ===
?[%{{mbti}}ENFJ|ENFP|ENTJ|ENTP|ESFJ|ESFP|ESTJ|ESTP|INFJ|INFP|INTJ|INTP|ISFJ|ISFP|ISTJ|ISTP]

===你现在最关心哪个阶段？ ===
?[%{{攻略}}脱单|热恋|相守]

给{{mbti}}一句有关{{攻略}}的七夕祝福，带七夕节明显的意境。

!===
## {{攻略}}｜专属恋爱指南 for {{mbti}}
!===

"""

    # 要测试的块索引
    block_index = 4

    # 变量（如果需要）
    variables: dict[str, str | list[str]] = {}

    # 文档提示词（如果需要）
    document_prompt: str | None = """你扮演七夕的月老，让这一天的天下有情人都能甜蜜约会，永浴爱河。

## 任务
- 提示词都是讲解指令，遵从指令要求做信息的讲解，不要回应指令。
- 用第一人称一对一讲解，像现场面对面交流一样
- 结合用户的不同特点，充分共情和举例

## 风格
- 情绪：热烈浪漫，治愈温暖，充满感染力
- 表达：多用 emoji ，多用感叹词
- 符合七夕节日气氛，带一些诗意和神秘

"""
    # =========================================

    try:
        llm_provider = create_llm_provider()

        # 创建 MarkdownFlow 实例
        mf = MarkdownFlow(
            document,
            llm_provider=llm_provider,
            document_prompt=document_prompt if document_prompt else None,
        )

        # 测试 PROMPT_ONLY 模式 - 查看消息结构
        print("\n📝 测试 PROMPT_ONLY 模式")
        print("-" * 60)

        result_prompt_raw = mf.process(
            block_index=block_index,
            mode=ProcessMode.PROMPT_ONLY,
            variables=variables if variables else None,
        )

        # 确保是 LLMResult 类型
        assert isinstance(result_prompt_raw, LLMResult)
        result_prompt = result_prompt_raw

        # 打印消息结构
        if result_prompt.metadata and "messages" in result_prompt.metadata:
            messages = result_prompt.metadata["messages"]
            print(f"\n消息数量: {len(messages)}\n")

            for i, msg in enumerate(messages, 1):
                role = msg.get("role", "")
                content = msg.get("content", "")

                print(f"{'=' * 60}")
                print(f"消息 {i} [{role.upper()}]")
                print(f"{'=' * 60}")
                print(content)
                print()

                # 关键检查
                if role == "system":
                    has_xml_instruction = "<preserve_or_translate>" in content
                    print(f"✅ system 包含 XML 标记说明: {has_xml_instruction}")

                elif role == "user":
                    has_xml_tag = "<preserve_or_translate>" in content
                    has_explanation = "不要输出<preserve_or_translate>" in content
                    print(f"✅ user 包含 XML 标记: {has_xml_tag}")
                    print(f"❌ user 不应包含说明（应在system）: {not has_explanation}")

                print()

        # 测试 COMPLETE 模式 - 查看 LLM 输出
        print("\n📝 测试 COMPLETE 模式")
        print("-" * 60)

        result_complete_raw = mf.process(
            block_index=block_index,
            mode=ProcessMode.COMPLETE,
            variables=variables if variables else None,
        )

        # 确保是 LLMResult 类型
        assert isinstance(result_complete_raw, LLMResult)
        result_complete = result_complete_raw

        print("\n" + "=" * 60)
        print("LLM 输出结果")
        print("=" * 60)
        print(result_complete.content)
        print("=" * 60)

        # 输出检查
        has_xml_in_output = "<preserve_or_translate>" in result_complete.content
        print(f"\n✅ 输出不包含 XML 标记: {not has_xml_in_output}")

        # 使用统计
        if result_complete.metadata and "usage" in result_complete.metadata:
            usage = result_complete.metadata["usage"]
            if usage:
                print(f"📊 Token 使用: {usage.get('total_tokens', 0)} tokens")

    except Exception as e:
        print(f"\n❌ 测试失败: {e}")
        import traceback

        traceback.print_exc()


if __name__ == "__main__":
    test_preserved_output()
