"""
Shared State Management - The workspace scratchpad.

This is the shared memory that all agents can read/write.
It's structured, not a free-for-all. Agents interact with it through
well-defined primitives.

Think of it as a structured database, not a chat log.
"""

from dataclasses import dataclass, field
from typing import Any, Optional
from datetime import datetime
import json
import logging

logger = logging.getLogger(__name__)


@dataclass
class StateEntry:
    """A single entry in shared state."""
    key: str
    value: Any
    created_by: str
    created_at: datetime = field(default_factory=datetime.now)
    modified_at: datetime = field(default_factory=datetime.now)
    version: int = 1
    metadata: dict[str, Any] = field(default_factory=dict)
    
    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "key": self.key,
            "value": self.value,
            "created_by": self.created_by,
            "created_at": self.created_at.isoformat(),
            "modified_at": self.modified_at.isoformat(),
            "version": self.version,
            "metadata": self.metadata,
        }


@dataclass
class AgentMemory:
    """Private memory for a single agent within the workspace."""
    agent_id: str
    memory: dict[str, Any] = field(default_factory=dict)
    observation_history: list[str] = field(default_factory=list)  # What the agent has observed
    action_history: list[str] = field(default_factory=list)  # What the agent has done
    
    def add_observation(self, observation: str) -> None:
        """Record an observation."""
        self.observation_history.append(observation)
    
    def add_action(self, action: str) -> None:
        """Record an action."""
        self.action_history.append(action)
    
    def set(self, key: str, value: Any) -> None:
        """Set a value in private memory."""
        self.memory[key] = value
    
    def get(self, key: str, default: Any = None) -> Any:
        """Get a value from private memory."""
        return self.memory.get(key, default)
    
    def to_dict(self) -> dict[str, Any]:
        """Convert to dictionary."""
        return {
            "agent_id": self.agent_id,
            "memory": self.memory,
            "observation_history": self.observation_history,
            "action_history": self.action_history,
        }


