#!/usr/bin/env uv run --script
# -*- mode: python -*-
# /// script
# requires-python = ">=3.8"
# dependencies = [
#     "riot>=0.20.1",
#     "ruamel.yaml>=0.17.21",
# ]
# ///
"""Test runner script for dd-trace-py.

This script helps developers run the appropriate test suites based on changed files.
It maps source files to their corresponding test suites and provides granular control
over which specific riot venvs to run.

Note: This runs entire test suites, not individual test files.
"""

import argparse
import fnmatch
import os
import re
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Set, NamedTuple

# Add project root and tests to Python path to import suitespec and riotfile
ROOT = Path(__file__).parents[1]
sys.path.insert(0, str(ROOT))
sys.path.insert(0, str(ROOT / "tests"))

from tests.suitespec import get_patterns, get_suites
import riotfile

# Constants
TESTAGENT_URL = 'http://localhost:9126'


class RiotVenv(NamedTuple):
    """Represents a riot venv with its metadata."""
    number: int
    hash: str
    name: str
    python_version: str
    packages: str

    def _normalize_package_name(self, name: str) -> List[str]:
        """Generate possible package name variations from venv name component.

        Examples:
            'django' -> ['django']
            'django_hosts' -> ['django-hosts', 'django_hosts']
            'psycopg2' -> ['psycopg2', 'psycopg2-binary']
        """
        variants = [name]

        # Try underscore to dash conversion (common in PyPI)
        if '_' in name:
            variants.append(name.replace('_', '-'))

        # Common package variations
        variations = {
            'psycopg2': ['psycopg2-binary'],
            'mysql': ['mysqlclient', 'mysql-connector-python'],
            'redis': ['redis-py'],
        }

        if name in variations:
            variants.extend(variations[name])

        return variants

    def _extract_package_versions_for_venv_name(self, venv_name: str, packages: str) -> List[str]:
        """Extract package versions that match the venv name components.

        Args:
            venv_name: e.g. 'django', 'django:celery', 'flask:redis'
            packages: Package info string with 'pkg: version' entries
        """
        if not packages or packages == "standard packages":
            return []

        # Split venv name by ':' to get components (django:celery -> django, celery)
        name_components = [comp.strip() for comp in venv_name.split(':')]

        found_versions = []
        packages_lower = packages.lower()

        for component in name_components:
            # Get all possible package name variants
            variants = self._normalize_package_name(component)

            for variant in variants:
                # Look for 'variant: version' pattern in packages string
                pattern = rf"\b{re.escape(variant)}:\s*([^,]+)"
                match = re.search(pattern, packages_lower, re.IGNORECASE)
                if match:
                    version = match.group(1).strip()
                    found_versions.append(f"{variant} {version}")
                    break  # Found a match for this component, move to next

        return found_versions

    @property
    def display_name(self) -> str:
        """Generate a user-friendly display name showing Python version and relevant packages."""
        # Extract package versions based on venv name
        package_versions = self._extract_package_versions_for_venv_name(self.name, self.packages)

        if package_versions:
            packages_str = ", ".join(package_versions)
            return f"Python {self.python_version}, {packages_str}"

        return f"Python {self.python_version}"


