"""
Agent Constraints - Keep agents on task and prevent drift.

Without constraints, agents become toddlers in a sandbox eating sand.

This module enforces:
- Token limits per round
- Role-specific instruction conditioning
- Agent card constraints
- Workspace rules
- Automatic truncation of irrelevant messages
- Feedback loops to keep agents on-task
"""

from dataclasses import dataclass, field
from typing import Any, Optional, Callable
from enum import Enum
import re

from synqed.message_model import StructuredMessage, MessageType


class ConstraintViolation(Exception):
    """Raised when a constraint is violated."""
    pass


class ViolationSeverity(Enum):
    """Severity of constraint violation."""
    WARNING = "warning"  # Log but allow
    ERROR = "error"  # Block the message
    CRITICAL = "critical"  # Block and potentially remove agent


@dataclass
class ConstraintResult:
    """Result of constraint check."""
    passed: bool
    violations: list[str] = field(default_factory=list)
    severity: ViolationSeverity = ViolationSeverity.WARNING
    suggested_fix: Optional[str] = None


@dataclass
class AgentConstraints:
    """Constraints for an agent's behavior in the workspace."""
    agent_id: str
    
    # Token limits
    max_tokens_per_message: int = 1000
    max_tokens_per_round: int = 5000
    max_total_tokens: Optional[int] = None
    
    # Message limits
    max_messages_per_round: int = 5
    max_consecutive_messages: int = 2
    
    # Content rules
    forbidden_topics: list[str] = field(default_factory=list)
    required_keywords: list[str] = field(default_factory=list)
    forbidden_patterns: list[str] = field(default_factory=list)  # Regex patterns
    
    # Role constraints
    role: Optional[str] = None
    allowed_message_types: Optional[list[MessageType]] = None  # None = all allowed
    role_specific_instructions: Optional[str] = None
    
    # Behavior
    must_stay_on_task: bool = True
    allow_meta_discussion: bool = False
    require_constructive_feedback: bool = True
    
    # Custom
    custom_rules: dict[str, Any] = field(default_factory=dict)


@dataclass
class WorkspaceRules:
    """Global rules for the workspace."""
    
    # Content
    max_message_length: int = 5000
    min_message_length: int = 10
    require_respectful_language: bool = True
    
    # Collaboration
    require_all_agents_participate: bool = True
    min_participation_threshold: float = 0.3  # % of messages
    
    # Focus
    task_relevance_threshold: float = 0.5  # Min relevance to task
    auto_truncate_long_messages: bool = True
    truncation_length: int = 2000
    
    # Safety
    block_harmful_content: bool = True
    block_personal_information: bool = True
    
    # Custom
    custom_rules: dict[str, Any] = field(default_factory=dict)


