# ezga/utils/molecule_blacklist.py
from __future__ import annotations

from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Any

from .chemistry import (
    pair_cutoff,
    get_positions_species,
    neighbors_within,
)

RuleFn = Callable[[Any], bool]  # fn(apm) -> True if motif present
CUTOFF_MAX = float(4.0)

def _rule_Hf(apm) -> bool:
    """
    Detect a free H2 dimer: each H has exactly one H neighbor within H-H cutoff,
    and no non-H neighbors within H-X cutoff.
    """
    R, S = get_positions_species(apm)
    if sum(1 for s in S if s == "H") < 2:
        return False

    hh_cut = 3.0 # pair_cutoff("H", "H")
    for i, si in enumerate(S):
        if si != "H":
            continue
        neigh_dist, neigh_idx = neighbors_within(apm, R[i], n=3, cutoff=CUTOFF_MAX)

        if neigh_dist[1] >= hh_cut:
            return True
        
    return False

def _rule_H2(apm) -> bool:
    """
    Detect a free H2 dimer: each H has exactly one H neighbor within H-H cutoff,
    and no non-H neighbors within H-X cutoff.
    """
    R, S = get_positions_species(apm)
    if sum(1 for s in S if s == "H") < 2:
        return False

    hh_cut = 0.9 # pair_cutoff("H", "H")
    for i, si in enumerate(S):
        if si != "H":
            continue
        neigh_dist, neigh_idx = neighbors_within(apm, R[i], n=3, cutoff=CUTOFF_MAX)

        if S[ neigh_idx[1] ] == "H" and neigh_dist[1] <= hh_cut:
            return True
        
    return False

def _rule_H2b(apm) -> bool:
    """
    Detect a free H2 dimer: each H has exactly one H neighbor within H-H cutoff,
    and no non-H neighbors within H-X cutoff.
    """
    R, S = get_positions_species(apm)
    if sum(1 for s in S if s == "H") < 2:
        return False

    hh_cut = 0.9 #pair_cutoff("H", "H")
    for i, si in enumerate(S):
        if si != "H":
            continue
        neigh_dist, neigh_idx = neighbors_within(apm, R[i], n=3, cutoff=CUTOFF_MAX)

        if S[ neigh_idx[1] ] == "H" and neigh_dist[1] <= hh_cut and neigh_dist[2] <= 2.0:
            return True
        
    return False

def _rule_H2O(apm) -> bool:
    """
    Detect an H2O molecule: O with exactly two H neighbors within O-H cutoff,
    and the two H are NOT H-H bonded (to avoid mislabeling OH + H2 as H2O).
    """
    R, S = get_positions_species(apm)
    if not any(s == "O" for s in S) or sum(1 for s in S if s == "H") < 2:
        return False

    oh_cut = pair_cutoff("O", "H")
    hh_cut = pair_cutoff("H", "H")

    for i, si in enumerate(S):
        if si != "O":
            continue
        neigh_dist, neigh_idx = neighbors_within(apm, R[i], n=4, cutoff=CUTOFF_MAX)

        if (
            S[neigh_idx[1]] == "H" and neigh_dist[1] <= oh_cut and
            S[neigh_idx[2]] == "H" and neigh_dist[2] <= oh_cut
        ):
            return True
        
    return False

def _rule_H2Of(apm) -> bool:
    """
    Detect an H2O molecule: O with exactly two H neighbors within O-H cutoff,
    and the two H are NOT H-H bonded (to avoid mislabeling OH + H2 as H2O).
    """
    R, S = get_positions_species(apm)

    if not any(s == "O" for s in S) or sum(1 for s in S if s == "H") < 2:
        return False

    oh_cut = pair_cutoff("O", "H")
    hh_cut = pair_cutoff("H", "H")

    for i, si in enumerate(S):
        if si != "O":
            continue
        neigh_dist, neigh_idx = neighbors_within(apm, R[i], n=4, cutoff=CUTOFF_MAX)

        if ( 
            S[ neigh_idx[1] ] == "H" and neigh_dist[1] <= oh_cut and 
            S[ neigh_idx[2] ] == "H" and neigh_dist[2] <= oh_cut and  
            neigh_dist[3] >= 2.5
        ):
            return True
        
    return False

def _rule_O2(apm) -> bool:
    """
    Detect an H2O molecule: O with exactly two H neighbors within O-H cutoff,
    and the two H are NOT H-H bonded (to avoid mislabeling OH + H2 as H2O).
    """
    R, S = get_positions_species(apm)
    if sum(1 for s in S if s == "O") < 2:
        return False

    oo_cut = 1.4 # pair_cutoff("O", "O")

    for i, si in enumerate(S):
        if si != "O":
            continue
        neigh_dist, neigh_idx = neighbors_within(apm, R[i], n=3, cutoff=CUTOFF_MAX)

        if (
            S[neigh_idx[1]] == "O" and neigh_dist[1] <= oo_cut
        ):
            return True
        
    return False

_DEFAULT_RULES: Dict[str, RuleFn] = {
    "H2":   _rule_H2,
    "H2B":  _rule_H2b,
    "H2O":  _rule_H2O,
    "H2OF": _rule_H2Of,
    "O2":   _rule_O2,
    "Hf":   _rule_Hf,
}

class BlacklistDetector:
    """
    Pluggable, PBC-aware motif blacklist checker.
    Usage:
        detector = BlacklistDetector(["H2O","H2"])
        present, tag = detector.contains(struct)
    """

    def __init__(self, blacklist: Iterable[str] = (), rules: Optional[Dict[str, RuleFn]] = None):
        self.blacklist: Set[str] = set(s.upper() for s in blacklist)
        self.rules: Dict[str, RuleFn] = dict(_DEFAULT_RULES if rules is None else rules)

    def register(self, name: str, fn: RuleFn) -> None:
        """Register/override a rule at runtime."""
        self.rules[name.upper()] = fn

    def contains(self, struct) -> Tuple[bool, Optional[str]]:
        """Return (True, 'TAG') if any blacklisted motif is detected in struct; else (False, None)."""
        if not self.blacklist:
            return (False, None)

        apm = struct.AtomPositionManager
        for tag in self.blacklist:
            rule = self.rules.get(tag)

            if rule is None:
                continue
            try:
                if rule(apm):
                    return (True, tag)
            except Exception:
                # Fail-soft: ignore rule errors in production; keep GA running.
                continue
        return (False, None)
