"""
Scheduler - Determines who speaks when.

This is the traffic controller of the workspace. It enforces turn-taking,
prevents chaos, and ensures the protocol is followed.

Without a scheduler, you get a free-for-all chat room.
With a scheduler, you get structured collaboration.
"""

from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, Callable
import logging
from datetime import datetime, timedelta

from synqed.message_model import StructuredMessage, MessageType, MessagePriority
from synqed.protocols import Protocol, ProtocolPhase

logger = logging.getLogger(__name__)


class SchedulingStrategy(Enum):
    """Scheduling strategies for agent turns."""
    ROUND_ROBIN = "round_robin"  # Strict turn order
    PRIORITY_QUEUE = "priority_queue"  # Based on message priority
    PROTOCOL_DRIVEN = "protocol_driven"  # Protocol decides
    FREE_FOR_ALL = "free_for_all"  # Anyone can speak (controlled chaos)
    MODERATOR_CONTROLLED = "moderator_controlled"  # Explicit speaker selection


@dataclass
class AgentTurnState:
    """State of an agent's participation."""
    agent_id: str
    agent_name: str
    
    # Turn tracking
    total_turns: int = 0
    turns_this_round: int = 0
    last_turn_time: Optional[datetime] = None
    
    # Constraints
    is_active: bool = True
    is_blocked: bool = False
    blocked_until: Optional[datetime] = None
    blocked_reason: Optional[str] = None
    
    # Tokens
    tokens_used_this_round: int = 0
    total_tokens_used: int = 0
    
    # Queue
    has_pending_message: bool = False
    pending_priority: MessagePriority = MessagePriority.NORMAL
    
    def reset_round(self) -> None:
        """Reset per-round counters."""
        self.turns_this_round = 0
        self.tokens_used_this_round = 0
    
    def record_turn(self, tokens: int = 0) -> None:
        """Record that agent took a turn."""
        self.total_turns += 1
        self.turns_this_round += 1
        self.last_turn_time = datetime.now()
        self.tokens_used_this_round += tokens
        self.total_tokens_used += tokens
    
    def is_available(self) -> bool:
        """Check if agent can take a turn."""
        if not self.is_active:
            return False
        
        if self.is_blocked:
            # Check if block expired
            if self.blocked_until and datetime.now() >= self.blocked_until:
                self.is_blocked = False
                self.blocked_until = None
                self.blocked_reason = None
                return True
            return False
        
        return True


@dataclass
class SchedulerConfig:
    """Configuration for the scheduler."""
    strategy: SchedulingStrategy = SchedulingStrategy.PROTOCOL_DRIVEN
    
    # Turn limits
    max_consecutive_turns: int = 2  # Max turns before forcing rotation
    min_turn_interval_seconds: Optional[int] = None  # Cooldown between turns
    
    # Fairness
    enforce_turn_fairness: bool = True  # Ensure all agents get equal turns
    max_turn_imbalance: int = 2  # Max difference in turns between agents
    
    # Timeouts
    turn_timeout_seconds: Optional[int] = 30  # Max time for a turn
    round_timeout_seconds: Optional[int] = 300  # Max time for a round
    
    # Priority handling
    allow_priority_interrupts: bool = True  # Critical messages can interrupt
    priority_threshold: MessagePriority = MessagePriority.HIGH  # Min priority for interrupt
    
    # Custom
    custom_rules: dict[str, any] = field(default_factory=dict)


