#!/usr/bin/env scripts/uv-run-script
# -*- mode: python -*-
# /// script
# requires-python = ">=3.8"
# dependencies = [
#     "riot>=0.20.1",
#     "ruamel.yaml>=0.17.21",
# ]
# ///
"""Test runner script for dd-trace-py.

This script helps developers run the appropriate test suites based on changed files.
It maps source files to their corresponding test suites and provides granular control
over which specific riot venvs to run.

Note: This runs entire test suites, not individual test files.
"""

import argparse
import fnmatch
import json
import os
import re
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Set, NamedTuple

# Add project root and tests to Python path to import suitespec and riotfile
ROOT = Path(__file__).parents[1]
sys.path.insert(0, str(ROOT))
sys.path.insert(0, str(ROOT / "tests"))

from tests.suitespec import get_patterns, get_suites
import riotfile

# Constants
TESTAGENT_URL = 'http://localhost:9126'


class RiotVenv(NamedTuple):
    """Represents a riot venv with its metadata."""
    number: int
    hash: str
    name: str
    python_version: str
    packages: str
    suite_name: str = ""  # Track which suite this venv belongs to
    command: str = ""  # The actual test command (e.g., pytest tests/contrib/flask/)

    def _normalize_package_name(self, name: str) -> List[str]:
        """Generate possible package name variations from venv name component.

        Examples:
            'django' -> ['django']
            'django_hosts' -> ['django-hosts', 'django_hosts']
            'psycopg2' -> ['psycopg2', 'psycopg2-binary']
        """
        variants = [name]

        # Try underscore to dash conversion (common in PyPI)
        if '_' in name:
            variants.append(name.replace('_', '-'))

        # Common package variations
        variations = {
            'psycopg2': ['psycopg2-binary'],
            'mysql': ['mysqlclient', 'mysql-connector-python'],
            'redis': ['redis-py'],
        }

        if name in variations:
            variants.extend(variations[name])

        return variants

    def _extract_package_versions_for_venv_name(self, venv_name: str, packages: str) -> List[str]:
        """Extract package versions that match the venv name components.

        Args:
            venv_name: e.g. 'django', 'django:celery', 'flask:redis'
            packages: Package info string with 'pkg: version' entries
        """
        if not packages or packages == "standard packages":
            return []

        # Split venv name by ':' to get components (django:celery -> django, celery)
        name_components = [comp.strip() for comp in venv_name.split(':')]

        found_versions = []
        packages_lower = packages.lower()

        for component in name_components:
            # Get all possible package name variants
            variants = self._normalize_package_name(component)

            for variant in variants:
                # Look for 'variant: version' pattern in packages string
                pattern = rf"\b{re.escape(variant)}:\s*([^,]+)"
                match = re.search(pattern, packages_lower, re.IGNORECASE)
                if match:
                    version = match.group(1).strip()
                    found_versions.append(f"{variant} {version}")
                    break  # Found a match for this component, move to next

        return found_versions

    @property
    def display_name(self) -> str:
        """Generate a user-friendly display name showing Python version and relevant packages."""
        # Extract package versions based on venv name
        package_versions = self._extract_package_versions_for_venv_name(self.name, self.packages)

        if package_versions:
            packages_str = ", ".join(package_versions)
            return f"Python {self.python_version}, {packages_str}"

        return f"Python {self.python_version}"


