from typing import Dict, Any, Optional, Union, Iterator, ClassVar
from synthegrator.code_problems import CodeProblem
from synthegrator.util import IteratorWithLength
import logging
from typing import Optional, Union, Dict, List, Any

logger = logging.getLogger(__name__)

class DatasetNameType:
    """Base class for dataset name type to use in type annotations"""
    pass



class DatasetSpec(DatasetNameType):
    """Specification for a dataset with optional filters."""
    
    def __init__(
        self,
        name: str,
        display_name: Optional[str] = None,
        is_a_base_collection: bool = False,
        parent: Optional['DatasetSpec'] = None,
        **filters: Any
    ) -> None:
        """
        Initialize a dataset specification.
        
        Args:
            name: The name of the dataset (e.g., "humaneval_plus").
            display_name: Optional friendly display name for the dataset (e.g., "HumanEval+")
            is_a_base_collection: Whether this dataset is a distinct collection from a 
                                 specific source/paper, not just a filtered view
            parent: Optional parent specification to inherit filters from
            **filters: Filter key-value pairs to apply to this dataset

        The actual items that get yielded by dataset should be the
        get_base_name(). So something like mbpp_sanatized will yield
        items with a dataset_name of "mbpp" which is the name of it's
        base collection parent.
        """
        self.name = name
        self.display_name = display_name or name
        self.is_a_base_collection = is_a_base_collection
        self.parent = parent
        self.filters = filters
        
    def __str__(self) -> str:
        return self.name

    def __eq__(self, other: object) -> bool:
        if isinstance(other, DatasetSpec):
            return (self.name == other.name and
                    self.is_a_base_collection == other.is_a_base_collection and
                    self.filters == other.filters and
                    (self.parent == other.parent if self.parent and other.parent else
                     self.parent is other.parent))
        elif isinstance(other, str):
            return self.name == other
        return NotImplemented

    def __hash__(self) -> int:
        # Create a hashable representation of filters
        filters_hash = frozenset(self.filters.items())
        # Include parent name instead of parent object to avoid recursion
        parent_hash = self.parent.name if self.parent else None
        return hash((self.name, parent_hash, filters_hash, self.is_a_base_collection))

    def __repr__(self) -> str:
        """Create a deterministic string representation for caching"""
        parent_str = f", parent={self.parent.name}" if self.parent else ""
        filters_str = ", " + ", ".join(f"{k}={v!r}" for k, v in self.filters.items()) if self.filters else ""
        return f"DatasetSpec(name={self.name!r}, is_base={self.is_a_base_collection}{parent_str}{filters_str})"

    @property
    def is_root(self) -> bool:
        """Whether this is a root dataset specification with no parent."""
        return self.parent is None
        
    def get_base_name(self) -> str:
        """Get the base name for routing to the appropriate dataset loader."""
        if not self.is_a_base_collection and self.parent:
            return self.parent.get_base_name()
        return self.name

    def get_all_filters(self) -> Dict[str, Any]:
        """Get all filters, including parent filters."""
        if not self.parent:
            return self.filters.copy()
            
        all_filters = self.parent.get_all_filters()
        all_filters.update(self.filters)
        return all_filters
        
    def get_base_collection(self) -> 'DatasetSpec':
        """Get the nearest ancestor (including self) that is a base collection."""
        if self.is_a_base_collection:
            return self
        elif self.parent:
            return self.parent.get_base_collection()
        # Fallback in case there's no base collection in the hierarchy
        return self
        
    def get_base_collection_name(self) -> str:
        """Get the name of the nearest base collection in the hierarchy."""
        base_collection = self.get_base_collection()
        return base_collection.name


