from pydantic import BaseModel, Field, field_validator, model_validator
from typing import List, Optional
import ipaddress

# ------------------------------------------------------------------------------
# SHARED VALIDATORS
# ------------------------------------------------------------------------------
def validate_cidr_format(v: str) -> str:
    """Strictly validates IPv4 CIDR format."""
    # Special case for "0.0.0.0/0" commonly used in Security Groups
    if v == "0.0.0.0/0":
        return v
        
    try:
        ipaddress.IPv4Network(v)
        return v
    except ValueError:
        raise ValueError(f"'{v}' is not a valid IPv4 CIDR (e.g., 10.0.0.0/16)")

# ------------------------------------------------------------------------------
# NETWORKING MODELS
# ------------------------------------------------------------------------------
class NetworkingConfig(BaseModel):
    vpc_cidr: str
    availability_zones: List[str]
    public_subnets_cidr: List[str]
    private_subnets_cidr: List[str]
    enable_nat_gateway: bool = True

    # 1. Check IP Formats
    @field_validator('vpc_cidr', 'public_subnets_cidr', 'private_subnets_cidr', mode='before')
    @classmethod
    def check_cidrs(cls, v):
        if isinstance(v, list):
            return [validate_cidr_format(item) for item in v]
        return validate_cidr_format(v)

    # 2. Check Logical Counts
    @model_validator(mode='after')
    def check_subnet_counts(self):
        az_count = len(self.availability_zones)
        pub_count = len(self.public_subnets_cidr)
        priv_count = len(self.private_subnets_cidr)

        if pub_count != az_count:
            raise ValueError(f"Public Subnets ({pub_count}) must match Availability Zones ({az_count}).")
        
        if priv_count != az_count:
            raise ValueError(f"Private Subnets ({priv_count}) must match Availability Zones ({az_count}).")
            
        return self

# ------------------------------------------------------------------------------
# SECURITY GROUP MODELS (Production Ready)
# ------------------------------------------------------------------------------

# ... existing imports ...

class SecurityRule(BaseModel):
    description: str = Field(..., min_length=5, description="Audit reason")
    protocol: str = Field(..., pattern=r"^(tcp|udp|icmp|-1)$")
    from_port: int = Field(..., ge=-1, le=65535)
    to_port: int = Field(..., ge=-1, le=65535)
    cidr: str

    @field_validator('cidr')
    @classmethod
    def check_cidr(cls, v):
        # Allow 0.0.0.0/0 but validate format
        if v == "0.0.0.0/0": return v
        return validate_cidr_format(v)

    @model_validator(mode='after')
    def validate_range(self):
        # Logic: Port range sanity check
        if self.protocol in ['tcp', 'udp']:
            if self.from_port > self.to_port:
                raise ValueError(f"from_port ({self.from_port}) cannot be greater than to_port ({self.to_port})")
        return self

class SecurityConfig(BaseModel):
    name: str = Field(..., min_length=3, pattern=r"^[a-z0-9-]+$")
    description: str = Field(..., min_length=10)
    linked_vpc_module: str
    ingress_rules: List[SecurityRule]
    egress_rules: List[SecurityRule]
# ... existing imports and classes ...

class ComputeConfig(BaseModel):
    name: str = Field(..., min_length=3, pattern=r"^[a-z0-9-]+$")
    instance_type: str
    ami_id: str
    root_volume_size: int = Field(..., ge=8, le=1000, description="Size in GB")
    subnet_tier: str # 'public' or 'private'
    
    # Dependencies
    linked_networking_module: str
    linked_security_module: str

    @field_validator('ami_id')
    @classmethod
    def validate_ami(cls, v):
        if not v.startswith("ami-"):
            raise ValueError("AMI ID must start with 'ami-' (e.g., ami-0c7217cdde317cfec)")
        if len(v) < 10:
             raise ValueError("AMI ID looks too short.")
        return v

    @field_validator('instance_type')
    @classmethod
    def validate_type(cls, v):
        # Basic check for 'class.size' format
        import re
        if not re.match(r"^[a-z]+\d+[a-z]*\.[a-z0-9]+$", v):
            raise ValueError("Invalid Instance Type format (e.g., t3.micro, m5.large)")
        return v
    
    @field_validator('subnet_tier')
    @classmethod
    def validate_tier(cls, v):
        if v not in ['public', 'private']:
             raise ValueError("Subnet tier must be 'public' or 'private'")
        return v

# ... existing imports ...

class EKSConfig(BaseModel):
    cluster_name: str = Field(..., min_length=3, pattern=r"^[a-z0-9-]+$")
    k8s_version: str
    instance_type: str
    desired_size: int = Field(..., ge=1)
    min_size: int = Field(..., ge=1)
    max_size: int = Field(..., ge=1)
    
    # Dependencies
    linked_networking_module: str
    linked_security_module: str

    @field_validator('k8s_version')
    @classmethod
    def validate_version(cls, v):
        allowed = ['1.29', '1.30', '1.31']
        if v not in allowed:
            raise ValueError(f"Unsupported K8s version. Choose from: {allowed}")
        return v

    @model_validator(mode='after')
    def check_scaling(self):
        if self.min_size > self.max_size:
            raise ValueError("Min size cannot be greater than Max size.")
        if not (self.min_size <= self.desired_size <= self.max_size):
            raise ValueError("Desired size must be between Min and Max.")
        return self
    
