"""WebSocket tunnel client with auto-reconnection."""

import asyncio
import websockets
import aiohttp
import json
import logging
import base64
from typing import Dict, Any, Optional
from asyncio import Future
from websockets.asyncio.client import ClientConnection

logger = logging.getLogger(__name__)


class TCPTunnel:
    """Represents a TCP tunnel through WebSocket."""
    
    def __init__(self, tunnel_id: str, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, tunnel_client: 'TunnelClient'):
        self.tunnel_id = tunnel_id
        self.reader = reader
        self.writer = writer
        self.tunnel_client = tunnel_client
        self.running = True
    
    async def forward_to_websocket(self):
        """Read from TCP and forward to WebSocket."""
        try:
            while self.running:
                data = await self.reader.read(4096)
                if not data:
                    logger.debug(f"TCP connection closed for tunnel {self.tunnel_id}")
                    break
                
                # Send to WebSocket
                await self.tunnel_client.send_response('tunnel_data', {
                    'tunnel_id': self.tunnel_id,
                    'data': base64.b64encode(data).decode('ascii')
                })
        except asyncio.CancelledError:
            logger.debug(f"Tunnel {self.tunnel_id} forwarding cancelled")
        except Exception as e:
            logger.error(f"Tunnel {self.tunnel_id} error: {e}")
        finally:
            await self.close()
    
    async def close(self):
        """Close TCP connection."""
        self.running = False
        try:
            self.writer.close()
            await self.writer.wait_closed()
        except Exception as e:
            logger.error(f"Error closing tunnel {self.tunnel_id}: {e}")
        
        # Notify WebSocket that tunnel is closed
        try:
            await self.tunnel_client.send_response('tunnel_closed', {
                'tunnel_id': self.tunnel_id
            })
        except Exception as e:
            logger.debug(f"Could not send tunnel_closed for {self.tunnel_id}: {e}")