class Scheduler:
    """
    Scheduler for managing agent turns in the workspace.
    
    The scheduler:
    1. Determines which agent can speak next
    2. Enforces turn-taking rules
    3. Prevents any agent from dominating
    4. Handles priority messages
    5. Ensures fairness
    6. Works with the protocol to enforce collaboration patterns
    
    This is what prevents chaos.
    """
    
    def __init__(
        self,
        config: SchedulerConfig,
        protocol: Protocol,
        agents: list[tuple[str, str]],  # [(agent_id, agent_name), ...]
    ):
        self.config = config
        self.protocol = protocol
        
        # Track agent states
        self.agent_states: dict[str, AgentTurnState] = {}
        for agent_id, agent_name in agents:
            self.agent_states[agent_id] = AgentTurnState(
                agent_id=agent_id,
                agent_name=agent_name,
            )
        
        # Turn order
        self.turn_order: list[str] = [agent_id for agent_id, _ in agents]
        self.current_turn_index: int = 0
        
        # Round tracking
        self.current_round: int = 0
        self.round_start_time: Optional[datetime] = None
        
        # Last speaker
        self.last_speaker: Optional[str] = None
        self.consecutive_turns: int = 0
        
        self.logger = logging.getLogger(__name__)
    
    def get_next_speaker(
        self,
        messages: list[StructuredMessage],
        pending_messages: Optional[dict[str, StructuredMessage]] = None,
    ) -> Optional[str]:
        """
        Determine who should speak next.
        
        Args:
            messages: Message history
            pending_messages: Dict of agent_id -> pending message (for priority)
            
        Returns:
            Agent ID who should speak, or None if no one should speak
        """
        # Check if anyone is available
        available_agents = [
            agent_id for agent_id, state in self.agent_states.items()
            if state.is_available()
        ]
        
        # Filter out agents who have maxed their turns this round
        max_turns = self.protocol.config.max_turns_per_agent_per_round
        available_agents = [
            agent_id for agent_id in available_agents
            if self.protocol.state.turns_taken_this_round.get(agent_id, 0) < max_turns
        ]
        
        if not available_agents:
            self.logger.debug("No agents available to speak (all exhausted or unavailable)")
            return None
        
        # Strategy dispatch
        if self.config.strategy == SchedulingStrategy.PROTOCOL_DRIVEN:
            return self._get_next_speaker_protocol(messages, available_agents)
        
        elif self.config.strategy == SchedulingStrategy.ROUND_ROBIN:
            return self._get_next_speaker_round_robin(available_agents)
        
        elif self.config.strategy == SchedulingStrategy.PRIORITY_QUEUE:
            return self._get_next_speaker_priority(pending_messages or {}, available_agents)
        
        elif self.config.strategy == SchedulingStrategy.FREE_FOR_ALL:
            return self._get_next_speaker_free_for_all(available_agents)
        
        else:
            self.logger.error(f"Unknown scheduling strategy: {self.config.strategy}")
            return None
    
    def _get_next_speaker_protocol(
        self,
        messages: list[StructuredMessage],
        available_agents: list[str],
    ) -> Optional[str]:
        """Let the protocol determine next speaker."""
        # Ask protocol
        next_speaker = self.protocol.get_next_speaker(available_agents, messages)
        
        # Protocol said "anyone can speak"
        if next_speaker is None:
            # Use fairness to pick someone who hasn't spoken much
            if self.config.enforce_turn_fairness:
                return self._get_least_active_agent(available_agents)
            else:
                # Round robin fallback
                return self._get_next_speaker_round_robin(available_agents)
        
        # Protocol picked someone - validate they're available
        if next_speaker in available_agents:
            return next_speaker
        
        self.logger.warning(f"Protocol selected unavailable agent: {next_speaker}")
        return None
    
    def _get_next_speaker_round_robin(
        self,
        available_agents: list[str],
    ) -> Optional[str]:
        """Strict round-robin scheduling."""
        # Find next agent in turn order who is available
        attempts = 0
        max_attempts = len(self.turn_order)
        
        while attempts < max_attempts:
            current_agent = self.turn_order[self.current_turn_index]
            
            # Move to next
            self.current_turn_index = (self.current_turn_index + 1) % len(self.turn_order)
            attempts += 1
            
            # Check if available
            if current_agent in available_agents:
                return current_agent
        
        self.logger.warning("Round-robin: no available agents found")
        return None
    
    def _get_next_speaker_priority(
        self,
        pending_messages: dict[str, StructuredMessage],
        available_agents: list[str],
    ) -> Optional[str]:
        """Priority-based scheduling."""
        # Find highest priority pending message from available agents
        best_agent = None
        best_priority = MessagePriority.LOW
        
        for agent_id in available_agents:
            if agent_id in pending_messages:
                msg = pending_messages[agent_id]
                if msg.metadata.priority.value < best_priority.value:
                    best_priority = msg.metadata.priority
                    best_agent = agent_id
        
        # If we found a high-priority message, return that agent
        if best_agent:
            return best_agent
        
        # Otherwise, fairness
        return self._get_least_active_agent(available_agents)
    
    def _get_next_speaker_free_for_all(
        self,
        available_agents: list[str],
    ) -> Optional[str]:
        """Free-for-all: return None to allow anyone to speak."""
        # Still enforce some fairness
        if self.config.enforce_turn_fairness:
            # Block agents who have spoken too much
            min_turns = min(
                state.turns_this_round
                for state in self.agent_states.values()
                if state.agent_id in available_agents
            )
            
            # Only allow agents within max_turn_imbalance of min
            fair_agents = [
                agent_id for agent_id in available_agents
                if self.agent_states[agent_id].turns_this_round <= min_turns + self.config.max_turn_imbalance
            ]
            
            if fair_agents:
                return None  # Any of the fair agents can speak
        
        return None  # Anyone can speak
    
    def _get_least_active_agent(self, available_agents: list[str]) -> Optional[str]:
        """Get the agent who has spoken the least."""
        if not available_agents:
            return None
        
        return min(
            available_agents,
            key=lambda a: (
                self.agent_states[a].turns_this_round,
                self.agent_states[a].total_turns,
            )
        )
    
    def can_agent_speak(
        self,
        agent_id: str,
        expected_speaker: Optional[str],
        message: Optional[StructuredMessage] = None,
    ) -> tuple[bool, Optional[str]]:
        """
        Check if an agent is allowed to speak now.
        
        Args:
            agent_id: Agent wanting to speak
            expected_speaker: Who the scheduler expects to speak (or None for anyone)
            message: Optional message to check (for priority interrupts)
            
        Returns:
            (allowed, reason) - True if allowed, False with reason if not
        """
        # Check if agent exists
        if agent_id not in self.agent_states:
            return False, f"Unknown agent: {agent_id}"
        
        state = self.agent_states[agent_id]
        
        # Check if available
        if not state.is_available():
            reason = state.blocked_reason or "Agent is not available"
            return False, reason
        
        # Check turn interval
        if self.config.min_turn_interval_seconds and state.last_turn_time:
            elapsed = (datetime.now() - state.last_turn_time).total_seconds()
            if elapsed < self.config.min_turn_interval_seconds:
                return False, f"Turn cooldown: {self.config.min_turn_interval_seconds - elapsed:.1f}s remaining"
        
        # Check consecutive turns
        if self.last_speaker == agent_id:
            if self.consecutive_turns >= self.config.max_consecutive_turns:
                # Allow critical priority to interrupt
                if message and self.config.allow_priority_interrupts:
                    if message.metadata.priority.value <= self.config.priority_threshold.value:
                        self.logger.info(f"Priority interrupt allowed for {agent_id}")
                        return True, None
                
                return False, f"Max consecutive turns reached ({self.config.max_consecutive_turns})"
        
        # Check fairness
        if self.config.enforce_turn_fairness:
            min_turns = min(s.turns_this_round for s in self.agent_states.values())
            if state.turns_this_round > min_turns + self.config.max_turn_imbalance:
                # Allow critical priority to override
                if message and message.metadata.priority == MessagePriority.CRITICAL:
                    return True, None
                return False, f"Turn fairness: other agents need to speak"
        
        # Check expected speaker
        if expected_speaker is not None and expected_speaker != agent_id:
            # Allow priority interrupts
            if message and self.config.allow_priority_interrupts:
                if message.metadata.priority.value <= self.config.priority_threshold.value:
                    self.logger.info(f"Priority interrupt: {agent_id} speaks instead of {expected_speaker}")
                    return True, None
            
            return False, f"Expected {expected_speaker} to speak"
        
        # All checks passed
        return True, None
    
    def record_turn(
        self,
        agent_id: str,
        message: StructuredMessage,
    ) -> None:
        """
        Record that an agent took a turn.
        
        Args:
            agent_id: Agent who spoke
            message: Message they sent
        """
        if agent_id not in self.agent_states:
            self.logger.warning(f"Recording turn for unknown agent: {agent_id}")
            return
        
        state = self.agent_states[agent_id]
        state.record_turn(tokens=message.metadata.token_count)
        
        # Update consecutive turn tracking
        if self.last_speaker == agent_id:
            self.consecutive_turns += 1
        else:
            self.consecutive_turns = 1
            self.last_speaker = agent_id
        
        self.logger.debug(
            f"Agent {agent_id} took turn (round: {state.turns_this_round}, "
            f"total: {state.total_turns}, consecutive: {self.consecutive_turns})"
        )
    
    def start_round(self, round_number: int) -> None:
        """Start a new round."""
        self.current_round = round_number
        self.round_start_time = datetime.now()
        
        # Reset per-round counters
        for state in self.agent_states.values():
            state.reset_round()
        
        # Reset consecutive turns
        self.last_speaker = None
        self.consecutive_turns = 0
        
        self.logger.info(f"Started round {round_number}")
    
    def end_round(self) -> dict[str, any]:
        """
        End the current round and return stats.
        
        Returns:
            Dictionary of round statistics
        """
        stats = {
            "round_number": self.current_round,
            "duration_seconds": None,
            "agent_turns": {},
            "agent_tokens": {},
            "total_turns": 0,
            "total_tokens": 0,
        }
        
        if self.round_start_time:
            duration = datetime.now() - self.round_start_time
            stats["duration_seconds"] = duration.total_seconds()
        
        for agent_id, state in self.agent_states.items():
            stats["agent_turns"][agent_id] = state.turns_this_round
            stats["agent_tokens"][agent_id] = state.tokens_used_this_round
            stats["total_turns"] += state.turns_this_round
            stats["total_tokens"] += state.tokens_used_this_round
        
        self.logger.info(f"Ended round {self.current_round}: {stats}")
        return stats
    
    def block_agent(
        self,
        agent_id: str,
        reason: str,
        duration_seconds: Optional[int] = None,
    ) -> bool:
        """
        Block an agent from speaking.
        
        Args:
            agent_id: Agent to block
            reason: Reason for blocking
            duration_seconds: Optional duration (None = until explicitly unblocked)
            
        Returns:
            True if blocked, False if agent not found
        """
        if agent_id not in self.agent_states:
            return False
        
        state = self.agent_states[agent_id]
        state.is_blocked = True
        state.blocked_reason = reason
        
        if duration_seconds:
            state.blocked_until = datetime.now() + timedelta(seconds=duration_seconds)
        
        self.logger.warning(f"Blocked agent {agent_id}: {reason}")
        return True
    
    def unblock_agent(self, agent_id: str) -> bool:
        """Unblock an agent."""
        if agent_id not in self.agent_states:
            return False
        
        state = self.agent_states[agent_id]
        state.is_blocked = False
        state.blocked_until = None
        state.blocked_reason = None
        
        self.logger.info(f"Unblocked agent {agent_id}")
        return True
    
    def get_agent_state(self, agent_id: str) -> Optional[AgentTurnState]:
        """Get state for an agent."""
        return self.agent_states.get(agent_id)
    
    def get_statistics(self) -> dict[str, any]:
        """Get overall statistics."""
        return {
            "current_round": self.current_round,
            "total_agents": len(self.agent_states),
            "active_agents": sum(1 for s in self.agent_states.values() if s.is_active),
            "blocked_agents": sum(1 for s in self.agent_states.values() if s.is_blocked),
            "agent_states": {
                agent_id: {
                    "total_turns": state.total_turns,
                    "turns_this_round": state.turns_this_round,
                    "total_tokens": state.total_tokens_used,
                    "is_available": state.is_available(),
                }
                for agent_id, state in self.agent_states.items()
            },
        }
    
    def __repr__(self) -> str:
        """String representation."""
        return (
            f"Scheduler(strategy={self.config.strategy.value}, "
            f"round={self.current_round}, "
            f"agents={len(self.agent_states)})"
        )