class ECSConfig(BaseModel):
    cluster_name: str = Field(..., min_length=3)
    container_image: str
    container_port: int = Field(..., ge=1, le=65535)
    cpu: int = Field(..., description="CPU units (e.g., 256, 512, 1024)")
    memory: int = Field(..., description="Memory in MB (e.g., 512, 1024, 2048)")
    desired_count: int = Field(..., ge=1, le=10)
    
    # Dependencies
    linked_networking_module: str
    linked_security_module: str

    @field_validator('cpu')
    @classmethod
    def validate_cpu(cls, v):
        allowed = [256, 512, 1024, 2048, 4096]
        if v not in allowed:
            raise ValueError(f"Fargate CPU must be one of {allowed} (where 1024 = 1 vCPU)")
        return v

    @model_validator(mode='after')
    def check_memory_ratio(self):
        # AWS Fargate Rule: Memory must be at least 2x CPU usually
        if self.memory < (self.cpu * 2):
            raise ValueError("Memory (MB) usually needs to be at least 2x the CPU value for Fargate.")
        return self

class DatabaseConfig(BaseModel):
    identifier: str = Field(..., min_length=3, pattern=r"^[a-z0-9-]+$")
    engine: str
    engine_version: str
    instance_class: str
    allocated_storage: int = Field(..., ge=20, description="Min 20GB for AWS RDS")
    username: str = Field(..., min_length=4)
    password: str = Field(..., min_length=8)
    multi_az: bool = False
    
    # Dependencies
    linked_networking_module: str
    linked_security_module: str

    @field_validator('engine')
    @classmethod
    def validate_engine(cls, v):
        if v not in ['postgres', 'mysql']:
            raise ValueError("Supported engines: postgres, mysql")
        return v

    @field_validator('username')
    @classmethod
    def validate_user(cls, v):
        if v.lower() == 'admin':
            raise ValueError("Username 'admin' is reserved/unsafe. Choose another.")
        return v
    
class StorageConfig(BaseModel):
    bucket_name: str = Field(..., min_length=3, max_length=63)
    versioning_enabled: bool = True
    force_destroy: bool = False # Safety first

    @field_validator('bucket_name')
    @classmethod
    def validate_s3_name(cls, v):
        # Rule 1: No Uppercase
        if v != v.lower():
            raise ValueError("S3 Bucket names must be lowercase.")
        
        # Rule 2: No Underscores
        if "_" in v:
            raise ValueError("S3 Bucket names cannot contain underscores (_). Use hyphens (-).")
            
        # Rule 3: Regex (Start/End with letter/number, allowed chars: a-z, 0-9, -)
        import re
        if not re.match(r'^[a-z0-9][a-z0-9-]*[a-z0-9]$', v):
            raise ValueError("Invalid format. Must start/end with lowercase letter or number.")
            
        return v

class AutoScalingConfig(BaseModel):
    name: str = Field(..., min_length=3)
    instance_type: str
    ami_id: str
    min_size: int = Field(..., ge=1)
    max_size: int = Field(..., ge=1)
    desired_capacity: int = Field(..., ge=1)
    
    # Dependencies
    linked_networking_module: str
    linked_security_module: str

    @field_validator('ami_id')
    @classmethod
    def validate_ami(cls, v):
        if not v.startswith("ami-"):
            raise ValueError("AMI ID must start with 'ami-'")
        return v

    @model_validator(mode='after')
    def check_scaling_logic(self):
        if self.min_size > self.max_size:
            raise ValueError("Min size cannot be greater than Max size.")
        if not (self.min_size <= self.desired_capacity <= self.max_size):
            raise ValueError("Desired capacity must be between Min and Max.")
        return self
    
class LoadBalancerConfig(BaseModel):
    name: str = Field(..., min_length=3)
    internal: bool = False
    target_port: int = Field(..., ge=1, le=65535)
    health_check_path: str = Field(..., pattern=r"^/.*")
    certificate_arn: str # We assume user pastes a valid ARN
    
    # Dependencies
    linked_networking_module: str
    linked_security_module: str
    linked_storage_module: str # For Access Logs

    @field_validator('certificate_arn')
    @classmethod
    def validate_cert(cls, v):
        if not v.startswith("arn:aws:acm:"):
            raise ValueError("Invalid ACM ARN. Must start with 'arn:aws:acm:...'")
        return v

class MonitoringConfig(BaseModel):
    monitor_type: str # 'ec2', 'database', 'loadbalancer'
    alert_email: str
    threshold: int = 80
    resource_name: str # Friendly name for display
    
    # Dynamic Fields (Depending on selection)
    target_id: Optional[str] = "" # InstanceID or DB ID
    lb_arn_suffix: Optional[str] = "" # Special for ALB
    
    # Dependencies used for linking
    linked_module_id: str

    @field_validator('monitor_type')
    @classmethod
    def validate_type(cls, v):
        if v not in ['ec2', 'database', 'loadbalancer']:
            raise ValueError("Invalid monitor type.")
        return v

    @field_validator('alert_email')
    @classmethod
    def validate_email(cls, v):
        import re
        if not re.match(r"^[\w\.-]+@[\w\.-]+\.\w+$", v):
            raise ValueError("Invalid email format")
        return v

class IAMConfig(BaseModel):
    role_name: str = Field(..., min_length=3, pattern=r"^[a-z0-9-]+$")
    trusted_service: str
    custom_actions: List[str] = []
    custom_resources: List[str] = ["*"]
    managed_policy_arns: List[str] = []

    @field_validator('trusted_service')
    @classmethod
    def validate_service(cls, v):
        if not v.endswith(".amazonaws.com"):
            raise ValueError("Service must end with .amazonaws.com (e.g. ec2.amazonaws.com)")
        return v

    @field_validator('custom_actions')
    @classmethod
    def validate_actions(cls, v):
        import re
        # Regex for "service:Action" or "service:*"
        pattern = r"^[a-zA-Z0-9]+:[\w\*]+$"
        for action in v:
            if not re.match(pattern, action):
                raise ValueError(f"Invalid action format: '{action}'. Use 'service:Action' (e.g. s3:ListBuckets)")
        return v