"""
    lager.debug.commands

    Debug an elf file - Updated for direct SSH execution
"""
import itertools
import click
import json
import io
from contextlib import redirect_stdout
from ..context import get_default_gateway, get_impl_path
from ..python.commands import run_python_internal
from ..paramtypes import MemoryAddressType, HexArrayType, BinfileType
from ..dut_storage import get_dut_ip


def validate_speed_param(ctx, param, value):
    """
    Validate speed parameter at CLI level for immediate user feedback.

    Args:
        ctx: Click context
        param: Click parameter
        value: Speed value from user

    Returns:
        Validated speed value

    Raises:
        click.BadParameter: If speed is invalid
    """
    if value is None or value == 'adaptive':
        return value

    try:
        speed_int = int(value)
    except (ValueError, TypeError):
        raise click.BadParameter(
            f"Invalid speed value: '{value}'. "
            f"Speed must be a positive integer (in kHz) or 'adaptive'"
        )

    if speed_int <= 0:
        raise click.BadParameter(
            f"Invalid speed: {speed_int} kHz. "
            f"Speed must be a positive integer greater than 0"
        )

    if speed_int > 50000:  # 50 MHz is unrealistically high for SWD/JTAG
        raise click.BadParameter(
            f"Invalid speed: {speed_int} kHz. "
            f"Maximum supported speed is 50000 kHz (50 MHz). "
            f"Typical speeds: 100-4000 kHz"
        )

    return value

def _get_debug_net(ctx, dut, net_name=None):
    """
    Get debug net information for the DUT.
    If net_name is provided, use that specific net.
    Otherwise, find the first available debug net.
    """
    # Run net.py list to get available nets
    buf = io.StringIO()
    try:
        with redirect_stdout(buf):
            run_python_internal(
                ctx, get_impl_path("net.py"), dut,
                image="", env={}, passenv=(), kill=False, download=(),
                allow_overwrite=False, signum="SIGTERM", timeout=0,
                detach=False, port=(), org=None, args=("list",)
            )
    except SystemExit:
        pass

    try:
        nets = json.loads(buf.getvalue() or "[]")
        debug_nets = [n for n in nets if n.get("role") == "debug"]

        if net_name:
            # Find specific debug net
            target_net = next((n for n in debug_nets if n.get("name") == net_name), None)
            if not target_net:
                click.secho(f"Debug net '{net_name}' not found.", fg='red', err=True)
                ctx.exit(1)
            return target_net
        else:
            # Find first available debug net
            if not debug_nets:
                click.secho("No debug nets found. Create one with: lager nets create <name> debug <device_type> <address>", fg='red', err=True)
                ctx.exit(1)
            return debug_nets[0]

    except json.JSONDecodeError:
        click.secho("Failed to parse nets information.", fg='red', err=True)
        ctx.exit(1)

class NetDebugGroup(click.MultiCommand):
    """Custom multi-command that treats first argument as net name"""

    def list_commands(self, ctx):
        """List all available debug subcommands"""
        return ['connect', 'disconnect', 'flash', 'reset', 'erase', 'memrd', 'info', 'status', 'rtt']

    def get_command(self, ctx, name):
        """Get the command for a given subcommand name"""
        commands = {
            'connect': connect,
            'disconnect': disconnect,
            'flash': flash,
            'reset': reset,
            'erase': erase,
            'memrd': memrd,
            'info': info,
            'status': status,
            'rtt': rtt,
        }
        return commands.get(name)

@click.command(name='debug', cls=NetDebugGroup)
@click.argument('net_name')
@click.pass_context
def _debug(ctx, net_name):
    """
    Debug firmware and manage debug sessions

    Usage: lager debug <NET_NAME> <subcommand> [options]

    Example: lager debug debug1 connect --dut TEST-2
    """
    # Store net_name in context for subcommands to use
    # ctx.obj is the LagerContext object, we'll add net_name as an attribute
    if not hasattr(ctx.obj, 'net_name'):
        ctx.obj.net_name = net_name