class TestRunner:
    def __init__(self):
        self.root = ROOT
        self.changed_files: Set[str] = set()
        self.matching_suites: Dict[str, dict] = {}
        self.required_services: Set[str] = set()

    def _run_git_command(self, cmd: List[str]) -> Set[str]:
        """Helper to run a git command and return set of file paths."""
        result = subprocess.run(cmd, capture_output=True, text=True, cwd=self.root)
        return set(result.stdout.strip().split()) if result.stdout.strip() else set()

    def get_git_changed_files(self, base_ref: str = "HEAD") -> Set[str]:
        """Get files changed in git since base_ref.

        For HEAD (default): returns staged + unstaged + untracked files
        For other refs: returns diff against that ref
        """
        try:
            if base_ref == "HEAD":
                # Get all local changes: staged, unstaged, and untracked
                staged = self._run_git_command(["git", "diff", "--cached", "--name-only"])
                unstaged = self._run_git_command(["git", "diff", "--name-only"])
                untracked = self._run_git_command(["git", "ls-files", "--others", "--exclude-standard"])
                return staged | unstaged | untracked
            else:
                # Get diff against specific ref
                return self._run_git_command(["git", "diff", "--name-only", base_ref])
        except subprocess.CalledProcessError:
            print("Warning: Failed to get git changes, using all suites")
            return set()

    def find_matching_suites(self, files: Set[str]) -> Dict[str, dict]:
        """Find test suites that match the given files."""
        suites = get_suites()
        matching = {}

        for suite_name, suite_config in suites.items():
            try:
                patterns = get_patterns(suite_name)
                if not patterns:
                    continue

                # Check if any changed files match the suite patterns
                matches = []
                for pattern in patterns:
                    matches.extend(fnmatch.filter(files, pattern))

                if matches:
                    matching[suite_name] = suite_config.copy()
                    matching[suite_name]['matched_files'] = matches

            except Exception as e:
                print(f"Warning: Error processing suite {suite_name}: {e}")

        return matching

    def extract_required_services(self, suites: Dict[str, dict]) -> Set[str]:
        """Extract all required services from selected suites."""
        services = set()
        needs_testagent = False

        for suite_config in suites.values():
            suite_services = suite_config.get('services', [])
            services.update(suite_services)

            # Check if any suite needs testagent (has snapshot: true)
            if suite_config.get('snapshot', False):
                needs_testagent = True

        # Add testagent if any suite needs snapshots
        if needs_testagent:
            services.add('testagent')

        return services

    def get_riot_patterns(self, suites: Dict[str, dict]) -> List[str]:
        """Get riot patterns for the selected suites."""
        patterns = []
        seen = set()

        for suite_name, suite_config in suites.items():
            if suite_config.get('runner') == 'riot':
                pattern = suite_config.get('pattern', suite_name)
                # Convert regex patterns to actual riot pattern names
                riot_names = self._parse_riot_pattern(pattern)
                for name in riot_names:
                    if name not in seen:
                        patterns.append(name)
                        seen.add(name)

        # Sort patterns for consistent ordering
        return sorted(patterns)

    def _parse_riot_pattern(self, pattern: str) -> List[str]:
        """Parse a suitespec pattern to extract riot venv names.

        Examples:
            "^django$" -> ["django"]
            "^(django|django:celery)$" -> ["django", "django:celery"]
            "flask" -> ["flask"]
        """
        # Remove regex anchors if present
        if pattern.startswith('^') and pattern.endswith('$'):
            pattern = pattern[1:-1]

        # Extract options from (option1|option2) patterns
        if pattern.startswith('(') and pattern.endswith(')'):
            return [opt.strip() for opt in pattern[1:-1].split('|')]

        return [pattern]

    def get_riot_venvs(self, pattern: str) -> List[RiotVenv]:
        """Get available riot venvs for a pattern by using riotfile.venv.instances()."""
        try:
            venvs = []
            pattern_regex = re.compile(pattern)

            # Use riot's own instances() method to get all venv instances
            for n, inst in enumerate(riotfile.venv.instances()):
                # Check if this instance matches our pattern (same logic as riot)
                if not inst.name or not inst.matches_pattern(pattern_regex):
                    continue

                # Extract package information from the instance
                packages_info = ""
                if hasattr(inst, 'pkgs') and inst.pkgs:
                    # Include all packages - we'll filter in display_name based on venv name
                    all_packages = [f"{pkg}: {version}" for pkg, version in inst.pkgs.items()]
                    packages_info = ", ".join(all_packages) if all_packages else "standard packages"

                venvs.append(RiotVenv(
                    number=n,
                    hash=inst.short_hash if hasattr(inst, 'short_hash') else f"hash{n}",
                    name=inst.name,
                    python_version=str(inst.py._hint) if hasattr(inst, 'py') and hasattr(inst.py, '_hint') else "3.10",
                    packages=packages_info
                ))

            return venvs

        except Exception as e:
            print(f"Warning: Failed to get riot venvs for pattern '{pattern}': {e}")
            return []

    def start_services(self, services: Set[str]) -> bool:
        """Start required Docker services."""
        if not services:
            return True

        print(f"\n🐳 Starting required services: {', '.join(sorted(services))}")
        try:
            cmd = ["docker", "compose", "up", "-d"] + list(services)
            subprocess.run(cmd, cwd=self.root, check=True)
            return True
        except subprocess.CalledProcessError as e:
            print(f"❌ Failed to start services: {e}")
            return False

    def _parse_selection(self, selection: str, max_items: int) -> Set[int]:
        """Parse user selection string into set of indices.

        Handles formats like: '1', '1,3,5', '1-5', '2-4,7,9-11'
        Returns 1-based indices.
        """
        indices = set()
        for part in selection.split(','):
            part = part.strip()
            if '-' in part:
                start, end = map(int, part.split('-'))
                indices.update(range(start, min(end + 1, max_items + 1)))
            else:
                idx = int(part)
                if 1 <= idx <= max_items:
                    indices.add(idx)
        return indices

    def _interactive_select(self, items: List, item_type: str, format_fn=None) -> List:
        """Generic interactive selection for any list of items.

        Args:
            items: List of items to select from
            item_type: Description of what's being selected (e.g., "patterns", "venvs")
            format_fn: Optional function to format each item for display
        """
        if not items:
            return []

        # Single item - no need to select
        if len(items) == 1:
            return items

        print(f"\n🎯 Select which {item_type} to run:")
        print("=" * 60)

        for i, item in enumerate(items, 1):
            display = format_fn(item) if format_fn else str(item)
            print(f"{i:2d}. {display}")

        while True:
            try:
                print(f"\nSelect {item_type} (e.g., '1,3' for specific, 'all' for everything, or 'none'):")
                selection = input("> ").strip().lower()

                if selection == 'none':
                    return []
                elif selection == 'all':
                    return items
                elif selection == 'latest' and item_type == "venvs":
                    return [items[-1]]  # Last item is usually latest version

                indices = self._parse_selection(selection, len(items))
                if indices:
                    selected = [items[i-1] for i in sorted(indices)]
                    return selected
                else:
                    print("   No valid items selected")

            except (ValueError, IndexError):
                print(f"❌ Invalid selection. Please use format like '1,3', 'all', or 'none'")

    def select_riot_patterns(self, matching_suites: Dict[str, dict]) -> List[str]:
        """Let user select which riot patterns to run."""
        riot_patterns = self.get_riot_patterns(matching_suites)
        if not riot_patterns:
            print("❌ No riot patterns found in matching suites")
            return []

        print(f"\n📋 Found {len(riot_patterns)} matching riot pattern(s):")
        print("   " + ", ".join(riot_patterns))

        selected = self._interactive_select(riot_patterns, "patterns")
        if selected:
            print(f"   Selected patterns: {', '.join(selected)}")
        return selected

    def select_venvs_for_patterns(self, selected_patterns: List[str]) -> List[RiotVenv]:
        """Let user select specific venvs for each pattern."""
        selected_venvs = []

        for pattern in selected_patterns:
            print(f"\n🔍 Getting available venvs for pattern '{pattern}'...")
            venvs = self.get_riot_venvs(pattern)

            if not venvs:
                print(f"   ⚠️  No venvs found for pattern '{pattern}'")
                continue

            print(f"\n📋 Available venvs for '{pattern}' ({len(venvs)} total):")
            print("=" * 80)

            # Custom format function for venvs
            def format_venv(v):
                return f"#{v.number:3d}  {v.hash}  {v.name}  {v.display_name}"

            selected = self._interactive_select(venvs, "venvs", format_venv)

            if selected:
                selected_venvs.extend(selected)
                print(f"   Selected {len(selected)} venv(s) for '{pattern}':")
                for venv in selected:
                    print(f"     • {venv.name}: {venv.display_name}")

        return selected_venvs

    def interactive_venv_selection(self, matching_suites: Dict[str, dict]) -> List[RiotVenv]:
        """Provide interactive venv selection with granular control."""
        if not matching_suites:
            print("❌ No matching test suites found for the changed files.")
            return []

        # Step 1: Select riot patterns
        selected_patterns = self.select_riot_patterns(matching_suites)
        if not selected_patterns:
            return []

        # Step 2: Select specific venvs for each pattern
        return self.select_venvs_for_patterns(selected_patterns)

    def run_tests(self, selected_venvs: List[RiotVenv], matching_suites: Dict[str, dict], riot_args: List[str] = None, dry_run: bool = False) -> bool:
        """Execute the selected venvs."""
        if not selected_venvs:
            print("ℹ️  No venvs selected for execution.")
            return True

        # Extract services and start them
        services = self.extract_required_services(matching_suites)
        if services and not self.start_services(services):
            return False

        print(f"\n🧪 Running {len(selected_venvs)} selected venv(s):")
        for venv in selected_venvs:
            print(f"   • {venv.name} - {venv.display_name}")

        # Check if any selected suite needs testagent (has snapshot: true)
        needs_testagent = any(
            suite_config.get('snapshot', False)
            for suite_config in matching_suites.values()
        )

        # Set up base environment
        env = os.environ.copy()
        if needs_testagent:
            env['DD_TRACE_AGENT_URL'] = TESTAGENT_URL
            print(f"🔧 Setting DD_TRACE_AGENT_URL={TESTAGENT_URL} for snapshot tests")

        # Execute each venv
        for venv in selected_venvs:

            # Execute using ddtest with the specific venv hash
            cmd = [str(self.root / "scripts" / "ddtest"), "riot", "-v", "run", "--pass-env", venv.hash]

            # Add riot args if provided
            if riot_args:
                cmd.extend(["--"] + riot_args)

            if dry_run:
                print(f"[DRY RUN] Would execute: {' '.join(cmd)}")
                if needs_testagent:
                    print(f"[DRY RUN] With env: DD_TRACE_AGENT_URL={env.get('DD_TRACE_AGENT_URL')}")
            else:
                print(f"\n▶️  Executing ({venv.display_name}): {' '.join(cmd)}")
                try:
                    result = subprocess.run(cmd, env=env, cwd=self.root)
                    if result.returncode != 0:
                        print(f"❌ {venv.display_name} failed with exit code {result.returncode}")
                        return False
                    else:
                        print(f"✅ {venv.display_name} completed successfully")
                except subprocess.CalledProcessError as e:
                    print(f"❌ Failed to run {venv.display_name}: {e}")
                    return False

        print("\n🎉 All selected venvs completed successfully!")
        return True