class TestRunner:
    def __init__(self):
        self.root = ROOT
        self.changed_files: Set[str] = set()
        self.matching_suites: Dict[str, dict] = {}
        self.required_services: Set[str] = set()

    def _run_git_command(self, cmd: List[str]) -> Set[str]:
        """Helper to run a git command and return set of file paths."""
        result = subprocess.run(cmd, capture_output=True, text=True, cwd=self.root)
        return set(result.stdout.strip().split()) if result.stdout.strip() else set()

    def get_git_changed_files(self, base_ref: str = "HEAD") -> Set[str]:
        """Get files changed in git since base_ref.

        For HEAD (default): returns staged + unstaged + untracked files
        For other refs: returns diff against that ref
        """
        try:
            if base_ref == "HEAD":
                # Get all local changes: staged, unstaged, and untracked
                staged = self._run_git_command(["git", "diff", "--cached", "--name-only"])
                unstaged = self._run_git_command(["git", "diff", "--name-only"])
                untracked = self._run_git_command(["git", "ls-files", "--others", "--exclude-standard"])
                return staged | unstaged | untracked
            else:
                # Get diff against specific ref
                return self._run_git_command(["git", "diff", "--name-only", base_ref])
        except subprocess.CalledProcessError:
            print("Warning: Failed to get git changes, using all suites")
            return set()

    def find_matching_suites(self, files: Set[str]) -> Dict[str, dict]:
        """Find test suites that match the given files."""
        suites = get_suites()
        matching = {}

        # Expand directory paths to include glob patterns
        expanded_files = set()
        for file_path in files:
            expanded_files.add(file_path)
            # If it's a directory (ends with / or exists as dir), add wildcard patterns
            path_obj = self.root / file_path.rstrip('/')
            if path_obj.is_dir():
                # Add common directory patterns
                base = file_path.rstrip('/')
                expanded_files.add(f"{base}/*")
                expanded_files.add(f"{base}/**/*")

        for suite_name, suite_config in suites.items():
            try:
                patterns = get_patterns(suite_name)
                if not patterns:
                    continue

                # Check if any changed files match the suite patterns
                matches = []
                for pattern in patterns:
                    matches.extend(fnmatch.filter(expanded_files, pattern))

                if matches:
                    matching[suite_name] = suite_config.copy()
                    matching[suite_name]['matched_files'] = matches

            except Exception as e:
                print(f"Warning: Error processing suite {suite_name}: {e}")

        return matching

    def extract_required_services(self, suites: Dict[str, dict]) -> Set[str]:
        """Extract all required services from selected suites."""
        services = set()
        needs_testagent = False

        for suite_config in suites.values():
            suite_services = suite_config.get('services', [])
            services.update(suite_services)

            # Check if any suite needs testagent (has snapshot: true)
            if suite_config.get('snapshot', False):
                needs_testagent = True

        # Add testagent if any suite needs snapshots
        if needs_testagent:
            services.add('testagent')

        return services

    def get_riot_venvs(self, pattern: str, suite_name: str = "") -> List[RiotVenv]:
        """Get available riot venvs for a pattern by using riotfile.venv.instances()."""
        try:
            venvs = []
            pattern_regex = re.compile(pattern)

            # Use riot's own instances() method to get all venv instances
            for n, inst in enumerate(riotfile.venv.instances()):
                # Check if this instance matches our pattern (same logic as riot)
                if not inst.name or not inst.matches_pattern(pattern_regex):
                    continue

                # Extract package information from the instance
                packages_info = ""
                if hasattr(inst, 'pkgs') and inst.pkgs:
                    # Include all packages - we'll filter in display_name based on venv name
                    all_packages = [f"{pkg}: {version}" for pkg, version in inst.pkgs.items()]
                    packages_info = ", ".join(all_packages) if all_packages else "standard packages"

                # Extract command from the instance
                command = ""
                if hasattr(inst, 'cmd'):
                    command = str(inst.cmd)
                elif hasattr(inst, 'command'):
                    command = str(inst.command)

                venvs.append(RiotVenv(
                    number=n,
                    hash=inst.short_hash if hasattr(inst, 'short_hash') else f"hash{n}",
                    name=inst.name,
                    python_version=str(inst.py._hint) if hasattr(inst, 'py') and hasattr(inst.py, '_hint') else "3.10",
                    packages=packages_info,
                    suite_name=suite_name,
                    command=command
                ))

            return venvs

        except Exception as e:
            print(f"Warning: Failed to get riot venvs for pattern '{pattern}': {e}")
            return []

    def start_services(self, services: Set[str]) -> bool:
        """Start required Docker services."""
        if not services:
            return True

        print(f"\n🐳 Starting required services: {', '.join(sorted(services))}")
        try:
            cmd = ["docker", "compose", "up", "-d"] + list(services)
            subprocess.run(cmd, cwd=self.root, check=True)
            return True
        except subprocess.CalledProcessError as e:
            print(f"❌ Failed to start services: {e}")
            return False

    def stop_services(self, services: Set[str]) -> bool:
        """Stop and remove Docker services."""
        if not services:
            return True

        print(f"\n🛑 Stopping services: {', '.join(sorted(services))}")
        try:
            cmd = ["docker", "compose", "down"] + list(services)
            subprocess.run(cmd, cwd=self.root, check=True)
            return True
        except subprocess.CalledProcessError as e:
            print(f"⚠️  Warning: Failed to stop services: {e}")
            return False  # Don't fail the whole run if cleanup fails

    def _parse_selection(self, selection: str, max_items: int) -> Set[int]:
        """Parse user selection string into set of indices.

        Handles formats like: '1', '1,3,5', '1-5', '2-4,7,9-11'
        Returns 1-based indices.
        """
        indices = set()
        for part in selection.split(','):
            part = part.strip()
            if '-' in part:
                start, end = map(int, part.split('-'))
                indices.update(range(start, min(end + 1, max_items + 1)))
            else:
                idx = int(part)
                if 1 <= idx <= max_items:
                    indices.add(idx)
        return indices

    def _interactive_select(self, items: List, item_type: str, format_fn=None) -> List:
        """Generic interactive selection for any list of items.

        Args:
            items: List of items to select from
            item_type: Description of what's being selected (e.g., "patterns", "venvs")
            format_fn: Optional function to format each item for display
        """
        if not items:
            return []

        # Single item - no need to select
        if len(items) == 1:
            return items

        print(f"\n🎯 Select which {item_type} to run:")
        print("=" * 60)

        for i, item in enumerate(items, 1):
            display = format_fn(item) if format_fn else str(item)
            print(f"{i:2d}. {display}")

        while True:
            try:
                print(f"\nSelect {item_type} (e.g., '1,3' for specific, 'all' for everything, or 'none'):")
                selection = input("> ").strip().lower()

                if selection == 'none':
                    return []
                elif selection == 'all':
                    return items
                elif selection == 'latest' and item_type == "venvs":
                    return [items[-1]]  # Last item is usually latest version

                indices = self._parse_selection(selection, len(items))
                if indices:
                    selected = [items[i-1] for i in sorted(indices)]
                    return selected
                else:
                    print("   No valid items selected")

            except (ValueError, IndexError):
                print(f"❌ Invalid selection. Please use format like '1,3', 'all', or 'none'")

    def select_riot_suites(self, matching_suites: Dict[str, dict]) -> Dict[str, dict]:
        """Let user select which riot suites to run."""
        riot_suites = {name: config for name, config in matching_suites.items()
                       if config.get('runner') == 'riot'}

        if not riot_suites:
            print("❌ No riot suites found in matching suites")
            return {}

        print(f"\n📋 Found {len(riot_suites)} matching riot suite(s):")
        suite_list = list(riot_suites.keys())

        selected = self._interactive_select(suite_list, "suites")
        if selected:
            print(f"   Selected suites: {', '.join(selected)}")
            return {name: riot_suites[name] for name in selected}
        return {}

    def select_venvs_for_suites(self, selected_suites: Dict[str, dict]) -> List[RiotVenv]:
        """Let user select specific venvs for each suite."""
        selected_venvs = []

        for suite_name, suite_config in selected_suites.items():
            pattern = suite_config.get('pattern', suite_name)
            print(f"\n🔍 Getting available venvs for suite '{suite_name}' (pattern: '{pattern}')...")
            venvs = self.get_riot_venvs(pattern, suite_name=suite_name)

            if not venvs:
                print(f"   ⚠️  No venvs found for suite '{suite_name}'")
                continue

            print(f"\n📋 Available venvs for suite '{suite_name}' ({len(venvs)} total):")
            print("=" * 80)

            # Custom format function for venvs
            def format_venv(v):
                return f"#{v.number:3d}  {v.hash}  {v.name}  {v.display_name}"

            selected = self._interactive_select(venvs, "venvs", format_venv)

            if selected:
                selected_venvs.extend(selected)
                print(f"   Selected {len(selected)} venv(s) for suite '{suite_name}':")
                for venv in selected:
                    print(f"     • {venv.name}: {venv.display_name}")

        return selected_venvs

    def interactive_venv_selection(self, matching_suites: Dict[str, dict]) -> List[RiotVenv]:
        """Provide interactive venv selection with granular control."""
        if not matching_suites:
            print("❌ No matching test suites found for the changed files.")
            return []

        # Step 1: Select riot suites
        selected_suites = self.select_riot_suites(matching_suites)
        if not selected_suites:
            return []

        # Step 2: Select specific venvs for each suite
        return self.select_venvs_for_suites(selected_suites)

    def run_tests(self, selected_venvs: List[RiotVenv], matching_suites: Dict[str, dict], riot_args: List[str] = None, dry_run: bool = False) -> bool:
        """Execute the selected venvs, grouped by suite with per-suite service management."""
        if not selected_venvs:
            print("ℹ️  No venvs selected for execution.")
            return True

        # Group venvs by suite
        venvs_by_suite: Dict[str, List[RiotVenv]] = {}
        for venv in selected_venvs:
            suite_name = venv.suite_name
            if suite_name not in venvs_by_suite:
                venvs_by_suite[suite_name] = []
            venvs_by_suite[suite_name].append(venv)

        print(f"\n🧪 Running {len(selected_venvs)} venv(s) across {len(venvs_by_suite)} suite(s):")
        for suite_name, venvs in venvs_by_suite.items():
            print(f"   • Suite '{suite_name}': {len(venvs)} venv(s)")

        # Execute each suite with its own service lifecycle
        for suite_name, venvs in venvs_by_suite.items():
            print(f"\n{'='*80}")
            print(f"🎯 Suite: {suite_name}")
            print(f"{'='*80}")

            # Get suite config
            suite_config = matching_suites.get(suite_name, {})

            # Extract services for this suite only
            suite_services = set(suite_config.get('services', []))
            needs_testagent = suite_config.get('snapshot', False)
            if needs_testagent:
                suite_services.add('testagent')

            # Start services for this suite
            if suite_services:
                if not self.start_services(suite_services):
                    print(f"❌ Failed to start services for suite '{suite_name}'")
                    return False
            else:
                print(f"ℹ️  No services required for suite '{suite_name}'")

            # Set up environment for this suite
            env = os.environ.copy()
            if needs_testagent:
                env['DD_TRACE_AGENT_URL'] = TESTAGENT_URL
                print(f"🔧 Setting DD_TRACE_AGENT_URL={TESTAGENT_URL} for snapshot tests")

            # Execute each venv in this suite
            suite_success = True
            for venv in venvs:
                # Execute using ddtest with the specific venv hash
                cmd = [str(self.root / "scripts" / "ddtest"), "riot", "-v", "run", "--pass-env", venv.hash]

                # Add riot args if provided
                if riot_args:
                    cmd.extend(["--"] + riot_args)

                if dry_run:
                    print(f"[DRY RUN] Would execute: {' '.join(cmd)}")
                    if needs_testagent:
                        print(f"[DRY RUN] With env: DD_TRACE_AGENT_URL={env.get('DD_TRACE_AGENT_URL')}")
                else:
                    print(f"\n▶️  Executing ({venv.display_name}): {' '.join(cmd)}")
                    try:
                        result = subprocess.run(cmd, env=env, cwd=self.root)
                        if result.returncode != 0:
                            print(f"❌ {venv.display_name} failed with exit code {result.returncode}")
                            suite_success = False
                            break  # Stop running venvs for this suite on first failure
                        else:
                            print(f"✅ {venv.display_name} completed successfully")
                    except subprocess.CalledProcessError as e:
                        print(f"❌ Failed to run {venv.display_name}: {e}")
                        suite_success = False
                        break

            # Stop services for this suite
            if suite_services:
                self.stop_services(suite_services)

            # If this suite failed, stop processing further suites
            if not suite_success:
                print(f"\n❌ Suite '{suite_name}' failed. Stopping execution.")
                return False

            print(f"\n✅ Suite '{suite_name}' completed successfully!")

        print("\n🎉 All selected suites completed successfully!")
        return True

    def output_suites_json(self, matching_suites: Dict[str, dict]) -> None:
        """Output matching suites and venvs as JSON for AI agent consumption."""
        suites_data = []

        for suite_name, suite_config in matching_suites.items():
            if suite_config.get('runner') != 'riot':
                continue

            pattern = suite_config.get('pattern', suite_name)
            venvs = self.get_riot_venvs(pattern, suite_name=suite_name)

            venvs_data = []
            for venv in venvs:
                venvs_data.append({
                    "hash": venv.hash,
                    "number": venv.number,
                    "python_version": venv.python_version,
                    "packages": venv.packages,
                    "command": venv.command,
                })

            suites_data.append({
                "name": suite_name,
                "matched_files": suite_config.get('matched_files', []),
                "venvs": venvs_data,
            })

        # Output JSON
        output = {"suites": suites_data}
        print(json.dumps(output, indent=2))

    def select_venvs_by_hash(self, matching_suites: Dict[str, dict], venv_hashes: List[str]) -> List[RiotVenv]:
        """Select specific venvs by their hashes from matching suites."""
        selected_venvs = []
        venv_hashes_set = set(venv_hashes)

        for suite_name, suite_config in matching_suites.items():
            if suite_config.get('runner') != 'riot':
                continue

            pattern = suite_config.get('pattern', suite_name)
            venvs = self.get_riot_venvs(pattern, suite_name=suite_name)

            for venv in venvs:
                if venv.hash in venv_hashes_set:
                    selected_venvs.append(venv)

        return selected_venvs


