"""
Observability - Full logs, traces, and contribution scoring.

This module provides transparency into workspace execution.
For devs, admins, or laypeople to trust the system, they need to see:
- Full logs
- Full message trace
- Graphs of interactions
- Final merged output
- Reasoning summary
- Agent contribution scores

Opaque systems die in enterprise.
"""

from dataclasses import dataclass, field
from typing import Any, Optional
from datetime import datetime
import json
from collections import defaultdict

from synqed.message_model import StructuredMessage, MessageType


@dataclass
class AgentContribution:
    """Contribution metrics for an agent."""
    agent_id: str
    agent_name: str
    
    # Quantitative
    total_messages: int = 0
    total_tokens: int = 0
    suggestions_made: int = 0
    critiques_given: int = 0
    votes_cast: int = 0
    
    # Qualitative
    influence_score: float = 0.0  # How much other agents referenced this agent
    quality_score: float = 0.0  # Based on acceptance rate of suggestions
    participation_score: float = 0.0  # Relative to other agents
    
    # Details
    message_ids: list[str] = field(default_factory=list)
    referenced_by: list[str] = field(default_factory=list)  # Other agents who referenced
    
    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "agent_id": self.agent_id,
            "agent_name": self.agent_name,
            "total_messages": self.total_messages,
            "total_tokens": self.total_tokens,
            "suggestions_made": self.suggestions_made,
            "critiques_given": self.critiques_given,
            "votes_cast": self.votes_cast,
            "influence_score": self.influence_score,
            "quality_score": self.quality_score,
            "participation_score": self.participation_score,
        }