class ConstraintEnforcer:
    """
    Enforces constraints on agent behavior.
    
    This prevents:
    1. Token abuse
    2. Off-topic rambling
    3. Role violations
    4. Unhelpful messages
    5. Dominance by single agent
    6. Meta-discussion loops
    """
    
    def __init__(
        self,
        workspace_rules: WorkspaceRules,
        agent_constraints: dict[str, AgentConstraints],
    ):
        self.workspace_rules = workspace_rules
        self.agent_constraints = agent_constraints
        
        # Tracking
        self.tokens_used: dict[str, int] = {}  # agent_id -> total tokens
        self.messages_sent: dict[str, int] = {}  # agent_id -> message count
        self.last_speakers: list[str] = []  # Recent speaker history
    
    def check_message(
        self,
        message: StructuredMessage,
        task: str,
    ) -> ConstraintResult:
        """
        Check if a message satisfies all constraints.
        
        Args:
            message: Message to check
            task: The workspace task (for relevance check)
            
        Returns:
            ConstraintResult with pass/fail and violations
        """
        violations = []
        severity = ViolationSeverity.WARNING
        
        agent_id = message.sender_id
        
        # Skip system messages
        if agent_id == "system":
            return ConstraintResult(passed=True)
        
        # Get agent constraints
        if agent_id not in self.agent_constraints:
            return ConstraintResult(passed=True)  # No constraints defined
        
        constraints = self.agent_constraints[agent_id]
        
        # Check workspace rules
        workspace_check = self._check_workspace_rules(message)
        if not workspace_check.passed:
            violations.extend(workspace_check.violations)
            severity = max(severity, workspace_check.severity, key=lambda s: s.value)
        
        # Check agent-specific constraints
        agent_check = self._check_agent_constraints(message, constraints)
        if not agent_check.passed:
            violations.extend(agent_check.violations)
            severity = max(severity, agent_check.severity, key=lambda s: s.value)
        
        # Check task relevance
        if constraints.must_stay_on_task:
            relevance_check = self._check_task_relevance(message, task)
            if not relevance_check.passed:
                violations.extend(relevance_check.violations)
                severity = max(severity, relevance_check.severity, key=lambda s: s.value)
        
        # Check participation balance
        balance_check = self._check_participation_balance(agent_id)
        if not balance_check.passed:
            violations.extend(balance_check.violations)
            severity = max(severity, balance_check.severity, key=lambda s: s.value)
        
        passed = len(violations) == 0 or severity == ViolationSeverity.WARNING
        
        return ConstraintResult(
            passed=passed,
            violations=violations,
            severity=severity,
        )
    
    def _check_workspace_rules(self, message: StructuredMessage) -> ConstraintResult:
        """Check workspace-level rules."""
        violations = []
        
        content = message.content
        rules = self.workspace_rules
        
        # Message length
        if len(content) > rules.max_message_length:
            violations.append(
                f"Message too long ({len(content)} > {rules.max_message_length})"
            )
        
        if len(content) < rules.min_message_length:
            violations.append(
                f"Message too short ({len(content)} < {rules.min_message_length})"
            )
        
        # Respectful language (simple check)
        if rules.require_respectful_language:
            disrespectful_patterns = [
                r'\b(stupid|dumb|idiot)\b',
                r'\b(shut up|be quiet)\b',
            ]
            for pattern in disrespectful_patterns:
                if re.search(pattern, content.lower()):
                    violations.append("Message contains potentially disrespectful language")
                    break
        
        # Harmful content (simple check)
        if rules.block_harmful_content:
            harmful_keywords = ['hack', 'exploit', 'attack', 'steal']
            if any(keyword in content.lower() for keyword in harmful_keywords):
                violations.append("Message may contain harmful content")
        
        severity = ViolationSeverity.ERROR if violations else ViolationSeverity.WARNING
        
        return ConstraintResult(
            passed=len(violations) == 0,
            violations=violations,
            severity=severity,
        )
    
    def _check_agent_constraints(
        self,
        message: StructuredMessage,
        constraints: AgentConstraints,
    ) -> ConstraintResult:
        """Check agent-specific constraints."""
        violations = []
        
        agent_id = message.sender_id
        
        # Token limits
        token_count = message.metadata.token_count
        
        if token_count > constraints.max_tokens_per_message:
            violations.append(
                f"Exceeds token limit per message ({token_count} > {constraints.max_tokens_per_message})"
            )
        
        # Message type allowed?
        if constraints.allowed_message_types:
            if message.message_type not in constraints.allowed_message_types:
                violations.append(
                    f"Message type {message.message_type.value} not allowed for this agent"
                )
        
        # Forbidden topics
        content_lower = message.content.lower()
        for topic in constraints.forbidden_topics:
            if topic.lower() in content_lower:
                violations.append(f"Contains forbidden topic: {topic}")
        
        # Forbidden patterns
        for pattern in constraints.forbidden_patterns:
            if re.search(pattern, message.content, re.IGNORECASE):
                violations.append(f"Matches forbidden pattern: {pattern}")
        
        # Required keywords (for certain message types)
        if constraints.required_keywords and message.message_type == MessageType.SUGGESTION:
            has_keyword = any(
                keyword.lower() in content_lower
                for keyword in constraints.required_keywords
            )
            if not has_keyword:
                violations.append(
                    f"Missing required keywords: {', '.join(constraints.required_keywords)}"
                )
        
        # Meta discussion check
        if not constraints.allow_meta_discussion:
            meta_keywords = ['protocol', 'process', 'how we', 'our approach', 'let\'s discuss']
            if any(keyword in content_lower for keyword in meta_keywords):
                violations.append("Meta-discussion not allowed (stay focused on the task)")
        
        # Constructive feedback (for critiques)
        if (constraints.require_constructive_feedback and 
            message.message_type == MessageType.CRITIQUE):
            # Check if critique includes suggestions
            has_suggestion = any(
                word in content_lower
                for word in ['suggest', 'recommend', 'could', 'might', 'perhaps']
            )
            if not has_suggestion:
                violations.append("Critique should include constructive suggestions")
        
        severity = ViolationSeverity.ERROR if len(violations) > 2 else ViolationSeverity.WARNING
        
        return ConstraintResult(
            passed=len(violations) == 0,
            violations=violations,
            severity=severity,
        )
    
    def _check_task_relevance(
        self,
        message: StructuredMessage,
        task: str,
    ) -> ConstraintResult:
        """Check if message is relevant to the task."""
        violations = []
        
        # Simple relevance check: do they share keywords?
        task_words = set(task.lower().split())
        message_words = set(message.content.lower().split())
        
        # Remove common words
        common_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for'}
        task_words -= common_words
        message_words -= common_words
        
        # Calculate overlap
        if not task_words:
            return ConstraintResult(passed=True)  # Can't check relevance
        
        overlap = len(task_words & message_words)
        relevance_score = overlap / len(task_words)
        
        threshold = self.workspace_rules.task_relevance_threshold
        
        if relevance_score < threshold:
            violations.append(
                f"Message may be off-topic (relevance: {relevance_score:.2f}, "
                f"threshold: {threshold:.2f})"
            )
        
        return ConstraintResult(
            passed=len(violations) == 0,
            violations=violations,
            severity=ViolationSeverity.WARNING,  # Just a warning for relevance
        )
    
    def _check_participation_balance(self, agent_id: str) -> ConstraintResult:
        """Check if agent is dominating the conversation."""
        violations = []
        
        if not self.messages_sent:
            return ConstraintResult(passed=True)
        
        total_messages = sum(self.messages_sent.values())
        agent_messages = self.messages_sent.get(agent_id, 0)
        
        if total_messages < 5:
            return ConstraintResult(passed=True)  # Too early to judge
        
        # Check if agent is speaking too much
        participation_rate = agent_messages / total_messages
        max_rate = 1 / len(self.messages_sent) * 2  # Allow up to 2x fair share
        
        if participation_rate > max_rate:
            violations.append(
                f"Agent is dominating conversation ({participation_rate:.1%} of messages)"
            )
        
        # Check consecutive messages
        if len(self.last_speakers) >= 3:
            if all(speaker == agent_id for speaker in self.last_speakers[-3:]):
                violations.append("Agent has sent 3 consecutive messages")
        
        return ConstraintResult(
            passed=len(violations) == 0,
            violations=violations,
            severity=ViolationSeverity.WARNING,
        )
    
    def record_message(self, message: StructuredMessage) -> None:
        """Record that a message was sent (for tracking)."""
        agent_id = message.sender_id
        
        if agent_id == "system":
            return
        
        # Update tracking
        self.tokens_used[agent_id] = self.tokens_used.get(agent_id, 0) + message.metadata.token_count
        self.messages_sent[agent_id] = self.messages_sent.get(agent_id, 0) + 1
        
        # Update last speakers
        self.last_speakers.append(agent_id)
        if len(self.last_speakers) > 10:
            self.last_speakers.pop(0)
    
    def get_conditioning_prompt(
        self,
        agent_id: str,
        task: str,
    ) -> str:
        """
        Get a conditioning prompt to guide the agent's behavior.
        
        This is prepended to the agent's context to keep them on track.
        """
        if agent_id not in self.agent_constraints:
            return ""
        
        constraints = self.agent_constraints[agent_id]
        
        lines = []
        
        # Role conditioning
        if constraints.role:
            lines.append(f"YOUR ROLE: {constraints.role}")
        
        if constraints.role_specific_instructions:
            lines.append(f"INSTRUCTIONS: {constraints.role_specific_instructions}")
        
        # Behavioral guidance
        lines.extend([
            "",
            "RULES:",
        ])
        
        if constraints.must_stay_on_task:
            lines.append(f"  • Stay focused on the task: {task}")
        
        if not constraints.allow_meta_discussion:
            lines.append("  • No meta-discussion - focus on solving the task")
        
        if constraints.require_constructive_feedback:
            lines.append("  • Provide constructive feedback with suggestions")
        
        if constraints.allowed_message_types:
            types = [t.value for t in constraints.allowed_message_types]
            lines.append(f"  • Use these message types: {', '.join(types)}")
        
        if constraints.forbidden_topics:
            lines.append(f"  • Avoid topics: {', '.join(constraints.forbidden_topics)}")
        
        lines.extend([
            "",
            f"TOKEN LIMIT: {constraints.max_tokens_per_message} per message",
            "",
        ])
        
        return "\n".join(lines)
    
    def truncate_if_needed(self, message: StructuredMessage) -> StructuredMessage:
        """Truncate message if it exceeds limits."""
        if not self.workspace_rules.auto_truncate_long_messages:
            return message
        
        max_length = self.workspace_rules.truncation_length
        
        if len(message.content) > max_length:
            truncated_content = message.content[:max_length] + "\n\n[Message truncated]"
            message.content = truncated_content
            message.metadata.token_count = len(truncated_content.split())
        
        return message
    
    def reset_round(self) -> None:
        """Reset per-round tracking."""
        # Keep total tokens but reset round-specific tracking
        self.last_speakers.clear()


def create_default_constraints(agent_id: str, role: Optional[str] = None) -> AgentConstraints:
    """Create default constraints for an agent."""
    return AgentConstraints(
        agent_id=agent_id,
        role=role,
        max_tokens_per_message=1000,
        max_messages_per_round=5,
        must_stay_on_task=True,
        allow_meta_discussion=False,
        require_constructive_feedback=True,
    )


def create_workspace_rules(strict: bool = True) -> WorkspaceRules:
    """
    Create workspace rules.
    
    Args:
        strict: If True, use stricter rules
    """
    if strict:
        return WorkspaceRules(
            max_message_length=2000,
            min_message_length=20,
            require_respectful_language=True,
            require_all_agents_participate=True,
            min_participation_threshold=0.3,
            task_relevance_threshold=0.5,
            auto_truncate_long_messages=True,
            block_harmful_content=True,
        )
    else:
        return WorkspaceRules(
            max_message_length=5000,
            min_message_length=10,
            require_respectful_language=True,
            task_relevance_threshold=0.3,
            auto_truncate_long_messages=False,
        )

