"""Tests for TCP tunneling functionality."""

import pytest
import asyncio
import json
import base64
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from mcpbundles_proxy.tunnel import TunnelClient, TCPTunnel


@pytest.fixture
def mock_websocket():
    """Create a mock WebSocket connection."""
    ws = AsyncMock()
    ws.send = AsyncMock()
    ws.close = AsyncMock()
    return ws


@pytest.fixture
def tunnel_client(mock_websocket):
    """Create a TunnelClient with mocked WebSocket."""
    client = TunnelClient(token="test_token")
    client.ws = mock_websocket
    return client


@pytest.mark.asyncio
async def test_tcp_tunnel_creation():
    """Test creating a TCP tunnel."""
    reader = AsyncMock(spec=asyncio.StreamReader)
    writer = AsyncMock(spec=asyncio.StreamWriter)
    client = TunnelClient(token="test_token")
    
    tunnel = TCPTunnel("tun_123", reader, writer, client)
    
    assert tunnel.tunnel_id == "tun_123"
    assert tunnel.reader == reader
    assert tunnel.writer == writer
    assert tunnel.tunnel_client == client
    assert tunnel.running is True


@pytest.mark.asyncio
async def test_tcp_tunnel_forward_to_websocket(tunnel_client):
    """Test forwarding data from TCP to WebSocket."""
    reader = AsyncMock(spec=asyncio.StreamReader)
    writer = AsyncMock(spec=asyncio.StreamWriter)
    
    # Mock reading data from TCP
    test_data = b"Hello, World!"
    reader.read = AsyncMock(side_effect=[test_data, b""])  # Second call returns empty (EOF)
    writer.close = AsyncMock()
    writer.wait_closed = AsyncMock()
    
    tunnel = TCPTunnel("tun_123", reader, writer, tunnel_client)
    
    # Run forwarding
    await tunnel.forward_to_websocket()
    
    # Verify data was sent to WebSocket
    assert tunnel_client.ws.send.called
    call_args = tunnel_client.ws.send.call_args_list[0][0][0]
    message = json.loads(call_args)
    
    assert message['type'] == 'tunnel_data'
    assert message['tunnel_id'] == 'tun_123'
    assert base64.b64decode(message['data']) == test_data


@pytest.mark.asyncio
async def test_tcp_tunnel_close(tunnel_client):
    """Test closing a TCP tunnel."""
    reader = AsyncMock(spec=asyncio.StreamReader)
    writer = AsyncMock(spec=asyncio.StreamWriter)
    writer.close = AsyncMock()
    writer.wait_closed = AsyncMock()
    
    tunnel = TCPTunnel("tun_123", reader, writer, tunnel_client)
    
    await tunnel.close()
    
    assert tunnel.running is False
    writer.close.assert_called_once()
    writer.wait_closed.assert_called_once()


@pytest.mark.asyncio
async def test_handle_tunnel_open_success(tunnel_client):
    """Test successfully opening a tunnel."""
    data = {
        'tunnel_id': 'tun_123',
        'target': 'localhost:8080'
    }
    
    # Mock asyncio.open_connection
    reader = AsyncMock(spec=asyncio.StreamReader)
    writer = AsyncMock(spec=asyncio.StreamWriter)
    
    with patch('asyncio.open_connection', AsyncMock(return_value=(reader, writer))):
        await tunnel_client.handle_tunnel_open(data)
    
    # Verify tunnel was created
    assert 'tun_123' in tunnel_client.tcp_tunnels
    tunnel = tunnel_client.tcp_tunnels['tun_123']
    assert tunnel.tunnel_id == 'tun_123'
    
    # Verify tunnel_ready response was sent
    assert tunnel_client.ws.send.called
    call_args = tunnel_client.ws.send.call_args_list[0][0][0]
    message = json.loads(call_args)
    assert message['type'] == 'tunnel_ready'
    assert message['tunnel_id'] == 'tun_123'


@pytest.mark.asyncio
async def test_handle_tunnel_open_invalid_target(tunnel_client):
    """Test opening tunnel with invalid target."""
    data = {
        'tunnel_id': 'tun_123',
        'target': 'remote.example.com:8080'  # Not localhost
    }
    
    await tunnel_client.handle_tunnel_open(data)
    
    # Verify tunnel was not created
    assert 'tun_123' not in tunnel_client.tcp_tunnels
    
    # Verify error response was sent
    assert tunnel_client.ws.send.called
    call_args = tunnel_client.ws.send.call_args_list[0][0][0]
    message = json.loads(call_args)
    assert message['type'] == 'tunnel_error'
    assert message['tunnel_id'] == 'tun_123'
    assert 'localhost' in message['error']