@click.command()
@click.pass_context
@click.option('--dut', required=False, help='DUT or Gateway IP')
@click.option('--json', 'json_output', is_flag=True, default=False,
              help='Output results in JSON format')
def status(ctx, dut, json_output):
    """Get debug status"""

    # Get net_name from parent context
    net_name = getattr(ctx.obj, 'net_name', None)

    # Resolve DUT name to IP if needed
    if dut:
        saved_ip = get_dut_ip(dut)
        if saved_ip:
            dut = saved_ip
    else:
        dut = get_default_gateway(ctx)

    debug_net = _get_debug_net(ctx, dut, net_name)

    status_args = {
        'action': 'status',
        'net': debug_net,
        'json': json_output
    }

    run_python_internal(
        ctx, get_impl_path('debug.py'), dut,
        image='', env={}, passenv=(), kill=False, download=(),
        allow_overwrite=False, signum='SIGTERM', timeout=0,
        detach=False, port=(), org=None,
        args=(json.dumps(status_args),)
    )

@click.command()
@click.pass_context
@click.option('--dut', required=False, help='DUT or Gateway IP')
@click.option('--force/--no-force', is_flag=True, default=True,
              help='Disconnect debugger before reconnecting', show_default=True)
@click.option('--halt/--no-halt', is_flag=True, default=False,
              help='Halt the device when connecting', show_default=True)
@click.option('--speed', type=str, default=None, callback=validate_speed_param,
              help='SWD/JTAG speed in kHz (e.g., 100, 4000) or "adaptive"')
@click.option('--quiet', is_flag=True, default=False,
              help='Suppress informational messages')
@click.option('--json', 'json_output', is_flag=True, default=False,
              help='Output results in JSON format')
@click.option('--rtt', is_flag=True, default=False,
              help='Automatically stream RTT logs after connecting')
@click.option('--rtt-reset', is_flag=True, default=False,
              help='Connect, reset device, then stream RTT logs (captures boot sequence)')
@click.option('--reset', is_flag=True, default=False,
              help='Reset the device after connecting')
def connect(ctx, dut, force, halt, speed, quiet, json_output, rtt, rtt_reset, reset):
    """Connect to debugger"""

    # Get net_name from parent context
    net_name = getattr(ctx.obj, 'net_name', None)

    # Resolve DUT name to IP if needed
    if dut:
        saved_ip = get_dut_ip(dut)
        if saved_ip:
            dut = saved_ip
    else:
        dut = get_default_gateway(ctx)

    debug_net = _get_debug_net(ctx, dut, net_name)
    device_type = debug_net.get('pin', 'unknown')

    # The backend (debug.py) already handles force reconnection properly:
    # - When force=True (default): It disconnects and reconnects automatically
    # - When force=False (--no-force): It reuses existing connection
    # No need for additional disconnect logic here

    # Pass the post-connect action flags to the implementation script
    # The implementation will handle these actions after connecting
    connect_args = {
        'action': 'connect',
        'net': debug_net,
        'force': force,
        'halt': halt,
        'device_type': device_type,
        'speed': speed,
        'quiet': quiet,
        'json': json_output,
        # Post-connect action flags
        'post_connect_rtt': rtt,
        'post_connect_reset': reset,
        'post_connect_rtt_reset': rtt_reset
    }

    run_python_internal(
        ctx, get_impl_path('debug.py'), dut,
        image='', env={}, passenv=(), kill=False, download=(),
        allow_overwrite=False, signum='SIGTERM', timeout=0,
        detach=False, port=(), org=None,
        args=(json.dumps(connect_args),)
    )

@click.command()
@click.pass_context
@click.option('--dut', required=False, help='DUT or Gateway IP')
def disconnect(ctx, dut):
    """Disconnect from debugger"""

    # Get net_name from parent context
    net_name = getattr(ctx.obj, 'net_name', None)

    # Resolve DUT name to IP if needed
    if dut:
        saved_ip = get_dut_ip(dut)
        if saved_ip:
            dut = saved_ip
    else:
        dut = get_default_gateway(ctx)

    debug_net = _get_debug_net(ctx, dut, net_name)

    disconnect_args = {
        'action': 'disconnect',
        'net': debug_net
    }

    run_python_internal(
        ctx, get_impl_path('debug.py'), dut,
        image='', env={}, passenv=(), kill=False, download=(),
        allow_overwrite=False, signum='SIGTERM', timeout=0,
        detach=False, port=(), org=None,
        args=(json.dumps(disconnect_args),)
    )

