import numpy as np
import pandas as pd
from pathlib import Path
import xgboost as xgb
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import roc_auc_score, precision_recall_curve, classification_report
from sklearn.preprocessing import StandardScaler
from scipy import stats
from collections import Counter, defaultdict
import re
#import RNA
import ViennaRNA as RNA
from itertools import product
import matplotlib.pyplot as plt
import seaborn as sns
import json
from statsmodels.stats.multitest import multipletests
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed  # 改为多进程
import multiprocessing

class RNAFeatureExtractor:
    def __init__(self, output_dir):  # 修复：添加 output_dir 参数
        self.output_dir = Path(output_dir)
        self.kmers_3 = [''.join(kmer) for kmer in product('ACGU', repeat=3)]
        self.kmers_2 = [''.join(kmer) for kmer in product('ACGU', repeat=2)]

    def run(self, positive_file, negative_file):
        """运行Step1分析"""
        print("Loading positive sequences...")
        pos_sequences, pos_structures = self.parse_structure_file(positive_file)
        print("Loading negative sequences...")
        neg_sequences, neg_structures = self.parse_structure_file(negative_file)
        
        print(f"Positive: {len(pos_sequences)}, Negative: {len(neg_sequences)}")
        
        # 提取特征
        print("Extracting features from positive sequences...")
        positive_features = []
        for seq, struct in zip(pos_sequences, pos_structures):
            features = self.extract_comprehensive_features(seq, struct)  # 修复：使用 self
            positive_features.append(features)
        
        print("Extracting features from negative sequences...")
        negative_features = []
        for seq, struct in zip(neg_sequences, neg_structures):
            features = self.extract_comprehensive_features(seq, struct)  # 修复：使用 self
            negative_features.append(features)
        
        # 创建DataFrame
        pos_df = pd.DataFrame(positive_features)
        neg_df = pd.DataFrame(negative_features)
        
        # 添加标签
        pos_df['label'] = 1
        neg_df['label'] = 0
        
        # 合并数据集
        full_df = pd.concat([pos_df, neg_df], ignore_index=True)
        
        # 处理NaN值
        full_df = full_df.fillna(0)
        
        feature_names = [col for col in full_df.columns if col != 'label']  # 修复：改为局部变量
        
        print(f"Final dataset shape: {full_df.shape}")
        print(f"Number of features: {len(feature_names)}")
        
        # 保存特征数据集 - 修复：使用 full_df 而不是 df
        output_file = self.output_dir / 'rna_features_dataset.csv'
        full_df.to_csv(output_file, index=False)  # 修复：改为 full_df
        print(f"特征已保存到 '{output_file}'")
             
        return full_df  

    def parse_structure_file(self, filename):
        """解析结构文件"""
        sequences = []
        structures = []
        
        with open(filename, 'r') as f:
            current_seq = ""
            current_struct = ""
            
            for line in f:
                line = line.strip()
                if line.startswith('>'):
                    if current_seq and current_struct:
                        sequences.append(current_seq)
                        structures.append(current_struct)
                    current_seq = ""
                    current_struct = ""
                elif line and not line.startswith('>'):
                    if all(c in 'ACGUacgu.-' for c in line):
                        current_seq = line.upper().replace('-', '').replace('.', '')
                    elif any(c in '().' for c in line):
                        current_struct = line
            
            if current_seq and current_struct:
                sequences.append(current_seq)
                structures.append(current_struct)
        
        return sequences, structures
        
    def get_kmer_frequencies(self, sequence, k=3):
        """计算k-mer频率特征"""
        kmer_counts = Counter()
        sequence = sequence.upper()
        
        for i in range(len(sequence) - k + 1):
            kmer = sequence[i:i+k]
            if all(nuc in 'ACGU' for nuc in kmer):
                kmer_counts[kmer] += 1
        
        total_kmers = sum(kmer_counts.values())
        kmer_freq = {f"kmer_{kmer}": count/total_kmers if total_kmers > 0 else 0 
                    for kmer, count in kmer_counts.items()}
        
        return kmer_freq
    
    def calculate_gc_content(self, sequence):
        """计算GC含量"""
        sequence = sequence.upper()
        gc_count = sequence.count('G') + sequence.count('C')
        return gc_count / len(sequence) if len(sequence) > 0 else 0
    
    def extract_stem_characteristics(self, structure, sequence):
        """提取stem结构特征 - 增强版本"""
        stems = self._find_stems_detailed(structure, sequence)
        
        if not stems:
            return self._get_empty_stem_features()
        
        stem_features = self._calculate_stem_features(stems)
        stem_sequence_features = self._calculate_stem_sequence_features(stems)
        stem_balance_features = self._calculate_stem_balance_features(stems)
        
        features = {}
        features.update(stem_features)
        features.update(stem_sequence_features)
        features.update(stem_balance_features)
        
        return features
    
    def _find_stems_detailed(self, structure, sequence):
        """详细识别stem结构，包含序列信息"""
        stems = []
        stack = []
        
        for i, char in enumerate(structure):
            if char == '(':
                stack.append(i)
            elif char == ')':
                if stack:
                    start = stack.pop()
                    # 计算stem长度
                    stem_length = 0
                    j = 0
                    while (start + j < len(structure) and 
                           i - j >= 0 and 
                           structure[start + j] == '(' and 
                           structure[i - j] == ')'):
                        stem_length += 1
                        j += 1
                    
                    if stem_length >= 2:  # 至少2bp的stem
                        stem_end_left = start + stem_length - 1
                        stem_start_right = i - stem_length + 1
                        
                        # 提取序列信息
                        left_arm = sequence[start:start+stem_length]
                        right_arm = sequence[stem_start_right:i+1]
                        loop_start = stem_end_left + 1
                        loop_end = stem_start_right - 1
                        
                        loop_seq = ""
                        if loop_start <= loop_end and loop_end < len(sequence):
                            loop_seq = sequence[loop_start:loop_end+1]
                        
                        stems.append({
                            'start': start,
                            'end': i,
                            'length': stem_length,
                            'left_arm': left_arm,
                            'right_arm': right_arm,
                            'loop_sequence': loop_seq,
                            'loop_length': len(loop_seq),
                            'total_length': i - start + 1
                        })
        
        return stems
    
    def _get_empty_stem_features(self):
        """返回空的stem特征"""
        base_features = {
            'stem_count': 0,
            'avg_stem_length': 0,
            'max_stem_length': 0,
            'total_stem_bases': 0,
            'stem_density': 0,
            'stem_continuity': 0,
            'stem_gc_content': 0,
            'stem_gc_skew': 0,
            'stem_at_ratio': 0,
            'stem_gu_ratio': 0,
            'stem_balance_score': 0,
            'stem_arm_symmetry': 0,
            'stem_5prime_gc': 0,
            'stem_3prime_gc': 0,
            'stem_bulge_count': 0,
            'stem_internal_loop_count': 0
        }
        
        # 添加k-mer特征
        for kmer in self.kmers_2:
            base_features[f'stem_kmer_{kmer}'] = 0
            base_features[f'stem_5p_kmer_{kmer}'] = 0
            base_features[f'stem_3p_kmer_{kmer}'] = 0
        
        for kmer in self.kmers_3[:10]:  # 只取前10个3-mer
            base_features[f'stem_kmer_{kmer}'] = 0
        
        return base_features
    
    def _calculate_stem_features(self, stems):
        """计算基本的stem特征"""
        stem_lengths = [stem['length'] for stem in stems]
        stem_bases = sum(stem_lengths) * 2
        
        return {
            'stem_count': len(stems),
            'avg_stem_length': np.mean(stem_lengths),
            'max_stem_length': max(stem_lengths),
            'total_stem_bases': stem_bases,
            'stem_density': stem_bases / len(stems[0]['left_arm'] + stems[0]['right_arm']) if stems else 0,
            'stem_continuity': self._calculate_stem_continuity(stems),
            'stem_bulge_count': self._count_stem_bulges(stems),
            'stem_internal_loop_count': self._count_internal_loops(stems)
        }
    
    def _calculate_stem_sequence_features(self, stems):
        """计算stem序列相关特征"""
        all_stem_seq = ""
        stem_5p_seqs = []
        stem_3p_seqs = []
        
        for stem in stems:
            all_stem_seq += stem['left_arm'] + stem['right_arm']
            stem_5p_seqs.append(stem['left_arm'])
            stem_3p_seqs.append(stem['right_arm'])
        
        features = {}
        
        # GC含量相关
        stem_gc = all_stem_seq.count('G') + all_stem_seq.count('C')
        stem_gc_content = stem_gc / len(all_stem_seq) if all_stem_seq else 0
        
        # GC skew: (G-C)/(G+C)
        g_count = all_stem_seq.count('G')
        c_count = all_stem_seq.count('C')
        stem_gc_skew = (g_count - c_count) / (g_count + c_count) if (g_count + c_count) > 0 else 0
        
        # AT比例
        a_count = all_stem_seq.count('A')
        u_count = all_stem_seq.count('U')
        stem_at_ratio = a_count / u_count if u_count > 0 else 0
        
        # GU比例
        stem_gu_ratio = g_count / u_count if u_count > 0 else 0
        
        features.update({
            'stem_gc_content': stem_gc_content,
            'stem_gc_skew': stem_gc_skew,
            'stem_at_ratio': stem_at_ratio,
            'stem_gu_ratio': stem_gu_ratio,
            'stem_5prime_gc': self._calculate_region_gc_content(stem_5p_seqs),
            'stem_3prime_gc': self._calculate_region_gc_content(stem_3p_seqs)
        })
        
        # K-mer特征
        stem_kmer_features = self._get_region_kmer_features(all_stem_seq, 'stem')
        features.update(stem_kmer_features)
        
        # 5'和3'臂的k-mer特征
        stem_5p_kmer_features = self._get_region_kmer_features(''.join(stem_5p_seqs), 'stem_5p')
        stem_3p_kmer_features = self._get_region_kmer_features(''.join(stem_3p_seqs), 'stem_3p')
        features.update(stem_5p_kmer_features)
        features.update(stem_3p_kmer_features)
        
        return features
    
    def _calculate_stem_balance_features(self, stems):
        """计算stem平衡特征"""
        balance_scores = []
        symmetry_scores = []
        
        for stem in stems:
            left_arm = stem['left_arm']
            right_arm = stem['right_arm']
            
            # 碱基平衡
            left_gc = left_arm.count('G') + left_arm.count('C')
            right_gc = right_arm.count('G') + right_arm.count('C')
            balance_score = 1 - abs(left_gc - right_gc) / len(left_arm) if len(left_arm) > 0 else 1
            
            # 序列对称性（反向互补）
            right_arm_rev_comp = self._reverse_complement(right_arm)
            symmetry = sum(1 for a, b in zip(left_arm, right_arm_rev_comp) if a == b) / len(left_arm) if len(left_arm) > 0 else 0
            
            balance_scores.append(balance_score)
            symmetry_scores.append(symmetry)
        
        return {
            'stem_balance_score': np.mean(balance_scores) if balance_scores else 0,
            'stem_arm_symmetry': np.mean(symmetry_scores) if symmetry_scores else 0
        }
    
    def _calculate_region_gc_content(self, sequences):
        """计算区域GC含量"""
        if not sequences:
            return 0
        combined_seq = ''.join(sequences)
        gc_count = combined_seq.count('G') + combined_seq.count('C')
        return gc_count / len(combined_seq) if combined_seq else 0
    
    def _get_region_kmer_features(self, sequence, prefix, max_kmers=20):
        """获取区域的k-mer特征"""
        if not sequence:
            return {}
        
        # 2-mer特征
        kmer_2_counts = Counter()
        for i in range(len(sequence) - 1):
            kmer = sequence[i:i+2]
            if all(nuc in 'ACGU' for nuc in kmer):
                kmer_2_counts[kmer] += 1
        
        # 3-mer特征（只取最常见的）
        kmer_3_counts = Counter()
        for i in range(len(sequence) - 2):
            kmer = sequence[i:i+3]
            if all(nuc in 'ACGU' for nuc in kmer):
                kmer_3_counts[kmer] += 1
        
        features = {}
        
        # 添加2-mer特征
        total_2mers = sum(kmer_2_counts.values())
        for kmer in self.kmers_2:
            count = kmer_2_counts.get(kmer, 0)
            features[f'{prefix}_kmer_{kmer}'] = count / total_2mers if total_2mers > 0 else 0
        
        # 添加最常见的3-mer特征
        top_3mers = kmer_3_counts.most_common(max_kmers)
        total_3mers = sum(kmer_3_counts.values())
        for kmer, count in top_3mers:
            features[f'{prefix}_kmer_{kmer}'] = count / total_3mers if total_3mers > 0 else 0
        
        return features
    
    def _reverse_complement(self, sequence):
        """计算反向互补序列"""
        comp_map = {'A': 'U', 'U': 'A', 'G': 'C', 'C': 'G'}
        return ''.join(comp_map.get(base, base) for base in sequence[::-1])
    
    def _calculate_stem_continuity(self, stems):
        """计算stem连续性得分"""
        if not stems:
            return 0
        
        continuity_scores = []
        for stem in stems:
            # stem越长，连续性越好
            continuity_scores.append(stem['length'])
        
        return np.mean(continuity_scores)
    
    def _count_stem_bulges(self, stems):
        """计算stem中的凸环数量"""
        # 简化实现：基于长度差异检测凸环
        bulge_count = 0
        for stem in stems:
            if abs(len(stem['left_arm']) - len(stem['right_arm'])) > 0:
                bulge_count += 1
        return bulge_count
    
    def _count_internal_loops(self, stems):
        """计算内部环数量"""
        # 简化实现：基于多个stem的接近程度
        if len(stems) < 2:
            return 0
        
        internal_loops = 0
        sorted_stems = sorted(stems, key=lambda x: x['start'])
        
        for i in range(len(sorted_stems) - 1):
            current_end = sorted_stems[i]['end']
            next_start = sorted_stems[i + 1]['start']
            if next_start - current_end <= 5:  # 假设间隔小于5bp可能是内部环
                internal_loops += 1
        
        return internal_loops

    def extract_loop_characteristics(self, structure, sequence):
        """提取loop特征 - 增强版本"""
        loops = self._find_loops_detailed(structure, sequence)
        
        if not loops:
            return self._get_empty_loop_features()
        
        loop_features = self._calculate_loop_features(loops)
        loop_sequence_features = self._calculate_loop_sequence_features(loops)
        
        features = {}
        features.update(loop_features)
        features.update(loop_sequence_features)
        
        return features
    
    def _find_loops_detailed(self, structure, sequence):
        """详细识别loop区域"""
        loops = []
        in_loop = False
        loop_start = 0
        
        for i, char in enumerate(structure):
            if char == '.' and not in_loop:
                in_loop = True
                loop_start = i
            elif char != '.' and in_loop:
                in_loop = False
                loop_length = i - loop_start
                if loop_length >= 1:
                    loop_seq = sequence[loop_start:i]
                    loops.append({
                        'start': loop_start,
                        'end': i - 1,
                        'length': loop_length,
                        'sequence': loop_seq,
                        'gc_content': self.calculate_gc_content(loop_seq),
                        'u_content': loop_seq.count('U') / len(loop_seq) if loop_seq else 0
                    })
        
        # 处理末尾的loop
        if in_loop:
            loop_length = len(structure) - loop_start
            if loop_length >= 1:
                loop_seq = sequence[loop_start:]
                loops.append({
                    'start': loop_start,
                    'end': len(structure) - 1,
                    'length': loop_length,
                    'sequence': loop_seq,
                    'gc_content': self.calculate_gc_content(loop_seq),
                    'u_content': loop_seq.count('U') / len(loop_seq) if loop_seq else 0
                })
        
        return loops
    
    def _get_empty_loop_features(self):
        """返回空的loop特征"""
        base_features = {
            'loop_count': 0,
            'avg_loop_length': 0,
            'max_loop_length': 0,
            'total_loop_bases': 0,
            'loop_diversity': 0,
            'loop_gc_content': 0,
            'loop_u_content': 0,
            'loop_au_ratio': 0,
            'loop_polyu_stretch': 0,
            'loop_sequence_complexity': 0
        }
        
        # 添加loop k-mer特征
        for kmer in self.kmers_2:
            base_features[f'loop_kmer_{kmer}'] = 0
        
        return base_features
    
    def _calculate_loop_features(self, loops):
        """计算基本的loop特征"""
        loop_lengths = [loop['length'] for loop in loops]
        loop_gc_contents = [loop['gc_content'] for loop in loops]
        loop_u_contents = [loop['u_content'] for loop in loops]
        
        return {
            'loop_count': len(loops),
            'avg_loop_length': np.mean(loop_lengths),
            'max_loop_length': max(loop_lengths),
            'total_loop_bases': sum(loop_lengths),
            'loop_diversity': len(set(loop_lengths)) / len(loop_lengths) if loop_lengths else 0,
            'loop_gc_content': np.mean(loop_gc_contents) if loop_gc_contents else 0,
            'loop_u_content': np.mean(loop_u_contents) if loop_u_contents else 0,
            'loop_au_ratio': self._calculate_loop_au_ratio(loops),
            'loop_polyu_stretch': self._calculate_polyu_stretch(loops),
            'loop_sequence_complexity': self._calculate_loop_complexity(loops)
        }
    
    def _calculate_loop_sequence_features(self, loops):
        """计算loop序列特征"""
        all_loop_seq = ''.join([loop['sequence'] for loop in loops])
        
        if not all_loop_seq:
            return {}
        
        features = {}
        
        # K-mer特征
        loop_kmer_features = self._get_region_kmer_features(all_loop_seq, 'loop')
        features.update(loop_kmer_features)
        
        return features
    
    def _calculate_loop_au_ratio(self, loops):
        """计算loop区域的A/U比例"""
        total_a = 0
        total_u = 0
        
        for loop in loops:
            total_a += loop['sequence'].count('A')
            total_u += loop['sequence'].count('U')
        
        return total_a / total_u if total_u > 0 else 0
    
    def _calculate_polyu_stretch(self, loops):
        """计算loop中的poly-U stretch"""
        max_polyu = 0
        for loop in loops:
            sequence = loop['sequence']
            current_streak = 0
            for base in sequence:
                if base == 'U':
                    current_streak += 1
                    max_polyu = max(max_polyu, current_streak)
                else:
                    current_streak = 0
        return max_polyu
    
    def _calculate_loop_complexity(self, loops):
        """计算loop序列复杂度"""
        if not loops:
            return 0
        
        complexities = []
        for loop in loops:
            seq = loop['sequence']
            if len(seq) == 0:
                complexities.append(0)
                continue
            
            # 基于不同碱基数量的简单复杂度度量
            unique_bases = len(set(seq))
            complexities.append(unique_bases / len(seq))
        
        return np.mean(complexities) if complexities else 0

    def calculate_folding_energy(self, sequence):
        """计算折叠自由能"""
        try:
            (structure, energy) = RNA.fold(sequence)
            return float(energy)
        except:
            return 0.0
    
    def identify_structural_motifs(self, structure, sequence):
        """识别常见结构motif - 增强版本"""
        motifs = {
            'hairpin_loop': self._detect_hairpin_loops_detailed(structure, sequence),
            'internal_loop': self._detect_internal_loops(structure, sequence),
            'bulge_loop': self._detect_bulge_loops(structure, sequence),
        }
        
        features = {
            'hairpin_count': len(motifs['hairpin_loop']),
            'internal_loop_count': len(motifs['internal_loop']),
            'bulge_count': len(motifs['bulge_loop']),
            'total_motifs': sum(len(m) for m in motifs.values())
        }
        
        # 添加hairpin特定特征
        if motifs['hairpin_loop']:
            hairpin_lengths = [h['loop_length'] for h in motifs['hairpin_loop']]
            hairpin_stem_lengths = [h['stem_length'] for h in motifs['hairpin_loop']]
            features.update({
                'avg_hairpin_loop_length': np.mean(hairpin_lengths),
                'avg_hairpin_stem_length': np.mean(hairpin_stem_lengths),
                'hairpin_loop_gc': np.mean([h['loop_gc'] for h in motifs['hairpin_loop']])
            })
        else:
            features.update({
                'avg_hairpin_loop_length': 0,
                'avg_hairpin_stem_length': 0,
                'hairpin_loop_gc': 0
            })
        
        return features
    
    def _detect_hairpin_loops_detailed(self, structure, sequence):
        """详细检测发夹环"""
        hairpins = []
        stems = self._find_stems_detailed(structure, sequence)
        
        for stem in stems:
            if stem['loop_length'] > 0:  # 有loop区域
                hairpins.append({
                    'stem_length': stem['length'],
                    'loop_length': stem['loop_length'],
                    'loop_gc': self.calculate_gc_content(stem['loop_sequence']),
                    'loop_sequence': stem['loop_sequence']
                })
        
        return hairpins
    
    def _detect_internal_loops(self, structure, sequence):
        """检测内部环"""
        # 简化实现
        return []
    
    def _detect_bulge_loops(self, structure, sequence):
        """检测凸环"""
        # 简化实现
        return []
    
    def extract_sequence_composition(self, sequence):
        """提取序列组成特征"""
        sequence = sequence.upper()
        total_len = len(sequence)
        
        if total_len == 0:
            return {
                'A_freq': 0, 'C_freq': 0, 'G_freq': 0, 'U_freq': 0,
                'AU_ratio': 0, 'GC_ratio': 0
            }
        
        comp = {
            'A_freq': sequence.count('A') / total_len,
            'C_freq': sequence.count('C') / total_len,
            'G_freq': sequence.count('G') / total_len,
            'U_freq': sequence.count('U') / total_len
        }
        
        comp['AU_ratio'] = comp['A_freq'] / comp['U_freq'] if comp['U_freq'] > 0 else 0
        comp['GC_ratio'] = comp['G_freq'] / comp['C_freq'] if comp['C_freq'] > 0 else 0
        
        return comp
    
    def extract_comprehensive_features(self, sequence, structure):
        """提取综合特征 - 增强版本"""
        features = {}
        
        # 1. 序列组成特征
        seq_comp = self.extract_sequence_composition(sequence)
        features.update(seq_comp)
        
        # 2. k-mer频率特征 (选择部分重要的k-mer)
        kmer_freq = self.get_kmer_frequencies(sequence, k=3)
        # 选择前20个最常见的k-mer作为特征
        top_kmers = sorted(kmer_freq.items(), key=lambda x: x[1], reverse=True)[:20]
        features.update(dict(top_kmers))
        
        # 3. GC含量
        features['gc_content'] = self.calculate_gc_content(sequence)
        
        # 4. 结构特征 - 增强版本
        stem_features = self.extract_stem_characteristics(structure, sequence)
        features.update({f'stem_{k}': v for k, v in stem_features.items()})
        
        loop_features = self.extract_loop_characteristics(structure, sequence)
        features.update({f'loop_{k}': v for k, v in loop_features.items()})
        
        # 5. 能量特征
        features['folding_energy'] = self.calculate_folding_energy(sequence)
        features['energy_per_base'] = features['folding_energy'] / len(sequence) if len(sequence) > 0 else 0
        
        # 6. 拓扑特征 - 增强版本
        motif_features = self.identify_structural_motifs(structure, sequence)
        features.update({f'motif_{k}': v for k, v in motif_features.items()})
        
        # 7. 全局结构特征
        features['paired_ratio'] = (structure.count('(') + structure.count(')')) / len(structure)
        features['unpaired_ratio'] = structure.count('.') / len(structure)
        features['structural_complexity'] = self._calculate_structural_complexity(structure)
        
        # 8. 新增：stem-loop整体特征
        features.update(self._calculate_stem_loop_global_features(structure, sequence))
        
        return features
    
    def _calculate_stem_loop_global_features(self, structure, sequence):
        """计算stem-loop整体特征"""
        stems = self._find_stems_detailed(structure, sequence)
        loops = self._find_loops_detailed(structure, sequence)
        
        features = {}
        
        if stems and loops:
            # stem-loop比例特征
            total_stem_bases = sum([stem['length'] * 2 for stem in stems])
            total_loop_bases = sum([loop['length'] for loop in loops])
            total_structured_bases = total_stem_bases + total_loop_bases
            
            features.update({
                'stem_loop_ratio': total_stem_bases / total_loop_bases if total_loop_bases > 0 else 0,
                'structured_ratio': total_structured_bases / len(sequence) if len(sequence) > 0 else 0,
                'stem_loop_density': len(stems) / len(sequence) if len(sequence) > 0 else 0,
            })
        else:
            features.update({
                'stem_loop_ratio': 0,
                'structured_ratio': 0,
                'stem_loop_density': 0,
            })
        
        return features
    
    def _calculate_structural_complexity(self, structure):
        """计算结构复杂度"""
        if len(structure) == 0:
            return 0
        
        # 基于结构变化的简单复杂度度量
        changes = 0
        prev_char = structure[0]
        for char in structure[1:]:
            if char != prev_char:
                changes += 1
            prev_char = char
        
        return changes / len(structure)

