import torch
import unittest
from unittest.mock import MagicMock, patch
from parallel_llm.inference.generator import ParallelGenerator, GenerationConfig

class TestGPUCompatibility(unittest.TestCase):
    def setUp(self):
        self.model = MagicMock(spec=torch.nn.Module)
        self.model.parameters.return_value = iter([MagicMock(device=torch.device("cuda"))])
        self.model.config.num_hidden_layers = 2
        self.model.config.num_attention_heads = 2
        self.model.config.hidden_size = 64
        self.model.config.vocab_size = 100
        self.config = GenerationConfig(use_torch_compile=True, batch_size=1)

    @patch("torch.cuda.is_available", return_value=True)
    @patch("torch.cuda.get_device_capability", return_value=(6, 0))
    @patch("torch.compile")
    def test_old_gpu_skips_compile(self, mock_compile, mock_get_cap, mock_is_available):
        # Mock torch.compile to fail if called (or just track calls)
        
        # Initialize generator
        generator = ParallelGenerator(
            model=self.model,
            config=self.config,
            use_kv_cache=False,
            use_cuda_graphs=False # Disable cuda graphs for this test to focus on compile
        )
        
        # Verify torch.compile was NOT called
        mock_compile.assert_not_called()
        print("\nSUCCESS: torch.compile was correctly skipped for GPU capability 6.0")

    @patch("torch.cuda.is_available", return_value=True)
    @patch("torch.cuda.get_device_capability", return_value=(7, 0))
    @patch("torch.compile")
    def test_new_gpu_uses_compile(self, mock_compile, mock_get_cap, mock_is_available):
        # Initialize generator
        generator = ParallelGenerator(
            model=self.model,
            config=self.config,
            use_kv_cache=False,
            use_cuda_graphs=False
        )
        
        # Verify torch.compile WAS called
        mock_compile.assert_called_once()
        print("\nSUCCESS: torch.compile was correctly used for GPU capability 7.0")

if __name__ == "__main__":
    unittest.main()