def main():
    parser = argparse.ArgumentParser(
        description="Run test suites based on changed files",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Run tests for all locally changed files
  scripts/run-tests

  # Run tests for specific files
  scripts/run-tests ddtrace/contrib/flask/patch.py tests/contrib/flask/test_flask.py

  # Run tests for changes since main branch
  scripts/run-tests --git-base=main

  # Show what would be run without executing
  scripts/run-tests --dry-run

  # Pass additional arguments to pytest
  scripts/run-tests ddtrace/contrib/django/patch.py -- -vvv -s --tb=short
        """
    )

    parser.add_argument(
        'files',
        nargs='*',
        help='Specific files to test (if not provided, uses git changes)'
    )

    parser.add_argument(
        '--git-base',
        default='HEAD',
        help='Git ref to compare against for changes (default: HEAD for local changes)'
    )

    parser.add_argument(
        '--dry-run',
        action='store_true',
        help='Show what would be run without executing'
    )

    parser.add_argument(
        '--all-suites',
        action='store_true',
        help='Show all available suites regardless of file changes'
    )

    parser.add_argument(
        '--list',
        action='store_true',
        help='Output JSON with all matching suites and venvs (for AI agents)'
    )

    parser.add_argument(
        '--venv',
        action='append',
        help='Run specific venvs (by hash) without interactive prompts. Can be used multiple times (e.g., --venv hash1 --venv hash2)'
    )

    # Parse args, but handle -- separator for riot args
    if '--' in sys.argv:
        separator_idx = sys.argv.index('--')
        script_args = sys.argv[1:separator_idx]
        riot_args = sys.argv[separator_idx + 1:]
    else:
        script_args = sys.argv[1:]
        riot_args = []

    args = parser.parse_args(script_args)

    runner = TestRunner()

    # Determine which files to check
    if args.files:
        # Use explicitly provided files
        files = set(args.files)
        print(f"📁 Checking explicitly provided files: {', '.join(files)}")
    elif args.all_suites:
        # Use a dummy set that will match all suites
        files = {"*"}
        print("📁 Checking all available test suites")
    else:
        # Use git changes
        files = runner.get_git_changed_files(args.git_base)
        if not files:
            print("ℹ️  No changed files found. Use --all-suites to see all available suites.")
            return 0
        print(f"📁 Found {len(files)} changed file(s): {', '.join(list(files)[:5])}")
        if len(files) > 5:
            print(f"    ... and {len(files) - 5} more")

    # Find matching suites
    if args.all_suites:
        # Get all suites
        matching_suites = get_suites()
        # Add empty matched_files for consistency
        for suite_config in matching_suites.values():
            suite_config['matched_files'] = []
    else:
        matching_suites = runner.find_matching_suites(files)

    # Handle --list flag (output JSON for AI agents)
    if args.list:
        runner.output_suites_json(matching_suites)
        return 0

    # Determine venv selection method
    if args.venv:
        # Use provided venvs (no interactive prompts)
        selected_venvs = runner.select_venvs_by_hash(matching_suites, args.venv)
        if not selected_venvs:
            print(f"❌ No venvs found matching hashes: {', '.join(args.venv)}")
            return 1
        print(f"📌 Selected {len(selected_venvs)} venv(s) from provided hashes")
    else:
        # Interactive venv selection
        selected_venvs = runner.interactive_venv_selection(matching_suites)

    # Execute tests
    success = runner.run_tests(selected_venvs, matching_suites, riot_args=riot_args, dry_run=args.dry_run)

    return 0 if success else 1


if __name__ == "__main__":
    sys.exit(main())