@click.command()
@click.pass_context
@click.option('--dut', required=False, help='DUT or Gateway IP')
@click.option('--hex', type=click.Path(exists=True))
@click.option('--elf', type=click.Path(exists=True))
@click.option('--bin', multiple=True, type=BinfileType(exists=True))
def flash(ctx, dut, hex, elf, bin):
    """Flash firmware to target"""

    # Get net_name from parent context
    net_name = getattr(ctx.obj, 'net_name', None)

    # Resolve DUT name to IP if needed
    if dut:
        saved_ip = get_dut_ip(dut)
        if saved_ip:
            dut = saved_ip
    else:
        dut = get_default_gateway(ctx)

    debug_net = _get_debug_net(ctx, dut, net_name)
    device_type = debug_net.get('pin', 'unknown')

    # Build flash arguments and file list for transfer
    flash_args = {
        'action': 'flash',
        'net': debug_net,
        'device_type': device_type,
        'hexfile': None,
        'elffile': None,
        'binfiles': [],
        'stream_logs': False
    }

    # Transfer files to gateway using SCP
    import subprocess
    import os

    if hex:
        filename = os.path.basename(hex)
        flash_args['hexfile'] = f'/tmp/{filename}'

        # Transfer file via SCP and set permissions in one SSH command
        result = subprocess.run(['scp', '-q', hex, f'lagerdata@{dut}:/tmp/{filename}'],
                              capture_output=True, text=True)
        if result.returncode != 0:
            click.secho(f'Failed to transfer file: {result.stderr}', fg='red')
            ctx.exit(1)

        # Make file readable by container
        subprocess.run(['ssh', '-o', 'LogLevel=ERROR', f'lagerdata@{dut}', f'chmod 644 /tmp/{filename}'],
                      capture_output=True)

    elif elf:
        filename = os.path.basename(elf)
        flash_args['elffile'] = f'/tmp/{filename}'

        # Transfer file via SCP quietly
        result = subprocess.run(['scp', '-q', elf, f'lagerdata@{dut}:/tmp/{filename}'],
                              capture_output=True, text=True)
        if result.returncode != 0:
            click.secho(f'Failed to transfer file: {result.stderr}', fg='red')
            ctx.exit(1)

        # Make file readable by container
        subprocess.run(['ssh', '-o', 'LogLevel=ERROR', f'lagerdata@{dut}', f'chmod 644 /tmp/{filename}'],
                      capture_output=True)

    elif bin:
        for bf in bin:
            filename = os.path.basename(bf.path)
            flash_args['binfiles'].append({
                'filename': f'/tmp/{filename}',
                'address': bf.address
            })

            # Transfer file via SCP quietly
            result = subprocess.run(['scp', '-q', bf.path, f'lagerdata@{dut}:/tmp/{filename}'],
                                  capture_output=True, text=True)
            if result.returncode != 0:
                click.secho(f'Failed to transfer file: {result.stderr}', fg='red')
                ctx.exit(1)

            # Make file readable by container
            subprocess.run(['ssh', '-o', 'LogLevel=ERROR', f'lagerdata@{dut}', f'chmod 644 /tmp/{filename}'],
                          capture_output=True)
    else:
        click.secho('Provide --hex, --elf, or --bin.', fg='red')
        ctx.exit(1)

    run_python_internal(
        ctx, get_impl_path('debug.py'), dut,
        image='', env={}, passenv=(), kill=False, download=(),
        allow_overwrite=False, signum='SIGTERM', timeout=0,
        detach=False, port=(), org=None,
        args=(json.dumps(flash_args),)
    )

