"""Unit tests for auth module functionality.

Note: OAuth contract tests are in tests/contract/test_oauth.py
This file focuses on unit testing individual functions.
"""

import pytest
from unittest.mock import Mock, patch
from datetime import datetime, timedelta
from mcpbundles_proxy.auth import find_free_port, refresh_token_if_needed, generate_pkce


def test_find_free_port():
    """Test finding available port."""
    port = find_free_port(start_port=8765)
    assert 8765 <= port < 8775


def test_find_free_port_all_busy():
    """Test error when no ports available."""
    with patch('socket.socket') as mock_socket:
        mock_socket.return_value.__enter__.return_value.bind.side_effect = OSError("Port in use")
        
        with pytest.raises(RuntimeError, match="Could not find available port"):
            find_free_port(start_port=8765, max_attempts=3)


def test_generate_pkce():
    """Test PKCE code generation."""
    code_verifier, code_challenge = generate_pkce()
    
    # Verify lengths
    assert 43 <= len(code_verifier) <= 128
    assert len(code_challenge) == 43
    
    # Verify they're different
    assert code_verifier != code_challenge
    
    # Verify URL-safe base64
    import string
    allowed_chars = set(string.ascii_letters + string.digits + '-_')
    assert all(c in allowed_chars for c in code_verifier)
    assert all(c in allowed_chars for c in code_challenge)


@pytest.mark.asyncio
async def test_refresh_token_if_needed_not_expired():
    """Test token refresh when token is still valid."""
    future_time = datetime.now() + timedelta(hours=2)
    token_data = {
        "access_token": "valid_token",
        "refresh_token": "refresh_token",
        "expires_at": future_time.isoformat()
    }
    
    result = await refresh_token_if_needed(token_data)
    assert result == token_data


@pytest.mark.asyncio
async def test_refresh_token_if_needed_expired():
    """Test token refresh when token is expired."""
    near_expiry = datetime.now() + timedelta(minutes=30)
    token_data = {
        "access_token": "old_token",
        "refresh_token": "refresh_token",
        "client_id": "test_client",
        "client_secret": "test_secret",
        "expires_at": near_expiry.isoformat()
    }
    
    new_token_response = {
        "access_token": "new_token",
        "expires_in": 86400
    }
    
    with patch('mcpbundles_proxy.auth.requests.post') as mock_post:
        mock_response = Mock()
        mock_response.status_code = 200
        mock_response.json.return_value = new_token_response
        mock_post.return_value = mock_response
        
        with patch('mcpbundles_proxy.auth.save_token'):
            result = await refresh_token_if_needed(token_data)
        
        assert result is not None
        assert result["access_token"] == "new_token"
        assert result["refresh_token"] == "refresh_token"
        assert result["client_id"] == "test_client"
        assert result["client_secret"] == "test_secret"


@pytest.mark.asyncio
async def test_refresh_token_failed():
    """Test token refresh failure."""
    near_expiry = datetime.now() + timedelta(minutes=30)
    token_data = {
        "access_token": "old_token",
        "refresh_token": "refresh_token",
        "expires_at": near_expiry.isoformat()
    }
    
    with patch('mcpbundles_proxy.auth.requests.post') as mock_post:
        mock_response = Mock()
        mock_response.status_code = 401
        mock_post.return_value = mock_response
        
        result = await refresh_token_if_needed(token_data)
        assert result is None


@pytest.mark.asyncio
async def test_refresh_token_network_error():
    """Test token refresh with network error."""
    near_expiry = datetime.now() + timedelta(minutes=30)
    token_data = {
        "access_token": "old_token",
        "refresh_token": "refresh_token",
        "client_id": "test_client",
        "client_secret": "test_secret",
        "expires_at": near_expiry.isoformat()
    }
    
    with patch('mcpbundles_proxy.auth.requests.post') as mock_post:
        mock_post.side_effect = ConnectionError("Network error")
        
        result = await refresh_token_if_needed(token_data)
        assert result is None

