from rdkit import Chem
from rdkit.Chem.Descriptors import CalcMolDescriptors
from rdkit.Chem import Lipinski
import numpy as np

import pandas as pd
import networkx as nx
from joblib import Parallel, delayed
from tqdm.auto import tqdm
import multiprocessing

from polyfeatures.processing import process_polymer_smiles

from rdkit import Chem

def get_backbone_rotatable_bonds(smiles, backbone_idx):
    mol = Chem.MolFromSmiles(smiles)

    rot_bond_smarts = "[!$(*#*)&!D1]-!@[!$(*#*)&!D1]"
    rot_bond_query = Chem.MolFromSmarts(rot_bond_smarts)

    rotatable_bond_matches = mol.GetSubstructMatches(rot_bond_query)

    rotatable_bond_ids = set()
    for a1, a2 in rotatable_bond_matches:
        bond = mol.GetBondBetweenAtoms(a1, a2)
        if bond:
            rotatable_bond_ids.add(bond.GetIdx())

    backbone_rotatable_bonds = []
    for bond_id in rotatable_bond_ids:
        bond = mol.GetBondWithIdx(bond_id)
        a1, a2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        if a1 in backbone_idx and a2 in backbone_idx:
            backbone_rotatable_bonds.append(bond_id)

    return len(backbone_rotatable_bonds)

def identify_backbone_atoms(mol, star_indices):
    if len(star_indices) < 2:
        # No clear backbone ends, assume whole molecule is backbone
        return set(range(mol.GetNumAtoms()))
    try:
        G = nx.from_numpy_array(Chem.GetAdjacencyMatrix(mol))
        path = nx.shortest_path(G, star_indices[0], star_indices[-1])
        return set(path)
    except:
        # Fallback: all atoms in backbone
        return set(range(mol.GetNumAtoms()))

def calculate_backbone_features(smiles):
    features = {'SMILES': smiles, 'backbone_length': 0.0, 'backbone_aromatic_fraction': 0.0, 'backbone_heavy_atom_count': 0.0, 'backbone_electronegative_count': 0.0, 'backbone_flexibility_index': 0.0}
    
    mol, star_indices = process_polymer_smiles(smiles)
    if mol is None:
        return features
    
    backbone_atoms = identify_backbone_atoms(mol, star_indices)
    
    aromatic_count = 0
    backbone_heavy_count = 0
    en_count = 0

    for idx in backbone_atoms:
        atom = mol.GetAtomWithIdx(idx)
        if atom.GetAtomicNum() > 1:
            backbone_heavy_count += 1
            if atom.GetIsAromatic():
                aromatic_count += 1

    for idx in backbone_atoms:
        atom = mol.GetAtomWithIdx(idx)
        if atom.GetSymbol() in ('O', 'N', 'F', 'Cl'):
            en_count += 1

    num_rota = get_backbone_rotatable_bonds(smiles, backbone_atoms)

    if backbone_heavy_count > 0:
        features['backbone_aromatic_fraction'] = aromatic_count / backbone_heavy_count
        features['backbone_heavy_atom_count'] = backbone_heavy_count
        features['backbone_flexibility_index'] = num_rota / backbone_heavy_count

    features['backbone_electronegative_count'] = en_count
    features['backbone_length'] = len(backbone_atoms)
    
    return features

def calculate_sidechain_features(smiles):
    features = {'SMILES': smiles, 'sidechain_length': 0.0, 'sidechain_heavy_atom_count': 0.0, 'sidechain_branch_count': 0.0, 'sidechain_electronegative_count': 0.0}
    
    mol, star_indices = process_polymer_smiles(smiles)
    if mol is None:
        return features
    
    backbone_atoms = identify_backbone_atoms(mol, star_indices)
    sidechain_atoms = set(range(mol.GetNumAtoms())) - backbone_atoms

    sidechain_heavy_count = 0
    hbond_donor_count = 0
    en_count = 0

    for idx in sidechain_atoms:
        atom = mol.GetAtomWithIdx(idx)
        if atom.GetAtomicNum() > 1:
            sidechain_heavy_count += 1

    sidechain_branches = 0
    visited = set()

    for atom in mol.GetAtoms():
        if atom.GetIdx() in backbone_atoms:
            for nbr in atom.GetNeighbors():
                nbr_idx = nbr.GetIdx()
                if nbr_idx not in backbone_atoms and nbr_idx not in visited:
                    sidechain_branches += 1
                    visited.add(nbr_idx)

    for idx in sidechain_atoms:
        atom = mol.GetAtomWithIdx(idx)
        if atom.GetSymbol() in ('O', 'N', 'F', 'Cl'):
            en_count += 1
            
    features['sidechain_heavy_atom_count'] = sidechain_heavy_count
    features['sidechain_branch_count'] = sidechain_branches
    features['sidechain_length'] = len(sidechain_atoms)
    features['sidechain_electronegative_count'] = en_count

    return features

def calculate_extra_features(smiles):
    features = {'SMILES': smiles}
    
    mol, star_indices = process_polymer_smiles(smiles)
    if mol is None:
        return features
    
    extra_features = CalcMolDescriptors(mol)

    features.update(extra_features)

    return features