@pytest.mark.asyncio
async def test_handle_tunnel_open_connection_refused(tunnel_client):
    """Test opening tunnel when connection is refused."""
    data = {
        'tunnel_id': 'tun_123',
        'target': 'localhost:9999'
    }
    
    # Mock connection refused
    with patch('asyncio.open_connection', AsyncMock(side_effect=ConnectionRefusedError())):
        await tunnel_client.handle_tunnel_open(data)
    
    # Verify tunnel was not created
    assert 'tun_123' not in tunnel_client.tcp_tunnels
    
    # Verify error response was sent
    assert tunnel_client.ws.send.called
    call_args = tunnel_client.ws.send.call_args_list[0][0][0]
    message = json.loads(call_args)
    assert message['type'] == 'tunnel_error'
    assert 'Connection refused' in message['error']


@pytest.mark.asyncio
async def test_handle_tunnel_open_timeout(tunnel_client):
    """Test opening tunnel with connection timeout."""
    data = {
        'tunnel_id': 'tun_123',
        'target': 'localhost:8080'
    }
    
    # Mock timeout
    with patch('asyncio.open_connection', AsyncMock(side_effect=asyncio.TimeoutError())):
        await tunnel_client.handle_tunnel_open(data)
    
    # Verify error response was sent
    assert tunnel_client.ws.send.called
    call_args = tunnel_client.ws.send.call_args_list[0][0][0]
    message = json.loads(call_args)
    assert message['type'] == 'tunnel_error'
    assert 'timeout' in message['error'].lower()


@pytest.mark.asyncio
async def test_handle_tunnel_data(tunnel_client):
    """Test forwarding data from WebSocket to TCP."""
    # Create a tunnel
    reader = AsyncMock(spec=asyncio.StreamReader)
    writer = AsyncMock(spec=asyncio.StreamWriter)
    writer.write = Mock()
    writer.drain = AsyncMock()
    
    tunnel = TCPTunnel("tun_123", reader, writer, tunnel_client)
    tunnel_client.tcp_tunnels['tun_123'] = tunnel
    
    # Send data
    test_data = b"Test message"
    data = {
        'tunnel_id': 'tun_123',
        'data': base64.b64encode(test_data).decode('ascii')
    }
    
    await tunnel_client.handle_tunnel_data(data)
    
    # Verify data was written to TCP
    writer.write.assert_called_once_with(test_data)
    writer.drain.assert_called_once()


@pytest.mark.asyncio
async def test_handle_tunnel_data_nonexistent_tunnel(tunnel_client):
    """Test forwarding data to nonexistent tunnel."""
    data = {
        'tunnel_id': 'tun_999',
        'data': base64.b64encode(b"test").decode('ascii')
    }
    
    # Should not raise an exception
    await tunnel_client.handle_tunnel_data(data)


@pytest.mark.asyncio
async def test_handle_tunnel_close(tunnel_client):
    """Test closing a tunnel via message."""
    # Create a tunnel
    reader = AsyncMock(spec=asyncio.StreamReader)
    writer = AsyncMock(spec=asyncio.StreamWriter)
    writer.close = AsyncMock()
    writer.wait_closed = AsyncMock()
    
    tunnel = TCPTunnel("tun_123", reader, writer, tunnel_client)
    tunnel_client.tcp_tunnels['tun_123'] = tunnel
    
    # Close tunnel
    data = {'tunnel_id': 'tun_123'}
    await tunnel_client.handle_tunnel_close(data)
    
    # Verify tunnel was removed
    assert 'tun_123' not in tunnel_client.tcp_tunnels
    
    # Verify TCP connection was closed
    writer.close.assert_called()


@pytest.mark.asyncio
async def test_stop_closes_all_tunnels(tunnel_client):
    """Test that stopping client closes all tunnels."""
    # Create multiple tunnels
    for i in range(3):
        reader = AsyncMock(spec=asyncio.StreamReader)
        writer = AsyncMock(spec=asyncio.StreamWriter)
        writer.close = AsyncMock()
        writer.wait_closed = AsyncMock()
        
        tunnel = TCPTunnel(f"tun_{i}", reader, writer, tunnel_client)
        tunnel_client.tcp_tunnels[f"tun_{i}"] = tunnel
    
    await tunnel_client.stop()
    
    # Verify all tunnels were closed
    assert len(tunnel_client.tcp_tunnels) == 0