@dataclass
class InteractionGraph:
    """Graph of agent interactions."""
    nodes: dict[str, dict[str, Any]] = field(default_factory=dict)  # agent_id -> node data
    edges: list[dict[str, Any]] = field(default_factory=list)  # interaction edges
    
    def add_interaction(
        self,
        from_agent: str,
        to_agent: str,
        interaction_type: str,
        message_id: str,
        weight: float = 1.0,
    ) -> None:
        """Add an interaction edge."""
        self.edges.append({
            "from": from_agent,
            "to": to_agent,
            "type": interaction_type,
            "message_id": message_id,
            "weight": weight,
        })
    
    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary for visualization."""
        return {
            "nodes": [
                {"id": node_id, **data}
                for node_id, data in self.nodes.items()
            ],
            "edges": self.edges,
        }


@dataclass
class ExecutionTrace:
    """Detailed trace of workspace execution."""
    workspace_id: str
    start_time: datetime
    end_time: Optional[datetime] = None
    
    # Timeline
    events: list[dict[str, Any]] = field(default_factory=list)
    
    # Phase transitions
    phase_transitions: list[dict[str, Any]] = field(default_factory=list)
    
    # Round summaries
    round_summaries: list[dict[str, Any]] = field(default_factory=list)
    
    def add_event(
        self,
        event_type: str,
        description: str,
        agent_id: Optional[str] = None,
        metadata: Optional[dict[str, Any]] = None,
    ) -> None:
        """Add an event to the trace."""
        self.events.append({
            "timestamp": datetime.now().isoformat(),
            "type": event_type,
            "description": description,
            "agent_id": agent_id,
            "metadata": metadata or {},
        })
    
    def add_phase_transition(
        self,
        from_phase: str,
        to_phase: str,
        round_number: int,
    ) -> None:
        """Record a phase transition."""
        self.phase_transitions.append({
            "timestamp": datetime.now().isoformat(),
            "from_phase": from_phase,
            "to_phase": to_phase,
            "round_number": round_number,
        })
    
    def add_round_summary(
        self,
        round_number: int,
        messages: int,
        tokens: int,
        duration_seconds: float,
        participants: list[str],
    ) -> None:
        """Add a round summary."""
        self.round_summaries.append({
            "round_number": round_number,
            "messages": messages,
            "tokens": tokens,
            "duration_seconds": duration_seconds,
            "participants": participants,
        })
    
    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "workspace_id": self.workspace_id,
            "start_time": self.start_time.isoformat(),
            "end_time": self.end_time.isoformat() if self.end_time else None,
            "total_events": len(self.events),
            "total_phases": len(self.phase_transitions),
            "total_rounds": len(self.round_summaries),
            "events": self.events,
            "phase_transitions": self.phase_transitions,
            "round_summaries": self.round_summaries,
        }


class ObservabilityCollector:
    """
    Collects and analyzes workspace execution data for observability.
    
    Provides:
    1. Full message logs
    2. Execution traces
    3. Interaction graphs
    4. Agent contribution scoring
    5. Reasoning summaries
    6. Exportable reports
    """
    
    def __init__(self, workspace_id: str):
        self.workspace_id = workspace_id
        self.start_time = datetime.now()
        
        # Data collection
        self.messages: list[StructuredMessage] = []
        self.trace = ExecutionTrace(workspace_id=workspace_id, start_time=self.start_time)
        self.graph = InteractionGraph()
        
        # Agent contributions
        self.contributions: dict[str, AgentContribution] = {}
        
        # Current phase tracking
        self.current_phase: Optional[str] = None
    
    def record_message(self, message: StructuredMessage) -> None:
        """Record a message."""
        self.messages.append(message)
        
        # Update contributions
        agent_id = message.sender_id
        if agent_id not in self.contributions:
            self.contributions[agent_id] = AgentContribution(
                agent_id=agent_id,
                agent_name=message.sender_name,
            )
        
        contrib = self.contributions[agent_id]
        contrib.total_messages += 1
        contrib.total_tokens += message.metadata.token_count
        contrib.message_ids.append(message.metadata.message_id)
        
        # Track specific message types
        if message.message_type == MessageType.SUGGESTION:
            contrib.suggestions_made += 1
        elif message.message_type == MessageType.CRITIQUE:
            contrib.critiques_given += 1
        elif message.message_type == MessageType.VOTE:
            contrib.votes_cast += 1
        
        # Track interactions (replies)
        if message.metadata.reply_to:
            # Find original message
            original = self._find_message(message.metadata.reply_to)
            if original and original.sender_id != agent_id:
                self.graph.add_interaction(
                    from_agent=agent_id,
                    to_agent=original.sender_id,
                    interaction_type=message.message_type.value,
                    message_id=message.metadata.message_id,
                )
                
                # Update influence
                if original.sender_id in self.contributions:
                    self.contributions[original.sender_id].referenced_by.append(agent_id)
        
        # Add event
        self.trace.add_event(
            event_type="message",
            description=f"{message.sender_name} sent {message.message_type.value}",
            agent_id=agent_id,
            metadata={
                "message_id": message.metadata.message_id,
                "message_type": message.message_type.value,
                "phase": message.metadata.phase,
                "tokens": message.metadata.token_count,
            },
        )
    
    def record_phase_change(
        self,
        from_phase: str,
        to_phase: str,
        round_number: int,
    ) -> None:
        """Record a phase transition."""
        self.trace.add_phase_transition(from_phase, to_phase, round_number)
        self.trace.add_event(
            event_type="phase_change",
            description=f"Phase changed: {from_phase} → {to_phase}",
            metadata={"from": from_phase, "to": to_phase, "round": round_number},
        )
        self.current_phase = to_phase
    
    def record_round_end(
        self,
        round_number: int,
        stats: dict[str, Any],
    ) -> None:
        """Record end of a round."""
        self.trace.add_round_summary(
            round_number=round_number,
            messages=stats.get("total_turns", 0),
            tokens=stats.get("total_tokens", 0),
            duration_seconds=stats.get("duration_seconds", 0),
            participants=list(stats.get("agent_turns", {}).keys()),
        )
        
        self.trace.add_event(
            event_type="round_end",
            description=f"Round {round_number} completed",
            metadata=stats,
        )
    
    def compute_contribution_scores(self) -> None:
        """Compute final contribution scores for all agents."""
        if not self.contributions:
            return
        
        # Compute participation scores (relative to average)
        total_messages = sum(c.total_messages for c in self.contributions.values())
        avg_messages = total_messages / len(self.contributions)
        
        for contrib in self.contributions.values():
            contrib.participation_score = contrib.total_messages / avg_messages if avg_messages > 0 else 0
        
        # Compute influence scores (based on how often referenced)
        max_references = max(
            len(c.referenced_by) for c in self.contributions.values()
        ) if self.contributions else 1
        
        for contrib in self.contributions.values():
            contrib.influence_score = (
                len(set(contrib.referenced_by)) / max_references
                if max_references > 0 else 0
            )
        
        # Compute quality scores (simplified: suggestions that got responses)
        for agent_id, contrib in self.contributions.items():
            suggestions = [
                msg for msg in self.messages
                if msg.sender_id == agent_id and msg.message_type == MessageType.SUGGESTION
            ]
            
            if not suggestions:
                contrib.quality_score = 0.5  # Neutral
                continue
            
            # Count how many suggestions got responses
            responded_to = 0
            for sugg in suggestions:
                sugg_id = sugg.metadata.message_id
                responses = [
                    msg for msg in self.messages
                    if msg.metadata.reply_to == sugg_id
                ]
                if responses:
                    responded_to += 1
            
            contrib.quality_score = responded_to / len(suggestions) if suggestions else 0.5
    
    def get_reasoning_summary(self) -> str:
        """Generate a reasoning summary of the workspace execution."""
        lines = [
            "REASONING SUMMARY",
            "=" * 80,
            "",
            f"Workspace: {self.workspace_id}",
            f"Duration: {self._format_duration()}",
            f"Total Messages: {len(self.messages)}",
            "",
            "KEY DECISIONS:",
        ]
        
        # Find key decision points (votes, final proposals)
        votes = [m for m in self.messages if m.message_type == MessageType.VOTE]
        if votes:
            lines.append("\nVotes Cast:")
            for vote in votes:
                option = vote.structured_data.get("option", "unknown")
                lines.append(f"  • {vote.sender_name} voted for: {option}")
        
        # Final proposals
        final_proposals = [
            m for m in self.messages
            if m.message_type == MessageType.FINAL_OUTPUT_PROPOSAL
        ]
        if final_proposals:
            lines.append("\nFinal Proposals:")
            for prop in final_proposals:
                confidence = prop.structured_data.get("confidence", 0)
                lines.append(f"  • {prop.sender_name} (confidence: {confidence:.2f})")
                lines.append(f"    {prop.content[:200]}...")
        
        # Critiques
        critiques = [m for m in self.messages if m.message_type == MessageType.CRITIQUE]
        if critiques:
            lines.append(f"\nCritiques Exchanged: {len(critiques)}")
        
        lines.extend([
            "",
            "AGENT CONTRIBUTIONS:",
        ])
        
        for contrib in sorted(
            self.contributions.values(),
            key=lambda c: c.participation_score,
            reverse=True
        ):
            lines.append(
                f"  • {contrib.agent_name}: "
                f"{contrib.total_messages} messages, "
                f"participation={contrib.participation_score:.2f}, "
                f"influence={contrib.influence_score:.2f}"
            )
        
        return "\n".join(lines)
    
    def get_interaction_summary(self) -> str:
        """Get summary of agent interactions."""
        lines = [
            "INTERACTION SUMMARY",
            "=" * 80,
            "",
        ]
        
        # Count interactions by type
        interaction_counts = defaultdict(int)
        for edge in self.graph.edges:
            key = f"{edge['from']} → {edge['to']}: {edge['type']}"
            interaction_counts[key] += 1
        
        if interaction_counts:
            lines.append("Interaction Patterns:")
            for pattern, count in sorted(
                interaction_counts.items(),
                key=lambda x: x[1],
                reverse=True
            )[:10]:  # Top 10
                lines.append(f"  • {pattern}: {count} times")
        else:
            lines.append("No direct interactions recorded")
        
        return "\n".join(lines)
    
    def export_full_report(self) -> dict[str, Any]:
        """Export a comprehensive observability report."""
        # Compute final scores
        self.compute_contribution_scores()
        
        return {
            "workspace_id": self.workspace_id,
            "start_time": self.start_time.isoformat(),
            "end_time": self.trace.end_time.isoformat() if self.trace.end_time else None,
            "duration_seconds": (
                (self.trace.end_time - self.start_time).total_seconds()
                if self.trace.end_time else None
            ),
            "statistics": {
                "total_messages": len(self.messages),
                "total_agents": len(self.contributions),
                "total_rounds": len(self.trace.round_summaries),
                "total_phases": len(self.trace.phase_transitions),
            },
            "contributions": {
                agent_id: contrib.to_dict()
                for agent_id, contrib in self.contributions.items()
            },
            "interaction_graph": self.graph.to_dict(),
            "execution_trace": self.trace.to_dict(),
            "message_log": [
                {
                    "timestamp": msg.metadata.timestamp.isoformat(),
                    "sender": msg.sender_name,
                    "type": msg.message_type.value,
                    "phase": msg.metadata.phase,
                    "content": msg.content[:200],  # Truncate for overview
                    "message_id": msg.metadata.message_id,
                }
                for msg in self.messages
            ],
            "reasoning_summary": self.get_reasoning_summary(),
            "interaction_summary": self.get_interaction_summary(),
        }
    
    def export_json(self) -> str:
        """Export report as JSON."""
        return json.dumps(self.export_full_report(), indent=2)
    
    def _find_message(self, message_id: str) -> Optional[StructuredMessage]:
        """Find a message by ID."""
        for msg in self.messages:
            if msg.metadata.message_id == message_id:
                return msg
        return None
    
    def _format_duration(self) -> str:
        """Format duration as human-readable string."""
        if not self.trace.end_time:
            return "In progress"
        
        duration = self.trace.end_time - self.start_time
        seconds = duration.total_seconds()
        
        if seconds < 60:
            return f"{seconds:.1f}s"
        elif seconds < 3600:
            return f"{seconds/60:.1f}m"
        else:
            return f"{seconds/3600:.1f}h"


def create_observability_report(
    workspace_id: str,
    messages: list[StructuredMessage],
    round_stats: list[dict[str, Any]],
    phase_transitions: list[tuple[str, str, int]],
    start_time: datetime,
    end_time: datetime,
) -> dict[str, Any]:
    """
    Create an observability report from workspace data.
    
    Args:
        workspace_id: Workspace ID
        messages: All messages
        round_stats: Statistics from each round
        phase_transitions: Phase transition history
        start_time: Start time
        end_time: End time
        
    Returns:
        Comprehensive observability report
    """
    collector = ObservabilityCollector(workspace_id)
    collector.start_time = start_time
    collector.trace.end_time = end_time
    
    # Record all messages
    for msg in messages:
        collector.record_message(msg)
    
    # Record round stats
    for i, stats in enumerate(round_stats):
        collector.record_round_end(i + 1, stats)
    
    # Record phase transitions
    for from_phase, to_phase, round_num in phase_transitions:
        collector.record_phase_change(from_phase, to_phase, round_num)
    
    # Generate report
    return collector.export_full_report()

