"""
Training configuration for Parallel-LLM
Defines configuration for distributed training
"""
from dataclasses import dataclass
from typing import Optional, Literal


@dataclass
class TrainingConfig:
    """Training configuration with distributed support"""
    # Basic training
    batch_size: int = 8
    learning_rate: float = 3e-4
    weight_decay: float = 0.1
    adam_beta1: float = 0.9
    adam_beta2: float = 0.95
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0

    # Schedule
    num_train_steps: int = 100000
    warmup_steps: int = 2000
    lr_scheduler: Literal["cosine", "linear", "constant"] = "cosine"

    # Mixed precision
    mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "bf16"

    # Distributed training
    distributed_backend: Literal["nccl", "gloo"] = "nccl"
    data_parallel_size: int = 1
    tensor_parallel_size: int = 1
    pipeline_parallel_size: int = 1

    # FSDP
    use_fsdp: bool = True
    fsdp_sharding_strategy: Literal["full", "shard_grad_op", "no_shard"] = "full"
    fsdp_backward_prefetch: bool = True
    fsdp_forward_prefetch: bool = True

    # DeepSpeed ZeRO
    use_deepspeed: bool = False
    zero_stage: Literal[0, 1, 2, 3] = 3
    zero_offload_optimizer: bool = False
    zero_offload_params: bool = False

    # Gradient checkpointing
    gradient_checkpointing: bool = True
    gradient_checkpointing_policy: Literal["full", "selective"] = "selective"

    # Compilation
    use_torch_compile: bool = True
    torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] = "max-autotune"
    use_cuda_graphs: bool = True

    # Logging
    logging_steps: int = 10
    eval_steps: int = 1000
    save_steps: int = 5000
    save_total_limit: int = 3

    # Monitoring
    use_wandb: bool = True
    wandb_project: str = "parallel-llm"

    # Checkpointing
    output_dir: str = "./checkpoints"
    resume_from_checkpoint: Optional[str] = None
