"""Repository structure mapping and analysis."""

import json
import time
from pathlib import Path
from typing import Dict, List, Set, Optional
from collections import defaultdict
import ast
import re
from rispec.config import Config


class RepositoryAnalyzer:
    """Analyzes repository structure, dependencies, and call graphs."""
    
    # Supported file extensions
    PYTHON_EXTENSIONS = {'.py'}
    # Can be extended for other languages
    SUPPORTED_EXTENSIONS = PYTHON_EXTENSIONS
    
    def __init__(self, repo_path: Path):
        """Initialize analyzer with repository path.
        
        Args:
            repo_path: Path to the repository root
        """
        self.repo_path = Path(repo_path).resolve()
        if not self.repo_path.exists():
            raise ValueError(f"Repository path does not exist: {self.repo_path}")
        
        self._source_files: List[Path] = []
        self._dependency_graph: Dict[str, Set[str]] = defaultdict(set)
        self._call_graph: Dict[str, Set[str]] = defaultdict(set)
        self._module_info: Dict[str, Dict] = {}
        self._indexed = False
        self._indexing_time_seconds: Optional[float] = None
        self._total_loc: Optional[int] = None
        self._timeout_warning: bool = False
    
    def index(self) -> Dict:
        """Index the repository structure.
        
        Returns:
            Dictionary with indexing results and statistics
        """
        # Pre-check: Estimate LOC before indexing
        self._check_loc_before_indexing()
        
        # Measure indexing time
        start_time = time.time()
        
        self._source_files = self._find_source_files()
        self._build_dependency_graph()
        self._build_call_graph()
        
        end_time = time.time()
        self._indexing_time_seconds = end_time - start_time
        
        # Check for timeout
        if self._indexing_time_seconds > Config.INDEXING_TIMEOUT_SECONDS:
            self._timeout_warning = True
        
        self._indexed = True
        
        return {
            "total_files": len(self._source_files),
            "total_modules": len(self._module_info),
            "indexed": True,
            "indexing_time_seconds": round(self._indexing_time_seconds, 2),
            "total_loc": self._total_loc,
            "timeout_warning": self._timeout_warning
        }
    
    def _check_loc_before_indexing(self):
        """Check estimated LOC before indexing and warn if exceeds threshold."""
        # Quick estimate: count lines in all Python files
        total_loc = 0
        ignore_dirs = {'.git', '__pycache__', '.pytest_cache', 'venv', 'env', 
                      'node_modules', '.venv', 'build', 'dist', '.eggs'}
        
        for file_path in self.repo_path.rglob("*.py"):
            if any(ignore_dir in file_path.parts for ignore_dir in ignore_dirs):
                continue
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    total_loc += len([line for line in f if line.strip()])
            except (UnicodeDecodeError, IOError):
                pass
        
        self._total_loc = total_loc
        
        if total_loc > Config.MAX_REPO_LOC:
            # Warning will be handled by CLI
            pass
    
    def _find_source_files(self) -> List[Path]:
        """Find all source files in the repository.
        
        Returns:
            List of source file paths
        """
        source_files = []
        
        # Common directories to ignore
        ignore_dirs = {'.git', '__pycache__', '.pytest_cache', 'venv', 'env', 
                      'node_modules', '.venv', 'build', 'dist', '.eggs'}
        
        for ext in self.SUPPORTED_EXTENSIONS:
            for file_path in self.repo_path.rglob(f"*{ext}"):
                # Skip files in ignored directories
                if any(ignore_dir in file_path.parts for ignore_dir in ignore_dirs):
                    continue
                source_files.append(file_path)
        
        return sorted(source_files)
    
    def _build_dependency_graph(self):
        """Build dependency graph from imports."""
        for file_path in self._source_files:
            if file_path.suffix == '.py':
                module_name = self._get_module_name(file_path)
                imports = self._extract_imports(file_path)
                self._dependency_graph[module_name] = imports
                self._module_info[module_name] = {
                    "path": str(file_path.relative_to(self.repo_path)),
                    "imports": list(imports),
                    "line_count": self._count_lines(file_path)
                }
    
    def _build_call_graph(self):
        """Build call graph from function/class references."""
        for file_path in self._source_files:
            if file_path.suffix == '.py':
                module_name = self._get_module_name(file_path)
                calls = self._extract_function_calls(file_path)
                self._call_graph[module_name] = calls
    
    def _get_module_name(self, file_path: Path) -> str:
        """Convert file path to module name."""
        relative_path = file_path.relative_to(self.repo_path)
        parts = relative_path.parts[:-1]  # Remove filename
        module_name = ".".join(parts) if parts else ""
        file_stem = file_path.stem
        if file_stem != "__init__":
            if module_name:
                module_name = f"{module_name}.{file_stem}"
            else:
                module_name = file_stem
        return module_name
    
    def _extract_imports(self, file_path: Path) -> Set[str]:
        """Extract import statements from Python file."""
        imports = set()
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()
                tree = ast.parse(content, filename=str(file_path))
                
                for node in ast.walk(tree):
                    if isinstance(node, ast.Import):
                        for alias in node.names:
                            imports.add(alias.name.split('.')[0])
                    elif isinstance(node, ast.ImportFrom):
                        if node.module:
                            imports.add(node.module.split('.')[0])
        except (SyntaxError, UnicodeDecodeError):
            # Skip files that can't be parsed
            pass
        return imports
    
    def _extract_function_calls(self, file_path: Path) -> Set[str]:
        """Extract function calls from Python file."""
        calls = set()
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()
                tree = ast.parse(content, filename=str(file_path))
                
                for node in ast.walk(tree):
                    if isinstance(node, ast.Call):
                        if isinstance(node.func, ast.Name):
                            calls.add(node.func.id)
                        elif isinstance(node.func, ast.Attribute):
                            calls.add(node.func.attr)
        except (SyntaxError, UnicodeDecodeError):
            pass
        return calls
    
    def _count_lines(self, file_path: Path) -> int:
        """Count lines of code in file."""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                return len([line for line in f if line.strip()])
        except (UnicodeDecodeError, IOError):
            return 0
    
    def get_summary(self, top_n: int = 10) -> Dict:
        """Get summary of repository structure.
        
        Args:
            top_n: Number of top modules to include
            
        Returns:
            Dictionary with summary information
        """
        if not self._indexed:
            self.index()
        
        # Calculate most referenced modules
        reference_count = defaultdict(int)
        for module, imports in self._dependency_graph.items():
            for imp in imports:
                reference_count[imp] += 1
        
        top_referenced = sorted(
            reference_count.items(),
            key=lambda x: x[1],
            reverse=True
        )[:top_n]
        
        return {
            "total_files": len(self._source_files),
            "total_modules": len(self._module_info),
            "top_referenced_modules": [
                {"module": mod, "references": count}
                for mod, count in top_referenced
            ],
            "key_modules": list(self._module_info.keys())[:top_n]
        }
    
    def get_summary_text(self, top_n: int = 10) -> str:
        """Get human-readable text summary.
        
        Args:
            top_n: Number of top modules to include
            
        Returns:
            Formatted text summary
        """
        summary = self.get_summary(top_n)
        lines = [
            f"Repository Analysis Summary",
            f"{'=' * 50}",
            f"Total source files: {summary['total_files']}",
            f"Total modules: {summary['total_modules']}",
            "",
            "Top Referenced Modules:",
        ]
        
        for item in summary['top_referenced_modules']:
            lines.append(f"  - {item['module']}: {item['references']} references")
        
        lines.extend([
            "",
            "Key Modules:",
        ])
        
        for module in summary['key_modules']:
            lines.append(f"  - {module}")
        
        return "\n".join(lines)
    
    def get_summary_json(self, top_n: int = 10) -> str:
        """Get JSON summary.
        
        Args:
            top_n: Number of top modules to include
            
        Returns:
            JSON string with summary
        """
        summary = self.get_summary(top_n)
        return json.dumps(summary, indent=2)
    
    def get_dependency_graph(self) -> Dict[str, List[str]]:
        """Get dependency graph as dictionary.
        
        Returns:
            Dictionary mapping modules to their dependencies
        """
        if not self._indexed:
            self.index()
        
        return {
            module: list(deps)
            for module, deps in self._dependency_graph.items()
        }
    
    def is_indexed(self) -> bool:
        """Check if repository has been indexed."""
        return self._indexed
    
    def get_indexing_stats(self) -> Dict:
        """Get indexing statistics including timing and LOC.
        
        Returns:
            Dictionary with indexing statistics
        """
        return {
            "indexing_time_seconds": self._indexing_time_seconds,
            "total_loc": self._total_loc,
            "timeout_warning": self._timeout_warning,
            "exceeds_loc_threshold": self._total_loc > Config.MAX_REPO_LOC if self._total_loc else False
        }