@click.command()
@click.pass_context
@click.option('--dut', required=False, help='DUT or Gateway IP')
@click.option('--speed', type=str, default='4000', callback=validate_speed_param,
              help='SWD/JTAG speed in kHz (default: 4000)')
@click.option('--yes', is_flag=True, default=False,
              help='Skip confirmation prompt')
@click.option('--quiet', is_flag=True, default=False,
              help='Suppress warning messages')
@click.option('--json', 'json_output', is_flag=True, default=False,
              help='Output results in JSON format')
def erase(ctx, dut, speed, yes, quiet, json_output):
    """Erase all flash memory on target"""

    # Get net_name from parent context
    net_name = getattr(ctx.obj, 'net_name', None)

    # Resolve DUT name to IP if needed
    if dut:
        saved_ip = get_dut_ip(dut)
        if saved_ip:
            dut = saved_ip
    else:
        dut = get_default_gateway(ctx)

    debug_net = _get_debug_net(ctx, dut, net_name)
    device_type = debug_net.get('pin', 'unknown')

    # Confirm the erase operation (skip if quiet or json mode)
    if not yes and not quiet and not json_output:
        click.echo(f"WARNING: This will erase ALL flash memory on {device_type}")
        click.echo("This operation cannot be undone!")
        if not click.confirm("Do you want to continue?"):
            click.echo("Chip erase cancelled.")
            ctx.exit(0)

    erase_args = {
        'action': 'erase',
        'net': debug_net,
        'device_type': device_type,
        'speed': speed,
        'transport': 'SWD',
        'quiet': quiet,
        'json': json_output
    }

    run_python_internal(
        ctx, get_impl_path('debug.py'), dut,
        image='', env={}, passenv=(), kill=False, download=(),
        allow_overwrite=False, signum='SIGTERM', timeout=0,
        detach=False, port=(), org=None,
        args=(json.dumps(erase_args),)
    )

@click.command()
@click.pass_context
@click.option('--dut', required=False, help='DUT or Gateway IP')
@click.option('--halt/--no-halt', is_flag=True, default=False,
              help='Halt the DUT after reset', show_default=True)
def reset(ctx, dut, halt):
    """Reset target"""

    # Get net_name from parent context
    net_name = getattr(ctx.obj, 'net_name', None)

    # Resolve DUT name to IP if needed
    if dut:
        saved_ip = get_dut_ip(dut)
        if saved_ip:
            dut = saved_ip
    else:
        dut = get_default_gateway(ctx)

    debug_net = _get_debug_net(ctx, dut, net_name)
    device_type = debug_net.get('pin', 'unknown')

    reset_args = {
        'action': 'reset',
        'net': debug_net,
        'device_type': device_type,
        'halt': halt
    }

    run_python_internal(
        ctx, get_impl_path('debug.py'), dut,
        image='', env={}, passenv=(), kill=False, download=(),
        allow_overwrite=False, signum='SIGTERM', timeout=0,
        detach=False, port=(), org=None,
        args=(json.dumps(reset_args),)
    )

@click.command()
@click.pass_context
@click.argument('start_addr', type=MemoryAddressType())
@click.argument('length', type=MemoryAddressType())
@click.option('--dut', required=False, help='DUT or Gateway IP')
@click.option('--json', 'json_output', is_flag=True, default=False,
              help='Output results in JSON format')
def memrd(ctx, start_addr, length, dut, json_output):
    """Read memory from target"""

    # Get net_name from parent context
    net_name = getattr(ctx.obj, 'net_name', None)

    # Resolve DUT name to IP if needed
    if dut:
        saved_ip = get_dut_ip(dut)
        if saved_ip:
            dut = saved_ip
    else:
        dut = get_default_gateway(ctx)

    debug_net = _get_debug_net(ctx, dut, net_name)
    device_type = debug_net.get('pin', 'unknown')

    memrd_args = {
        'action': 'memrd',
        'net': debug_net,
        'device_type': device_type,
        'start_addr': start_addr,
        'length': length,
        'json': json_output
    }

    run_python_internal(
        ctx, get_impl_path('debug.py'), dut,
        image='', env={}, passenv=(), kill=False, download=(),
        allow_overwrite=False, signum='SIGTERM', timeout=0,
        detach=False, port=(), org=None,
        args=(json.dumps(memrd_args),)
    )