class DatasetNameMeta(type):
    def __iter__(cls):
        """Make the DatasetName class directly iterable."""
        for attr_name, attr_value in vars(cls).items():
            if isinstance(attr_value, DatasetSpec) and not attr_name.startswith('_'):
                yield attr_value

    def __instancecheck__(cls, instance):
        """Make isinstance(x, DatasetName) work for DatasetSpec instances defined in DatasetName."""
        if isinstance(instance, DatasetSpec):
            # Check if this instance is one of the dataset specs defined in DatasetName
            for attr_name, attr_value in vars(cls).items():
                if not attr_name.startswith('_') and instance == attr_value:
                    return True
        return False


class DatasetName(DatasetNameType, metaclass=DatasetNameMeta):
    """Collection of available datasets and their variants."""
    # TODO: add some way of registering datasets and their loaders
    # dymanically so easier for downstream libraries.

    # Base datasets - these are all distinct collections from papers/sources
    humaneval: ClassVar[DatasetSpec] = DatasetSpec(
        name="humaneval", 
        display_name="HumanEval",
        is_a_base_collection=True
    )
    
    mbpp: ClassVar[DatasetSpec] = DatasetSpec(
        name="mbpp",
        display_name="MBPP",
        is_a_base_collection=True,
        format_as_method_completion_problem=True,
    )
    
    livecodebench: ClassVar[DatasetSpec] = DatasetSpec(
        name="livecodebench", 
        display_name="LiveCodeBench",
        is_a_base_collection=True,
    )
    
    dypy_line_completion: ClassVar[DatasetSpec] = DatasetSpec(
        name="dypy_line_completion",
        display_name="DyPy",
        is_a_base_collection=True
    )
    
    humaneval_plus: ClassVar[DatasetSpec] = DatasetSpec(
        name="humaneval_plus",
        display_name="HumanEval+", 
        is_a_base_collection=True,  # It's a distinct collection from eval+
        parent=humaneval,  # But inherits routing information
        use_eval_plus=True
    )
    
    mbpp_plus: ClassVar[DatasetSpec] = DatasetSpec(
        name="mbpp_plus",
        display_name="MBPP+", 
        is_a_base_collection=True,  # It's a distinct collection from distinct paper
        parent=mbpp,  # But inherits routing information
        format_as_method_completion_problem=True,
    )
    
    # Just a filtered view of MBPP
    mbpp_sanatized: ClassVar[DatasetSpec] = DatasetSpec(
        name="mbpp_sanatized",
        display_name="MBPP-Sanitized", 
        is_a_base_collection=False,  # Not a distinct collection, just a filtered view
        parent=mbpp,
        sanitized_only=True,
    )
    
    # Filtered views of LiveCodeBench by difficulty
    livecodebench_easy: ClassVar[DatasetSpec] = DatasetSpec(
        name="livecodebench_easy",
        display_name="LiveCodeBench-Easy", 
        is_a_base_collection=False,
        parent=livecodebench,
        filter_difficulty="easy"
    )
    
    livecodebench_medium: ClassVar[DatasetSpec] = DatasetSpec(
        name="livecodebench_medium",
        display_name="LiveCodeBench-Medium", 
        is_a_base_collection=False,
        parent=livecodebench,
        filter_difficulty="medium"
    )
    
    livecodebench_hard: ClassVar[DatasetSpec] = DatasetSpec(
        name="livecodebench_hard",
        display_name="LiveCodeBench-Hard", 
        is_a_base_collection=False,
        parent=livecodebench,
        filter_difficulty="hard"
    )
    

    repocod: ClassVar[DatasetSpec] = DatasetSpec(
        name="repocod",
        display_name="REPOCOD", 
        is_a_base_collection=True,
    )


# Type for the dataset parameter 
DatasetSpecifier = Union[str, DatasetSpec]


