"""
==============================================================
Description  : futures.py模块测试代码
Develop      : VSCode
Author       : sandorn sandorn@live.cn
Date         : 2024-08-21 14:27:52
==============================================================
"""

from __future__ import annotations

import time
import unittest

import pytest


class TestBaseThreadRunner(unittest.TestCase):
    """测试BaseThreadRunner类的功能"""

    def setUp(self):
        """测试前准备"""
        from xtthread.futures import BaseThreadRunner

        self.BaseThreadRunner = BaseThreadRunner

    def test_basic_functionality(self):
        """测试基本的任务提交和执行"""

        def worker(task_id):
            time.sleep(0.1)
            return f'Task {task_id} completed'

        pool = self.BaseThreadRunner(max_workers=3)

        # 提交5个任务
        _ = [pool.submit(worker, i) for i in range(5)]

        # 等待所有线程完成并获取结果
        thread_results = pool.get_results()

        assert len(thread_results) == 5
        for i, result in enumerate(thread_results):
            assert result == f'Task {i} completed'

    def test_thread_limit(self):
        """测试线程数量限制"""
        import threading

        active_threads = set()
        max_count = [0]
        lock = threading.Lock()

        def worker(task_id):
            with lock:
                active_threads.add(task_id)
                current_active = len([t for t in threads if t.is_alive()])
                max_count[0] = max(max_count[0], current_active)
            time.sleep(0.3)  # 增加睡眠时间确保并发

        pool = self.BaseThreadRunner(max_workers=2)
        threads = []

        # 提交4个任务
        for i in range(4):
            thread = pool.submit(worker, i)
            threads.append(thread)

        # 等待所有线程完成
        for thread in threads:
            thread.join()

        # 验证最大线程数不超过限制（可能为1或2，取决于执行时机）
        assert max_count[0] <= 2
        assert set(active_threads) == set(range(4))

    def test_shutdown(self):
        """测试关闭功能"""
        pool = self.BaseThreadRunner(max_workers=2)

        def worker():
            time.sleep(0.1)
            return 'done'

        # 提交任务
        thread = pool.submit(worker)
        pool.shutdown(wait=True)

        # 验证任务完成
        assert thread.get_result() == 'done'