class SharedState:
    """
    Shared state manager for the workspace.
    
    This provides:
    1. Shared scratchpad - all agents can read/write
    2. Private agent memories - only the agent can access
    3. Versioning - track changes over time
    4. Access control - who can modify what
    5. Transactions - atomic updates
    
    This is NOT a message log. This is structured state.
    """
    
    def __init__(self):
        # Shared state accessible to all agents
        self._shared: dict[str, StateEntry] = {}
        
        # Private agent memories
        self._agent_memories: dict[str, AgentMemory] = {}
        
        # Access control
        self._read_only_keys: set[str] = set()
        self._owner_only_keys: dict[str, str] = {}  # key -> owner_agent_id
        
        # History
        self._history: list[tuple[datetime, str, str, Any]] = []  # (time, agent, key, value)
        
        self.logger = logging.getLogger(__name__)
    
    # ========================================================================
    # Shared State Operations
    # ========================================================================
    
    def set(
        self,
        key: str,
        value: Any,
        agent_id: str,
        metadata: Optional[dict[str, Any]] = None,
        overwrite: bool = True,
    ) -> bool:
        """
        Set a value in shared state.
        
        Args:
            key: State key
            value: Value to set
            agent_id: Agent making the change
            metadata: Optional metadata
            overwrite: Whether to overwrite existing value
            
        Returns:
            True if successful, False if denied
        """
        # Check access control
        if not self._check_write_access(key, agent_id):
            self.logger.warning(f"Agent {agent_id} denied write access to {key}")
            return False
        
        # Check if exists
        if key in self._shared and not overwrite:
            self.logger.warning(f"Key {key} already exists and overwrite=False")
            return False
        
        # Create or update entry
        now = datetime.now()
        if key in self._shared:
            entry = self._shared[key]
            entry.value = value
            entry.modified_at = now
            entry.version += 1
            if metadata:
                entry.metadata.update(metadata)
        else:
            entry = StateEntry(
                key=key,
                value=value,
                created_by=agent_id,
                created_at=now,
                modified_at=now,
                metadata=metadata or {},
            )
            self._shared[key] = entry
        
        # Record history
        self._history.append((now, agent_id, key, value))
        
        self.logger.debug(f"Agent {agent_id} set {key} = {value}")
        return True
    
    def get(self, key: str, default: Any = None) -> Any:
        """
        Get a value from shared state.
        
        Args:
            key: State key
            default: Default value if not found
            
        Returns:
            Value or default
        """
        if key in self._shared:
            return self._shared[key].value
        return default
    
    def get_entry(self, key: str) -> Optional[StateEntry]:
        """Get full state entry with metadata."""
        return self._shared.get(key)
    
    def has(self, key: str) -> bool:
        """Check if key exists."""
        return key in self._shared
    
    def delete(self, key: str, agent_id: str) -> bool:
        """
        Delete a key from shared state.
        
        Args:
            key: Key to delete
            agent_id: Agent requesting deletion
            
        Returns:
            True if deleted, False if denied or not found
        """
        if not self._check_write_access(key, agent_id):
            self.logger.warning(f"Agent {agent_id} denied delete access to {key}")
            return False
        
        if key in self._shared:
            del self._shared[key]
            self._history.append((datetime.now(), agent_id, key, None))
            self.logger.debug(f"Agent {agent_id} deleted {key}")
            return True
        
        return False
    
    def list_keys(self, prefix: Optional[str] = None) -> list[str]:
        """
        List all keys in shared state.
        
        Args:
            prefix: Optional prefix filter
            
        Returns:
            List of keys
        """
        keys = list(self._shared.keys())
        if prefix:
            keys = [k for k in keys if k.startswith(prefix)]
        return sorted(keys)
    
    def get_all(self) -> dict[str, Any]:
        """Get all shared state as dictionary."""
        return {key: entry.value for key, entry in self._shared.items()}
    
    # ========================================================================
    # Agent Private Memory
    # ========================================================================
    
    def get_agent_memory(self, agent_id: str) -> AgentMemory:
        """
        Get private memory for an agent.
        
        Args:
            agent_id: Agent ID
            
        Returns:
            AgentMemory instance
        """
        if agent_id not in self._agent_memories:
            self._agent_memories[agent_id] = AgentMemory(agent_id=agent_id)
        return self._agent_memories[agent_id]
    
    def set_agent_memory(
        self,
        agent_id: str,
        key: str,
        value: Any,
    ) -> None:
        """Set a value in agent's private memory."""
        memory = self.get_agent_memory(agent_id)
        memory.set(key, value)
    
    def get_agent_memory_value(
        self,
        agent_id: str,
        key: str,
        default: Any = None,
    ) -> Any:
        """Get a value from agent's private memory."""
        memory = self.get_agent_memory(agent_id)
        return memory.get(key, default)
    
    def record_observation(self, agent_id: str, observation: str) -> None:
        """Record an observation in agent's private memory."""
        memory = self.get_agent_memory(agent_id)
        memory.add_observation(observation)
    
    def record_action(self, agent_id: str, action: str) -> None:
        """Record an action in agent's private memory."""
        memory = self.get_agent_memory(agent_id)
        memory.add_action(action)
    
    # ========================================================================
    # Access Control
    # ========================================================================
    
    def set_read_only(self, key: str) -> None:
        """Mark a key as read-only (no one can modify)."""
        self._read_only_keys.add(key)
    
    def set_owner(self, key: str, agent_id: str) -> None:
        """Set an agent as owner of a key (only they can modify)."""
        self._owner_only_keys[key] = agent_id
    
    def _check_write_access(self, key: str, agent_id: str) -> bool:
        """Check if agent can write to key."""
        # Read-only check
        if key in self._read_only_keys:
            return False
        
        # Owner-only check
        if key in self._owner_only_keys:
            return self._owner_only_keys[key] == agent_id
        
        # Default: allow
        return True
    
    # ========================================================================
    # Advanced Operations
    # ========================================================================
    
    def increment(self, key: str, agent_id: str, delta: int = 1) -> Optional[int]:
        """
        Atomically increment a counter.
        
        Args:
            key: Counter key
            agent_id: Agent making the change
            delta: Amount to increment by
            
        Returns:
            New value, or None if failed
        """
        current = self.get(key, 0)
        if not isinstance(current, (int, float)):
            self.logger.error(f"Cannot increment non-numeric value at {key}")
            return None
        
        new_value = current + delta
        if self.set(key, new_value, agent_id):
            return new_value
        return None
    
    def append(self, key: str, agent_id: str, value: Any) -> bool:
        """
        Append to a list in shared state.
        
        Args:
            key: List key
            agent_id: Agent making the change
            value: Value to append
            
        Returns:
            True if successful
        """
        current = self.get(key, [])
        if not isinstance(current, list):
            self.logger.error(f"Cannot append to non-list at {key}")
            return False
        
        new_list = current + [value]
        return self.set(key, new_list, agent_id)
    
    def merge_dict(
        self,
        key: str,
        agent_id: str,
        updates: dict[str, Any],
    ) -> bool:
        """
        Merge updates into a dictionary in shared state.
        
        Args:
            key: Dict key
            agent_id: Agent making the change
            updates: Dictionary of updates to merge
            
        Returns:
            True if successful
        """
        current = self.get(key, {})
        if not isinstance(current, dict):
            self.logger.error(f"Cannot merge into non-dict at {key}")
            return False
        
        new_dict = {**current, **updates}
        return self.set(key, new_dict, agent_id)
    
    # ========================================================================
    # History and Versioning
    # ========================================================================
    
    def get_history(
        self,
        key: Optional[str] = None,
        agent_id: Optional[str] = None,
        limit: Optional[int] = None,
    ) -> list[tuple[datetime, str, str, Any]]:
        """
        Get history of changes.
        
        Args:
            key: Filter by key
            agent_id: Filter by agent
            limit: Limit number of results
            
        Returns:
            List of (timestamp, agent_id, key, value) tuples
        """
        history = self._history
        
        if key:
            history = [h for h in history if h[2] == key]
        
        if agent_id:
            history = [h for h in history if h[1] == agent_id]
        
        if limit:
            history = history[-limit:]
        
        return history
    
    def get_version(self, key: str) -> int:
        """Get version number of a key."""
        if key in self._shared:
            return self._shared[key].version
        return 0
    
    # ========================================================================
    # Serialization
    # ========================================================================
    
    def to_dict(self) -> dict[str, Any]:
        """Export state as dictionary."""
        return {
            "shared_state": {
                key: entry.to_dict() 
                for key, entry in self._shared.items()
            },
            "agent_memories": {
                agent_id: memory.to_dict()
                for agent_id, memory in self._agent_memories.items()
            },
            "read_only_keys": list(self._read_only_keys),
            "owner_only_keys": self._owner_only_keys,
        }
    
    def from_dict(self, data: dict[str, Any]) -> None:
        """Import state from dictionary."""
        # Load shared state
        for key, entry_dict in data.get("shared_state", {}).items():
            entry = StateEntry(
                key=entry_dict["key"],
                value=entry_dict["value"],
                created_by=entry_dict["created_by"],
                created_at=datetime.fromisoformat(entry_dict["created_at"]),
                modified_at=datetime.fromisoformat(entry_dict["modified_at"]),
                version=entry_dict["version"],
                metadata=entry_dict.get("metadata", {}),
            )
            self._shared[key] = entry
        
        # Load agent memories
        for agent_id, memory_dict in data.get("agent_memories", {}).items():
            memory = AgentMemory(agent_id=agent_id)
            memory.memory = memory_dict.get("memory", {})
            memory.observation_history = memory_dict.get("observation_history", [])
            memory.action_history = memory_dict.get("action_history", [])
            self._agent_memories[agent_id] = memory
        
        # Load access control
        self._read_only_keys = set(data.get("read_only_keys", []))
        self._owner_only_keys = data.get("owner_only_keys", {})
    
    def to_json(self) -> str:
        """Export as JSON string."""
        return json.dumps(self.to_dict(), indent=2, default=str)
    
    @classmethod
    def from_json(cls, json_str: str) -> "SharedState":
        """Create from JSON string."""
        data = json.loads(json_str)
        state = cls()
        state.from_dict(data)
        return state
    
    # ========================================================================
    # Context for Agents
    # ========================================================================
    
    def get_context_for_agent(
        self,
        agent_id: str,
        include_history: bool = False,
    ) -> dict[str, Any]:
        """
        Get context dictionary for an agent.
        
        This is what the agent sees when they query the workspace state.
        
        Args:
            agent_id: Agent ID
            include_history: Whether to include change history
            
        Returns:
            Context dictionary
        """
        context = {
            "shared_state": self.get_all(),
            "my_memory": self.get_agent_memory(agent_id).to_dict(),
        }
        
        if include_history:
            context["history"] = [
                {
                    "timestamp": ts.isoformat(),
                    "agent": agent,
                    "key": key,
                    "value": value,
                }
                for ts, agent, key, value in self.get_history(limit=50)
            ]
        
        return context
    
    def __repr__(self) -> str:
        """String representation."""
        return (
            f"SharedState(keys={len(self._shared)}, "
            f"agents={len(self._agent_memories)}, "
            f"history_entries={len(self._history)})"
        )

