"""
Interaction Protocols - Structured collaboration patterns for multi-agent systems.

This module defines HOW agents collaborate. Each protocol enforces specific
rules about communication, turn-taking, and decision-making.

Protocols are the HEART of the workspace - they determine whether you get
emergent intelligence or just LLMs hallucinating at each other.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional, Callable
import logging

from synqed.message_model import StructuredMessage, MessageType, MessagePriority

logger = logging.getLogger(__name__)


class ProtocolType(Enum):
    """Available collaboration protocols."""
    DEBATE = "debate"  # Agents debate to reach best solution
    CONSENSUS = "consensus"  # Agents build consensus through discussion
    ROLE_BASED = "role_based"  # Agents follow role-specific responsibilities
    PLANNING = "planning"  # Agents collaboratively build a plan
    CRITIQUE_AND_REVISE = "critique_and_revise"  # Propose → Critique → Revise
    ROUND_ROBIN_PROPOSAL = "round_robin_proposal"  # Each agent proposes in turn
    PARALLEL_EXPLORATION = "parallel_exploration"  # Agents explore independently then synthesize


class ProtocolPhase(Enum):
    """Phases that protocols can go through."""
    INITIALIZATION = "initialization"
    PROPOSAL = "proposal"
    DISCUSSION = "discussion"
    CRITIQUE = "critique"
    REFINEMENT = "refinement"
    VOTING = "voting"
    CONSENSUS_BUILDING = "consensus_building"
    SYNTHESIS = "synthesis"
    FINALIZATION = "finalization"
    COMPLETED = "completed"


@dataclass
class ProtocolConfig:
    """Configuration for a protocol."""
    # Basic settings
    max_rounds: int = 5
    min_rounds: int = 1
    timeout_seconds: Optional[int] = None
    
    # Turn management
    max_turns_per_agent_per_round: int = 3  # Allow agents to contribute multiple times
    allow_interruptions: bool = False
    
    # Token constraints
    max_tokens_per_message: int = 1000
    max_total_tokens: Optional[int] = None
    
    # Termination conditions
    require_unanimous_consent: bool = False
    require_majority_vote: bool = False
    allow_early_termination: bool = True
    convergence_threshold: float = 0.8  # For consensus protocols
    
    # Agent constraints
    enforce_role_constraints: bool = True
    allow_off_topic: bool = False
    require_structured_responses: bool = True
    
    # Custom rules
    custom_rules: dict[str, Any] = field(default_factory=dict)


@dataclass
class ProtocolState:
    """Current state of protocol execution."""
    current_phase: ProtocolPhase = ProtocolPhase.INITIALIZATION
    current_round: int = 0
    total_messages: int = 0
    total_tokens: int = 0
    
    # Turn tracking
    turn_order: list[str] = field(default_factory=list)
    current_turn_index: int = 0
    turns_taken_this_round: dict[str, int] = field(default_factory=dict)
    
    # Progress tracking
    proposals_submitted: dict[str, str] = field(default_factory=dict)  # agent_id -> message_id
    votes_cast: dict[str, str] = field(default_factory=dict)  # agent_id -> vote
    critiques_given: dict[str, list[str]] = field(default_factory=dict)  # agent_id -> [message_ids]
    
    # Termination
    is_complete: bool = False
    early_termination_reason: Optional[str] = None
    
    # Custom state
    custom_state: dict[str, Any] = field(default_factory=dict)


class Protocol(ABC):
    """
    Base class for collaboration protocols.
    
    A protocol defines:
    1. What phases the collaboration goes through
    2. Who can speak when
    3. What message types are allowed in each phase
    4. When the protocol terminates
    5. How to synthesize final output
    """
    
    def __init__(self, config: ProtocolConfig):
        self.config = config
        self.state = ProtocolState()
        self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
    
    @abstractmethod
    def get_phases(self) -> list[ProtocolPhase]:
        """Return the phases this protocol goes through."""
        pass
    
    @abstractmethod
    def get_next_speaker(self, agents: list[str], messages: list[StructuredMessage]) -> Optional[str]:
        """
        Determine who should speak next.
        
        Args:
            agents: List of agent IDs in the workspace
            messages: Message history
            
        Returns:
            Agent ID who should speak next, or None if no one should speak
        """
        pass
    
    @abstractmethod
    def is_message_allowed(
        self,
        message: StructuredMessage,
        agent_id: str,
    ) -> tuple[bool, Optional[str]]:
        """
        Check if a message is allowed under the protocol rules.
        
        Args:
            message: The message to check
            agent_id: The agent trying to send the message
            
        Returns:
            (allowed, reason) - True if allowed, False with reason if not
        """
        pass
    
    @abstractmethod
    def should_advance_phase(self, messages: list[StructuredMessage]) -> bool:
        """
        Check if we should advance to the next phase.
        
        Args:
            messages: Message history
            
        Returns:
            True if we should advance to next phase
        """
        pass
    
    @abstractmethod
    def should_terminate(self, messages: list[StructuredMessage]) -> tuple[bool, Optional[str]]:
        """
        Check if the protocol should terminate.
        
        Args:
            messages: Message history
            
        Returns:
            (should_terminate, reason)
        """
        pass
    
    def advance_phase(self) -> None:
        """Advance to the next phase."""
        phases = self.get_phases()
        current_idx = phases.index(self.state.current_phase)
        
        if current_idx < len(phases) - 1:
            self.state.current_phase = phases[current_idx + 1]
            self.logger.info(f"Advanced to phase: {self.state.current_phase.value}")
        else:
            self.state.current_phase = ProtocolPhase.COMPLETED
            self.state.is_complete = True
            self.logger.info("Protocol completed")
    
    def advance_round(self) -> None:
        """Advance to the next round."""
        self.state.current_round += 1
        self.state.turns_taken_this_round.clear()
        self.state.current_turn_index = 0
        self.logger.info(f"Advanced to round: {self.state.current_round}")
    
    def record_message(self, message: StructuredMessage) -> None:
        """Record that a message was sent."""
        self.state.total_messages += 1
        self.state.total_tokens += message.metadata.token_count
        
        agent_id = message.sender_id
        self.state.turns_taken_this_round[agent_id] = \
            self.state.turns_taken_this_round.get(agent_id, 0) + 1
        
        # Track specific message types
        if message.message_type == MessageType.VOTE:
            option = message.structured_data.get("option")
            if option:
                self.state.votes_cast[agent_id] = option
        
        elif message.message_type == MessageType.SUGGESTION:
            self.state.proposals_submitted[agent_id] = message.metadata.message_id
        
        elif message.message_type == MessageType.CRITIQUE:
            if agent_id not in self.state.critiques_given:
                self.state.critiques_given[agent_id] = []
            self.state.critiques_given[agent_id].append(message.metadata.message_id)
    
    def check_constraints(self, message: StructuredMessage) -> tuple[bool, Optional[str]]:
        """Check if message violates protocol constraints."""
        # Token limit
        if message.metadata.token_count > self.config.max_tokens_per_message:
            return False, f"Message exceeds token limit ({message.metadata.token_count} > {self.config.max_tokens_per_message})"
        
        # Total tokens
        if self.config.max_total_tokens:
            if self.state.total_tokens + message.metadata.token_count > self.config.max_total_tokens:
                return False, f"Would exceed total token budget"
        
        # Turns per round
        agent_id = message.sender_id
        turns_taken = self.state.turns_taken_this_round.get(agent_id, 0)
        if turns_taken >= self.config.max_turns_per_agent_per_round:
            return False, f"Agent has already taken {turns_taken} turns this round"
        
        return True, None


# ============================================================================
# CRITIQUE-AND-REVISE PROTOCOL
# ============================================================================

class CritiqueAndReviseProtocol(Protocol):
    """
    Critique-and-Revise Protocol.
    
    Flow:
    1. PROPOSAL: Each agent proposes their solution
    2. CRITIQUE: Agents critique each other's proposals
    3. REFINEMENT: Agents refine based on critiques
    4. SYNTHESIS: Best ideas are combined into final output
    
    This is one of the most effective protocols for quality output.
    """
    
    def get_phases(self) -> list[ProtocolPhase]:
        return [
            ProtocolPhase.INITIALIZATION,
            ProtocolPhase.PROPOSAL,
            ProtocolPhase.CRITIQUE,
            ProtocolPhase.REFINEMENT,
            ProtocolPhase.SYNTHESIS,
            ProtocolPhase.COMPLETED,
        ]
    
    def get_next_speaker(
        self,
        agents: list[str],
        messages: list[StructuredMessage]
    ) -> Optional[str]:
        """Determine next speaker based on phase and turn order."""
        phase = self.state.current_phase
        
        # PROPOSAL phase: each agent proposes in turn
        if phase == ProtocolPhase.PROPOSAL:
            for agent_id in agents:
                if agent_id not in self.state.proposals_submitted:
                    return agent_id
            return None  # All proposals submitted
        
        # CRITIQUE phase: each agent critiques others
        if phase == ProtocolPhase.CRITIQUE:
            for agent_id in agents:
                critiques_given = len(self.state.critiques_given.get(agent_id, []))
                # Each agent should critique at least (N-1) proposals
                expected_critiques = len(agents) - 1
                if critiques_given < expected_critiques:
                    return agent_id
            return None  # All critiques given
        
        # REFINEMENT phase: agents refine in same order as proposals
        if phase == ProtocolPhase.REFINEMENT:
            # Check who has refined
            refined_agents = set()
            for msg in messages:
                if (msg.message_type == MessageType.SUGGESTION and
                    msg.metadata.phase == ProtocolPhase.REFINEMENT.value):
                    refined_agents.add(msg.sender_id)
            
            for agent_id in agents:
                if agent_id not in refined_agents:
                    return agent_id
            return None  # All refined
        
        return None
    
    def is_message_allowed(
        self,
        message: StructuredMessage,
        agent_id: str,
    ) -> tuple[bool, Optional[str]]:
        """Check if message type is appropriate for current phase."""
        # Check basic constraints
        allowed, reason = self.check_constraints(message)
        if not allowed:
            return False, reason
        
        phase = self.state.current_phase
        msg_type = message.message_type
        
        # PROPOSAL phase: only SUGGESTION messages
        if phase == ProtocolPhase.PROPOSAL:
            if msg_type != MessageType.SUGGESTION:
                return False, f"Only SUGGESTION messages allowed in PROPOSAL phase"
            if agent_id in self.state.proposals_submitted:
                return False, f"Agent has already submitted a proposal"
        
        # CRITIQUE phase: only CRITIQUE messages
        elif phase == ProtocolPhase.CRITIQUE:
            if msg_type != MessageType.CRITIQUE:
                return False, f"Only CRITIQUE messages allowed in CRITIQUE phase"
        
        # REFINEMENT phase: only SUGGESTION messages (refined)
        elif phase == ProtocolPhase.REFINEMENT:
            if msg_type != MessageType.SUGGESTION:
                return False, f"Only SUGGESTION messages allowed in REFINEMENT phase"
        
        return True, None
    
    def should_advance_phase(self, messages: list[StructuredMessage]) -> bool:
        """Advance phase when all agents have completed current phase."""
        phase = self.state.current_phase
        
        # Count unique agents in this phase
        agents_in_phase = set()
        for msg in messages:
            if msg.metadata.phase == phase.value:
                agents_in_phase.add(msg.sender_id)
        
        # In PROPOSAL: advance when all have proposed
        if phase == ProtocolPhase.PROPOSAL:
            return len(self.state.proposals_submitted) == len(self.state.turn_order)
        
        # In CRITIQUE: advance when all have critiqued
        if phase == ProtocolPhase.CRITIQUE:
            expected_critiques = len(self.state.turn_order) - 1
            for agent_id in self.state.turn_order:
                critiques = len(self.state.critiques_given.get(agent_id, []))
                if critiques < expected_critiques:
                    return False
            return True
        
        # In REFINEMENT: advance when all have refined
        if phase == ProtocolPhase.REFINEMENT:
            return len(agents_in_phase) == len(self.state.turn_order)
        
        return False
    
    def should_terminate(self, messages: list[StructuredMessage]) -> tuple[bool, Optional[str]]:
        """Terminate after synthesis phase."""
        if self.state.current_phase == ProtocolPhase.COMPLETED:
            return True, "Protocol completed successfully"
        
        # Check round limit
        if self.state.current_round >= self.config.max_rounds:
            return True, f"Reached maximum rounds ({self.config.max_rounds})"
        
        # Check token budget
        if self.config.max_total_tokens:
            if self.state.total_tokens >= self.config.max_total_tokens:
                return True, f"Reached token budget ({self.config.max_total_tokens})"
        
        return False, None


# ============================================================================
# DEBATE PROTOCOL
# ============================================================================

class DebateProtocol(Protocol):
    """
    Debate Protocol.
    
    Flow:
    1. PROPOSAL: Agents propose competing solutions
    2. DISCUSSION: Agents argue for their proposals and against others
    3. VOTING: Agents vote on the best proposal
    4. FINALIZATION: Winner is selected and refined
    
    Good for finding the best single solution among alternatives.
    """
    
    def get_phases(self) -> list[ProtocolPhase]:
        return [
            ProtocolPhase.INITIALIZATION,
            ProtocolPhase.PROPOSAL,
            ProtocolPhase.DISCUSSION,
            ProtocolPhase.VOTING,
            ProtocolPhase.FINALIZATION,
            ProtocolPhase.COMPLETED,
        ]
    
    def get_next_speaker(
        self,
        agents: list[str],
        messages: list[StructuredMessage]
    ) -> Optional[str]:
        """Round-robin speaking order in discussion, simultaneous in other phases."""
        phase = self.state.current_phase
        
        if phase == ProtocolPhase.DISCUSSION:
            # Round-robin
            if not self.state.turn_order:
                self.state.turn_order = agents
            
            # Get current speaker
            current_speaker = self.state.turn_order[self.state.current_turn_index]
            
            # Check if they've spoken
            turns_taken = self.state.turns_taken_this_round.get(current_speaker, 0)
            if turns_taken < self.config.max_turns_per_agent_per_round:
                return current_speaker
            
            # Move to next agent
            self.state.current_turn_index = (self.state.current_turn_index + 1) % len(self.state.turn_order)
            return self.state.turn_order[self.state.current_turn_index]
        
        # Other phases: any agent can speak
        return None
    
    def is_message_allowed(
        self,
        message: StructuredMessage,
        agent_id: str,
    ) -> tuple[bool, Optional[str]]:
        """Check message appropriateness for phase."""
        allowed, reason = self.check_constraints(message)
        if not allowed:
            return False, reason
        
        phase = self.state.current_phase
        msg_type = message.message_type
        
        if phase == ProtocolPhase.PROPOSAL:
            if msg_type != MessageType.SUGGESTION:
                return False, "Only SUGGESTION messages in PROPOSAL phase"
        
        elif phase == ProtocolPhase.DISCUSSION:
            if msg_type not in [MessageType.CRITIQUE, MessageType.OBSERVATION, MessageType.RESPONSE]:
                return False, "Only CRITIQUE, OBSERVATION, or RESPONSE in DISCUSSION"
        
        elif phase == ProtocolPhase.VOTING:
            if msg_type != MessageType.VOTE:
                return False, "Only VOTE messages in VOTING phase"
        
        return True, None
    
    def should_advance_phase(self, messages: list[StructuredMessage]) -> bool:
        """Advance based on participation and round limits."""
        phase = self.state.current_phase
        
        if phase == ProtocolPhase.PROPOSAL:
            # All agents proposed
            return len(self.state.proposals_submitted) == len(self.state.turn_order)
        
        elif phase == ProtocolPhase.DISCUSSION:
            # Minimum rounds met and everyone has spoken
            if self.state.current_round < self.config.min_rounds:
                return False
            
            # Check if everyone has participated
            agents_spoken = set(msg.sender_id for msg in messages 
                              if msg.metadata.phase == phase.value)
            return len(agents_spoken) == len(self.state.turn_order)
        
        elif phase == ProtocolPhase.VOTING:
            # All votes cast
            return len(self.state.votes_cast) == len(self.state.turn_order)
        
        return False
    
    def should_terminate(self, messages: list[StructuredMessage]) -> tuple[bool, Optional[str]]:
        """Terminate after finalization or if unanimous early."""
        if self.state.current_phase == ProtocolPhase.COMPLETED:
            return True, "Protocol completed"
        
        # Early termination: unanimous vote
        if (self.config.allow_early_termination and 
            self.state.current_phase == ProtocolPhase.VOTING):
            if len(set(self.state.votes_cast.values())) == 1 and \
               len(self.state.votes_cast) == len(self.state.turn_order):
                return True, "Unanimous vote reached"
        
        if self.state.current_round >= self.config.max_rounds:
            return True, f"Max rounds reached"
        
        return False, None


# ============================================================================
# CONSENSUS PROTOCOL
# ============================================================================

class ConsensusProtocol(Protocol):
    """
    Consensus Protocol.
    
    Flow:
    1. PROPOSAL: Agents propose ideas
    2. DISCUSSION: Agents discuss and find common ground
    3. CONSENSUS_BUILDING: Agents iteratively converge on shared solution
    4. FINALIZATION: Consensus is documented
    
    Good for building shared understanding and collaborative solutions.
    """
    
    def get_phases(self) -> list[ProtocolPhase]:
        return [
            ProtocolPhase.INITIALIZATION,
            ProtocolPhase.PROPOSAL,
            ProtocolPhase.DISCUSSION,
            ProtocolPhase.CONSENSUS_BUILDING,
            ProtocolPhase.FINALIZATION,
            ProtocolPhase.COMPLETED,
        ]
    
    def get_next_speaker(
        self,
        agents: list[str],
        messages: list[StructuredMessage]
    ) -> Optional[str]:
        """Free-form discussion - any agent can speak."""
        # In consensus, we want free-flowing discussion
        # Just return None to allow any agent to speak when they have something to contribute
        return None
    
    def is_message_allowed(
        self,
        message: StructuredMessage,
        agent_id: str,
    ) -> tuple[bool, Optional[str]]:
        """Allow most message types to facilitate consensus building."""
        allowed, reason = self.check_constraints(message)
        if not allowed:
            return False, reason
        
        phase = self.state.current_phase
        msg_type = message.message_type
        
        # Most message types are allowed in consensus
        # Just block inappropriate ones
        if phase == ProtocolPhase.CONSENSUS_BUILDING:
            if msg_type == MessageType.VOTE:
                return False, "No voting in consensus - we build agreement, not vote"
        
        return True, None
    
    def should_advance_phase(self, messages: list[StructuredMessage]) -> bool:
        """Advance when convergence threshold is met."""
        phase = self.state.current_phase
        
        # Proposal phase: all agents proposed
        if phase == ProtocolPhase.PROPOSAL:
            return len(self.state.proposals_submitted) == len(self.state.turn_order)
        
        # Discussion phase: minimum rounds + all participated
        if phase == ProtocolPhase.DISCUSSION:
            if self.state.current_round < self.config.min_rounds:
                return False
            agents_spoken = set(msg.sender_id for msg in messages 
                              if msg.metadata.phase == phase.value)
            return len(agents_spoken) == len(self.state.turn_order)
        
        # Consensus building: check for convergence
        if phase == ProtocolPhase.CONSENSUS_BUILDING:
            # Count ACKNOWLEDGMENT messages indicating agreement
            recent_messages = [m for m in messages[-10:] 
                             if m.message_type == MessageType.ACKNOWLEDGMENT]
            if len(recent_messages) >= len(self.state.turn_order) * self.config.convergence_threshold:
                return True
            
            # Or minimum rounds in this phase
            phase_rounds = sum(1 for m in messages if m.metadata.phase == phase.value)
            return phase_rounds >= self.config.min_rounds
        
        return False
    
    def should_terminate(self, messages: list[StructuredMessage]) -> tuple[bool, Optional[str]]:
        """Terminate when consensus is reached or rounds exhausted."""
        if self.state.current_phase == ProtocolPhase.COMPLETED:
            return True, "Consensus reached"
        
        if self.state.current_round >= self.config.max_rounds:
            return True, "Max rounds reached (consensus may be partial)"
        
        return False, None


# Protocol factory

def create_protocol(protocol_type: ProtocolType, config: Optional[ProtocolConfig] = None) -> Protocol:
    """
    Create a protocol instance.
    
    Args:
        protocol_type: Type of protocol to create
        config: Optional configuration (uses defaults if not provided)
        
    Returns:
        Protocol instance
    """
    if config is None:
        config = ProtocolConfig()
    
    if protocol_type == ProtocolType.CRITIQUE_AND_REVISE:
        return CritiqueAndReviseProtocol(config)
    elif protocol_type == ProtocolType.DEBATE:
        return DebateProtocol(config)
    elif protocol_type == ProtocolType.CONSENSUS:
        return ConsensusProtocol(config)
    else:
        raise ValueError(f"Unknown protocol type: {protocol_type}")