class TestEnhancedThreadPool(unittest.TestCase):
    """测试EnhancedThreadPool类的功能"""

    def setUp(self):
        """测试前准备"""
        from xtthread.futures import EnhancedThreadPool

        self.EnhancedThreadPool = EnhancedThreadPool

    def test_basic_functionality(self):
        """测试基本的任务提交和执行"""

        def worker(task_id):
            time.sleep(0.1)
            return f'Task {task_id} completed'

        pool = self.EnhancedThreadPool(max_workers=3)
        futures = []

        # 提交5个任务
        for i in range(5):
            future = pool.submit_task(worker, i)
            futures.append(future)

        # 等待所有任务完成并获取结果
        results = pool.wait_all_completed()

        # 验证结果数量和内容
        assert len(results) == 5
        success_results = [result.result for result in results if result.success]
        assert len(success_results) == 5
        for i in range(5):
            assert f'Task {i} completed' in success_results

    def test_batch_submit_simple(self):
        """测试批量提交简单格式的任务"""

        def worker(task_id):
            time.sleep(0.1)
            return f'Task {task_id} completed'

        pool = self.EnhancedThreadPool()
        # 简单格式: [item1, item2, ...]
        pool.submit_tasks(worker, list(range(5)))

        # 等待所有任务完成并获取结果
        results = pool.wait_all_completed()

        # 验证结果
        assert len(results) == 5
        success_results = [result.result for result in results if result.success]
        assert len(success_results) == 5

    def test_batch_submit_complex(self):
        """测试批量提交复杂格式的任务"""

        def worker(task_id, delay=0.1):
            time.sleep(delay)
            return f'Task {task_id} completed with delay {delay}'

        pool = self.EnhancedThreadPool()
        # 复杂格式: [(args_tuple, kwargs_dict), ...]
        iterables = [((1,), {'delay': 0.05}), ((2,), {'delay': 0.1}), ((3,), {'delay': 0.15})]
        pool.submit_tasks(worker, iterables)

        # 等待所有任务完成并获取结果
        results = pool.wait_all_completed()

        # 验证结果
        assert len(results) == 3
        success_results = [result.result for result in results if result.success]
        assert len(success_results) == 3
        for i in range(1, 4):
            assert any(f'Task {i} completed' in result for result in success_results)

    def test_exception_handling(self):
        """测试异常处理功能"""

        def worker(task_id):
            if task_id % 2 == 0:
                raise ValueError(f'Error in task {task_id}')
            return f'Task {task_id} completed'

        pool = self.EnhancedThreadPool()
        futures = []
        for i in range(5):
            future = pool.submit_task(worker, i)
            futures.append(future)

        # 等待所有任务完成并获取结果
        results = pool.wait_all_completed()
        print(44444444444444444444, [type(res.error).__name__ for res in results if not res.success])

        # 验证结果
        assert len(results) == 5
        success_count = sum(1 for result in results if result.success)
        error_count = sum(1 for result in results if not result.success)
        assert success_count == 2  # 任务1, 3 应该成功
        assert error_count == 3  # 任务0, 2, 4 应该失败

    def test_context_manager(self):
        """测试上下文管理器功能"""

        def worker(task_id):
            time.sleep(0.1)
            return f'Task {task_id} completed'

        results = []
        pool = self.EnhancedThreadPool()
        futures = []

        with pool:
            for i in range(5):
                future = pool.submit_task(worker, i)
                futures.append(future)

            # 等待所有任务完成并获取结果
            results = pool.wait_all_completed()

        # 验证结果
        assert len(results) == 5
        success_results = [result.result for result in results if result.success]
        assert len(success_results) == 5

    def test_shutdown(self):
        """测试关闭功能"""

        def worker():
            time.sleep(0.1)
            return 'done'

        pool = self.EnhancedThreadPool(max_workers=2)

        # 提交任务
        pool.submit_task(worker)
        pool.shutdown(wait=True)

        # 验证线程池已关闭（通过尝试再次提交任务应引发异常）
        with pytest.raises(RuntimeError):
            pool.submit_task(worker)

    def test_auto_worker_calculation(self):
        """测试自动计算工作线程数"""
        import os

        # 测试默认参数时的自动计算
        pool = self.EnhancedThreadPool()
        base_workers = os.cpu_count() or 4
        expected_workers = base_workers * 4
        assert pool.executor._max_workers == expected_workers

        # 测试指定max_workers时的行为
        pool = self.EnhancedThreadPool(max_workers=10)
        assert pool.executor._max_workers == 10

    def test_ordered_results(self):
        """测试EnhancedThreadPool的结果顺序"""

        def ordered_worker(task_id: int) -> int:
            """按顺序返回任务ID的任务，添加随机等待时间"""
            import random

            time.sleep(random.uniform(0.01, 0.1))  # 随机等待时间，增加测试难度
            return task_id

        # 测试1: 单个任务提交的顺序
        pool1 = self.EnhancedThreadPool(max_workers=3)
        for i in range(10):
            pool1.submit_task(ordered_worker, i)

        results1 = pool1.wait_all_completed()

        # 检查结果顺序
        expected_order = list(range(10))
        actual_order = [result.result for result in results1 if result.success]

        assert actual_order == expected_order
        assert len(results1) == 10

        # 测试2: 批量任务提交的顺序
        pool2 = self.EnhancedThreadPool(max_workers=3)
        tasks = list(range(10))
        pool2.submit_tasks(ordered_worker, tasks)

        results2 = pool2.wait_all_completed()

        # 检查结果顺序
        actual_order2 = [result.result for result in results2 if result.success]

        assert actual_order2 == expected_order
        assert len(results2) == 10

    def test_wait_all_completed(self):
        """测试wait_all_completed方法的功能"""

        # 测试1: 基本功能 - 所有任务正常完成
        def worker1(task_id):
            time.sleep(0.1)
            return f'Task {task_id} completed'

        pool1 = self.EnhancedThreadPool(max_workers=3)
        pool1.submit_tasks(worker1, list(range(5)))

        # 等待所有任务完成并获取结果
        results1 = pool1.wait_all_completed()
        assert len(results1) == 5
        success_count1 = sum(1 for result in results1 if result.success)
        assert success_count1 == 5

        # 测试2: 超时处理
        def worker2(task_id):
            time.sleep(0.3)  # 故意设置较长的睡眠时间
            return f'Task {task_id} completed'

        pool2 = self.EnhancedThreadPool(max_workers=1)  # 限制为1个工作线程确保任务排队
        # 提交任务并保存返回的future列表用于验证
        futures = pool2.submit_tasks(worker2, list(range(3)))
        assert len(futures) == 3

        # 设置较短的超时时间，应该只能获取到部分已完成的任务结果
        start_time = time.time()
        results2 = pool2.wait_all_completed(timeout=0.4)
        elapsed_time = time.time() - start_time

        # 确保在超时时间内返回
        assert elapsed_time < 0.5

        # 继续等待剩余任务完成，增加等待时间和重试次数
        remaining_results = []
        max_wait_time = 2.0  # 最多等待2秒
        wait_start_time = time.time()

        while pool2._future_tasks and (time.time() - wait_start_time) < max_wait_time:
            new_results = pool2.wait_all_completed(timeout=0.5)  # 增加超时时间
            remaining_results.extend(new_results)
            time.sleep(0.1)  # 增加等待间隔

        # 如果还有未完成的任务，强制等待它们完成
        if pool2._future_tasks:
            for future in pool2._future_tasks:
                future.result()  # 阻塞等待每个任务完成
            final_results = pool2.wait_all_completed()
            remaining_results.extend(final_results)

        # 验证总共获取到了所有任务的结果
        total_results = len(results2) + len(remaining_results)
        assert total_results == 3

        # 测试3: 异常处理
        def worker3(task_id):
            if task_id % 2 == 0:
                raise ValueError(f'Error in task {task_id}')
            return f'Task {task_id} completed'

        pool3 = self.EnhancedThreadPool()
        pool3.submit_tasks(worker3, list(range(5)))

        # 等待所有任务完成并获取结果
        results3 = pool3.wait_all_completed()
        print(44444444444444444444, [type(res.error).__name__ for res in results3])
        assert len(results3) == 5
        success_count3 = sum(1 for result in results3 if result.success)
        error_count3 = sum(1 for result in results3 if not result.success)
        assert success_count3 == 2  # 任务1, 3 应该成功
        assert error_count3 == 3  # 任务0, 2, 4 应该失败


if __name__ == '__main__':
    # 运行测试
    unittest.main(verbosity=2)
