"""
Model configuration for Parallel-LLM
Defines configuration for transformer models
"""
from dataclasses import dataclass
from typing import Optional, Literal
import torch


@dataclass
class ModelConfig:
    """Base model configuration"""
    # Architecture
    vocab_size: int = 50257
    hidden_size: int = 2048
    num_hidden_layers: int = 24
    num_attention_heads: int = 16
    num_key_value_heads: Optional[int] = None  # For GQA
    intermediate_size: int = 8192
    max_position_embeddings: int = 4096

    # Diffusion parameters
    num_diffusion_steps: int = 10
    noise_schedule: Literal["linear", "cosine", "sqrt"] = "cosine"
    self_condition: bool = True
    confidence_threshold: float = 0.9

    # Energy-based model
    use_energy_model: bool = True
    energy_hidden_size: int = 4096
    energy_num_layers: int = 4

    # Attention
    use_flash_attention: bool = True
    use_sliding_window: bool = False
    sliding_window_size: Optional[int] = None

    # Normalization and activations
    rms_norm_eps: float = 1e-6
    hidden_act: str = "silu"

    # Dropout
    attention_dropout: float = 0.0
    hidden_dropout: float = 0.0

    # Precision
    dtype: torch.dtype = torch.bfloat16
    use_fp8: bool = False

    # Initialization
    initializer_range: float = 0.02


@dataclass
class MultimodalConfig(ModelConfig):
    """Configuration for multimodal models"""
    # Vision encoder
    vision_encoder: Literal["vit", "clip", "siglip"] = "vit"
    image_size: int = 224
    patch_size: int = 16
    num_channels: int = 3
    vision_hidden_size: int = 1024
    vision_num_layers: int = 24
    vision_num_heads: int = 16

    # Fusion
    fusion_type: Literal["cross_attention", "perceiver", "moe"] = "cross_attention"
    num_cross_attention_layers: int = 4

    # Contrastive learning
    use_contrastive: bool = True
    contrastive_temperature: float = 0.07
    contrastive_dim: int = 512


def get_default_config(model_type: Literal["unimodal", "multimodal"] = "unimodal"):
    """Get default configuration for model type"""
    if model_type == "unimodal":
        return ModelConfig()
    elif model_type == "multimodal":
        return MultimodalConfig()
    else:
        raise ValueError(f"Unknown model type: {model_type}")