def main():
    parser = argparse.ArgumentParser(
        description="Run test suites based on changed files",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Run tests for all locally changed files
  scripts/run-tests

  # Run tests for specific files
  scripts/run-tests ddtrace/contrib/flask/patch.py tests/contrib/flask/test_flask.py

  # Run tests for changes since main branch
  scripts/run-tests --git-base=main

  # Show what would be run without executing
  scripts/run-tests --dry-run

  # Pass additional arguments to pytest
  scripts/run-tests ddtrace/contrib/django/patch.py -- -vvv -s --tb=short
        """
    )

    parser.add_argument(
        'files',
        nargs='*',
        help='Specific files to test (if not provided, uses git changes)'
    )

    parser.add_argument(
        '--git-base',
        default='HEAD',
        help='Git ref to compare against for changes (default: HEAD for local changes)'
    )

    parser.add_argument(
        '--dry-run',
        action='store_true',
        help='Show what would be run without executing'
    )

    parser.add_argument(
        '--all-suites',
        action='store_true',
        help='Show all available suites regardless of file changes'
    )

    # Parse args, but handle -- separator for riot args
    if '--' in sys.argv:
        separator_idx = sys.argv.index('--')
        script_args = sys.argv[1:separator_idx]
        riot_args = sys.argv[separator_idx + 1:]
    else:
        script_args = sys.argv[1:]
        riot_args = []

    args = parser.parse_args(script_args)

    runner = TestRunner()

    # Determine which files to check
    if args.files:
        # Use explicitly provided files
        files = set(args.files)
        print(f"📁 Checking explicitly provided files: {', '.join(files)}")
    elif args.all_suites:
        # Use a dummy set that will match all suites
        files = {"*"}
        print("📁 Checking all available test suites")
    else:
        # Use git changes
        files = runner.get_git_changed_files(args.git_base)
        if not files:
            print("ℹ️  No changed files found. Use --all-suites to see all available suites.")
            return 0
        print(f"📁 Found {len(files)} changed file(s): {', '.join(list(files)[:5])}")
        if len(files) > 5:
            print(f"    ... and {len(files) - 5} more")

    # Find matching suites
    if args.all_suites:
        # Get all suites
        matching_suites = get_suites()
        # Add empty matched_files for consistency
        for suite_config in matching_suites.values():
            suite_config['matched_files'] = []
    else:
        matching_suites = runner.find_matching_suites(files)

    # Interactive venv selection
    selected_venvs = runner.interactive_venv_selection(matching_suites)

    # Execute tests
    success = runner.run_tests(selected_venvs, matching_suites, riot_args=riot_args, dry_run=args.dry_run)

    return 0 if success else 1


if __name__ == "__main__":
    sys.exit(main())