class TunnelClient:
    """WebSocket tunnel client that forwards requests to localhost."""
    
    def __init__(self, token: str, service_manager=None):
        self.token = token
        self.service_manager = service_manager
        self.ws: Optional[ClientConnection] = None
        self.pending_requests: Dict[str, Future[Any]] = {}  # request_id -> Future
        self.tcp_tunnels: Dict[str, TCPTunnel] = {}  # tunnel_id -> TCPTunnel
        self.retry_count = 0
        self.max_retry_delay = 30
        self.running = True
        
    async def connect(self):
        """Connect and maintain WebSocket tunnel with auto-reconnect."""
        uri = f"wss://api.mcpbundles.com/v1/tunnel/connect?token={self.token}"
        
        while self.running:
            try:
                logger.info("Connecting to tunnel...")
                async with websockets.connect(
                    uri,
                    ping_interval=20,
                    ping_timeout=10
                ) as ws:
                    self.ws = ws
                    self.retry_count = 0  # Reset on successful connect
                    logger.info("✅ Tunnel connected")
                    
                    await self.listen()
                    
            except websockets.exceptions.WebSocketException as e:
                logger.error(f"WebSocket error: {e}")
                await self._handle_disconnect()
            except Exception as e:
                logger.error(f"Connection error: {e}")
                await self._handle_disconnect()
                
    async def _handle_disconnect(self):
        """Clean up and prepare for reconnect."""
        # Fail all pending requests
        for request_id, future in list(self.pending_requests.items()):
            if not future.done():
                future.set_exception(Exception("Tunnel disconnected"))
        self.pending_requests.clear()
        
        if not self.running:
            return
        
        # Exponential backoff: 1s, 2s, 4s, 8s, 16s, max 30s
        delay = min(2 ** self.retry_count, self.max_retry_delay)
        logger.info(f"Reconnecting in {delay}s...")
        await asyncio.sleep(delay)
        self.retry_count += 1
                
    async def listen(self):
        """Listen for incoming tunnel messages."""
        async for message in self.ws:
            try:
                data = json.loads(message)
                msg_type = data.get('type')
                
                if msg_type == 'http_request':
                    # Forward HTTP request to localhost
                    asyncio.create_task(self.handle_request(data))
                
                elif msg_type == 'service_discovery':
                    # Discover available local services
                    asyncio.create_task(self.handle_service_discovery(data))
                
                elif msg_type == 'service_verify':
                    # Verify a specific service
                    asyncio.create_task(self.handle_service_verify(data))
                
                elif msg_type == 'service_start':
                    # Start a local service
                    asyncio.create_task(self.handle_service_start(data))
                
                elif msg_type == 'service_stop':
                    # Stop a local service
                    asyncio.create_task(self.handle_service_stop(data))
                
                elif msg_type == 'service_update':
                    # Update service configuration
                    asyncio.create_task(self.handle_service_update(data))
                
                elif msg_type == 'config_update':
                    # Update service configuration (legacy)
                    asyncio.create_task(self.handle_config_update(data))
                
                elif msg_type == 'status_request':
                    # Send status report
                    asyncio.create_task(self.send_status())
                
                elif msg_type == 'tunnel_open':
                    # Open TCP tunnel
                    asyncio.create_task(self.handle_tunnel_open(data))
                
                elif msg_type == 'tunnel_data':
                    # Forward data to TCP tunnel
                    asyncio.create_task(self.handle_tunnel_data(data))
                
                elif msg_type == 'tunnel_close':
                    # Close TCP tunnel
                    asyncio.create_task(self.handle_tunnel_close(data))
                
                else:
                    logger.warning(f"Unknown message type: {msg_type}")
                    
            except json.JSONDecodeError as e:
                logger.error(f"Invalid JSON received: {e}")
            except Exception as e:
                logger.error(f"Error processing message: {e}")
                
    async def handle_request(self, request):
        """Forward HTTP request to localhost service."""
        url = f"http://{request['target']}{request['path']}"
        request_id = request['request_id']
        
        # Validate localhost only
        target = request['target']
        if not (target.startswith('localhost:') or target.startswith('127.0.0.1:')):
            await self.send_error(request_id, 'Only localhost targets allowed')
            return
        
        try:
            logger.debug(f"Forwarding {request['method']} {url}")
            
            async with aiohttp.ClientSession() as session:
                async with session.request(
                    method=request['method'],
                    url=url,
                    headers=request.get('headers', {}),
                    data=request.get('body'),
                    timeout=aiohttp.ClientTimeout(total=30)
                ) as response:
                    # Check response size (10MB limit)
                    body = await response.text()
                    if len(body) > 10 * 1024 * 1024:
                        await self.send_error(
                            request_id,
                            'Response exceeds 10MB limit'
                        )
                        return
                    
                    # Send response back through tunnel
                    await self.ws.send(json.dumps({
                        'type': 'http_response',
                        'request_id': request_id,
                        'status': response.status,
                        'headers': dict(response.headers),
                        'body': body
                    }))
                    logger.debug(f"Response sent for {request_id}")
                    
        except asyncio.TimeoutError:
            logger.warning(f"Request timeout for {request_id}")
            await self.send_error(request_id, 'Request timeout (30s)')
        except aiohttp.ClientError as e:
            logger.error(f"HTTP error for {request_id}: {e}")
            await self.send_error(request_id, f'Connection failed: {str(e)}')
        except Exception as e:
            logger.error(f"Unexpected error for {request_id}: {e}")
            await self.send_error(request_id, f'Internal error: {str(e)}')
    
    async def send_error(self, request_id: str, error_message: str):
        """Send error response back through tunnel."""
        try:
            if self.ws:
                await self.ws.send(json.dumps({
                    'type': 'http_response',
                    'request_id': request_id,
                    'status': 500,
                    'headers': {},
                    'body': json.dumps({'error': error_message})
                }))
        except Exception as e:
            logger.error(f"Failed to send error response: {e}")
    
    async def handle_service_discovery(self, data: dict):
        """Handle service discovery request from backend."""
        try:
            logger.info("Received service discovery request")
            ports = data.get('ports', [])
            
            if not self.service_manager:
                logger.warning("No service manager configured")
                await self.send_response('discovery_result', {
                    'error': 'Service manager not available'
                })
                return
            
            # Discover services
            results = await self.service_manager.discover_services(ports)
            
            # Send results back
            await self.send_response('discovery_result', {
                'ports': results
            })
            logger.info(f"Discovery complete: {len(results)} ports scanned")
            
        except Exception as e:
            logger.error(f"Error handling service discovery: {e}")
            await self.send_response('discovery_result', {
                'error': str(e)
            })
    
    async def handle_service_verify(self, data: dict):
        """Handle service verification request from backend."""
        try:
            target = data.get('target')
            logger.info(f"Received service verify request for {target}")
            
            if not self.service_manager:
                logger.warning("No service manager configured")
                await self.send_response('verify_result', {
                    'verified': False,
                    'error': 'Service manager not available'
                })
                return
            
            # Verify service
            result = await self.service_manager.verify_service(target)
            
            # Send result back
            await self.send_response('verify_result', result)
            logger.info(f"Verification complete for {target}: {result.get('verified')}")
            
        except Exception as e:
            logger.error(f"Error handling service verification: {e}")
            await self.send_response('verify_result', {
                'verified': False,
                'error': str(e)
            })
    
    async def handle_service_start(self, data: dict):
        """Handle service start request from backend."""
        try:
            service = data.get('service')
            config = data.get('config', {})
            logger.info(f"Received service start request for {service}")
            
            if not self.service_manager:
                logger.warning("No service manager configured")
                await self.send_response('service_started', {
                    'started': False,
                    'error': 'Service manager not available'
                })
                return
            
            # Start service
            result = await self.service_manager.start_service(service, config)
            
            # Send result back
            await self.send_response('service_started', result)
            logger.info(f"Service start complete for {service}: {result.get('started')}")
            
        except Exception as e:
            logger.error(f"Error handling service start: {e}")
            await self.send_response('service_started', {
                'started': False,
                'error': str(e)
            })
    
    async def handle_service_stop(self, data: dict):
        """Handle service stop request from backend."""
        try:
            service = data.get('service')
            logger.info(f"Received service stop request for {service}")
            
            if not self.service_manager:
                logger.warning("No service manager configured")
                await self.send_response('service_stopped', {
                    'stopped': False,
                    'error': 'Service manager not available'
                })
                return
            
            # Stop service
            result = await self.service_manager.stop_service(service)
            
            # Send result back
            await self.send_response('service_stopped', result)
            logger.info(f"Service stop complete for {service}: {result.get('stopped')}")
            
        except Exception as e:
            logger.error(f"Error handling service stop: {e}")
            await self.send_response('service_stopped', {
                'stopped': False,
                'error': str(e)
            })
    
    async def handle_service_update(self, data: dict):
        """Handle service update request from backend."""
        try:
            service = data.get('service')
            config = data.get('config', {})
            logger.info(f"Received service update request for {service}")
            
            if not self.service_manager:
                logger.warning("No service manager configured")
                await self.send_response('service_updated', {
                    'updated': False,
                    'error': 'Service manager not available'
                })
                return
            
            # Update service
            result = await self.service_manager.update_service(service, config)
            
            # Send result back
            await self.send_response('service_updated', result)
            logger.info(f"Service update complete for {service}: {result.get('updated')}")
            
        except Exception as e:
            logger.error(f"Error handling service update: {e}")
            await self.send_response('service_updated', {
                'updated': False,
                'error': str(e)
            })
    
    async def send_response(self, response_type: str, data: dict):
        """Send a response message back through tunnel."""
        try:
            if not self.ws:
                return
            
            message = {
                'type': response_type,
                **data
            }
            
            await self.ws.send(json.dumps(message))
            logger.debug(f"Sent {response_type} response")
            
        except Exception as e:
            logger.error(f"Error sending response: {e}")
    
    async def handle_config_update(self, data: dict):
        """Handle configuration update from backend (legacy)."""
        try:
            logger.info("Received config update from backend")
            config = data.get('config', {})
            
            if self.service_manager:
                await self.service_manager.update_config(config)
                logger.info("Services updated successfully")
                
                # Send status update back
                await self.send_status()
            else:
                logger.warning("No service manager configured")
                
        except Exception as e:
            logger.error(f"Error handling config update: {e}")
    
    async def send_status(self):
        """Send status report to backend."""
        try:
            if not self.ws:
                return
            
            status = {
                'type': 'status_update',
                'services': {}
            }
            
            if self.service_manager:
                status['services'] = self.service_manager.get_status()
            
            await self.ws.send(json.dumps(status))
            logger.debug("Status update sent to backend")
            
        except Exception as e:
            logger.error(f"Error sending status: {e}")
    
    async def stop(self):
        """Gracefully stop the tunnel."""
        logger.info("Stopping tunnel...")
        self.running = False
        
        # Close all TCP tunnels
        for tunnel_id, tunnel in list(self.tcp_tunnels.items()):
            await tunnel.close()
        self.tcp_tunnels.clear()
        
        if self.ws:
            await self.ws.close()
    
    async def handle_tunnel_open(self, data: dict):
        """Open a TCP tunnel to local service."""
        tunnel_id = data.get('tunnel_id')
        target = data.get('target')
        
        if not tunnel_id:
            logger.warning("Received tunnel_open without tunnel_id")
            return
        
        logger.info(f"Opening tunnel {tunnel_id} to {target}")
        
        try:
            # Parse target first
            if not target or ':' not in target:
                await self.send_response('tunnel_error', {
                    'tunnel_id': tunnel_id,
                    'error': 'Invalid target format. Expected "host:port"'
                })
                return
            
            # Validate localhost only
            if not target.startswith(('localhost:', '127.0.0.1:')):
                await self.send_response('tunnel_error', {
                    'tunnel_id': tunnel_id,
                    'error': 'Only localhost targets allowed'
                })
                return
            
            host, port_str = target.split(':', 1)
            port = int(port_str)
            
            # Open TCP connection
            reader, writer = await asyncio.wait_for(
                asyncio.open_connection(host, port),
                timeout=10.0
            )
            
            # Create tunnel
            tunnel = TCPTunnel(tunnel_id, reader, writer, self)
            self.tcp_tunnels[tunnel_id] = tunnel
            
            # Start forwarding from TCP to WebSocket
            asyncio.create_task(tunnel.forward_to_websocket())
            
            # Notify success
            await self.send_response('tunnel_ready', {
                'tunnel_id': tunnel_id,
                'target': target
            })
            logger.info(f"✅ Tunnel {tunnel_id} opened to {target}")
            
        except asyncio.TimeoutError:
            logger.error(f"Tunnel {tunnel_id} connection timeout")
            await self.send_response('tunnel_error', {
                'tunnel_id': tunnel_id,
                'error': 'Connection timeout'
            })
        except ConnectionRefusedError:
            logger.error(f"Tunnel {tunnel_id} connection refused")
            await self.send_response('tunnel_error', {
                'tunnel_id': tunnel_id,
                'error': 'Connection refused'
            })
        except ValueError as e:
            logger.error(f"Tunnel {tunnel_id} invalid port: {e}")
            await self.send_response('tunnel_error', {
                'tunnel_id': tunnel_id,
                'error': f'Invalid port: {e}'
            })
        except Exception as e:
            logger.error(f"Tunnel {tunnel_id} open failed: {e}")
            await self.send_response('tunnel_error', {
                'tunnel_id': tunnel_id,
                'error': str(e)
            })
    
    async def handle_tunnel_data(self, data: dict):
        """Forward data from WebSocket to TCP."""
        tunnel_id = data.get('tunnel_id')
        
        if not tunnel_id:
            logger.warning("Received tunnel_data without tunnel_id")
            return
        
        tunnel = self.tcp_tunnels.get(tunnel_id)
        
        if not tunnel:
            logger.warning(f"Tunnel {tunnel_id} not found for data forwarding")
            return
        
        try:
            # Decode base64 and write to TCP
            tcp_data = base64.b64decode(data['data'])
            tunnel.writer.write(tcp_data)
            await tunnel.writer.drain()
            logger.debug(f"Forwarded {len(tcp_data)} bytes to tunnel {tunnel_id}")
            
        except Exception as e:
            logger.error(f"Error forwarding data to tunnel {tunnel_id}: {e}")
            await tunnel.close()
            self.tcp_tunnels.pop(tunnel_id, None)
    
    async def handle_tunnel_close(self, data: dict):
        """Close TCP tunnel."""
        tunnel_id = data.get('tunnel_id')
        
        if not tunnel_id:
            logger.warning("Received tunnel_close without tunnel_id")
            return
        
        tunnel = self.tcp_tunnels.pop(tunnel_id, None)
        
        if tunnel:
            logger.info(f"Closing tunnel {tunnel_id}")
            await tunnel.close()
        else:
            logger.warning(f"Tunnel {tunnel_id} not found for closing")


