"""Context management for iterative context retrieval."""

import json
from pathlib import Path
from typing import Dict, List, Set, Optional
from dataclasses import dataclass, asdict
from datetime import datetime
from collections import defaultdict
import ast
import re

from rispec.repository import RepositoryAnalyzer
from rispec.config import Config


@dataclass
class ContextEntry:
    """Represents a file/directory in the context."""
    path: str
    reason: str
    added_at: str
    source: str  # 'initial', 'expansion', 'manual', 'llm_suggested'


@dataclass
class ExplorationLog:
    """Log entry for context exploration."""
    timestamp: str
    action: str  # 'file_added', 'file_excluded', 'expansion_triggered', 'llm_query'
    file_path: Optional[str]
    reason: str
    details: Dict


@dataclass
class TaskSession:
    """Represents a task session with context management."""
    task_id: str
    repo_path: str
    description: str
    created_at: str
    context_files: List[ContextEntry]
    excluded_files: Set[str]
    exploration_log: List[ExplorationLog]
    current_state: str  # 'initializing', 'expanding', 'ready'


class ContextManager:
    """Manages iterative context retrieval for tasks."""
    
    def __init__(self, analyzer: RepositoryAnalyzer, task_id: str, description: str):
        """Initialize context manager for a task.
        
        Args:
            analyzer: RepositoryAnalyzer instance (must be indexed)
            task_id: Unique identifier for the task
            description: Task description/change request
        """
        if not analyzer.is_indexed():
            analyzer.index()
        
        self.analyzer = analyzer
        self.task_id = task_id
        self.description = description
        
        self.session = TaskSession(
            task_id=task_id,
            repo_path=str(analyzer.repo_path),
            description=description,
            created_at=datetime.now().isoformat(),
            context_files=[],
            excluded_files=set(),
            exploration_log=[],
            current_state='initializing'
        )
        
        # Initial context identification
        self._identify_initial_context()
    
    def _identify_initial_context(self):
        """Identify initial set of relevant files based on symbols, paths, and keywords."""
        keywords = self._extract_keywords(self.description)
        symbols = self._extract_symbols(self.description)
        
        # Find files matching keywords in paths
        for file_path in self.analyzer._source_files:
            relative_path = str(file_path.relative_to(self.analyzer.repo_path))
            
            # Check if path contains keywords
            if any(kw.lower() in relative_path.lower() for kw in keywords):
                self._add_to_context(relative_path, f"Path matches keyword: {keywords}", 'initial')
                continue
            
            # Check if file contains symbols
            if symbols:
                file_symbols = self._extract_file_symbols(file_path)
                if any(sym in file_symbols for sym in symbols):
                    self._add_to_context(relative_path, f"Contains symbol: {symbols}", 'initial')
        
        self.session.current_state = 'ready'
        self._log_exploration('initial_context_identified', None, f"Identified {len(self.session.context_files)} initial files")
    
    def _extract_keywords(self, text: str) -> List[str]:
        """Extract keywords from text."""
        # Simple keyword extraction (can be enhanced)
        words = re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)
        # Filter out common words
        stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should', 'could', 'may', 'might', 'must', 'can', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'what', 'which', 'who', 'when', 'where', 'why', 'how'}
        keywords = [w.lower() for w in words if w.lower() not in stop_words and len(w) > 2]
        return list(set(keywords))
    
    def _extract_symbols(self, text: str) -> List[str]:
        """Extract symbol names (function, class, module names) from text."""
        # Look for patterns like "function_name", "ClassName", "module.function"
        symbols = re.findall(r'\b[A-Z][a-zA-Z0-9_]*\b|\b[a-z_][a-zA-Z0-9_]*\b', text)
        # Filter to likely symbol names
        filtered = [s for s in symbols if not s.lower() in ['add', 'remove', 'update', 'create', 'delete', 'get', 'set', 'make', 'use']]
        return list(set(filtered))
    
    def _extract_file_symbols(self, file_path: Path) -> Set[str]:
        """Extract symbols (functions, classes) from a file."""
        symbols = set()
        
        if file_path.suffix != '.py':
            return symbols
        
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()
                tree = ast.parse(content, filename=str(file_path))
                
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        symbols.add(node.name)
                    elif isinstance(node, ast.ClassDef):
                        symbols.add(node.name)
        except (SyntaxError, UnicodeDecodeError):
            pass
        
        return symbols
    
    def _add_to_context(self, file_path: str, reason: str, source: str):
        """Add a file to context."""
        if file_path in self.session.excluded_files:
            return
        
        # Check if already in context
        if any(entry.path == file_path for entry in self.session.context_files):
            return
        
        entry = ContextEntry(
            path=file_path,
            reason=reason,
            added_at=datetime.now().isoformat(),
            source=source
        )
        self.session.context_files.append(entry)
        self._log_exploration('file_added', file_path, reason)
    
    def _log_exploration(self, action: str, file_path: Optional[str], reason: str, details: Optional[Dict] = None):
        """Log an exploration step."""
        log_entry = ExplorationLog(
            timestamp=datetime.now().isoformat(),
            action=action,
            file_path=file_path,
            reason=reason,
            details=details or {}
        )
        self.session.exploration_log.append(log_entry)
    
    def expand_context(self, symbols: Optional[List[str]] = None, llm_suggestion: Optional[str] = None):
        """Expand context based on new symbols or LLM suggestion.
        
        Args:
            symbols: List of symbols that need to be resolved
            llm_suggestion: LLM suggestion for files to include
        """
        self.session.current_state = 'expanding'
        
        if symbols:
            self._expand_for_symbols(symbols)
        
        if llm_suggestion:
            self._expand_from_llm_suggestion(llm_suggestion)
        
        self.session.current_state = 'ready'
    
    def _expand_for_symbols(self, symbols: List[str]):
        """Expand context to include files that define or use these symbols."""
        for symbol in symbols:
            # Find files that define or use this symbol by searching source files
            for file_path in self.analyzer._source_files:
                relative_path = str(file_path.relative_to(self.analyzer.repo_path))
                
                # Skip if already in context or excluded
                if relative_path in self.session.excluded_files:
                    continue
                if any(entry.path == relative_path for entry in self.session.context_files):
                    continue
                
                # Check if file contains the symbol
                file_symbols = self._extract_file_symbols(file_path)
                if symbol in file_symbols:
                    self._add_to_context(relative_path, f"Defines symbol: {symbol}", 'expansion')
                    continue
                
                # Check if file imports/uses the symbol
                module_name = self._get_module_name_from_path(file_path)
                if module_name in self.analyzer._dependency_graph:
                    imports = self.analyzer._dependency_graph[module_name]
                    if symbol in imports or any(symbol in imp for imp in imports):
                        self._add_to_context(relative_path, f"Uses symbol: {symbol}", 'expansion')
        
        self._log_exploration('expansion_triggered', None, f"Expanded for symbols: {symbols}")
    
    def _get_module_name_from_path(self, file_path: Path) -> str:
        """Get module name from file path."""
        relative_path = file_path.relative_to(self.analyzer.repo_path)
        parts = relative_path.parts[:-1]  # Remove filename
        module_name = ".".join(parts) if parts else ""
        file_stem = file_path.stem
        if file_stem != "__init__":
            if module_name:
                module_name = f"{module_name}.{file_stem}"
            else:
                module_name = file_stem
        return module_name
    
    def _expand_from_llm_suggestion(self, suggestion: str):
        """Expand context based on LLM suggestion."""
        # Parse LLM suggestion for file paths
        # Simple pattern: look for file paths in the suggestion
        file_pattern = r'[\w/\\]+\.py'
        suggested_files = re.findall(file_pattern, suggestion)
        
        for file_path in suggested_files:
            # Normalize path
            normalized = file_path.replace('\\', '/')
            # Check if file exists in repo
            full_path = self.analyzer.repo_path / normalized
            if full_path.exists():
                self._add_to_context(normalized, f"LLM suggested: {suggestion[:50]}", 'llm_suggested')
        
        self._log_exploration('llm_query', None, suggestion)
    
    def detect_missing_information(self, code_context: str = "") -> Optional[str]:
        """Use LLM to detect missing information or unresolved symbols.
        
        Args:
            code_context: Optional code context to analyze
            
        Returns:
            LLM suggestion for files to include, or None if no missing info detected
        """
        if not Config.OPENAI_API_KEY:
            return None
        
        try:
            from openai import OpenAI
            
            client = OpenAI(api_key=Config.OPENAI_API_KEY)
            
            # Build prompt
            context_files_summary = "\n".join([f"- {entry.path}" for entry in self.session.context_files[:20]])
            prompt = f"""You are analyzing a code change request. Given the task description and current context files, identify any missing information or unresolved symbols.

Task Description: {self.description}

Current Context Files ({len(self.session.context_files)}):
{context_files_summary}

{code_context if code_context else ""}

Analyze if there are any:
1. Unresolved symbols (functions, classes, modules) mentioned in the task but not in context
2. Related files that should be included (dependencies, config files, tests)
3. Missing context needed to understand the change

If missing information is detected, suggest specific file paths (relative to repo root) that should be added to the context. Format your response as a list of file paths, one per line. If no missing information, respond with "None"."""
            
            response = client.chat.completions.create(
                model=Config.OPENAI_MODEL,
                messages=[
                    {
                        "role": "system",
                        "content": "You are a code analysis assistant. Identify missing context files needed for a code change."
                    },
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
                max_tokens=300,
                temperature=0.3
            )
            
            suggestion = response.choices[0].message.content.strip()
            if suggestion.lower() == "none" or not suggestion:
                return None
            
            return suggestion
        except Exception as e:
            # Log error but don't fail
            self._log_exploration('llm_query', None, f"LLM detection failed: {str(e)}")
            return None
    
    def force_include(self, file_path: str, reason: str = "Manually included"):
        """Force include a file/directory in context.
        
        Args:
            file_path: File path or directory path (relative to repo root)
            reason: Reason for inclusion
        """
        if Path(file_path).is_dir() or file_path.endswith('/'):
            # Include all files in directory
            for source_file in self.analyzer._source_files:
                relative_path = str(source_file.relative_to(self.analyzer.repo_path))
                if relative_path.startswith(file_path):
                    self._add_to_context(relative_path, reason, 'manual')
        else:
            self._add_to_context(file_path, reason, 'manual')
        
        self._log_exploration('file_added', file_path, f"Force included: {reason}")
    
    def exclude(self, file_path: str, reason: str = "Manually excluded"):
        """Exclude a file/directory from context.
        
        Args:
            file_path: File path or directory path (relative to repo root)
            reason: Reason for exclusion
        """
        # Normalize path (remove trailing slash for comparison)
        normalized_path = file_path.rstrip('/')
        is_directory = file_path.endswith('/')
        
        # Check if it's a directory by checking if any source files or context files start with this path
        if is_directory or any(
            str(source_file.relative_to(self.analyzer.repo_path)).startswith(normalized_path + '/')
            for source_file in self.analyzer._source_files
        ) or any(
            entry.path.startswith(normalized_path + '/')
            for entry in self.session.context_files
        ):
            # Exclude all files in directory (from both source files and context)
            for source_file in self.analyzer._source_files:
                relative_path = str(source_file.relative_to(self.analyzer.repo_path))
                if relative_path.startswith(normalized_path + '/') or relative_path == normalized_path:
                    self.session.excluded_files.add(relative_path)
            
            # Also exclude context files that match
            for entry in list(self.session.context_files):
                if entry.path.startswith(normalized_path + '/') or entry.path == normalized_path:
                    self.session.excluded_files.add(entry.path)
                    self.session.context_files.remove(entry)
        else:
            self.session.excluded_files.add(file_path)
            # Remove from context if present
            self.session.context_files = [
                e for e in self.session.context_files if e.path != file_path
            ]
        
        self._log_exploration('file_excluded', file_path, reason)
    
    def get_context_files(self) -> List[str]:
        """Get list of files currently in context."""
        return [entry.path for entry in self.session.context_files]
    
    def get_exploration_log(self) -> List[Dict]:
        """Get exploration log as dictionaries."""
        return [asdict(log) for log in self.session.exploration_log]
    
    def save_session(self):
        """Save session to JSON file."""
        Config.ensure_data_dir()
        session_file = Config.DATA_DIR / f"task_{self.task_id}.json"
        
        # Convert session to dict
        session_dict = {
            "task_id": self.session.task_id,
            "repo_path": self.session.repo_path,
            "description": self.session.description,
            "created_at": self.session.created_at,
            "context_files": [asdict(entry) for entry in self.session.context_files],
            "excluded_files": list(self.session.excluded_files),
            "exploration_log": [asdict(log) for log in self.session.exploration_log],
            "current_state": self.session.current_state
        }
        
        with open(session_file, 'w', encoding='utf-8') as f:
            json.dump(session_dict, f, indent=2)
    
    @classmethod
    def load_session(cls, analyzer: RepositoryAnalyzer, task_id: str) -> 'ContextManager':
        """Load session from JSON file.
        
        Args:
            analyzer: RepositoryAnalyzer instance
            task_id: Task ID to load
            
        Returns:
            ContextManager instance
        """
        session_file = Config.DATA_DIR / f"task_{task_id}.json"
        
        if not session_file.exists():
            raise FileNotFoundError(f"Session {task_id} not found")
        
        with open(session_file, 'r', encoding='utf-8') as f:
            session_dict = json.load(f)
        
        manager = cls.__new__(cls)
        manager.analyzer = analyzer
        manager.task_id = task_id
        manager.description = session_dict['description']
        
        # Reconstruct session
        manager.session = TaskSession(
            task_id=session_dict['task_id'],
            repo_path=session_dict['repo_path'],
            description=session_dict['description'],
            created_at=session_dict['created_at'],
            context_files=[ContextEntry(**entry) for entry in session_dict['context_files']],
            excluded_files=set(session_dict['excluded_files']),
            exploration_log=[ExplorationLog(**log) for log in session_dict['exploration_log']],
            current_state=session_dict['current_state']
        )
        
        return manager

