"""
Channel and filter definitions for PredictionData streaming.
"""

from enum import Enum
from typing import List, Optional
from dataclasses import dataclass


class DataType(Enum):
    """Available data types from PredictionData API."""
    BOOKS = "books"
    TRADES = "trades"
    ONCHAIN_FILLS = "onchain/fills"


@dataclass
class Channel:
    """
    Represents a data channel filter for streaming market data.
    
    Args:
        name: The type of data to stream (books, trades, or onchain_fills)
        symbols: List of market symbols/slugs to filter
        token_ids: Optional list of specific token IDs to fetch
    
    Example:
        Channel(name="books", symbols=["will-trump-win-2024/YES"])
        Channel(name="onchain_fills", token_ids=["0x123..."])
    """
    name: str
    symbols: Optional[List[str]] = None
    token_ids: Optional[List[str]] = None
    
    def __post_init__(self):
        # Convert name to DataType if it's a string
        if isinstance(self.name, str):
            # Handle both "onchain_fills" and "onchain/fills"
            name_normalized = self.name.replace("_", "/")
            
            # Try to match to DataType
            for data_type in DataType:
                if data_type.value == name_normalized or data_type.name.lower() == self.name.lower():
                    self.data_type = data_type
                    break
            else:
                raise ValueError(f"Invalid data type: {self.name}. Must be one of: {[dt.value for dt in DataType]}")
        elif isinstance(self.name, DataType):
            self.data_type = self.name
        else:
            raise TypeError("name must be a string or DataType enum")
        
        # Validate that at least symbols or token_ids is provided
        if not self.symbols and not self.token_ids:
            raise ValueError("Either symbols or token_ids must be provided")
    
    def get_data_type(self) -> DataType:
        """Get the DataType enum value."""
        return self.data_type
    
    def get_identifiers(self) -> List[str]:
        """
        Get the list of identifiers (symbols or token_ids) for this channel.
        Prioritizes token_ids over symbols if both are provided.
        """
        if self.token_ids:
            return self.token_ids
        return self.symbols or []
    
    def uses_token_ids(self) -> bool:
        """Returns True if this channel uses token IDs instead of symbols."""
        return bool(self.token_ids)