@pytest.mark.asyncio
async def test_tunnel_bidirectional_flow():
    """Test bidirectional data flow through tunnel."""
    # This is an integration-style test
    
    # Create a mock echo server
    server_reader = asyncio.StreamReader()
    server_writer = AsyncMock(spec=asyncio.StreamWriter)
    server_writer.write = Mock()
    server_writer.drain = AsyncMock()
    server_writer.close = AsyncMock()
    server_writer.wait_closed = AsyncMock()
    
    # Create client
    client = TunnelClient(token="test_token")
    client.ws = AsyncMock()
    client.ws.send = AsyncMock()
    
    # Create tunnel
    tunnel = TCPTunnel("tun_test", server_reader, server_writer, client)
    client.tcp_tunnels["tun_test"] = tunnel
    
    # Simulate receiving data from WebSocket
    test_message = b"PING"
    data = {
        'tunnel_id': 'tun_test',
        'data': base64.b64encode(test_message).decode('ascii')
    }
    
    await client.handle_tunnel_data(data)
    
    # Verify data was written to TCP
    server_writer.write.assert_called_once_with(test_message)
    server_writer.drain.assert_called_once()


@pytest.mark.asyncio
async def test_tunnel_handles_large_data():
    """Test tunnel can handle large data chunks."""
    client = TunnelClient(token="test_token")
    client.ws = AsyncMock()
    client.ws.send = AsyncMock()
    
    reader = AsyncMock(spec=asyncio.StreamReader)
    writer = AsyncMock(spec=asyncio.StreamWriter)
    
    # Create large data (10KB)
    large_data = b"X" * 10240
    reader.read = AsyncMock(side_effect=[large_data, b""])
    writer.close = AsyncMock()
    writer.wait_closed = AsyncMock()
    
    tunnel = TCPTunnel("tun_large", reader, writer, client)
    
    await tunnel.forward_to_websocket()
    
    # Verify data was sent
    assert client.ws.send.called
    call_args = client.ws.send.call_args_list[0][0][0]
    message = json.loads(call_args)
    decoded_data = base64.b64decode(message['data'])
    assert len(decoded_data) == 10240


@pytest.mark.asyncio
async def test_tunnel_open_with_127_0_0_1(tunnel_client):
    """Test opening tunnel with 127.0.0.1 address."""
    data = {
        'tunnel_id': 'tun_123',
        'target': '127.0.0.1:5432'
    }
    
    reader = AsyncMock(spec=asyncio.StreamReader)
    writer = AsyncMock(spec=asyncio.StreamWriter)
    
    with patch('asyncio.open_connection', AsyncMock(return_value=(reader, writer))):
        await tunnel_client.handle_tunnel_open(data)
    
    # Verify tunnel was created
    assert 'tun_123' in tunnel_client.tcp_tunnels


@pytest.mark.asyncio
async def test_tunnel_open_invalid_format(tunnel_client):
    """Test opening tunnel with invalid target format."""
    data = {
        'tunnel_id': 'tun_123',
        'target': 'localhost'  # Missing port
    }
    
    await tunnel_client.handle_tunnel_open(data)
    
    # Verify error response
    assert tunnel_client.ws.send.called
    call_args = tunnel_client.ws.send.call_args_list[0][0][0]
    message = json.loads(call_args)
    assert message['type'] == 'tunnel_error'
    assert 'Invalid target format' in message['error']


@pytest.mark.asyncio
async def test_multiple_concurrent_tunnels(tunnel_client):
    """Test handling multiple concurrent tunnels."""
    tunnels_to_create = ['tun_1', 'tun_2', 'tun_3']
    
    with patch('asyncio.open_connection') as mock_connect:
        for tunnel_id in tunnels_to_create:
            reader = AsyncMock(spec=asyncio.StreamReader)
            writer = AsyncMock(spec=asyncio.StreamWriter)
            mock_connect.return_value = (reader, writer)
            
            data = {
                'tunnel_id': tunnel_id,
                'target': f'localhost:{8000 + int(tunnel_id.split("_")[1])}'
            }
            await tunnel_client.handle_tunnel_open(data)
    
    # Verify all tunnels were created
    assert len(tunnel_client.tcp_tunnels) == 3
    for tunnel_id in tunnels_to_create:
        assert tunnel_id in tunnel_client.tcp_tunnels