def yield_problems_from_name(
    dataset: DatasetSpecifier,
    max_problems: Optional[int] = None,
    **kwargs: Any
) -> Union[Iterator[CodeProblem], IteratorWithLength[CodeProblem]]:
    """
    Load and yield problems from a dataset by name, with optional filtering.
    
    Args:
        dataset: A dataset specifier - either a string name or DatasetSpec object
        max_problems: Maximum number of problems to yield
        **kwargs: Additional arguments to pass to the dataset yield function
        
    Returns:
        An iterator of CodeProblem instances
        
    Examples:
        > yield_problems_from_name(DatasetName.LIVECODEBENCH_EASY, max_problems=10)
        > yield_problems_from_name("humaneval", use_eval_plus=True)
        > yield_problems_from_name("mbpp", sanitized_only=True)
    """
    if isinstance(dataset, DatasetSpec):
        # Extract base dataset name and filters
        base_name = dataset.get_base_name()
        filters = dataset.get_all_filters()
        
        # Merge filters with kwargs (kwargs take precedence)
        merged_kwargs = filters.copy()
        merged_kwargs.update(kwargs)
        kwargs = merged_kwargs
    else:
        base_name = str(dataset)
        
    # Pass max_problems to all dataset loaders
    if max_problems is not None:
        kwargs['max_problems'] = max_problems
    
    # Dispatch to the appropriate dataset loader
    if base_name == DatasetName.humaneval or base_name == DatasetName.humaneval_plus:
        from synthegrator.synthdatasets.human_eval import yield_human_eval
        return yield_human_eval(**kwargs)
    elif base_name == DatasetName.mbpp:
        from synthegrator.synthdatasets.mbpp import yield_mbpp
        return yield_mbpp(**kwargs)
    elif base_name == DatasetName.mbpp_plus:
        from synthegrator.synthdatasets.mbpp import yield_mbpp_plus
        return yield_mbpp_plus(**kwargs)
    elif base_name == DatasetName.dypy_line_completion:
        from synthegrator.synthdatasets.dypybench import yield_dypybench
        return yield_dypybench(**kwargs)
    elif base_name == DatasetName.livecodebench:
        from synthegrator.synthdatasets.livecodebench import yield_livecode_problems
        return yield_livecode_problems(**kwargs)
    elif base_name == DatasetName.repocod:
        from synthegrator.synthdatasets.repocod import yield_repocod_problems
        return yield_repocod_problems(**kwargs)
    else:
        raise ValueError(f"Unknown dataset: {base_name}")


def load_hf_dataset_with_backoff(
    path: str,
    name: Optional[str] = None,  # This should always be specified
    max_retries: int = 5,
    initial_delay: float = 1.0,
    max_delay: float = 60.0,
    jitter: float = 0.1,
    **kwargs
):
    """
    Load a Hugging Face dataset with exponential backoff for handling rate limits.

    Args:
        path: Dataset path or name
        name: Dataset configuration name (SHOULD BE EXPLICITLY SPECIFIED)
        max_retries: Maximum number of retry attempts
        initial_delay: Initial delay between retries in seconds
        max_delay: Maximum delay between retries in seconds
        jitter: Random jitter factor to add to delay
        **kwargs: Additional arguments to pass to load_dataset
    """
    import time
    import random
    from datasets import load_dataset
    from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError
    from requests.exceptions import HTTPError, ConnectionError, Timeout
    retry_exceptions = (
        HTTPError,
        ConnectionError,
        Timeout,
        FileNotFoundError,
    )

    last_exception = None
    delay = initial_delay

    for attempt in range(max_retries):
        try:
            # Always use the explicitly provided configuration
            if name is not None:
                return load_dataset(path, name, **kwargs)
            else:
                return load_dataset(path, **kwargs)

        except retry_exceptions as e:
            last_exception = e

            # Only handle rate limiting and connection errors
            if attempt < max_retries - 1:
                # Calculate sleep time with exponential backoff and jitter
                jitter_amount = random.uniform(-jitter, jitter) * delay
                sleep_time = min(delay + jitter_amount, max_delay)

                time.sleep(sleep_time)
                delay = min(delay * 2, max_delay)

    if last_exception:
        raise last_exception
    else:
        raise RuntimeError(f"Failed to load dataset {path} after {max_retries} attempts")