# Note: gdbserver command removed as it relies on WebSocket tunneling
# For direct debugging, users should use gdb directly with the J-Link GDB server
# running on the gateway, which can be accessed via SSH port forwarding if needed

@click.command()
@click.pass_context
@click.option('--dut', required=False, help='DUT or Gateway IP')
def info(ctx, dut):
    """Show debug net information"""

    # Get net_name from parent context
    net_name = getattr(ctx.obj, 'net_name', None)

    # Resolve DUT name to IP if needed
    if dut:
        saved_ip = get_dut_ip(dut)
        if saved_ip:
            dut = saved_ip
    else:
        dut = get_default_gateway(ctx)

    debug_net = _get_debug_net(ctx, dut, net_name)

    click.echo(f"Debug Net Information:")
    click.echo(f"  Name: {debug_net.get('name')}")
    click.echo(f"  Device Type: {debug_net.get('pin', 'unknown')}")
    click.echo(f"  Instrument: {debug_net.get('instrument')}")
    click.echo(f"  Address: {debug_net.get('address')}")
    click.echo()

    # Get additional device info
    info_args = {
        'action': 'info',
        'net': debug_net
    }

    run_python_internal(
        ctx, get_impl_path('debug.py'), dut,
        image='', env={}, passenv=(), kill=False, download=(),
        allow_overwrite=False, signum='SIGTERM', timeout=0,
        detach=False, port=(), org=None,
        args=(json.dumps(info_args),)
    )

@click.command()
@click.pass_context
@click.option('--dut', required=False, help='DUT or Gateway IP')
@click.option('--timeout', type=int, default=None, help='Read timeout in seconds (default: continuous until Ctrl+C)')
@click.option('--channel', type=int, default=0, help='RTT channel (0 or 1)', show_default=True)
@click.option('-e', '--elf', type=click.Path(exists=True), help='ELF file for defmt-print decoding')
def rtt(ctx, dut, timeout, channel, elf):
    """Stream RTT logs from target"""
    import sys
    import subprocess
    import shutil

    # Get net_name from parent context
    net_name = getattr(ctx.obj, 'net_name', None)

    # Resolve DUT name to IP if needed
    if dut:
        saved_ip = get_dut_ip(dut)
        if saved_ip:
            dut = saved_ip
    else:
        dut = get_default_gateway(ctx)

    debug_net = _get_debug_net(ctx, dut, net_name)

    # Build RTT arguments
    rtt_args = {
        'action': 'rtt',
        'net': debug_net,
        'timeout': timeout,
        'channel': channel
    }

    # If -e flag is provided, pipe to defmt-print
    defmt_process = None
    original_stdout = None
    if elf:
        # Check if defmt-print is available
        if not shutil.which('defmt-print'):
            click.secho('Error: defmt-print not found. Install it with: cargo install defmt-print', fg='red', err=True)
            ctx.exit(1)

        # Start defmt-print as a subprocess
        defmt_cmd = ['defmt-print', '-e', elf]
        defmt_process = subprocess.Popen(
            defmt_cmd,
            stdin=subprocess.PIPE,
            stdout=sys.stdout,
            stderr=sys.stderr
        )

        # Redirect stdout to defmt-print's stdin
        original_stdout = sys.stdout
        sys.stdout = defmt_process.stdin

    try:
        # Execute implementation script (same pattern as other debug commands)
        run_python_internal(
            ctx, get_impl_path('debug.py'), dut,
            image='', env={}, passenv=(), kill=False, download=(),
            allow_overwrite=False, signum='SIGTERM', timeout=0,
            detach=False, port=(), org=None,
            args=(json.dumps(rtt_args),)
        )
    finally:
        # Clean up defmt-print process if we started it
        if defmt_process:
            sys.stdout = original_stdout
            try:
                defmt_process.stdin.close()
                defmt_process.wait(timeout=2)
            except:
                defmt_process.kill()
                defmt_process.wait()