# 独立函数，用于多进程
def extract_features_single_standalone(args):
    """独立函数：为单个序列提取特征 - 用于多进程"""
    seq, struct, output_dir = args
    try:
        # 在每个进程中创建独立的特征提取器
        extractor = RNAFeatureExtractor(output_dir)
        features = extractor.extract_comprehensive_features(seq, struct)
        return features
    except Exception as e:
        print(f"Error extracting features for sequence: {e}")
        return {}

class DataProcessor:
    def __init__(self, output_dir, threads=1):
        self.output_dir = Path(output_dir)
        self.feature_extractor = RNAFeatureExtractor(output_dir) 
        self.scaler = StandardScaler()
        self.feature_names = None
        self.threads = threads
    
    def parse_structure_file(self, filename):
        """解析结构文件"""
        return self.feature_extractor.parse_structure_file(filename)
    
    def extract_features_parallel(self, sequences, structures, desc="Extracting features"):
        """使用多进程并行提取特征"""
        if self.threads <= 1:
            # 单线程模式
            features = []
            for seq, struct in tqdm(zip(sequences, structures), 
                                   total=len(sequences), 
                                   desc=desc):
                feature = self.feature_extractor.extract_comprehensive_features(seq, struct)
                features.append(feature)
            return features
        else:
            # 多进程模式
            print(f"{desc} using {self.threads} process(es)...")
            print(f"Total sequences: {len(sequences)}")
            
            # 准备任务参数
            tasks = [(seq, struct, str(self.output_dir)) for seq, struct in zip(sequences, structures)]
            
            features = []
            with ProcessPoolExecutor(max_workers=self.threads) as executor:
                # 提交所有任务
                futures = [executor.submit(extract_features_single_standalone, task) for task in tasks]
                
                # 使用tqdm显示进度
                completed = 0
                with tqdm(total=len(sequences), desc=desc) as pbar:
                    for future in as_completed(futures):
                        try:
                            result = future.result()
                            features.append(result)
                        except Exception as e:
                            print(f"Feature extraction failed: {e}")
                            features.append({})  # 使用空字典作为失败标记
                        finally:
                            completed += 1
                            pbar.update(1)
            
            print(f"  Successfully extracted features from {len([f for f in features if f])} sequences")
            return features

    def prepare_dataset(self, positive_file, negative_file):
        """准备完整数据集 - 使用多进程加速"""
        print("Loading positive sequences...")
        pos_sequences, pos_structures = self.parse_structure_file(positive_file)
        print("Loading negative sequences...")
        neg_sequences, neg_structures = self.parse_structure_file(negative_file)
        
        print(f"Positive: {len(pos_sequences)}, Negative: {len(neg_sequences)}")
        print(f"Using {self.threads} process(es) for feature extraction")
        
        # 使用多进程提取特征
        print("Extracting features from positive sequences...")
        positive_features = self.extract_features_parallel(
            pos_sequences, pos_structures, "Positive features"
        )

        print("Extracting features from negative sequences...")
        negative_features = self.extract_features_parallel(
            neg_sequences, neg_structures, "Negative features"
        )
        
        # 过滤掉提取失败的样本
        positive_features = [f for f in positive_features if f]
        negative_features = [f for f in negative_features if f]
        
        print(f"Successfully extracted features: {len(positive_features)} positive, {len(negative_features)} negative")
        
        # 创建DataFrame
        pos_df = pd.DataFrame(positive_features)
        neg_df = pd.DataFrame(negative_features)
        
        # 添加标签
        pos_df['label'] = 1
        neg_df['label'] = 0
        
        # 合并数据集
        full_df = pd.concat([pos_df, neg_df], ignore_index=True)
        
        # 处理NaN值
        full_df = full_df.fillna(0)
        
        self.feature_names = [col for col in full_df.columns if col != 'label']
        
        print(f"Final dataset shape: {full_df.shape}")
        print(f"Number of features: {len(self.feature_names)}")
        
        # 保存到输出目录
        output_file = self.output_dir / "rna_features_dataset.csv"
        full_df.to_csv(output_file, index=False)
        print(f"\n特征已保存到 '{output_file}'")
        
        print(f"DataFrame info before plotting:")
        print(f"Shape: {full_df.shape}")
        print(f"Columns: {list(full_df.columns)}")
        print(f"Label counts: {full_df['label'].value_counts()}")

        self.create_feature_analysis_plots(full_df, self.output_dir)

        return full_df

    
    def prepare_features(self, df):
        """准备特征矩阵和标签"""
        if self.feature_names is None:
            self.feature_names = [col for col in df.columns if col != 'label']
            
        X = df[self.feature_names].values
        y = df['label'].values
        
        # 标准化特征
        X_scaled = self.scaler.fit_transform(X)
        
        return X_scaled, y

    def create_feature_analysis_plots(self, df, output_dir):
        """创建特征分析可视化图表 - 调试版本"""
        print("Creating feature analysis plots...")
        
        try:
            # 确保输出目录存在
            output_dir = Path(output_dir)
            output_dir.mkdir(parents=True, exist_ok=True)
            print(f"Output directory: {output_dir.absolute()}")
            
            # 基本数据检查
            print(f"DataFrame shape: {df.shape}")
            print(f"Positive samples: {sum(df['label'] == 1)}")
            print(f"Negative samples: {sum(df['label'] == 0)}")
            
            # 检查数值特征
            numeric_features = df.select_dtypes(include=[np.number]).columns.tolist()
            if 'label' in numeric_features:
                numeric_features.remove('label')
            
            print(f"Available numeric features: {len(numeric_features)}")
            if len(numeric_features) == 0:
                print("No numeric features available for plotting")
                return
            
            # 分离数据
            positive_df = df[df['label'] == 1]
            negative_df = df[df['label'] == 0]
            
            if len(positive_df) == 0 or len(negative_df) == 0:
                print("No positive or negative samples available")
                return
            
            # 计算特征差异
            feature_differences = []
            for feature in numeric_features[:20]:  # 限制特征数量
                try:
                    pos_mean = positive_df[feature].mean()
                    neg_mean = negative_df[feature].mean()
                    pos_std = positive_df[feature].std()
                    neg_std = negative_df[feature].std()
                    
                    # 处理标准差为0的情况
                    if pos_std == 0 and neg_std == 0:
                        cohens_d = 0
                    else:
                        pooled_std = np.sqrt((pos_std**2 + neg_std**2) / 2)
                        cohens_d = (pos_mean - neg_mean) / pooled_std
                    
                    feature_differences.append({
                        'feature': feature,
                        'cohens_d': cohens_d,
                        'pos_mean': pos_mean,
                        'neg_mean': neg_mean
                    })
                except Exception as e:
                    print(f"Error processing feature {feature}: {e}")
                    continue
            
            if len(feature_differences) == 0:
                print("No valid features for plotting")
                return
            
            # 按效应量绝对值排序
            feature_differences.sort(key=lambda x: abs(x['cohens_d']), reverse=True)
            print(f"Top 5 features by effect size:")
            for fd in feature_differences[:5]:
                print(f"  {fd['feature']}: {fd['cohens_d']:.3f}")
            
            # 1. 特征差异图
            print("Creating feature differences plot...")
            plt.figure(figsize=(12, 8))
            
            top_features = feature_differences[:15]
            features = [fd['feature'] for fd in top_features]
            effect_sizes = [fd['cohens_d'] for fd in top_features]
            
            colors = ['red' if d > 0 else 'blue' for d in effect_sizes]
            y_pos = np.arange(len(features))
            
            plt.barh(y_pos, effect_sizes, color=colors, alpha=0.7)
            plt.yticks(y_pos, [f[:20] + '...' if len(f) > 20 else f for f in features])
            plt.xlabel("Cohen's d Effect Size")
            plt.title('Top Feature Differences (Positive vs Negative)')
            plt.grid(True, alpha=0.3)
            
            for i, v in enumerate(effect_sizes):
                plt.text(v + (0.01 if v > 0 else -0.01), i, f'{v:.2f}', 
                        ha='left' if v > 0 else 'right', va='center', fontsize=8)
            
            plt.tight_layout()
            diff_path = output_dir / 'feature_differences.png'
            plt.savefig(diff_path, dpi=300, bbox_inches='tight')
            plt.close()
            print(f"Saved: {diff_path}")
            
            # 2. 特征相关性热图
            print("Creating correlation heatmap...")
            if len(feature_differences) >= 3:
                plt.figure(figsize=(10, 8))
                
                corr_features = [fd['feature'] for fd in feature_differences[:10]]
                correlation_matrix = df[corr_features].corr()
                
                sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', 
                        center=0, square=True, cbar_kws={"shrink": .8})
                plt.title('Feature Correlation Matrix (Top 10 Features)')
                plt.xticks(rotation=45, ha='right')
                plt.yticks(rotation=0)
                plt.tight_layout()
                
                corr_path = output_dir / 'feature_correlation.png'
                plt.savefig(corr_path, dpi=300, bbox_inches='tight')
                plt.close()
                print(f"Saved: {corr_path}")
            else:
                print("Not enough features for correlation heatmap")
            
            # 3. 特征分布图
            print("Creating feature distributions plot...")
            if len(feature_differences) >= 6:
                plt.figure(figsize=(15, 10))
                
                top_6_features = [fd['feature'] for fd in feature_differences[:6]]
                
                for i, feature in enumerate(top_6_features, 1):
                    plt.subplot(2, 3, i)
                    
                    sns.kdeplot(positive_df[feature], label='Positive', color='red', alpha=0.7)
                    sns.kdeplot(negative_df[feature], label='Negative', color='blue', alpha=0.7)
                    
                    plt.xlabel(feature)
                    plt.ylabel('Density')
                    plt.title(f'{feature[:15]}...\n(d: {feature_differences[i-1]["cohens_d"]:.2f})')
                    plt.legend()
                    plt.grid(True, alpha=0.3)
                
                plt.tight_layout()
                dist_path = output_dir / 'feature_distributions.png'
                plt.savefig(dist_path, dpi=300, bbox_inches='tight')
                plt.close()
                print(f"Saved: {dist_path}")
            else:
                print("Not enough features for distribution plots")
            
            print("Feature analysis plots completed!")
            
        except Exception as e:
            print(f"Error creating feature analysis plots: {e}")
            import traceback
            traceback.print_exc()

if __name__ == "__main__":
    # 初始化处理器
    processor = DataProcessor(Path(".")) 
    
    # 准备数据
    try:
        df = processor.prepare_dataset('positive_structures.txt', 'negative_structures.txt')
        
        # 显示特征信息
        print("\n=== 特征统计 ===")
        print(f"特征数量: {len(processor.feature_names)}")
        print(f"阳性样本: {sum(df['label'] == 1)}")
        print(f"阴性样本: {sum(df['label'] == 0)}")
        
        # 显示前几个特征
        print("\n前10个特征:")
        print(processor.feature_names[:10])
        
        # 保存特征到文件
        df.to_csv('rna_features_dataset.csv', index=False)
        print("\n特征已保存到 'rna_features_dataset.csv'")
        
    except FileNotFoundError:
        print("文件未找到")