import torch
import unittest
from unittest.mock import MagicMock
from parallel_llm.inference.generator import ParallelGenerator, GenerationConfig

class TestVocabOverflow(unittest.TestCase):
    def setUp(self):
        self.model = MagicMock(spec=torch.nn.Module)
        self.model.parameters.return_value = iter([MagicMock(device=torch.device("cpu"))])
        self.model.config.num_hidden_layers = 1
        self.model.config.num_attention_heads = 1
        self.model.config.hidden_size = 16
        self.model.config.vocab_size = 10  # Small vocab
        self.model.mask_token_id = 0
        self.config = GenerationConfig(batch_size=1, use_torch_compile=False)

    def test_vocab_overflow_warning(self):
        # Initialize generator
        generator = ParallelGenerator(
            model=self.model,
            config=self.config,
            use_kv_cache=False,
            use_cuda_graphs=False
        )
        
        # Create input with token ID 15 (larger than vocab 10)
        input_ids = torch.tensor([[1, 15]], dtype=torch.long)
        
        # Mock model forward pass to return dummy logits
        # Output shape: [batch, seq_len, vocab_size]
        # seq_len = prompt_len (2) + max_new_tokens (default 512) -> just mock output
        self.model.return_value = torch.randn(1, 2 + 512, 10)
        
        print("\nAttempting to generate with out-of-bounds token...")
        try:
            # This should NOT raise ValueError now
            # It might crash later due to shape mismatch in real model, but here we mock it
            # We just want to ensure the initial check passes
            generator.generate(input_ids, max_new_tokens=1)
            print("SUCCESS: Generation proceeded without ValueError")
        except ValueError as e:
            self.fail(f"ValueError was raised: {e}")
        except Exception as e:
            # Other errors are expected since we are mocking heavily, but ValueError is what we check
            print(f"Caught expected downstream error (ignored): {e}")

if __name__ == "__main__":
    unittest.main()
