"""
Main client for PredictionData API.
"""

import asyncio
import gzip
import io
from datetime import datetime, timedelta
from typing import AsyncIterator, List, Optional, Tuple, Dict, Any
import csv

import aiohttp

from predictiondata.channel import Channel, DataType


class PredictionDataClient:
    """
    Client for accessing PredictionData API.
    
    Args:
        api_key: Your PredictionData API key
        base_url: Base URL for the API (default: http://datasets.predictiondata.dev)
    
    Example:
        client = PredictionDataClient(api_key="your_api_key")
        messages = client.replay(
            exchange="polymarket",
            from_date="2024-11-01",
            to_date="2024-11-15",
            filters=[Channel(name="onchain_fills", symbols=["will-trump-win-2024/YES"])]
        )
        async for local_timestamp, message in messages:
            process_message(message)
    """
    
    def __init__(self, api_key: str, base_url: str = "http://datasets.predictiondata.dev"):
        self.api_key = api_key
        self.base_url = base_url.rstrip("/")
        self._session: Optional[aiohttp.ClientSession] = None
    
    async def _get_session(self) -> aiohttp.ClientSession:
        """Get or create aiohttp session."""
        if self._session is None or self._session.closed:
            self._session = aiohttp.ClientSession()
        return self._session
    
    async def close(self):
        """Close the client session."""
        if self._session and not self._session.closed:
            await self._session.close()
    
    async def __aenter__(self):
        """Async context manager entry."""
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit."""
        await self.close()
    
    def _build_url(self, exchange: str, data_type: DataType, identifier: str, date: datetime, use_slug: bool) -> str:
        """
        Build the API URL for a specific date and identifier.
        
        Args:
            exchange: Exchange name (e.g., "polymarket")
            data_type: Type of data (books, trades, onchain_fills)
            identifier: Market symbol or token ID
            date: Date to fetch
            use_slug: Whether the identifier is a slug (True) or token ID (False)
        
        Returns:
            Full URL string
        """
        date_str = date.strftime("%Y-%m-%d")
        slug_param = "true" if use_slug else "false"
        
        url = f"{self.base_url}/{exchange}/{data_type.value}/{identifier}/{date_str}.csv.gz?slug={slug_param}&apikey={self.api_key}"
        return url
    
    async def _fetch_csv_data(self, url: str) -> List[Dict[str, Any]]:
        """
        Fetch and parse a gzipped CSV file from the API.
        
        Args:
            url: URL to fetch
        
        Returns:
            List of dictionaries representing CSV rows
        """
        session = await self._get_session()
        
        try:
            async with session.get(url) as response:
                if response.status == 404:
                    # No data available for this date/symbol
                    return []
                
                response.raise_for_status()
                
                # Read the gzipped content
                compressed_data = await response.read()
                
                # Decompress
                decompressed_data = gzip.decompress(compressed_data)
                
                # Parse CSV
                csv_text = decompressed_data.decode('utf-8')
                csv_reader = csv.DictReader(io.StringIO(csv_text))
                
                return list(csv_reader)
        except aiohttp.ClientResponseError as e:
            if e.status == 404:
                return []
            raise
    
    def _parse_timestamp(self, row: Dict[str, Any], data_type: DataType) -> Tuple[int, int]:
        """
        Parse timestamps from a row based on data type.
        
        Returns:
            Tuple of (local_timestamp, exchange_timestamp) in milliseconds
        """
        if data_type == DataType.ONCHAIN_FILLS:
            # Onchain fills only have block_timestamp
            block_timestamp = int(row['block_timestamp'])
            return (block_timestamp, block_timestamp)
        else:
            # Books and trades have both local_timestamp and exchange_timestamp
            local_timestamp = int(row['local_timestamp'])
            exchange_timestamp = int(row['exchange_timestamp'])
            return (local_timestamp, exchange_timestamp)
    
    def _enrich_message(self, row: Dict[str, Any], data_type: DataType, symbol: str, date: datetime) -> Dict[str, Any]:
        """
        Enrich a message with metadata.
        
        Args:
            row: Raw CSV row
            data_type: Type of data
            symbol: Market symbol/identifier
            date: Date of data
        
        Returns:
            Enriched message dictionary
        """
        message = dict(row)
        message['_symbol'] = symbol
        message['_data_type'] = data_type.value
        message['_date'] = date.strftime("%Y-%m-%d")
        
        # Convert numeric fields
        if data_type == DataType.BOOKS:
            # Books have prices and sizes as comma-separated strings in quotes
            message['ask_prices'] = message['ask_prices']
            message['ask_sizes'] = message['ask_sizes']
            message['bid_prices'] = message['bid_prices']
            message['bid_sizes'] = message['bid_sizes']
        elif data_type == DataType.TRADES:
            message['size'] = float(message['size'])
            message['price'] = float(message['price'])
        elif data_type == DataType.ONCHAIN_FILLS:
            message['block_number'] = int(message['block_number'])
            message['size'] = float(message['size'])
            message['price'] = float(message['price'])
        
        return message
    
    async def replay(
        self,
        exchange: str,
        from_date: str,
        to_date: str,
        filters: List[Channel]
    ) -> AsyncIterator[Tuple[int, Dict[str, Any]]]:
        """
        Replay historical market data for specified date range and filters.
        
        Args:
            exchange: Exchange name (e.g., "polymarket")
            from_date: Start date (YYYY-MM-DD format)
            to_date: End date (YYYY-MM-DD format)
            filters: List of Channel filters specifying what data to fetch
        
        Yields:
            Tuples of (exchange_timestamp_ms, message_dict)
        
        Example:
            async for exchange_timestamp, message in client.replay(
                exchange="polymarket",
                from_date="2024-11-01",
                to_date="2024-11-15",
                filters=[Channel(name="trades", symbols=["will-trump-win-2024/YES"])]
            ):
                print(f"Time: {exchange_timestamp}, Message: {message}")
        """
        # Parse dates
        start_date = datetime.strptime(from_date, "%Y-%m-%d")
        end_date = datetime.strptime(to_date, "%Y-%m-%d")
        
        # Generate date range
        current_date = start_date
        date_range = []
        while current_date <= end_date:
            date_range.append(current_date)
            current_date += timedelta(days=1)
        
        # Process each filter
        for channel in filters:
            data_type = channel.get_data_type()
            identifiers = channel.get_identifiers()
            use_slug = not channel.uses_token_ids()
            
            # For each identifier and date, fetch and yield data
            for identifier in identifiers:
                for date in date_range:
                    url = self._build_url(exchange, data_type, identifier, date, use_slug)
                    
                    try:
                        rows = await self._fetch_csv_data(url)
                        
                        # Sort by timestamp and yield
                        for row in rows:
                            local_timestamp, exchange_timestamp = self._parse_timestamp(row, data_type)
                            enriched_message = self._enrich_message(row, data_type, identifier, date)
                            
                            yield (exchange_timestamp, enriched_message)
                    
                    except Exception as e:
                        # Log error but continue processing
                        print(f"Error fetching data for {identifier} on {date}: {e}")
                        continue
    
    async def fetch_day(
        self,
        exchange: str,
        data_type: str,
        identifier: str,
        date: str,
        use_slug: bool = True
    ) -> List[Dict[str, Any]]:
        """
        Fetch data for a single day and identifier.
        
        Args:
            exchange: Exchange name (e.g., "polymarket")
            data_type: Type of data ("books", "trades", or "onchain_fills")
            identifier: Market symbol or token ID
            date: Date in YYYY-MM-DD format
            use_slug: Whether identifier is a slug (True) or token ID (False)
        
        Returns:
            List of message dictionaries
        
        Example:
            data = await client.fetch_day(
                exchange="polymarket",
                data_type="trades",
                identifier="will-trump-win-2024/YES",
                date="2024-11-15"
            )
        """
        dt = datetime.strptime(date, "%Y-%m-%d")
        
        # Convert string to DataType
        data_type_enum = DataType(data_type.replace("_", "/"))
        
        url = self._build_url(exchange, data_type_enum, identifier, dt, use_slug)
        rows = await self._fetch_csv_data(url)
        
        return [self._enrich_message(row, data_type_enum, identifier, dt) for row in rows]

