"""Test OAuth authentication flow."""

import pytest
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime, timedelta
from mcpbundles_proxy.auth import find_free_port, oauth_flow, refresh_token_if_needed


def test_find_free_port():
    """Test finding available port."""
    port = find_free_port(start_port=8765)
    assert 8765 <= port < 8775


def test_find_free_port_all_busy():
    """Test error when no ports available."""
    with patch('socket.socket') as mock_socket:
        mock_socket.return_value.__enter__.return_value.bind.side_effect = OSError("Port in use")
        
        with pytest.raises(RuntimeError, match="Could not find available port"):
            find_free_port(start_port=8765, max_attempts=3)


@pytest.mark.asyncio
async def test_refresh_token_if_needed_not_expired(tmp_path, monkeypatch):
    """Test token refresh when token is still valid."""
    # Token expires in 2 hours
    future_time = datetime.now() + timedelta(hours=2)
    token_data = {
        "access_token": "valid_token",
        "refresh_token": "refresh_token",
        "expires_at": future_time.isoformat()
    }
    
    result = await refresh_token_if_needed(token_data)
    assert result == token_data


@pytest.mark.asyncio
async def test_refresh_token_if_needed_expired(tmp_path, monkeypatch):
    """Test token refresh when token is expired."""
    # Token expires in 30 minutes (should refresh)
    near_expiry = datetime.now() + timedelta(minutes=30)
    token_data = {
        "access_token": "old_token",
        "refresh_token": "refresh_token",
        "client_id": "test_client",
        "client_secret": "test_secret",
        "expires_at": near_expiry.isoformat()
    }
    
    new_token_response = {
        "access_token": "new_token",
        "expires_in": 86400
    }
    
    with patch('mcpbundles_proxy.auth.requests.post') as mock_post:
        mock_response = Mock()
        mock_response.status_code = 200
        mock_response.json.return_value = new_token_response
        mock_post.return_value = mock_response
        
        with patch('mcpbundles_proxy.auth.save_token'):
            result = await refresh_token_if_needed(token_data)
        
        assert result is not None
        assert result["access_token"] == "new_token"
        assert result["refresh_token"] == "refresh_token"  # Kept old refresh token
        assert result["client_id"] == "test_client"
        assert result["client_secret"] == "test_secret"


@pytest.mark.asyncio
async def test_refresh_token_failed(tmp_path, monkeypatch):
    """Test token refresh failure."""
    near_expiry = datetime.now() + timedelta(minutes=30)
    token_data = {
        "access_token": "old_token",
        "refresh_token": "refresh_token",
        "expires_at": near_expiry.isoformat()
    }
    
    with patch('mcpbundles_proxy.auth.requests.post') as mock_post:
        mock_response = Mock()
        mock_response.status_code = 401
        mock_post.return_value = mock_response
        
        result = await refresh_token_if_needed(token_data)
        assert result is None


def test_oauth_flow_success(monkeypatch):
    """Test successful OAuth flow."""
    with patch('mcpbundles_proxy.auth.find_free_port', return_value=8765):
        with patch('mcpbundles_proxy.auth.HTTPServer') as mock_server:
            with patch('mcpbundles_proxy.auth.webbrowser.open'):
                with patch('mcpbundles_proxy.auth.requests.post') as mock_post:
                    with patch('mcpbundles_proxy.auth.save_token'):
                        with patch('mcpbundles_proxy.auth.load_token', return_value=None):
                            # Mock OAuth callback
                            mock_handler_class = MagicMock()
                            mock_handler_class.auth_code = "test_code"
                            monkeypatch.setattr('mcpbundles_proxy.auth.OAuthCallbackHandler', mock_handler_class)
                            
                            # Mock DCR response then token exchange
                            dcr_response = Mock()
                            dcr_response.status_code = 201
                            dcr_response.json.return_value = {
                                "client_id": "new_client_id",
                                "client_secret": "new_client_secret"
                            }
                            
                            token_response = Mock()
                            token_response.status_code = 200
                            token_response.json.return_value = {
                                "access_token": "access_token",
                                "refresh_token": "refresh_token",
                                "expires_in": 86400
                            }
                            
                            mock_post.side_effect = [dcr_response, token_response]
                            
                            result = oauth_flow()
                            
                            assert result["access_token"] == "access_token"
                            assert result["client_id"] == "new_client_id"
                            assert result["client_secret"] == "new_client_secret"
                            assert "expires_at" in result


def test_oauth_flow_timeout(monkeypatch):
    """Test OAuth flow timeout."""
    with patch('mcpbundles_proxy.auth.find_free_port', return_value=8765):
        with patch('mcpbundles_proxy.auth.HTTPServer') as mock_server:
            with patch('mcpbundles_proxy.auth.webbrowser.open'):
                with patch('mcpbundles_proxy.auth.load_token', return_value=None):
                    with patch('mcpbundles_proxy.auth.requests.post') as mock_post:
                        # Mock DCR success
                        dcr_response = Mock()
                        dcr_response.status_code = 201
                        dcr_response.json.return_value = {
                            "client_id": "test_client",
                            "client_secret": "test_secret"
                        }
                        mock_post.return_value = dcr_response
                        
                        # Mock no auth code received
                        mock_handler_class = MagicMock()
                        mock_handler_class.auth_code = None
                        monkeypatch.setattr('mcpbundles_proxy.auth.OAuthCallbackHandler', mock_handler_class)
                        
                        with pytest.raises(Exception, match="Authentication failed or timed out"):
                            oauth_flow()


def test_oauth_flow_exchange_failed(monkeypatch):
    """Test OAuth token exchange failure."""
    with patch('mcpbundles_proxy.auth.find_free_port', return_value=8765):
        with patch('mcpbundles_proxy.auth.HTTPServer'):
            with patch('mcpbundles_proxy.auth.webbrowser.open'):
                with patch('mcpbundles_proxy.auth.load_token', return_value=None):
                    with patch('mcpbundles_proxy.auth.requests.post') as mock_post:
                        # Mock OAuth callback success
                        mock_handler_class = MagicMock()
                        mock_handler_class.auth_code = "test_code"
                        monkeypatch.setattr('mcpbundles_proxy.auth.OAuthCallbackHandler', mock_handler_class)
                        
                        # Mock DCR success then token exchange failure
                        dcr_response = Mock()
                        dcr_response.status_code = 201
                        dcr_response.json.return_value = {
                            "client_id": "test_client",
                            "client_secret": "test_secret"
                        }
                        
                        token_response = Mock()
                        token_response.status_code = 400
                        token_response.text = "Invalid code"
                        
                        mock_post.side_effect = [dcr_response, token_response]
                        
                        with pytest.raises(Exception, match="Token exchange failed"):
                            oauth_flow()

