"""
Pure tree-sitter AST-based code analysis for maximum precision.
NO REGEX FALLBACKS - Tree-sitter only for all supported languages.

This module provides accurate code analysis using tree-sitter parsers
for multiple languages, extracting imports, function calls, class definitions,
and other references from source code.
"""

from __future__ import annotations

import os
from typing import List, Dict, Optional
from dataclasses import dataclass

try:
    from tree_sitter_language_pack import get_parser
    TREE_SITTER_AVAILABLE = True
except ImportError:
    TREE_SITTER_AVAILABLE = False
    raise ImportError(
        "tree-sitter-language-pack is required for code analysis. "
        "Install it with: pip install tree-sitter-language-pack"
    )


@dataclass
class CodeReference:
    """Represents a code reference found during analysis."""
    type: str  # 'import', 'function_call', 'class_def', 'function_def'
    name: str
    priority: float  # 0.0 to 1.0, higher = more important
    context: Optional[str] = None


class FastCodeAnalyzer:
    """
    Pure tree-sitter AST-based code analyzer.
    Only supports languages with tree-sitter parsers.
    """

    def __init__(self):
        if not TREE_SITTER_AVAILABLE:
            raise RuntimeError("tree-sitter-languages is not available")
            
        self.cache: Dict[str, List[CodeReference]] = {}
        self._max_cache_size = 1000
        self.parsers: Dict[str, any] = {}
        self._init_parsers()

    def _init_parsers(self):
        """Initialize tree-sitter parsers for supported languages."""
        try:
            self.parsers['python'] = get_parser('python')
            self.parsers['javascript'] = get_parser('javascript')
            self.parsers['typescript'] = get_parser('typescript')
            self.parsers['java'] = get_parser('java')
            self.parsers['go'] = get_parser('go')
            self.parsers['rust'] = get_parser('rust')
        except Exception as e:
            raise RuntimeError(f"Failed to initialize tree-sitter parsers: {e}")

    def extract_references_fast(self, file_path: str) -> List[CodeReference]:
        """
        Extract all useful references from a file using pure tree-sitter AST parsing.
        Returns empty array for unsupported file types.

        Args:
            file_path: Path to the source file

        Returns:
            List of CodeReference objects found in the file
        """
        # Check cache first
        if file_path in self.cache:
            return self.cache[file_path]

        ext = os.path.splitext(file_path)[1].lower()

        try:
            if ext == '.py':
                refs = self._parse_python_ast(file_path)
            elif ext in ['.js', '.jsx', '.mjs', '.cjs']:
                refs = self._parse_javascript_ast(file_path)
            elif ext in ['.ts', '.tsx']:
                refs = self._parse_typescript_ast(file_path)
            elif ext == '.java':
                refs = self._parse_java_ast(file_path)
            elif ext == '.go':
                refs = self._parse_go_ast(file_path)
            elif ext == '.rs':
                refs = self._parse_rust_ast(file_path)
            else:
                # Unsupported file type - return empty array
                return []

            self._cache_references(file_path, refs)
            return refs

        except Exception as e:
            print(f"Error parsing {file_path}: {e}")
            return []

    def _parse_python_ast(self, file_path: str) -> List[CodeReference]:
        """Parse Python file using tree-sitter."""
        refs = []

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()

            tree = self.parsers['python'].parse(bytes(content, 'utf-8'))
            root_node = tree.root_node

            def walk(node):
                # Import statements
                if node.type == 'import_statement':
                    for child in node.children:
                        if child.type == 'dotted_name':
                            refs.append(CodeReference(
                                type='import',
                                name=child.text.decode('utf-8'),
                                priority=0.8,
                                context=f"import {child.text.decode('utf-8')}"
                            ))

                elif node.type == 'import_from_statement':
                    module_name = None
                    for child in node.children:
                        if child.type == 'dotted_name' and not module_name:
                            module_name = child.text.decode('utf-8')
                            refs.append(CodeReference(
                                type='import',
                                name=module_name,
                                priority=0.8,
                                context=f"from {module_name}"
                            ))
                        elif child.type == 'dotted_name' and module_name:
                            refs.append(CodeReference(
                                type='import',
                                name=child.text.decode('utf-8'),
                                priority=0.7,
                                context=f"from {module_name} import {child.text.decode('utf-8')}"
                            ))
                        elif child.type == 'identifier' and module_name:
                            refs.append(CodeReference(
                                type='import',
                                name=child.text.decode('utf-8'),
                                priority=0.7,
                                context=f"from {module_name} import {child.text.decode('utf-8')}"
                            ))

                # Class definitions
                elif node.type == 'class_definition':
                    for child in node.children:
                        if child.type == 'identifier':
                            refs.append(CodeReference(
                                type='class_def',
                                name=child.text.decode('utf-8'),
                                priority=0.9,
                                context=f"class {child.text.decode('utf-8')}"
                            ))
                            break

                # Function definitions
                elif node.type == 'function_definition':
                    for child in node.children:
                        if child.type == 'identifier':
                            func_name = child.text.decode('utf-8')
                            if not func_name.startswith('_'):
                                refs.append(CodeReference(
                                    type='function_def',
                                    name=func_name,
                                    priority=0.85,
                                    context=f"def {func_name}"
                                ))
                            break

                # Function calls
                elif node.type == 'call':
                    func_node = node.child_by_field_name('function')
                    if func_node:
                        if func_node.type == 'identifier':
                            func_name = func_node.text.decode('utf-8')
                            if not func_name.startswith('_'):
                                refs.append(CodeReference(
                                    type='function_call',
                                    name=func_name,
                                    priority=0.6
                                ))
                        elif func_node.type == 'attribute':
                            attr_node = func_node.child_by_field_name('attribute')
                            if attr_node:
                                refs.append(CodeReference(
                                    type='function_call',
                                    name=attr_node.text.decode('utf-8'),
                                    priority=0.6
                                ))

                # Recurse
                for child in node.children:
                    walk(child)

            walk(root_node)

        except Exception as e:
            print(f"Error in _parse_python_ast: {e}")

        return self._deduplicate_references(refs)

    def _parse_javascript_ast(self, file_path: str) -> List[CodeReference]:
        """Parse JavaScript file using tree-sitter."""
        refs = []

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()

            tree = self.parsers['javascript'].parse(bytes(content, 'utf-8'))
            refs = self._extract_js_refs_treesitter(tree.root_node)

        except Exception as e:
            print(f"Error in _parse_javascript_ast: {e}")

        return self._deduplicate_references(refs)

    def _parse_typescript_ast(self, file_path: str) -> List[CodeReference]:
        """Parse TypeScript file using tree-sitter."""
        refs = []

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()

            tree = self.parsers['typescript'].parse(bytes(content, 'utf-8'))
            refs = self._extract_js_refs_treesitter(tree.root_node)

        except Exception as e:
            print(f"Error in _parse_typescript_ast: {e}")

        return self._deduplicate_references(refs)

    def _extract_js_refs_treesitter(self, root_node) -> List[CodeReference]:
        """Extract references from JavaScript/TypeScript using tree-sitter."""
        refs = []

        def walk(node):
            # Import statements
            if node.type == 'import_statement':
                source_node = node.child_by_field_name('source')
                if source_node:
                    import_path = source_node.text.decode('utf-8').strip('"\'')
                    module_name = import_path.split('/')[-1].replace('.js', '').replace('.ts', '')
                    refs.append(CodeReference(
                        type='import',
                        name=module_name,
                        priority=0.8
                    ))

            # Require calls
            elif node.type == 'call_expression':
                func_node = node.child_by_field_name('function')
                if func_node and func_node.text.decode('utf-8') == 'require':
                    args_node = node.child_by_field_name('arguments')
                    if args_node:
                        for child in args_node.children:
                            if child.type == 'string':
                                require_path = child.text.decode('utf-8').strip('"\'')
                                module_name = require_path.split('/')[-1].replace('.js', '')
                                refs.append(CodeReference(
                                    type='import',
                                    name=module_name,
                                    priority=0.7
                                ))
                                break

            # Class declarations
            if node.type == 'class_declaration':
                name_node = node.child_by_field_name('name')
                if name_node:
                    refs.append(CodeReference(
                        type='class_def',
                        name=name_node.text.decode('utf-8'),
                        priority=0.9
                    ))

            # Function declarations
            elif node.type in ['function_declaration', 'method_definition']:
                name_node = node.child_by_field_name('name')
                if name_node:
                    func_name = name_node.text.decode('utf-8')
                    if not func_name.startswith('_'):
                        refs.append(CodeReference(
                            type='function_def',
                            name=func_name,
                            priority=0.85
                        ))

            # Arrow functions and const declarations
            elif node.type in ['lexical_declaration', 'variable_declaration']:
                for child in node.children:
                    if child.type == 'variable_declarator':
                        name_node = child.child_by_field_name('name')
                        value_node = child.child_by_field_name('value')
                        if name_node and value_node and value_node.type in ['arrow_function', 'function']:
                            refs.append(CodeReference(
                                type='function_def',
                                name=name_node.text.decode('utf-8'),
                                priority=0.85
                            ))

            # Function calls
            elif node.type == 'call_expression':
                func_node = node.child_by_field_name('function')
                if func_node:
                    func_name = ''
                    if func_node.type == 'identifier':
                        func_name = func_node.text.decode('utf-8')
                    elif func_node.type == 'member_expression':
                        prop_node = func_node.child_by_field_name('property')
                        if prop_node:
                            func_name = prop_node.text.decode('utf-8')
                    
                    if func_name and func_name not in ['if', 'for', 'while', 'switch', 'catch', 'require']:
                        refs.append(CodeReference(
                            type='function_call',
                            name=func_name,
                            priority=0.6
                        ))

            # Recurse
            for child in node.children:
                walk(child)

        walk(root_node)
        return refs

    def _parse_java_ast(self, file_path: str) -> List[CodeReference]:
        """Parse Java file using tree-sitter."""
        refs = []

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()

            tree = self.parsers['java'].parse(bytes(content, 'utf-8'))
            root_node = tree.root_node

            def walk(node):
                # Import declarations
                if node.type == 'import_declaration':
                    for child in node.children:
                        if child.type in ['scoped_identifier', 'identifier']:
                            full_path = child.text.decode('utf-8')
                            class_name = full_path.split('.')[-1]
                            refs.append(CodeReference(
                                type='import',
                                name=class_name,
                                priority=0.8
                            ))
                            break

                # Class declarations
                elif node.type == 'class_declaration':
                    name_node = node.child_by_field_name('name')
                    if name_node:
                        refs.append(CodeReference(
                            type='class_def',
                            name=name_node.text.decode('utf-8'),
                            priority=0.9
                        ))

                # Method declarations
                elif node.type == 'method_declaration':
                    name_node = node.child_by_field_name('name')
                    if name_node:
                        refs.append(CodeReference(
                            type='function_def',
                            name=name_node.text.decode('utf-8'),
                            priority=0.85
                        ))

                # Recurse
                for child in node.children:
                    walk(child)

            walk(root_node)

        except Exception as e:
            print(f"Error in _parse_java_ast: {e}")

        return self._deduplicate_references(refs)

    def _parse_go_ast(self, file_path: str) -> List[CodeReference]:
        """Parse Go file using tree-sitter."""
        refs = []

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()

            tree = self.parsers['go'].parse(bytes(content, 'utf-8'))
            root_node = tree.root_node

            def walk(node):
                # Import specs
                if node.type == 'import_spec':
                    path_node = node.child_by_field_name('path')
                    if path_node:
                        import_path = path_node.text.decode('utf-8').strip('"')
                        package_name = import_path.split('/')[-1]
                        refs.append(CodeReference(
                            type='import',
                            name=package_name,
                            priority=0.8
                        ))

                # Type declarations
                elif node.type == 'type_declaration':
                    for child in node.children:
                        if child.type == 'type_spec':
                            name_node = child.child_by_field_name('name')
                            if name_node:
                                refs.append(CodeReference(
                                    type='class_def',
                                    name=name_node.text.decode('utf-8'),
                                    priority=0.9
                                ))

                # Function declarations
                elif node.type == 'function_declaration':
                    name_node = node.child_by_field_name('name')
                    if name_node:
                        refs.append(CodeReference(
                            type='function_def',
                            name=name_node.text.decode('utf-8'),
                            priority=0.85
                        ))

                # Recurse
                for child in node.children:
                    walk(child)

            walk(root_node)

        except Exception as e:
            print(f"Error in _parse_go_ast: {e}")

        return self._deduplicate_references(refs)

    def _parse_rust_ast(self, file_path: str) -> List[CodeReference]:
        """Parse Rust file using tree-sitter."""
        refs = []

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()

            tree = self.parsers['rust'].parse(bytes(content, 'utf-8'))
            root_node = tree.root_node

            def walk(node):
                # Use declarations
                if node.type == 'use_declaration':
                    identifiers = []
                    
                    def collect_identifiers(n):
                        if n.type == 'identifier':
                            identifiers.append(n)
                        for child in n.children:
                            collect_identifiers(child)
                    
                    collect_identifiers(node)
                    
                    if identifiers:
                        last_name = identifiers[-1]
                        refs.append(CodeReference(
                            type='import',
                            name=last_name.text.decode('utf-8'),
                            priority=0.8
                        ))

                # Struct items
                elif node.type == 'struct_item':
                    name_node = node.child_by_field_name('name')
                    if name_node:
                        refs.append(CodeReference(
                            type='class_def',
                            name=name_node.text.decode('utf-8'),
                            priority=0.9
                        ))

                # Function items
                elif node.type == 'function_item':
                    name_node = node.child_by_field_name('name')
                    if name_node:
                        refs.append(CodeReference(
                            type='function_def',
                            name=name_node.text.decode('utf-8'),
                            priority=0.85
                        ))

                # Recurse
                for child in node.children:
                    walk(child)

            walk(root_node)

        except Exception as e:
            print(f"Error in _parse_rust_ast: {e}")

        return self._deduplicate_references(refs)

    def _deduplicate_references(self, refs: List[CodeReference]) -> List[CodeReference]:
        """Remove duplicate references, keeping highest priority."""
        seen = {}
        for ref in refs:
            key = (ref.type, ref.name)
            if key not in seen or seen[key].priority < ref.priority:
                seen[key] = ref

        unique_refs = list(seen.values())
        unique_refs.sort(key=lambda r: r.priority, reverse=True)
        return unique_refs

    def _cache_references(self, file_path: str, refs: List[CodeReference]):
        """Cache references with size limit."""
        if len(self.cache) >= self._max_cache_size:
            first_key = next(iter(self.cache))
            del self.cache[first_key]

        self.cache[file_path] = refs

    def clear_cache(self):
        """Clear the reference cache."""
        self.cache.clear()

    def get_top_references(
        self,
        file_paths: List[str],
        n: int = 8,
        ref_types: Optional[List[str]] = None
    ) -> List[CodeReference]:
        """
        Get top N references from multiple files.

        Args:
            file_paths: List of file paths to analyze
            n: Number of top references to return
            ref_types: Optional list of reference types to filter by

        Returns:
            Top N references sorted by priority
        """
        all_refs = []

        for file_path in file_paths:
            refs = self.extract_references_fast(file_path)
            if ref_types:
                refs = [r for r in refs if r.type in ref_types]
            all_refs.extend(refs)

        unique_refs = self._deduplicate_references(all_refs)
        return unique_refs[:n]
