"""
Inference configuration for Parallel-LLM
Defines configuration for parallel generation
"""
from dataclasses import dataclass
from typing import Optional, Literal


@dataclass
class InferenceConfig:
    """Inference configuration for parallel generation"""
    # Generation parameters
    max_new_tokens: int = 512
    temperature: float = 1.0
    top_k: int = 50
    top_p: float = 0.95
    repetition_penalty: float = 1.0

    # Parallel generation
    num_parallel_tokens: int = 64  # Generate this many at once
    num_refinement_steps: int = 5
    confidence_threshold: float = 0.9
    use_adaptive_refinement: bool = True

    # Batching
    batch_size: int = 1
    use_continuous_batching: bool = True
    max_batch_size: int = 128

    # KV cache
    use_paged_attention: bool = True
    block_size: int = 16
    max_num_blocks: int = 2048

    # Speculative decoding
    use_speculative_decoding: bool = False
    draft_model_path: Optional[str] = None
    num_speculative_tokens: int = 5

    # Quantization
    quantization: Optional[Literal["int8", "fp8", "int4"]] = None

    # Parallelism
    tensor_parallel_size: int = 1
    pipeline_parallel_size: int = 1

    # Performance
    use_torch_compile: bool = True
    use_cuda_graphs: bool = True


# Alias for backward compatibility
GenerationConfig = InferenceConfig
