"""
Intelligent orchestrator with 4-turn AST-based code navigation.

Implements Fast Context-style multi-turn search using local AST analysis
and intelligent heuristics, without additional LLM calls.
"""

from __future__ import annotations

import os
import time
import re
from typing import List, Optional, Dict, Any
from concurrent.futures import ThreadPoolExecutor, as_completed

from .base import CandidateMatch, PipelineConfig, QueryResponse
from .orchestrator import PipelineOrchestrator
from .code_analyzer import FastCodeAnalyzer, CodeReference
from .search_context import SearchContext
from .bm25_scorer import score_candidates as bm25_score_candidates
from .flashrank_reranker import flashrank_rerank


class IntelligentOrchestrator(PipelineOrchestrator):
    """
    Enhanced orchestrator with 4-turn intelligent search.

    Extends the base PipelineOrchestrator with:
    - 4 strategic search turns (vs 2 keyword-only turns)
    - 8 focused parallel searches per turn (vs 128 scattered searches)
    - AST-based code navigation (follow imports, function calls)
    - Intelligent gap filling (tests, configs, related files)
    - Only 1 LLM call (rerank) - same as base orchestrator
    """

    def __init__(self, config: PipelineConfig):
        super().__init__(config)
        self.code_analyzer = FastCodeAnalyzer()
        # 3-turn strategy with 2 parallel searches each
        # Reduced parallelism - ripgrep is already internally parallelized
        self.parallel_per_turn = 2
        self.enable_ast = True

    def search(
        self,
        query: str,
        keywords: Dict[str, Any],
        directory: Optional[str] = None,
        file_patterns: Optional[List[str]] = None,
        max_results: Optional[int] = None,
        server_client=None
    ) -> QueryResponse:
        """
        Execute 4-turn intelligent search pipeline.

        3 turns × 8 parallel = 24 total searches
        Based on Windsurf research for optimal speed/accuracy balance.

        Args:
            query: Natural language search query
            keywords: Extracted keywords from agent
            directory: Directory to search in
            file_patterns: File patterns to constrain search
            max_results: Override default max results
            server_client: HTTP client for calling rerank endpoint

        Returns:
            QueryResponse with ranked results
        """
        start_time = time.time()

        # Initialize
        max_results = max_results or self.config.top_k_results
        directory = directory or "."
        search_context = SearchContext()

        # Prepare keywords
        from .base import ExtractedKeywords
        extracted_keywords = ExtractedKeywords(
            primary_terms=keywords.get("primary_terms", []),
            search_terms=keywords.get("search_terms", []),
            file_patterns=keywords.get("file_patterns", []),
            intent=keywords.get("intent", query)
        )

        if file_patterns is None:
            file_patterns = extracted_keywords.file_patterns

        # 2-turn search strategy (grep + AST)
        all_candidates: List[CandidateMatch] = []

        # TURN 1: Primary keyword search
        turn1_start = time.time()
        turn1_candidates = self._turn1_primary_search(
            keywords=extracted_keywords,
            directory=directory,
            file_patterns=file_patterns,
            search_context=search_context
        )
        all_candidates.extend(turn1_candidates)
        search_context.update_from_results(turn1_candidates)


        # TURN 2: AST reference following
        turn2_start = time.time()
        turn2_candidates = self._turn2_ast_references(
            keywords=extracted_keywords,
            previous_results=turn1_candidates,
            search_context=search_context,
            directory=directory,
            file_patterns=file_patterns
        )
        all_candidates.extend(turn2_candidates)
        search_context.update_from_results(turn2_candidates)


        # Smart deduplication
        dedup_start = time.time()
        unique_candidates = self._smart_deduplicate(
            all_candidates,
            extracted_keywords.primary_terms
        )


        # Stage 1: BM42 pre-scoring (limit to 100 candidates max)
        # Limit candidates to 100 before BM42 scoring to avoid slowdown
        if len(unique_candidates) > 100:
            unique_candidates = unique_candidates[:100]
        
        keywords_dict = {
            'primary_terms': extracted_keywords.primary_terms,
            'search_terms': [],
        }
        unique_candidates = bm25_score_candidates(query, unique_candidates, keywords_dict)

        # Take top 50 from BM42 for FlashRank reranking
        bm42_top = unique_candidates[:50]

        # Stage 2: FlashRank cross-encoder reranking (only top 10 for Cerebras)
        reranked_top = flashrank_rerank(query, bm42_top, top_k=10)

        # Use reranked results for LLM reranking
        unique_candidates = reranked_top

        # Read file contents
        spans = self.read_tool.read_spans_from_candidates(unique_candidates)

        # Re-rank using server (only LLM call, same as base orchestrator)
        if not server_client:
            raise ValueError("server_client is required for reranking")

        rerank_candidates = []
        for span in spans:
            rerank_candidates.append({
                "path": span.path,
                "start_line": span.start_line,
                "end_line": span.end_line,
                "content": span.text,
                "score": 0.0
            })

        llm_start = time.time()
        
        rerank_response = server_client.post(
            "/v1/rerank",
            json={
                "query": query,
                "candidates": rerank_candidates,
                "keywords": {
                    "primary_terms": extracted_keywords.primary_terms,
                    "search_terms": [],  # Removed - causes hallucination
                    "file_patterns": extracted_keywords.file_patterns,
                    "intent": extracted_keywords.intent
                },
                "max_results": max_results
            }
        )
        rerank_response.raise_for_status()
        rerank_data = rerank_response.json()

        # Convert to RankedResult
        from .base import RankedResult
        ranked_results = []
        for result in rerank_data.get('results', []):
            ranked_results.append(RankedResult(
                path=result.get('path', ''),
                score=result.get('score', 0.0),
                highlights=result.get('highlights', []),
                summary=result.get('summary'),
                file_info=result.get('file_info'),
                reasoning=result.get('reasoning')
            ))



        return QueryResponse(
            results=ranked_results,
            total_candidates=len(unique_candidates),
            query=query,
            overall_reasoning=rerank_data.get('overall_reasoning'),
            token_usage=rerank_data.get('token_usage')
        )

    def _turn1_primary_search(
        self,
        keywords: Any,
        directory: str,
        file_patterns: Optional[List[str]],
        search_context: SearchContext
    ) -> List[CandidateMatch]:
        """
        Turn 1: Primary keyword search (8 parallel).

        Searches for the most important keywords first.
        """
        best_terms = self._select_best_search_terms(keywords, n=self.parallel_per_turn)

        for term in best_terms:
            search_context.add_pattern(term)

        candidates = []
        with ThreadPoolExecutor(max_workers=self.parallel_per_turn) as executor:
            futures = []
            for term in best_terms:
                future = executor.submit(
                    self.grep_tool.search,
                    query=term,
                    directory=directory,
                    file_patterns=file_patterns,
                    case_sensitive=False,
                    context_lines=3
                )
                futures.append(future)

            for future in futures:
                try:
                    results = future.result(timeout=3)
                    candidates.extend(results)
                except Exception:
                    # Timeout or error - skip this search and continue
                    continue

        return candidates

    def _turn2_ast_references(
        self,
        keywords: Any,
        previous_results: List[CandidateMatch],
        search_context: SearchContext,
        directory: str,
        file_patterns: Optional[List[str]]
    ) -> List[CandidateMatch]:
        """
        Turn 2: AST reference following (8 parallel).

        Analyzes top files from Turn 1 and searches for imports,
        function calls, and class references.
        """
        top_files = search_context.get_high_quality_files(n=5)
        if not top_files:
            top_files = [r.path for r in previous_results[:5]]

        # Extract AST references
        all_references = []
        for file_path in top_files:
            refs = self.code_analyzer.extract_references_fast(file_path)
            all_references.extend(refs)

        # Score and select top references
        top_refs = self._score_references(all_references, search_context)[:self.parallel_per_turn]

        candidates = []
        with ThreadPoolExecutor(max_workers=self.parallel_per_turn) as executor:
            futures = []

            for ref in top_refs:
                search_pattern = self._reference_to_search_pattern(ref)
                if search_pattern and not search_context.has_seen_pattern(search_pattern):
                    search_context.add_pattern(search_pattern)
                    future = executor.submit(
                        self.grep_tool.search,
                        query=search_pattern,
                        directory=directory,
                        file_patterns=file_patterns,
                        case_sensitive=False,
                        context_lines=3
                    )
                    futures.append(future)

            for future in futures:
                try:
                    results = future.result(timeout=3)
                    candidates.extend(results)
                except Exception:
                    # Timeout or error - skip this search and continue
                    continue

        return candidates

    def _turn4_gap_filling(
        self,
        keywords: Any,
        search_context: SearchContext,
        directory: str
    ) -> List[CandidateMatch]:
        """
        Turn 4: Gap filling (8 parallel).

        Searches for tests, configs, and sibling directories.
        """
        all_searches = []
        primary_term = ""
        if hasattr(keywords, 'primary_terms') and keywords.primary_terms:
            primary_term = keywords.primary_terms[0]

        if not primary_term:
            return []

        # Skip test file searches - they often pollute results
        # Skip config file searches - rarely contain target code

        # Sibling directories (4 searches)
        sibling_dirs = search_context.get_sibling_directories(n=4)
        for sibling in sibling_dirs:
            all_searches.append((None, primary_term, sibling))

        # Missing file types (4 searches)
        missing_types = search_context.get_missing_file_types()
        for ext in missing_types[:4]:
            all_searches.append((f"*{ext}", primary_term))

        # Limit to 16
        all_searches = all_searches[:self.parallel_per_turn]

        candidates = []
        with ThreadPoolExecutor(max_workers=self.parallel_per_turn) as executor:
            futures = []

            for search in all_searches:
                if len(search) == 3:
                    # Sibling directory search
                    file_pattern, term, search_dir = search
                    future = executor.submit(
                        self.grep_tool.search,
                        query=term,
                        directory=search_dir,
                        file_patterns=[file_pattern] if file_pattern else None,
                        case_sensitive=False,
                        context_lines=3
                    )
                else:
                    # Standard search with file pattern
                    file_pattern, term = search
                    future = executor.submit(
                        self.grep_tool.search,
                        query=term,
                        directory=directory,
                        file_patterns=[file_pattern] if file_pattern else None,
                        case_sensitive=False,
                        context_lines=3
                    )
                futures.append(future)

            for future in futures:
                try:
                    results = future.result(timeout=3)
                    candidates.extend(results)
                except Exception:
                    # Timeout or error - skip this search and continue
                    continue

        return candidates

    def _select_best_search_terms(self, keywords: Any, n: int) -> List[str]:
        """
        Select best N search terms from primary_terms only.

        No LLM calls - uses deterministic scoring.
        Secondary terms (search_terms) are ignored - they cause hallucination.
        """
        primary = keywords.primary_terms if hasattr(keywords, 'primary_terms') else []

        # Only use primary_terms - secondary terms cause wrong results
        all_terms = primary
        scored_terms = []

        for term in all_terms:
            score = 0.0

            # Prefer longer, more specific terms
            score += min(len(term) / 20.0, 0.5)

            # Prefer terms with structure (camelCase, snake_case)
            if '_' in term or (term != term.lower() and term != term.upper()):
                score += 0.3

            # Prefer terms that look like identifiers
            if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', term):
                score += 0.2

            # Penalize very common words
            common_words = ['the', 'a', 'an', 'is', 'in', 'to', 'for', 'of', 'and', 'or']
            if term.lower() in common_words:
                score -= 0.8

            # Penalize very short terms
            if len(term) < 3:
                score -= 0.3

            scored_terms.append((term, score))

        # Sort by score and return top N
        scored_terms.sort(key=lambda x: x[1], reverse=True)
        return [term for term, score in scored_terms[:n] if score > 0]

    def _smart_deduplicate(
        self,
        candidates: List[CandidateMatch],
        keywords: List[str]
    ) -> List[CandidateMatch]:
        """
        Smart deduplication: one result per file per keyword.

        Files with multiple keywords get boosted in ranking.
        This reduces candidate count and improves relevance.

        Args:
            candidates: All candidates from grep
            keywords: List of search keywords

        Returns:
            Deduplicated candidates with keyword coverage boost
        """
        from collections import defaultdict

        # Group candidates by file
        file_candidates = defaultdict(list)
        for candidate in candidates:
            file_candidates[candidate.path].append(candidate)

        # For each file, keep best match per keyword
        result = []
        file_keyword_coverage = {}  # Track how many keywords each file matches

        for file_path, file_cands in file_candidates.items():
            # Track which keywords this file matches
            keywords_matched = set()
            best_per_keyword = {}  # keyword -> best candidate

            for candidate in file_cands:
                matched_text_lower = candidate.matched_text.lower()
                context_lower = (
                    (candidate.context_before or '') +
                    matched_text_lower +
                    (candidate.context_after or '')
                ).lower()

                # Find ALL keywords this candidate matches (for coverage tracking)
                # But assign to first keyword for deduplication
                matched_keyword = None
                for kw in keywords:
                    if kw.lower() in context_lower:
                        keywords_matched.add(kw)
                        if matched_keyword is None:
                            matched_keyword = kw

                if matched_keyword:
                    # Keep candidate with most context or first occurrence
                    if matched_keyword not in best_per_keyword:
                        best_per_keyword[matched_keyword] = candidate
                    else:
                        # Prefer candidate with more context
                        existing = best_per_keyword[matched_keyword]
                        existing_context = len(existing.context_before or '') + len(existing.context_after or '')
                        new_context = len(candidate.context_before or '') + len(candidate.context_after or '')
                        if new_context > existing_context:
                            best_per_keyword[matched_keyword] = candidate
                else:
                    # No keyword match, keep if it's the only one from this file
                    if 'no_keyword' not in best_per_keyword:
                        best_per_keyword['no_keyword'] = candidate

            # Calculate keyword coverage score for this file
            coverage = len(keywords_matched)
            file_keyword_coverage[file_path] = coverage

            # Add best candidates from this file with coverage boost
            # Deduplicate by checking for overlapping line ranges
            seen_ranges = []  # [(start, end), ...]

            def ranges_overlap(s1, e1, s2, e2):
                return s1 <= e2 and s2 <= e1

            for keyword, candidate in best_per_keyword.items():
                # Calculate line range for this candidate
                start_line = candidate.line_number
                # Estimate end line from context
                context_lines = len((candidate.context_after or '').split('\n')) if candidate.context_after else 0
                end_line = start_line + context_lines

                # Skip if overlaps with any seen range
                if any(ranges_overlap(start_line, end_line, s, e) for s, e in seen_ranges):
                    continue

                seen_ranges.append((start_line, end_line))

                # Boost score based on keyword coverage (more keywords = higher score)
                candidate.keyword_coverage = coverage
                result.append(candidate)

        # Sort by keyword coverage (files with more keywords first)
        result.sort(key=lambda c: c.keyword_coverage if hasattr(c, 'keyword_coverage') else 0, reverse=True)

        return result

    def _score_references(
        self,
        references: List[CodeReference],
        search_context: SearchContext
    ) -> List[CodeReference]:
        """Score references by relevance, avoiding duplicates."""
        scored = []

        for ref in references:
            # Skip if we've already searched for this
            if search_context.has_seen_pattern(ref.name):
                continue

            score = ref.priority

            # Boost imports and definitions
            if ref.type in ['import', 'class_def', 'function_def']:
                score += 0.2

            # Prefer longer names (more specific)
            score += min(len(ref.name) / 30.0, 0.15)

            scored.append((ref, score))

        # Sort by score
        scored.sort(key=lambda x: x[1], reverse=True)
        return [ref for ref, score in scored]

    def _reference_to_search_pattern(self, ref: CodeReference) -> Optional[str]:
        """Convert a code reference to a grep search pattern."""
        if ref.type == 'import':
            # Search for file or module name
            return ref.name
        elif ref.type == 'function_call':
            # Search for function definition
            return f"def {ref.name}"
        elif ref.type == 'function_def':
            # Search for usage of this function
            return f"{ref.name}("
        elif ref.type == 'class_def':
            # Search for class definition or usage
            return f"class {ref.name}"
        elif ref.type == 'identifier':
            return ref.name

        return None

    def _turn4_discover_related(
        self,
        keywords: Any,
        search_context: SearchContext,
        directory: str
    ) -> List[CandidateMatch]:
        """
        Turn 4: Discover related files we might have missed.

        Searches for test files, config files, documentation, and
        files in sibling directories.
        """
        gap_searches = []

        # Get primary term for context
        primary_term = ""
        if hasattr(keywords, 'primary_terms') and keywords.primary_terms:
            primary_term = keywords.primary_terms[0]

        # 1. Search for test files if we found implementation
        has_impl = any('test' not in f.lower() for f in search_context.files)
        if has_impl and primary_term:
            gap_searches.append({
                'pattern': primary_term,
                'file_patterns': ['*test*.py', '*test*.js', '*spec*.js', '*_test.go', '*Test.java']
            })

        # 2. Search in sibling directories
        sibling_dirs = search_context.get_sibling_directories(n=2)
        for sibling in sibling_dirs:
            if primary_term:
                gap_searches.append({
                    'pattern': primary_term,
                    'directory': sibling,
                    'file_patterns': None
                })

        # 3. Search for config files
        if primary_term:
            gap_searches.append({
                'pattern': primary_term,
                'file_patterns': ['*.config.js', '*.json', '*.yaml', '*.yml', '*.toml', '*.ini']
            })

        # 4. Search for documentation
        if primary_term:
            gap_searches.append({
                'pattern': primary_term,
                'file_patterns': ['*.md', '*.rst', '*.txt', 'README*']
            })

        # 5. Search for related file types we haven't found yet
        missing_types = search_context.get_missing_file_types()
        if missing_types and primary_term:
            gap_searches.append({
                'pattern': primary_term,
                'file_patterns': [f"*{ext}" for ext in missing_types[:3]]
            })

        # Execute up to 8 parallel searches
        candidates = []
        with ThreadPoolExecutor(max_workers=self.parallel_per_turn) as executor:
            futures = []

            for search in gap_searches[:self.parallel_per_turn]:
                search_dir = search.get('directory', directory)
                search_patterns = search.get('file_patterns')

                future = executor.submit(
                    self.grep_tool.search,
                    query=search['pattern'],
                    directory=search_dir,
                    file_patterns=search_patterns,
                    case_sensitive=False,
                    context_lines=3
                )
                futures.append(future)

            for future in futures:
                try:
                    results = future.result(timeout=3)
                    candidates.extend(results)
                except Exception:
                    continue

        return candidates
